From 3d735dca8700bb7faea3642e2285fc02ca675882 Mon Sep 17 00:00:00 2001 From: Kevin Hu Date: Mon, 9 Dec 2024 12:38:04 +0800 Subject: [PATCH] Add support to iframe chatbot (#3929) ### What problem does this PR solve? #3909 ### Type of change - [x] New Feature (non-breaking change which adds functionality) --- agent/canvas.py | 11 + agent/component/base.py | 7 +- agent/component/begin.py | 1 + agent/templates/customer_service.json | 6 +- api/apps/conversation_app.py | 4 +- api/apps/sdk/session.py | 316 +++++------------------- api/apps/system_app.py | 1 + api/db/db_models.py | 7 + api/db/services/canvas_service.py | 118 +++++++++ api/db/services/conversation_service.py | 221 +++++++++++++++++ api/db/services/dialog_service.py | 23 +- api/db/services/document_service.py | 3 +- 12 files changed, 438 insertions(+), 280 deletions(-) create mode 100644 api/db/services/conversation_service.py diff --git a/agent/canvas.py b/agent/canvas.py index bfffabd59..d8b04a983 100644 --- a/agent/canvas.py +++ b/agent/canvas.py @@ -21,6 +21,7 @@ from functools import partial from agent.component import component_class from agent.component.base import ComponentBase + class Canvas(ABC): """ dsl = { @@ -320,3 +321,13 @@ class Canvas(ABC): def get_prologue(self): return self.components["begin"]["obj"]._param.prologue + + def set_global_param(self, **kwargs): + for k, v in kwargs.items(): + for q in self.components["begin"]["obj"]._param.query: + if k != q["key"]: + continue + q["value"] = v + + def get_preset_param(self): + return self.components["begin"]["obj"]._param.query \ No newline at end of file diff --git a/agent/component/base.py b/agent/component/base.py index 2660be7d3..2dc0cd49b 100644 --- a/agent/component/base.py +++ b/agent/component/base.py @@ -383,9 +383,6 @@ class ComponentBase(ABC): "params": {} } """ - out = json.loads(str(self._param)).get("output", {}) - if isinstance(out, dict) and "vector" in out: - del out["vector"] return """{{ "component_name": "{}", "params": {}, @@ -393,7 +390,7 @@ class ComponentBase(ABC): "inputs": {} }}""".format(self.component_name, self._param, - json.dumps(out, ensure_ascii=False), + json.dumps(json.loads(str(self._param)).get("output", {}), ensure_ascii=False), json.dumps(json.loads(str(self._param)).get("inputs", []), ensure_ascii=False) ) @@ -462,7 +459,7 @@ class ComponentBase(ABC): self._param.inputs = [] outs = [] for q in self._param.query: - if q["component_id"]: + if q.get("component_id"): if q["component_id"].split("@")[0].lower().find("begin") >= 0: cpn_id, key = q["component_id"].split("@") for p in self._canvas.get_component(cpn_id)["obj"]._param.query: diff --git a/agent/component/begin.py b/agent/component/begin.py index 037a8a057..766c0c667 100644 --- a/agent/component/begin.py +++ b/agent/component/begin.py @@ -26,6 +26,7 @@ class BeginParam(ComponentParamBase): def __init__(self): super().__init__() self.prologue = "Hi! I'm your smart assistant. What can I do for you?" + self.query = [] def check(self): return True diff --git a/agent/templates/customer_service.json b/agent/templates/customer_service.json index edc9931c1..e8aa89b63 100644 --- a/agent/templates/customer_service.json +++ b/agent/templates/customer_service.json @@ -336,7 +336,7 @@ "parameters": [], "presencePenaltyEnabled": true, "presence_penalty": 0.4, - "prompt": "Role: You are a customer support. \n\nTask: Please answer the question based on content of knowledge base. \n\nRequirements & restrictions:\n - DO NOT make things up when all knowledge base content is irrelevant to the question. \n - Answers need to consider chat history.\n - Request about customer's contact information like, Wechat number, LINE number, twitter, discord, etc,. , when knowledge base content can't answer his question. So, product expert could contact him soon to solve his problem.\n\n Knowledge base content is as following:\n {input}\n The above is the content of knowledge base.", + "prompt": "Role: You are a customer support. \n\nTask: Please answer the question based on content of knowledge base. \n\nReuirements & restrictions:\n - DO NOT make things up when all knowledge base content is irrelevant to the question. \n - Answers need to consider chat history.\n - Request about customer's contact information like, Wechat number, LINE number, twitter, discord, etc,. , when knowlegebase content can't answer his question. So, product expert could contact him soon to solve his problem.\n\n Knowledge base content is as following:\n {input}\n The above is the content of knowledge base.", "temperature": 0.1, "temperatureEnabled": true, "topPEnabled": true, @@ -603,7 +603,7 @@ { "data": { "form": { - "text": "Static messages.\nDefine response after receive user's contact information." + "text": "Static messages.\nDefine replys after recieve user's contact information." }, "label": "Note", "name": "N: What else?" @@ -691,7 +691,7 @@ { "data": { "form": { - "text": "Complete questions by conversation history.\nUser: What's RAGFlow?\nAssistant: RAGFlow is xxx.\nUser: How to deploy it?\n\nRefine it: How to deploy RAGFlow?" + "text": "Complete questions by conversation history.\nUser: What's RAGFlow?\nAssistant: RAGFlow is xxx.\nUser: How to deloy it?\n\nRefine it: How to deploy RAGFlow?" }, "label": "Note", "name": "N: Refine Question" diff --git a/api/apps/conversation_app.py b/api/apps/conversation_app.py index cce87c337..d787e5f1f 100644 --- a/api/apps/conversation_app.py +++ b/api/apps/conversation_app.py @@ -17,12 +17,14 @@ import json import re import traceback from copy import deepcopy + +from api.db.services.conversation_service import ConversationService from api.db.services.user_service import UserTenantService from flask import request, Response from flask_login import login_required, current_user from api.db import LLMType -from api.db.services.dialog_service import DialogService, ConversationService, chat, ask +from api.db.services.dialog_service import DialogService, chat, ask from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.llm_service import LLMBundle, TenantService, TenantLLMService from api import settings diff --git a/api/apps/sdk/session.py b/api/apps/sdk/session.py index b78106130..14b840eca 100644 --- a/api/apps/sdk/session.py +++ b/api/apps/sdk/session.py @@ -15,17 +15,19 @@ # import re import json -from copy import deepcopy -from uuid import uuid4 from api.db import LLMType from flask import request, Response + +from api.db.services.conversation_service import ConversationService, iframe_completion +from api.db.services.conversation_service import completion as rag_completion +from api.db.services.canvas_service import completion as agent_completion from api.db.services.dialog_service import ask from agent.canvas import Canvas from api.db import StatusEnum -from api.db.db_models import API4Conversation +from api.db.db_models import APIToken from api.db.services.api_service import API4ConversationService from api.db.services.canvas_service import UserCanvasService -from api.db.services.dialog_service import DialogService, ConversationService, chat +from api.db.services.dialog_service import DialogService from api.db.services.knowledgebase_service import KnowledgebaseService from api.utils import get_uuid from api.utils.api_utils import get_error_data_result @@ -66,8 +68,6 @@ def create_agent_session(tenant_id, agent_id): e, cvs = UserCanvasService.get_by_id(agent_id) if not e: return get_error_data_result("Agent not found.") - if cvs.user_id != tenant_id: - return get_error_data_result(message="You do not own the agent.") if not isinstance(cvs.dsl, str): cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False) @@ -110,98 +110,10 @@ def update(tenant_id, chat_id, session_id): @manager.route('/chats//completions', methods=['POST']) # noqa: F821 @token_required -def completion(tenant_id, chat_id): - dia = DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value) - if not dia: - return get_error_data_result(message="You do not own the chat") +def chat_completion(tenant_id, chat_id): req = request.json - if not req.get("session_id"): - conv = { - "id": get_uuid(), - "dialog_id": chat_id, - "name": req.get("name", "New session"), - "message": [{"role": "assistant", "content": dia[0].prompt_config.get("prologue")}] - } - if not conv.get("name"): - return get_error_data_result(message="`name` can not be empty.") - ConversationService.save(**conv) - e, conv = ConversationService.get_by_id(conv["id"]) - session_id = conv.id - else: - session_id = req.get("session_id") - if not req.get("question"): - return get_error_data_result(message="Please input your question.") - conv = ConversationService.query(id=session_id, dialog_id=chat_id) - if not conv: - return get_error_data_result(message="Session does not exist") - conv = conv[0] - msg = [] - question = { - "content": req.get("question"), - "role": "user", - "id": str(uuid4()) - } - conv.message.append(question) - for m in conv.message: - if m["role"] == "system": - continue - if m["role"] == "assistant" and not msg: - continue - msg.append(m) - message_id = msg[-1].get("id") - e, dia = DialogService.get_by_id(conv.dialog_id) - - if not conv.reference: - conv.reference = [] - conv.message.append({"role": "assistant", "content": "", "id": message_id}) - conv.reference.append({"chunks": [], "doc_aggs": []}) - - def fillin_conv(ans): - reference = ans["reference"] - temp_reference = deepcopy(ans["reference"]) - nonlocal conv, message_id - if not conv.reference: - conv.reference.append(temp_reference) - else: - conv.reference[-1] = temp_reference - conv.message[-1] = {"role": "assistant", "content": ans["answer"], - "id": message_id, "prompt": ans.get("prompt", "")} - if "chunks" in reference: - chunks = reference.get("chunks") - chunk_list = [] - for chunk in chunks: - new_chunk = { - "id": chunk["chunk_id"], - "content": chunk["content_with_weight"], - "document_id": chunk["doc_id"], - "document_name": chunk["docnm_kwd"], - "dataset_id": chunk["kb_id"], - "image_id": chunk.get("image_id", ""), - "similarity": chunk["similarity"], - "vector_similarity": chunk["vector_similarity"], - "term_similarity": chunk["term_similarity"], - "positions": chunk.get("positions", []), - } - chunk_list.append(new_chunk) - reference["chunks"] = chunk_list - ans["id"] = message_id - ans["session_id"] = session_id - - def stream(): - nonlocal dia, msg, req, conv - try: - for ans in chat(dia, msg, **req): - fillin_conv(ans) - yield "data:" + json.dumps({"code": 0, "data": ans}, ensure_ascii=False) + "\n\n" - ConversationService.update_by_id(conv.id, conv.to_dict()) - except Exception as e: - yield "data:" + json.dumps({"code": 500, "message": str(e), - "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, - ensure_ascii=False) + "\n\n" - yield "data:" + json.dumps({"code": 0, "data": True}, ensure_ascii=False) + "\n\n" - if req.get("stream", True): - resp = Response(stream(), mimetype="text/event-stream") + resp = Response(rag_completion(tenant_id, chat_id, **req), mimetype="text/event-stream") resp.headers.add_header("Cache-control", "no-cache") resp.headers.add_header("Connection", "keep-alive") resp.headers.add_header("X-Accel-Buffering", "no") @@ -211,172 +123,26 @@ def completion(tenant_id, chat_id): else: answer = None - for ans in chat(dia, msg, **req): + for ans in rag_completion(tenant_id, chat_id, **req): answer = ans - fillin_conv(ans) - ConversationService.update_by_id(conv.id, conv.to_dict()) break return get_result(data=answer) @manager.route('/agents//completions', methods=['POST']) # noqa: F821 @token_required -def agent_completion(tenant_id, agent_id): +def agent_completions(tenant_id, agent_id): req = request.json - - e, cvs = UserCanvasService.get_by_id(agent_id) - if not e: - return get_error_data_result("Agent not found.") - if cvs.user_id != tenant_id: - return get_error_data_result(message="You do not own the agent.") - if not isinstance(cvs.dsl, str): - cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False) - canvas = Canvas(cvs.dsl, tenant_id) - - if not req.get("session_id"): - session_id = get_uuid() - conv = { - "id": session_id, - "dialog_id": cvs.id, - "user_id": req.get("user_id", ""), - "message": [{"role": "assistant", "content": canvas.get_prologue()}], - "source": "agent", - "dsl": json.loads(cvs.dsl) - } - API4ConversationService.save(**conv) - conv = API4Conversation(**conv) - else: - session_id = req.get("session_id") - e, conv = API4ConversationService.get_by_id(req["session_id"]) - if not e: - return get_error_data_result(message="Session not found!") - canvas = Canvas(json.dumps(conv.dsl), tenant_id) - - messages = conv.message - question = req.get("question") - if not question: - return get_error_data_result("`question` is required.") - question = { - "role": "user", - "content": question, - "id": str(uuid4()) - } - messages.append(question) - msg = [] - for m in messages: - if m["role"] == "system": - continue - if m["role"] == "assistant" and not msg: - continue - msg.append(m) - if not msg[-1].get("id"): - msg[-1]["id"] = get_uuid() - message_id = msg[-1]["id"] - - stream = req.get("stream", True) - - def fillin_conv(ans): - reference = ans["reference"] - temp_reference = deepcopy(ans["reference"]) - nonlocal conv, message_id - if not conv.reference: - conv.reference.append(temp_reference) - else: - conv.reference[-1] = temp_reference - conv.message[-1] = {"role": "assistant", "content": ans["answer"], "id": message_id} - if "chunks" in reference: - chunks = reference.get("chunks") - chunk_list = [] - for chunk in chunks: - new_chunk = { - "id": chunk["chunk_id"], - "content": chunk["content"], - "document_id": chunk["doc_id"], - "document_name": chunk["docnm_kwd"], - "dataset_id": chunk["kb_id"], - "image_id": chunk["image_id"], - "similarity": chunk["similarity"], - "vector_similarity": chunk["vector_similarity"], - "term_similarity": chunk["term_similarity"], - "positions": chunk["positions"], - } - chunk_list.append(new_chunk) - reference["chunks"] = chunk_list - ans["id"] = message_id - ans["session_id"] = session_id - - def rename_field(ans): - reference = ans['reference'] - if not isinstance(reference, dict): - return - for chunk_i in reference.get('chunks', []): - if 'docnm_kwd' in chunk_i: - chunk_i['doc_name'] = chunk_i['docnm_kwd'] - chunk_i.pop('docnm_kwd') - - if not conv.reference: - conv.reference = [] - conv.message.append({"role": "assistant", "content": "", "id": message_id}) - conv.reference.append({"chunks": [], "doc_aggs": []}) - - final_ans = {"reference": [], "content": ""} - - canvas.add_user_input(msg[-1]["content"]) - - if stream: - def sse(): - nonlocal answer, cvs - try: - for ans in canvas.run(stream=stream): - if ans.get("running_status"): - yield "data:" + json.dumps({"code": 0, "message": "", - "data": {"answer": ans["content"], - "running_status": True}}, - ensure_ascii=False) + "\n\n" - continue - for k in ans.keys(): - final_ans[k] = ans[k] - ans = {"answer": ans["content"], "reference": ans.get("reference", [])} - fillin_conv(ans) - rename_field(ans) - yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, - ensure_ascii=False) + "\n\n" - - canvas.messages.append({"role": "assistant", "content": final_ans["content"], "id": message_id}) - canvas.history.append(("assistant", final_ans["content"])) - if final_ans.get("reference"): - canvas.reference.append(final_ans["reference"]) - conv.dsl = json.loads(str(canvas)) - API4ConversationService.append_message(conv.id, conv.to_dict()) - except Exception as e: - conv.dsl = json.loads(str(canvas)) - API4ConversationService.append_message(conv.id, conv.to_dict()) - yield "data:" + json.dumps({"code": 500, "message": str(e), - "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, - ensure_ascii=False) + "\n\n" - yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n" - - resp = Response(sse(), mimetype="text/event-stream") + if req.get("stream", True): + resp = Response(agent_completion(tenant_id, agent_id, **req), mimetype="text/event-stream") resp.headers.add_header("Cache-control", "no-cache") resp.headers.add_header("Connection", "keep-alive") resp.headers.add_header("X-Accel-Buffering", "no") resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8") return resp - for answer in canvas.run(stream=False): - if answer.get("running_status"): - continue - final_ans["content"] = "\n".join(answer["content"]) if "content" in answer else "" - canvas.messages.append({"role": "assistant", "content": final_ans["content"], "id": message_id}) - if final_ans.get("reference"): - canvas.reference.append(final_ans["reference"]) - conv.dsl = json.loads(str(canvas)) - - result = {"answer": final_ans["content"], "reference": final_ans.get("reference", [])} - fillin_conv(result) - API4ConversationService.append_message(conv.id, conv.to_dict()) - rename_field(result) - return get_result(data=result) + for answer in agent_completion(tenant_id, agent_id, **req): + return get_result(data=answer) @manager.route('/chats//sessions', methods=['GET']) # noqa: F821 @@ -590,3 +356,57 @@ Keywords: {question} Related search terms: """}], {"temperature": 0.9}) return get_result(data=[re.sub(r"^[0-9]\. ", "", a) for a in ans.split("\n") if re.match(r"^[0-9]\. ", a)]) + + +@manager.route('/chatbots//completions', methods=['POST']) # noqa: F821 +def chatbot_completions(dialog_id): + req = request.json + + token = request.headers.get('Authorization').split() + if len(token) != 2: + return get_error_data_result(message='Authorization is not valid!"') + token = token[1] + objs = APIToken.query(beta=token) + if not objs: + return get_error_data_result(message='Token is not valid!"') + + if "quote" not in req: + req["quote"] = False + + if req.get("stream", True): + resp = Response(iframe_completion(objs[0].tenant_id, dialog_id, **req), mimetype="text/event-stream") + resp.headers.add_header("Cache-control", "no-cache") + resp.headers.add_header("Connection", "keep-alive") + resp.headers.add_header("X-Accel-Buffering", "no") + resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8") + return resp + + for answer in agent_completion(objs[0].tenant_id, dialog_id, **req): + return get_result(data=answer) + + +@manager.route('/agentbots//completions', methods=['POST']) # noqa: F821 +def agent_bot_completions(agent_id): + req = request.json + + token = request.headers.get('Authorization').split() + if len(token) != 2: + return get_error_data_result(message='Authorization is not valid!"') + token = token[1] + objs = APIToken.query(beta=token) + if not objs: + return get_error_data_result(message='Token is not valid!"') + + if "quote" not in req: + req["quote"] = False + + if req.get("stream", True): + resp = Response(agent_completion(objs[0].tenant_id, agent_id, **req), mimetype="text/event-stream") + resp.headers.add_header("Cache-control", "no-cache") + resp.headers.add_header("Connection", "keep-alive") + resp.headers.add_header("X-Accel-Buffering", "no") + resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8") + return resp + + for answer in agent_completion(objs[0].tenant_id, agent_id, **req): + return get_result(data=answer) diff --git a/api/apps/system_app.py b/api/apps/system_app.py index 7b715f14b..28b74d658 100644 --- a/api/apps/system_app.py +++ b/api/apps/system_app.py @@ -205,6 +205,7 @@ def new_token(): obj = { "tenant_id": tenant_id, "token": generate_confirmation_token(tenant_id), + "beta": generate_confirmation_token(generate_confirmation_token(tenant_id)).replace("ragflow-", "")[:32], "create_time": current_timestamp(), "create_date": datetime_format(datetime.now()), "update_time": None, diff --git a/api/db/db_models.py b/api/db/db_models.py index 0c052ca18..0c4d12c03 100644 --- a/api/db/db_models.py +++ b/api/db/db_models.py @@ -934,6 +934,7 @@ class APIToken(DataBaseModel): token = CharField(max_length=255, null=False, index=True) dialog_id = CharField(max_length=32, null=False, index=True) source = CharField(max_length=16, null=True, help_text="none|agent|dialog", index=True) + beta = CharField(max_length=255, null=True, index=True) class Meta: db_table = "api_token" @@ -1083,4 +1084,10 @@ def migrate_db(): ) except Exception: pass + try: + migrate( + migrator.add_column("api_token", "beta", CharField(max_length=255, null=True, index=True)) + ) + except Exception: + pass diff --git a/api/db/services/canvas_service.py b/api/db/services/canvas_service.py index 0fac2f248..e7aada1fd 100644 --- a/api/db/services/canvas_service.py +++ b/api/db/services/canvas_service.py @@ -13,8 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import json +from uuid import uuid4 +from agent.canvas import Canvas from api.db.db_models import DB, CanvasTemplate, UserCanvas +from api.db.services.api_service import API4ConversationService from api.db.services.common_service import CommonService +from api.db.services.conversation_service import structure_answer +from api.utils import get_uuid class CanvasTemplateService(CommonService): @@ -42,3 +48,115 @@ class UserCanvasService(CommonService): agents = agents.paginate(page_number, items_per_page) return list(agents.dicts()) + + +def completion(tenant_id, agent_id, question, session_id=None, stream=True, **kwargs): + e, cvs = UserCanvasService.get_by_id(agent_id) + assert e, "Agent not found." + assert cvs.user_id == tenant_id, "You do not own the agent." + + if not isinstance(cvs.dsl, str): + cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False) + canvas = Canvas(cvs.dsl, tenant_id) + + if not session_id: + session_id = get_uuid() + conv = { + "id": session_id, + "dialog_id": cvs.id, + "user_id": kwargs.get("user_id", ""), + "message": [{"role": "assistant", "content": canvas.get_prologue()}], + "source": "agent", + "dsl": json.loads(cvs.dsl) + } + API4ConversationService.save(**conv) + yield "data:" + json.dumps({"code": 0, + "message": "", + "data": { + "session_id": session_id, + "answer": canvas.get_prologue(), + "reference": [], + "param": canvas.get_preset_param() + } + }, + ensure_ascii=False) + "\n\n" + yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n" + return + else: + session_id = session_id + e, conv = API4ConversationService.get_by_id(session_id) + assert e, "Session not found!" + canvas = Canvas(json.dumps(conv.dsl), tenant_id) + + messages = conv.message + question = { + "role": "user", + "content": question, + "id": str(uuid4()) + } + messages.append(question) + msg = [] + for m in messages: + if m["role"] == "system": + continue + if m["role"] == "assistant" and not msg: + continue + msg.append(m) + if not msg[-1].get("id"): + msg[-1]["id"] = get_uuid() + message_id = msg[-1]["id"] + + if not conv.reference: + conv.reference = [] + conv.message.append({"role": "assistant", "content": "", "id": message_id}) + conv.reference.append({"chunks": [], "doc_aggs": []}) + + final_ans = {"reference": [], "content": ""} + + canvas.add_user_input(msg[-1]["content"]) + + if stream: + try: + for ans in canvas.run(stream=stream): + if ans.get("running_status"): + yield "data:" + json.dumps({"code": 0, "message": "", + "data": {"answer": ans["content"], + "running_status": True}}, + ensure_ascii=False) + "\n\n" + continue + for k in ans.keys(): + final_ans[k] = ans[k] + ans = {"answer": ans["content"], "reference": ans.get("reference", [])} + ans = structure_answer(conv, ans, message_id, session_id) + yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, + ensure_ascii=False) + "\n\n" + + canvas.messages.append({"role": "assistant", "content": final_ans["content"], "id": message_id}) + canvas.history.append(("assistant", final_ans["content"])) + if final_ans.get("reference"): + canvas.reference.append(final_ans["reference"]) + conv.dsl = json.loads(str(canvas)) + API4ConversationService.append_message(conv.id, conv.to_dict()) + except Exception as e: + conv.dsl = json.loads(str(canvas)) + API4ConversationService.append_message(conv.id, conv.to_dict()) + yield "data:" + json.dumps({"code": 500, "message": str(e), + "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, + ensure_ascii=False) + "\n\n" + yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n" + + else: + for answer in canvas.run(stream=False): + if answer.get("running_status"): + continue + final_ans["content"] = "\n".join(answer["content"]) if "content" in answer else "" + canvas.messages.append({"role": "assistant", "content": final_ans["content"], "id": message_id}) + if final_ans.get("reference"): + canvas.reference.append(final_ans["reference"]) + conv.dsl = json.loads(str(canvas)) + + result = {"answer": final_ans["content"], "reference": final_ans.get("reference", [])} + result = structure_answer(conv, result, message_id, session_id) + API4ConversationService.append_message(conv.id, conv.to_dict()) + yield result + break \ No newline at end of file diff --git a/api/db/services/conversation_service.py b/api/db/services/conversation_service.py new file mode 100644 index 000000000..7844bcc6a --- /dev/null +++ b/api/db/services/conversation_service.py @@ -0,0 +1,221 @@ +# +# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from uuid import uuid4 +from api.db import StatusEnum +from api.db.db_models import Conversation, DB +from api.db.services.api_service import API4ConversationService +from api.db.services.common_service import CommonService +from api.db.services.dialog_service import DialogService, chat +from api.utils import get_uuid +import json +from copy import deepcopy + + +class ConversationService(CommonService): + model = Conversation + + @classmethod + @DB.connection_context() + def get_list(cls,dialog_id,page_number, items_per_page, orderby, desc, id , name): + sessions = cls.model.select().where(cls.model.dialog_id ==dialog_id) + if id: + sessions = sessions.where(cls.model.id == id) + if name: + sessions = sessions.where(cls.model.name == name) + if desc: + sessions = sessions.order_by(cls.model.getter_by(orderby).desc()) + else: + sessions = sessions.order_by(cls.model.getter_by(orderby).asc()) + + sessions = sessions.paginate(page_number, items_per_page) + + return list(sessions.dicts()) + + +def structure_answer(conv, ans, message_id, session_id): + reference = ans["reference"] + temp_reference = deepcopy(ans["reference"]) + if not conv.reference: + conv.reference.append(temp_reference) + else: + conv.reference[-1] = temp_reference + conv.message[-1] = {"role": "assistant", "content": ans["answer"], "id": message_id} + + chunk_list = [{ + "id": chunk["chunk_id"], + "content": chunk["content"], + "document_id": chunk["doc_id"], + "document_name": chunk["docnm_kwd"], + "dataset_id": chunk["kb_id"], + "image_id": chunk["image_id"], + "similarity": chunk["similarity"], + "vector_similarity": chunk["vector_similarity"], + "term_similarity": chunk["term_similarity"], + "positions": chunk["positions"], + } for chunk in reference.get("chunks", [])] + + reference["chunks"] = chunk_list + ans["id"] = message_id + ans["session_id"] = session_id + + return ans + + +def completion(tenant_id, chat_id, question, name="New session", session_id=None, stream=True, **kwargs): + assert name, "`name` can not be empty." + dia = DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value) + assert dia, "You do not own the chat." + + if not session_id: + conv = { + "id": get_uuid(), + "dialog_id": chat_id, + "name": name, + "message": [{"role": "assistant", "content": dia[0].prompt_config.get("prologue")}] + } + ConversationService.save(**conv) + yield "data:" + json.dumps({"code": 0, "message": "", + "data": { + "answer": conv["message"][0]["content"], + "reference": {}, + "audio_binary": None, + "id": None, + "session_id": session_id + }}, + ensure_ascii=False) + "\n\n" + yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n" + return + + conv = ConversationService.query(id=session_id, dialog_id=chat_id) + if not conv: + raise LookupError("Session does not exist") + + conv = conv[0] + msg = [] + question = { + "content": question, + "role": "user", + "id": str(uuid4()) + } + conv.message.append(question) + for m in conv.message: + if m["role"] == "system": + continue + if m["role"] == "assistant" and not msg: + continue + msg.append(m) + message_id = msg[-1].get("id") + e, dia = DialogService.get_by_id(conv.dialog_id) + + if not conv.reference: + conv.reference = [] + conv.message.append({"role": "assistant", "content": "", "id": message_id}) + conv.reference.append({"chunks": [], "doc_aggs": []}) + + if stream: + try: + for ans in chat(dia, msg, True, **kwargs): + ans = structure_answer(conv, ans, message_id, session_id) + yield "data:" + json.dumps({"code": 0, "data": ans}, ensure_ascii=False) + "\n\n" + ConversationService.update_by_id(conv.id, conv.to_dict()) + except Exception as e: + yield "data:" + json.dumps({"code": 500, "message": str(e), + "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, + ensure_ascii=False) + "\n\n" + yield "data:" + json.dumps({"code": 0, "data": True}, ensure_ascii=False) + "\n\n" + + else: + answer = None + for ans in chat(dia, msg, False, **kwargs): + answer = structure_answer(conv, ans, message_id, session_id) + ConversationService.update_by_id(conv.id, conv.to_dict()) + break + yield answer + + +def iframe_completion(dialog_id, question, session_id=None, stream=True, **kwargs): + e, dia = DialogService.get_by_id(dialog_id) + assert e, "Dialog not found" + if not session_id: + session_id = get_uuid() + conv = { + "id": session_id, + "dialog_id": dialog_id, + "user_id": kwargs.get("user_id", ""), + "message": [{"role": "assistant", "content": dia.prompt_config["prologue"]}] + } + API4ConversationService.save(**conv) + yield "data:" + json.dumps({"code": 0, "message": "", + "data": { + "answer": conv["message"][0]["content"], + "reference": {}, + "audio_binary": None, + "id": None, + "session_id": session_id + }}, + ensure_ascii=False) + "\n\n" + yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n" + return + else: + session_id = session_id + e, conv = API4ConversationService.get_by_id(session_id) + assert e, "Session not found!" + + messages = conv.message + question = { + "role": "user", + "content": question, + "id": str(uuid4()) + } + messages.append(question) + + msg = [] + for m in messages: + if m["role"] == "system": + continue + if m["role"] == "assistant" and not msg: + continue + msg.append(m) + if not msg[-1].get("id"): + msg[-1]["id"] = get_uuid() + message_id = msg[-1]["id"] + + if not conv.reference: + conv.reference = [] + conv.message.append({"role": "assistant", "content": "", "id": message_id}) + conv.reference.append({"chunks": [], "doc_aggs": []}) + + if stream: + try: + for ans in chat(dia, msg, True, **kwargs): + ans = structure_answer(conv, ans, message_id, session_id) + yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, + ensure_ascii=False) + "\n\n" + API4ConversationService.append_message(conv.id, conv.to_dict()) + except Exception as e: + yield "data:" + json.dumps({"code": 500, "message": str(e), + "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, + ensure_ascii=False) + "\n\n" + yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n" + + else: + answer = None + for ans in chat(dia, msg, False, **kwargs): + answer = structure_answer(conv, ans, message_id, session_id) + API4ConversationService.append_message(conv.id, conv.to_dict()) + break + yield answer + diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index 1a63b7962..6fd490187 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -23,7 +23,7 @@ from timeit import default_timer as timer import datetime from datetime import timedelta from api.db import LLMType, ParserType,StatusEnum -from api.db.db_models import Dialog, Conversation,DB +from api.db.db_models import Dialog, DB from api.db.services.common_service import CommonService from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.llm_service import LLMService, TenantLLMService, LLMBundle @@ -60,27 +60,6 @@ class DialogService(CommonService): return list(chats.dicts()) -class ConversationService(CommonService): - model = Conversation - - @classmethod - @DB.connection_context() - def get_list(cls,dialog_id,page_number, items_per_page, orderby, desc, id , name): - sessions = cls.model.select().where(cls.model.dialog_id ==dialog_id) - if id: - sessions = sessions.where(cls.model.id == id) - if name: - sessions = sessions.where(cls.model.name == name) - if desc: - sessions = sessions.order_by(cls.model.getter_by(orderby).desc()) - else: - sessions = sessions.order_by(cls.model.getter_by(orderby).asc()) - - sessions = sessions.paginate(page_number, items_per_page) - - return list(sessions.dicts()) - - def message_fit_in(msg, max_length=4000): def count(): nonlocal msg diff --git a/api/db/services/document_service.py b/api/db/services/document_service.py index aea4931eb..f4b6f6874 100644 --- a/api/db/services/document_service.py +++ b/api/db/services/document_service.py @@ -425,11 +425,12 @@ def queue_raptor_tasks(doc): def doc_upload_and_parse(conversation_id, file_objs, user_id): from rag.app import presentation, picture, naive, audio, email - from api.db.services.dialog_service import ConversationService, DialogService + from api.db.services.dialog_service import DialogService from api.db.services.file_service import FileService from api.db.services.llm_service import LLMBundle from api.db.services.user_service import TenantService from api.db.services.api_service import API4ConversationService + from api.db.services.conversation_service import ConversationService e, conv = ConversationService.get_by_id(conversation_id) if not e: