mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-08-13 14:39:02 +08:00
Fix the bug that the agent could not find the context (#3795)
### What problem does this PR solve? Fix the bug that the agent could not find the context #3682 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) Co-authored-by: liuhua <10215101452@stu.ecun.edu.cn>
This commit is contained in:
parent
8b650fc9ef
commit
9654e64a0a
@ -79,7 +79,8 @@ def create_agent_session(tenant_id, agent_id):
|
|||||||
"dialog_id": cvs.id,
|
"dialog_id": cvs.id,
|
||||||
"user_id": req.get("usr_id","") if isinstance(req, dict) else "",
|
"user_id": req.get("usr_id","") if isinstance(req, dict) else "",
|
||||||
"message": [{"role": "assistant", "content": canvas.get_prologue()}],
|
"message": [{"role": "assistant", "content": canvas.get_prologue()}],
|
||||||
"source": "agent"
|
"source": "agent",
|
||||||
|
"dsl":json.loads(cvs.dsl)
|
||||||
}
|
}
|
||||||
API4ConversationService.save(**conv)
|
API4ConversationService.save(**conv)
|
||||||
conv["agent_id"] = conv.pop("dialog_id")
|
conv["agent_id"] = conv.pop("dialog_id")
|
||||||
@ -237,7 +238,8 @@ def agent_completion(tenant_id, agent_id):
|
|||||||
"dialog_id": cvs.id,
|
"dialog_id": cvs.id,
|
||||||
"user_id": req.get("user_id", ""),
|
"user_id": req.get("user_id", ""),
|
||||||
"message": [{"role": "assistant", "content": canvas.get_prologue()}],
|
"message": [{"role": "assistant", "content": canvas.get_prologue()}],
|
||||||
"source": "agent"
|
"source": "agent",
|
||||||
|
"dsl": json.loads(cvs.dsl)
|
||||||
}
|
}
|
||||||
API4ConversationService.save(**conv)
|
API4ConversationService.save(**conv)
|
||||||
conv = API4Conversation(**conv)
|
conv = API4Conversation(**conv)
|
||||||
@ -246,6 +248,7 @@ def agent_completion(tenant_id, agent_id):
|
|||||||
e, conv = API4ConversationService.get_by_id(req["session_id"])
|
e, conv = API4ConversationService.get_by_id(req["session_id"])
|
||||||
if not e:
|
if not e:
|
||||||
return get_error_data_result(message="Session not found!")
|
return get_error_data_result(message="Session not found!")
|
||||||
|
canvas = Canvas(json.dumps(conv.dsl), tenant_id)
|
||||||
|
|
||||||
messages = conv.message
|
messages = conv.message
|
||||||
question = req.get("question")
|
question = req.get("question")
|
||||||
@ -267,11 +270,11 @@ def agent_completion(tenant_id, agent_id):
|
|||||||
if not msg[-1].get("id"): msg[-1]["id"] = get_uuid()
|
if not msg[-1].get("id"): msg[-1]["id"] = get_uuid()
|
||||||
message_id = msg[-1]["id"]
|
message_id = msg[-1]["id"]
|
||||||
|
|
||||||
if "quote" not in req: req["quote"] = False
|
|
||||||
stream = req.get("stream", True)
|
stream = req.get("stream", True)
|
||||||
|
|
||||||
def fillin_conv(ans):
|
def fillin_conv(ans):
|
||||||
reference = ans["reference"]
|
reference = ans["reference"]
|
||||||
|
print(reference,flush=True)
|
||||||
temp_reference = deepcopy(ans["reference"])
|
temp_reference = deepcopy(ans["reference"])
|
||||||
nonlocal conv, message_id
|
nonlocal conv, message_id
|
||||||
if not conv.reference:
|
if not conv.reference:
|
||||||
@ -322,7 +325,7 @@ def agent_completion(tenant_id, agent_id):
|
|||||||
def sse():
|
def sse():
|
||||||
nonlocal answer, cvs
|
nonlocal answer, cvs
|
||||||
try:
|
try:
|
||||||
for ans in canvas.run(stream=True):
|
for ans in canvas.run(stream=stream):
|
||||||
if ans.get("running_status"):
|
if ans.get("running_status"):
|
||||||
yield "data:" + json.dumps({"code": 0, "message": "",
|
yield "data:" + json.dumps({"code": 0, "message": "",
|
||||||
"data": {"answer": ans["content"],
|
"data": {"answer": ans["content"],
|
||||||
@ -341,10 +344,10 @@ def agent_completion(tenant_id, agent_id):
|
|||||||
canvas.history.append(("assistant", final_ans["content"]))
|
canvas.history.append(("assistant", final_ans["content"]))
|
||||||
if final_ans.get("reference"):
|
if final_ans.get("reference"):
|
||||||
canvas.reference.append(final_ans["reference"])
|
canvas.reference.append(final_ans["reference"])
|
||||||
cvs.dsl = json.loads(str(canvas))
|
conv.dsl = json.loads(str(canvas))
|
||||||
API4ConversationService.append_message(conv.id, conv.to_dict())
|
API4ConversationService.append_message(conv.id, conv.to_dict())
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
cvs.dsl = json.loads(str(canvas))
|
conv.dsl = json.loads(str(canvas))
|
||||||
API4ConversationService.append_message(conv.id, conv.to_dict())
|
API4ConversationService.append_message(conv.id, conv.to_dict())
|
||||||
yield "data:" + json.dumps({"code": 500, "message": str(e),
|
yield "data:" + json.dumps({"code": 500, "message": str(e),
|
||||||
"data": {"answer": "**ERROR**: " + str(e), "reference": []}},
|
"data": {"answer": "**ERROR**: " + str(e), "reference": []}},
|
||||||
@ -364,7 +367,7 @@ def agent_completion(tenant_id, agent_id):
|
|||||||
canvas.messages.append({"role": "assistant", "content": final_ans["content"], "id": message_id})
|
canvas.messages.append({"role": "assistant", "content": final_ans["content"], "id": message_id})
|
||||||
if final_ans.get("reference"):
|
if final_ans.get("reference"):
|
||||||
canvas.reference.append(final_ans["reference"])
|
canvas.reference.append(final_ans["reference"])
|
||||||
cvs.dsl = json.loads(str(canvas))
|
conv.dsl = json.loads(str(canvas))
|
||||||
|
|
||||||
result = {"answer": final_ans["content"], "reference": final_ans.get("reference", [])}
|
result = {"answer": final_ans["content"], "reference": final_ans.get("reference", [])}
|
||||||
fillin_conv(result)
|
fillin_conv(result)
|
||||||
@ -372,7 +375,6 @@ def agent_completion(tenant_id, agent_id):
|
|||||||
rename_field(result)
|
rename_field(result)
|
||||||
return get_result(data=result)
|
return get_result(data=result)
|
||||||
|
|
||||||
|
|
||||||
@manager.route('/chats/<chat_id>/sessions', methods=['GET'])
|
@manager.route('/chats/<chat_id>/sessions', methods=['GET'])
|
||||||
@token_required
|
@token_required
|
||||||
def list_session(chat_id,tenant_id):
|
def list_session(chat_id,tenant_id):
|
||||||
@ -452,7 +454,6 @@ def delete(tenant_id, chat_id):
|
|||||||
ConversationService.delete_by_id(id)
|
ConversationService.delete_by_id(id)
|
||||||
return get_result()
|
return get_result()
|
||||||
|
|
||||||
|
|
||||||
@manager.route('/sessions/ask', methods=['POST'])
|
@manager.route('/sessions/ask', methods=['POST'])
|
||||||
@token_required
|
@token_required
|
||||||
def ask_about(tenant_id):
|
def ask_about(tenant_id):
|
||||||
@ -472,7 +473,6 @@ def ask_about(tenant_id):
|
|||||||
if kb.chunk_num == 0:
|
if kb.chunk_num == 0:
|
||||||
return get_error_data_result(f"The dataset {kb_id} doesn't own parsed file")
|
return get_error_data_result(f"The dataset {kb_id} doesn't own parsed file")
|
||||||
uid = tenant_id
|
uid = tenant_id
|
||||||
|
|
||||||
def stream():
|
def stream():
|
||||||
nonlocal req, uid
|
nonlocal req, uid
|
||||||
try:
|
try:
|
||||||
|
@ -947,7 +947,7 @@ class API4Conversation(DataBaseModel):
|
|||||||
reference = JSONField(null=True, default=[])
|
reference = JSONField(null=True, default=[])
|
||||||
tokens = IntegerField(default=0)
|
tokens = IntegerField(default=0)
|
||||||
source = CharField(max_length=16, null=True, help_text="none|agent|dialog", index=True)
|
source = CharField(max_length=16, null=True, help_text="none|agent|dialog", index=True)
|
||||||
|
dsl = JSONField(null=True, default={})
|
||||||
duration = FloatField(default=0, index=True)
|
duration = FloatField(default=0, index=True)
|
||||||
round = IntegerField(default=0, index=True)
|
round = IntegerField(default=0, index=True)
|
||||||
thumb_up = IntegerField(default=0, index=True)
|
thumb_up = IntegerField(default=0, index=True)
|
||||||
@ -1070,3 +1070,10 @@ def migrate_db():
|
|||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
try:
|
||||||
|
migrate(
|
||||||
|
migrator.add_column("api_4_conversation","dsl",JSONField(null=True, default={}))
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user