From 8305632852a44648b923acd1e5a85627040519a4 Mon Sep 17 00:00:00 2001 From: Kevin Hu Date: Mon, 4 Nov 2024 17:20:16 +0800 Subject: [PATCH] add agent completion API (#3192) ### What problem does this PR solve? #3105 ### Type of change - [x] New Feature (non-breaking change which adds functionality) --- api/apps/sdk/session.py | 161 +++++++++++++++++++++++++++++- api/db/services/canvas_service.py | 1 + 2 files changed, 160 insertions(+), 2 deletions(-) diff --git a/api/apps/sdk/session.py b/api/apps/sdk/session.py index 1f4434525..1b7b0ce0c 100644 --- a/api/apps/sdk/session.py +++ b/api/apps/sdk/session.py @@ -14,16 +14,22 @@ # limitations under the License. # import json +from functools import partial from uuid import uuid4 from flask import request, Response +from agent.canvas import Canvas from api.db import StatusEnum +from api.db.db_models import API4Conversation +from api.db.services.api_service import API4ConversationService +from api.db.services.canvas_service import UserCanvasService from api.db.services.dialog_service import DialogService, ConversationService, chat from api.utils import get_uuid from api.utils.api_utils import get_error_data_result from api.utils.api_utils import get_result, token_required + @manager.route('/chats//sessions', methods=['POST']) @token_required def create(tenant_id,chat_id): @@ -31,7 +37,7 @@ def create(tenant_id,chat_id): req["dialog_id"] = chat_id dia = DialogService.query(tenant_id=tenant_id, id=req["dialog_id"], status=StatusEnum.VALID.value) if not dia: - return get_error_data_result(retmsg="You do not own the assistant") + return get_error_data_result(retmsg="You do not own the assistant.") conv = { "id": get_uuid(), "dialog_id": req["dialog_id"], @@ -50,6 +56,32 @@ def create(tenant_id,chat_id): del conv["reference"] return get_result(data=conv) + +@manager.route('/agents//sessions', methods=['POST']) +@token_required +def create_agent_session(tenant_id, agent_id): + req = request.json + e, cvs = UserCanvasService.get_by_id(agent_id) + if not e: + return get_error_data_result("Agent not found.") + if cvs.user_id != tenant_id: + return get_error_data_result(retmsg="You do not own the agent.") + + if not isinstance(cvs.dsl, str): + cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False) + + canvas = Canvas(cvs.dsl, tenant_id) + conv = { + "id": get_uuid(), + "dialog_id": cvs.id, + "user_id": req.get("user_id", ""), + "message": [{"role": "assistant", "content": canvas.get_prologue()}], + "source": "agent" + } + API4ConversationService.save(**conv) + return get_result(data=conv) + + @manager.route('/chats//sessions/', methods=['PUT']) @token_required def update(tenant_id,chat_id,session_id): @@ -74,7 +106,7 @@ def update(tenant_id,chat_id,session_id): @manager.route('/chats//completions', methods=['POST']) @token_required -def completion(tenant_id,chat_id): +def completion(tenant_id, chat_id): req = request.json if not req.get("session_id"): conv = { @@ -158,6 +190,130 @@ def completion(tenant_id,chat_id): break return get_result(data=answer) + +@manager.route('/agents//completions', methods=['POST']) +@token_required +def agent_completion(tenant_id, agent_id): + req = request.json + e, cvs = UserCanvasService.get_by_id(agent_id) + if not e: + return get_error_data_result("Agent not found.") + if cvs.user_id != tenant_id: + return get_error_data_result(retmsg="You do not own the agent.") + if not isinstance(cvs.dsl, str): + cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False) + + canvas = Canvas(cvs.dsl, tenant_id) + + msg = [] + for m in req["messages"]: + if m["role"] == "system": + continue + if m["role"] == "assistant" and not msg: + continue + msg.append(m) + if not msg[-1].get("id"): msg[-1]["id"] = get_uuid() + message_id = msg[-1]["id"] + + if not req.get("session_id"): + session_id = get_uuid() + conv = { + "id": session_id, + "dialog_id": cvs.id, + "user_id": req.get("user_id", ""), + "message": [{"role": "assistant", "content": canvas.get_prologue()}], + "source": "agent" + } + API4ConversationService.save(**conv) + conv = API4Conversation(**conv) + else: + session_id = req.get("session_id") + e, conv = API4ConversationService.get_by_id(req["session_id"]) + if not e: + return get_error_data_result(retmsg="Session not found!") + + if "quote" not in req: req["quote"] = False + stream = req.get("stream", True) + + def fillin_conv(ans): + nonlocal conv, message_id + if not conv.reference: + conv.reference.append(ans["reference"]) + else: + conv.reference[-1] = ans["reference"] + conv.message[-1] = {"role": "assistant", "content": ans["answer"], "id": message_id} + ans["id"] = message_id + ans["session_id"] = session_id + + def rename_field(ans): + reference = ans['reference'] + if not isinstance(reference, dict): + return + for chunk_i in reference.get('chunks', []): + if 'docnm_kwd' in chunk_i: + chunk_i['doc_name'] = chunk_i['docnm_kwd'] + chunk_i.pop('docnm_kwd') + conv.message.append(msg[-1]) + + if not conv.reference: + conv.reference = [] + conv.message.append({"role": "assistant", "content": "", "id": message_id}) + conv.reference.append({"chunks": [], "doc_aggs": []}) + + final_ans = {"reference": [], "content": ""} + + canvas.messages.append(msg[-1]) + canvas.add_user_input(msg[-1]["content"]) + answer = canvas.run(stream=stream) + + assert answer is not None, "Nothing. Is it over?" + + if stream: + assert isinstance(answer, partial), "Nothing. Is it over?" + + def sse(): + nonlocal answer, cvs, conv + try: + for ans in answer(): + for k in ans.keys(): + final_ans[k] = ans[k] + ans = {"answer": ans["content"], "reference": ans.get("reference", [])} + fillin_conv(ans) + rename_field(ans) + yield "data:" + json.dumps({"retcode": 0, "retmsg": "", "data": ans}, + ensure_ascii=False) + "\n\n" + + canvas.messages.append({"role": "assistant", "content": final_ans["content"], "id": message_id}) + if final_ans.get("reference"): + canvas.reference.append(final_ans["reference"]) + cvs.dsl = json.loads(str(canvas)) + API4ConversationService.append_message(conv.id, conv.to_dict()) + except Exception as e: + yield "data:" + json.dumps({"retcode": 500, "retmsg": str(e), + "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, + ensure_ascii=False) + "\n\n" + yield "data:" + json.dumps({"retcode": 0, "retmsg": "", "data": True}, ensure_ascii=False) + "\n\n" + + resp = Response(sse(), mimetype="text/event-stream") + resp.headers.add_header("Cache-control", "no-cache") + resp.headers.add_header("Connection", "keep-alive") + resp.headers.add_header("X-Accel-Buffering", "no") + resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8") + return resp + + final_ans["content"] = "\n".join(answer["content"]) if "content" in answer else "" + canvas.messages.append({"role": "assistant", "content": final_ans["content"], "id": message_id}) + if final_ans.get("reference"): + canvas.reference.append(final_ans["reference"]) + cvs.dsl = json.loads(str(canvas)) + + result = {"answer": final_ans["content"], "reference": final_ans.get("reference", [])} + fillin_conv(result) + API4ConversationService.append_message(conv.id, conv.to_dict()) + rename_field(result) + return get_result(data=result) + + @manager.route('/chats//sessions', methods=['GET']) @token_required def list(chat_id,tenant_id): @@ -211,6 +367,7 @@ def list(chat_id,tenant_id): del conv["reference"] return get_result(data=convs) + @manager.route('/chats//sessions', methods=["DELETE"]) @token_required def delete(tenant_id,chat_id): diff --git a/api/db/services/canvas_service.py b/api/db/services/canvas_service.py index ed2cdf63a..112dcf0bf 100644 --- a/api/db/services/canvas_service.py +++ b/api/db/services/canvas_service.py @@ -22,5 +22,6 @@ from api.db.services.common_service import CommonService class CanvasTemplateService(CommonService): model = CanvasTemplate + class UserCanvasService(CommonService): model = UserCanvas