add agent completion API (#3192)

### What problem does this PR solve?

#3105

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Kevin Hu 2024-11-04 17:20:16 +08:00 committed by GitHub
parent 57f23e0808
commit 8305632852
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 160 additions and 2 deletions

View File

@ -14,16 +14,22 @@
# limitations under the License. # limitations under the License.
# #
import json import json
from functools import partial
from uuid import uuid4 from uuid import uuid4
from flask import request, Response from flask import request, Response
from agent.canvas import Canvas
from api.db import StatusEnum from api.db import StatusEnum
from api.db.db_models import API4Conversation
from api.db.services.api_service import API4ConversationService
from api.db.services.canvas_service import UserCanvasService
from api.db.services.dialog_service import DialogService, ConversationService, chat from api.db.services.dialog_service import DialogService, ConversationService, chat
from api.utils import get_uuid from api.utils import get_uuid
from api.utils.api_utils import get_error_data_result from api.utils.api_utils import get_error_data_result
from api.utils.api_utils import get_result, token_required from api.utils.api_utils import get_result, token_required
@manager.route('/chats/<chat_id>/sessions', methods=['POST']) @manager.route('/chats/<chat_id>/sessions', methods=['POST'])
@token_required @token_required
def create(tenant_id,chat_id): def create(tenant_id,chat_id):
@ -31,7 +37,7 @@ def create(tenant_id,chat_id):
req["dialog_id"] = chat_id req["dialog_id"] = chat_id
dia = DialogService.query(tenant_id=tenant_id, id=req["dialog_id"], status=StatusEnum.VALID.value) dia = DialogService.query(tenant_id=tenant_id, id=req["dialog_id"], status=StatusEnum.VALID.value)
if not dia: if not dia:
return get_error_data_result(retmsg="You do not own the assistant") return get_error_data_result(retmsg="You do not own the assistant.")
conv = { conv = {
"id": get_uuid(), "id": get_uuid(),
"dialog_id": req["dialog_id"], "dialog_id": req["dialog_id"],
@ -50,6 +56,32 @@ def create(tenant_id,chat_id):
del conv["reference"] del conv["reference"]
return get_result(data=conv) return get_result(data=conv)
@manager.route('/agents/<agent_id>/sessions', methods=['POST'])
@token_required
def create_agent_session(tenant_id, agent_id):
req = request.json
e, cvs = UserCanvasService.get_by_id(agent_id)
if not e:
return get_error_data_result("Agent not found.")
if cvs.user_id != tenant_id:
return get_error_data_result(retmsg="You do not own the agent.")
if not isinstance(cvs.dsl, str):
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
canvas = Canvas(cvs.dsl, tenant_id)
conv = {
"id": get_uuid(),
"dialog_id": cvs.id,
"user_id": req.get("user_id", ""),
"message": [{"role": "assistant", "content": canvas.get_prologue()}],
"source": "agent"
}
API4ConversationService.save(**conv)
return get_result(data=conv)
@manager.route('/chats/<chat_id>/sessions/<session_id>', methods=['PUT']) @manager.route('/chats/<chat_id>/sessions/<session_id>', methods=['PUT'])
@token_required @token_required
def update(tenant_id,chat_id,session_id): def update(tenant_id,chat_id,session_id):
@ -74,7 +106,7 @@ def update(tenant_id,chat_id,session_id):
@manager.route('/chats/<chat_id>/completions', methods=['POST']) @manager.route('/chats/<chat_id>/completions', methods=['POST'])
@token_required @token_required
def completion(tenant_id,chat_id): def completion(tenant_id, chat_id):
req = request.json req = request.json
if not req.get("session_id"): if not req.get("session_id"):
conv = { conv = {
@ -158,6 +190,130 @@ def completion(tenant_id,chat_id):
break break
return get_result(data=answer) return get_result(data=answer)
@manager.route('/agents/<agent_id>/completions', methods=['POST'])
@token_required
def agent_completion(tenant_id, agent_id):
req = request.json
e, cvs = UserCanvasService.get_by_id(agent_id)
if not e:
return get_error_data_result("Agent not found.")
if cvs.user_id != tenant_id:
return get_error_data_result(retmsg="You do not own the agent.")
if not isinstance(cvs.dsl, str):
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
canvas = Canvas(cvs.dsl, tenant_id)
msg = []
for m in req["messages"]:
if m["role"] == "system":
continue
if m["role"] == "assistant" and not msg:
continue
msg.append(m)
if not msg[-1].get("id"): msg[-1]["id"] = get_uuid()
message_id = msg[-1]["id"]
if not req.get("session_id"):
session_id = get_uuid()
conv = {
"id": session_id,
"dialog_id": cvs.id,
"user_id": req.get("user_id", ""),
"message": [{"role": "assistant", "content": canvas.get_prologue()}],
"source": "agent"
}
API4ConversationService.save(**conv)
conv = API4Conversation(**conv)
else:
session_id = req.get("session_id")
e, conv = API4ConversationService.get_by_id(req["session_id"])
if not e:
return get_error_data_result(retmsg="Session not found!")
if "quote" not in req: req["quote"] = False
stream = req.get("stream", True)
def fillin_conv(ans):
nonlocal conv, message_id
if not conv.reference:
conv.reference.append(ans["reference"])
else:
conv.reference[-1] = ans["reference"]
conv.message[-1] = {"role": "assistant", "content": ans["answer"], "id": message_id}
ans["id"] = message_id
ans["session_id"] = session_id
def rename_field(ans):
reference = ans['reference']
if not isinstance(reference, dict):
return
for chunk_i in reference.get('chunks', []):
if 'docnm_kwd' in chunk_i:
chunk_i['doc_name'] = chunk_i['docnm_kwd']
chunk_i.pop('docnm_kwd')
conv.message.append(msg[-1])
if not conv.reference:
conv.reference = []
conv.message.append({"role": "assistant", "content": "", "id": message_id})
conv.reference.append({"chunks": [], "doc_aggs": []})
final_ans = {"reference": [], "content": ""}
canvas.messages.append(msg[-1])
canvas.add_user_input(msg[-1]["content"])
answer = canvas.run(stream=stream)
assert answer is not None, "Nothing. Is it over?"
if stream:
assert isinstance(answer, partial), "Nothing. Is it over?"
def sse():
nonlocal answer, cvs, conv
try:
for ans in answer():
for k in ans.keys():
final_ans[k] = ans[k]
ans = {"answer": ans["content"], "reference": ans.get("reference", [])}
fillin_conv(ans)
rename_field(ans)
yield "data:" + json.dumps({"retcode": 0, "retmsg": "", "data": ans},
ensure_ascii=False) + "\n\n"
canvas.messages.append({"role": "assistant", "content": final_ans["content"], "id": message_id})
if final_ans.get("reference"):
canvas.reference.append(final_ans["reference"])
cvs.dsl = json.loads(str(canvas))
API4ConversationService.append_message(conv.id, conv.to_dict())
except Exception as e:
yield "data:" + json.dumps({"retcode": 500, "retmsg": str(e),
"data": {"answer": "**ERROR**: " + str(e), "reference": []}},
ensure_ascii=False) + "\n\n"
yield "data:" + json.dumps({"retcode": 0, "retmsg": "", "data": True}, ensure_ascii=False) + "\n\n"
resp = Response(sse(), mimetype="text/event-stream")
resp.headers.add_header("Cache-control", "no-cache")
resp.headers.add_header("Connection", "keep-alive")
resp.headers.add_header("X-Accel-Buffering", "no")
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
return resp
final_ans["content"] = "\n".join(answer["content"]) if "content" in answer else ""
canvas.messages.append({"role": "assistant", "content": final_ans["content"], "id": message_id})
if final_ans.get("reference"):
canvas.reference.append(final_ans["reference"])
cvs.dsl = json.loads(str(canvas))
result = {"answer": final_ans["content"], "reference": final_ans.get("reference", [])}
fillin_conv(result)
API4ConversationService.append_message(conv.id, conv.to_dict())
rename_field(result)
return get_result(data=result)
@manager.route('/chats/<chat_id>/sessions', methods=['GET']) @manager.route('/chats/<chat_id>/sessions', methods=['GET'])
@token_required @token_required
def list(chat_id,tenant_id): def list(chat_id,tenant_id):
@ -211,6 +367,7 @@ def list(chat_id,tenant_id):
del conv["reference"] del conv["reference"]
return get_result(data=convs) return get_result(data=convs)
@manager.route('/chats/<chat_id>/sessions', methods=["DELETE"]) @manager.route('/chats/<chat_id>/sessions', methods=["DELETE"])
@token_required @token_required
def delete(tenant_id,chat_id): def delete(tenant_id,chat_id):

View File

@ -22,5 +22,6 @@ from api.db.services.common_service import CommonService
class CanvasTemplateService(CommonService): class CanvasTemplateService(CommonService):
model = CanvasTemplate model = CanvasTemplate
class UserCanvasService(CommonService): class UserCanvasService(CommonService):
model = UserCanvas model = UserCanvas