Fix the bug that the agent could not find the context (#3795)

### What problem does this PR solve?

Fix the bug that the agent could not find the context
#3682
### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

Co-authored-by: liuhua <10215101452@stu.ecun.edu.cn>
This commit is contained in:
liuhua 2024-12-02 19:05:18 +08:00 committed by GitHub
parent 8b650fc9ef
commit 9654e64a0a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 36 additions and 29 deletions

View File

@ -35,7 +35,7 @@ from api.db.services.llm_service import LLMBundle
@manager.route('/chats/<chat_id>/sessions', methods=['POST']) @manager.route('/chats/<chat_id>/sessions', methods=['POST'])
@token_required @token_required
def create(tenant_id, chat_id): def create(tenant_id,chat_id):
req = request.json req = request.json
req["dialog_id"] = chat_id req["dialog_id"] = chat_id
dia = DialogService.query(tenant_id=tenant_id, id=req["dialog_id"], status=StatusEnum.VALID.value) dia = DialogService.query(tenant_id=tenant_id, id=req["dialog_id"], status=StatusEnum.VALID.value)
@ -77,9 +77,10 @@ def create_agent_session(tenant_id, agent_id):
conv = { conv = {
"id": get_uuid(), "id": get_uuid(),
"dialog_id": cvs.id, "dialog_id": cvs.id,
"user_id": req.get("usr_id", "") if isinstance(req, dict) else "", "user_id": req.get("usr_id","") if isinstance(req, dict) else "",
"message": [{"role": "assistant", "content": canvas.get_prologue()}], "message": [{"role": "assistant", "content": canvas.get_prologue()}],
"source": "agent" "source": "agent",
"dsl":json.loads(cvs.dsl)
} }
API4ConversationService.save(**conv) API4ConversationService.save(**conv)
conv["agent_id"] = conv.pop("dialog_id") conv["agent_id"] = conv.pop("dialog_id")
@ -88,11 +89,11 @@ def create_agent_session(tenant_id, agent_id):
@manager.route('/chats/<chat_id>/sessions/<session_id>', methods=['PUT']) @manager.route('/chats/<chat_id>/sessions/<session_id>', methods=['PUT'])
@token_required @token_required
def update(tenant_id, chat_id, session_id): def update(tenant_id,chat_id,session_id):
req = request.json req = request.json
req["dialog_id"] = chat_id req["dialog_id"] = chat_id
conv_id = session_id conv_id = session_id
conv = ConversationService.query(id=conv_id, dialog_id=chat_id) conv = ConversationService.query(id=conv_id,dialog_id=chat_id)
if not conv: if not conv:
return get_error_data_result(message="Session does not exist") return get_error_data_result(message="Session does not exist")
if not DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value): if not DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value):
@ -123,12 +124,12 @@ def completion(tenant_id, chat_id):
return get_error_data_result(message="`name` can not be empty.") return get_error_data_result(message="`name` can not be empty.")
ConversationService.save(**conv) ConversationService.save(**conv)
e, conv = ConversationService.get_by_id(conv["id"]) e, conv = ConversationService.get_by_id(conv["id"])
session_id = conv.id session_id=conv.id
else: else:
session_id = req.get("session_id") session_id = req.get("session_id")
if not req.get("question"): if not req.get("question"):
return get_error_data_result(message="Please input your question.") return get_error_data_result(message="Please input your question.")
conv = ConversationService.query(id=session_id, dialog_id=chat_id) conv = ConversationService.query(id=session_id,dialog_id=chat_id)
if not conv: if not conv:
return get_error_data_result(message="Session does not exist") return get_error_data_result(message="Session does not exist")
conv = conv[0] conv = conv[0]
@ -182,18 +183,18 @@ def completion(tenant_id, chat_id):
chunk_list.append(new_chunk) chunk_list.append(new_chunk)
reference["chunks"] = chunk_list reference["chunks"] = chunk_list
ans["id"] = message_id ans["id"] = message_id
ans["session_id"] = session_id ans["session_id"]=session_id
def stream(): def stream():
nonlocal dia, msg, req, conv nonlocal dia, msg, req, conv
try: try:
for ans in chat(dia, msg, **req): for ans in chat(dia, msg, **req):
fillin_conv(ans) fillin_conv(ans)
yield "data:" + json.dumps({"code": 0, "data": ans}, ensure_ascii=False) + "\n\n" yield "data:" + json.dumps({"code": 0, "data": ans}, ensure_ascii=False) + "\n\n"
ConversationService.update_by_id(conv.id, conv.to_dict()) ConversationService.update_by_id(conv.id, conv.to_dict())
except Exception as e: except Exception as e:
yield "data:" + json.dumps({"code": 500, "message": str(e), yield "data:" + json.dumps({"code": 500, "message": str(e),
"data": {"answer": "**ERROR**: " + str(e), "reference": []}}, "data": {"answer": "**ERROR**: " + str(e),"reference": []}},
ensure_ascii=False) + "\n\n" ensure_ascii=False) + "\n\n"
yield "data:" + json.dumps({"code": 0, "data": True}, ensure_ascii=False) + "\n\n" yield "data:" + json.dumps({"code": 0, "data": True}, ensure_ascii=False) + "\n\n"
@ -237,7 +238,8 @@ def agent_completion(tenant_id, agent_id):
"dialog_id": cvs.id, "dialog_id": cvs.id,
"user_id": req.get("user_id", ""), "user_id": req.get("user_id", ""),
"message": [{"role": "assistant", "content": canvas.get_prologue()}], "message": [{"role": "assistant", "content": canvas.get_prologue()}],
"source": "agent" "source": "agent",
"dsl": json.loads(cvs.dsl)
} }
API4ConversationService.save(**conv) API4ConversationService.save(**conv)
conv = API4Conversation(**conv) conv = API4Conversation(**conv)
@ -246,6 +248,7 @@ def agent_completion(tenant_id, agent_id):
e, conv = API4ConversationService.get_by_id(req["session_id"]) e, conv = API4ConversationService.get_by_id(req["session_id"])
if not e: if not e:
return get_error_data_result(message="Session not found!") return get_error_data_result(message="Session not found!")
canvas = Canvas(json.dumps(conv.dsl), tenant_id)
messages = conv.message messages = conv.message
question = req.get("question") question = req.get("question")
@ -267,11 +270,11 @@ def agent_completion(tenant_id, agent_id):
if not msg[-1].get("id"): msg[-1]["id"] = get_uuid() if not msg[-1].get("id"): msg[-1]["id"] = get_uuid()
message_id = msg[-1]["id"] message_id = msg[-1]["id"]
if "quote" not in req: req["quote"] = False
stream = req.get("stream", True) stream = req.get("stream", True)
def fillin_conv(ans): def fillin_conv(ans):
reference = ans["reference"] reference = ans["reference"]
print(reference,flush=True)
temp_reference = deepcopy(ans["reference"]) temp_reference = deepcopy(ans["reference"])
nonlocal conv, message_id nonlocal conv, message_id
if not conv.reference: if not conv.reference:
@ -322,7 +325,7 @@ def agent_completion(tenant_id, agent_id):
def sse(): def sse():
nonlocal answer, cvs nonlocal answer, cvs
try: try:
for ans in canvas.run(stream=True): for ans in canvas.run(stream=stream):
if ans.get("running_status"): if ans.get("running_status"):
yield "data:" + json.dumps({"code": 0, "message": "", yield "data:" + json.dumps({"code": 0, "message": "",
"data": {"answer": ans["content"], "data": {"answer": ans["content"],
@ -341,10 +344,10 @@ def agent_completion(tenant_id, agent_id):
canvas.history.append(("assistant", final_ans["content"])) canvas.history.append(("assistant", final_ans["content"]))
if final_ans.get("reference"): if final_ans.get("reference"):
canvas.reference.append(final_ans["reference"]) canvas.reference.append(final_ans["reference"])
cvs.dsl = json.loads(str(canvas)) conv.dsl = json.loads(str(canvas))
API4ConversationService.append_message(conv.id, conv.to_dict()) API4ConversationService.append_message(conv.id, conv.to_dict())
except Exception as e: except Exception as e:
cvs.dsl = json.loads(str(canvas)) conv.dsl = json.loads(str(canvas))
API4ConversationService.append_message(conv.id, conv.to_dict()) API4ConversationService.append_message(conv.id, conv.to_dict())
yield "data:" + json.dumps({"code": 500, "message": str(e), yield "data:" + json.dumps({"code": 500, "message": str(e),
"data": {"answer": "**ERROR**: " + str(e), "reference": []}}, "data": {"answer": "**ERROR**: " + str(e), "reference": []}},
@ -364,7 +367,7 @@ def agent_completion(tenant_id, agent_id):
canvas.messages.append({"role": "assistant", "content": final_ans["content"], "id": message_id}) canvas.messages.append({"role": "assistant", "content": final_ans["content"], "id": message_id})
if final_ans.get("reference"): if final_ans.get("reference"):
canvas.reference.append(final_ans["reference"]) canvas.reference.append(final_ans["reference"])
cvs.dsl = json.loads(str(canvas)) conv.dsl = json.loads(str(canvas))
result = {"answer": final_ans["content"], "reference": final_ans.get("reference", [])} result = {"answer": final_ans["content"], "reference": final_ans.get("reference", [])}
fillin_conv(result) fillin_conv(result)
@ -372,10 +375,9 @@ def agent_completion(tenant_id, agent_id):
rename_field(result) rename_field(result)
return get_result(data=result) return get_result(data=result)
@manager.route('/chats/<chat_id>/sessions', methods=['GET']) @manager.route('/chats/<chat_id>/sessions', methods=['GET'])
@token_required @token_required
def list_session(chat_id, tenant_id): def list_session(chat_id,tenant_id):
if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value): if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value):
return get_error_data_result(message=f"You don't own the assistant {chat_id}.") return get_error_data_result(message=f"You don't own the assistant {chat_id}.")
id = request.args.get("id") id = request.args.get("id")
@ -387,7 +389,7 @@ def list_session(chat_id, tenant_id):
desc = False desc = False
else: else:
desc = True desc = True
convs = ConversationService.get_list(chat_id, page_number, items_per_page, orderby, desc, id, name) convs = ConversationService.get_list(chat_id,page_number,items_per_page,orderby,desc,id,name)
if not convs: if not convs:
return get_result(data=[]) return get_result(data=[])
for conv in convs: for conv in convs:
@ -429,7 +431,7 @@ def list_session(chat_id, tenant_id):
@manager.route('/chats/<chat_id>/sessions', methods=["DELETE"]) @manager.route('/chats/<chat_id>/sessions', methods=["DELETE"])
@token_required @token_required
def delete(tenant_id, chat_id): def delete(tenant_id,chat_id):
if not DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value): if not DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value):
return get_error_data_result(message="You don't own the chat") return get_error_data_result(message="You don't own the chat")
req = request.json req = request.json
@ -437,22 +439,21 @@ def delete(tenant_id, chat_id):
if not req: if not req:
ids = None ids = None
else: else:
ids = req.get("ids") ids=req.get("ids")
if not ids: if not ids:
conv_list = [] conv_list = []
for conv in convs: for conv in convs:
conv_list.append(conv.id) conv_list.append(conv.id)
else: else:
conv_list = ids conv_list=ids
for id in conv_list: for id in conv_list:
conv = ConversationService.query(id=id, dialog_id=chat_id) conv = ConversationService.query(id=id,dialog_id=chat_id)
if not conv: if not conv:
return get_error_data_result(message="The chat doesn't own the session") return get_error_data_result(message="The chat doesn't own the session")
ConversationService.delete_by_id(id) ConversationService.delete_by_id(id)
return get_result() return get_result()
@manager.route('/sessions/ask', methods=['POST']) @manager.route('/sessions/ask', methods=['POST'])
@token_required @token_required
def ask_about(tenant_id): def ask_about(tenant_id):
@ -461,18 +462,17 @@ def ask_about(tenant_id):
return get_error_data_result("`question` is required.") return get_error_data_result("`question` is required.")
if not req.get("dataset_ids"): if not req.get("dataset_ids"):
return get_error_data_result("`dataset_ids` is required.") return get_error_data_result("`dataset_ids` is required.")
if not isinstance(req.get("dataset_ids"), list): if not isinstance(req.get("dataset_ids"),list):
return get_error_data_result("`dataset_ids` should be a list.") return get_error_data_result("`dataset_ids` should be a list.")
req["kb_ids"] = req.pop("dataset_ids") req["kb_ids"]=req.pop("dataset_ids")
for kb_id in req["kb_ids"]: for kb_id in req["kb_ids"]:
if not KnowledgebaseService.accessible(kb_id, tenant_id): if not KnowledgebaseService.accessible(kb_id,tenant_id):
return get_error_data_result(f"You don't own the dataset {kb_id}.") return get_error_data_result(f"You don't own the dataset {kb_id}.")
kbs = KnowledgebaseService.query(id=kb_id) kbs = KnowledgebaseService.query(id=kb_id)
kb = kbs[0] kb = kbs[0]
if kb.chunk_num == 0: if kb.chunk_num == 0:
return get_error_data_result(f"The dataset {kb_id} doesn't own parsed file") return get_error_data_result(f"The dataset {kb_id} doesn't own parsed file")
uid = tenant_id uid = tenant_id
def stream(): def stream():
nonlocal req, uid nonlocal req, uid
try: try:

View File

@ -947,7 +947,7 @@ class API4Conversation(DataBaseModel):
reference = JSONField(null=True, default=[]) reference = JSONField(null=True, default=[])
tokens = IntegerField(default=0) tokens = IntegerField(default=0)
source = CharField(max_length=16, null=True, help_text="none|agent|dialog", index=True) source = CharField(max_length=16, null=True, help_text="none|agent|dialog", index=True)
dsl = JSONField(null=True, default={})
duration = FloatField(default=0, index=True) duration = FloatField(default=0, index=True)
round = IntegerField(default=0, index=True) round = IntegerField(default=0, index=True)
thumb_up = IntegerField(default=0, index=True) thumb_up = IntegerField(default=0, index=True)
@ -1070,3 +1070,10 @@ def migrate_db():
) )
except Exception: except Exception:
pass pass
try:
migrate(
migrator.add_column("api_4_conversation","dsl",JSONField(null=True, default={}))
)
except Exception:
pass