From b35e811fe73aadd78bafb750c060b2a480601836 Mon Sep 17 00:00:00 2001 From: liuhua <10215101452@stu.ecnu.edu.cn> Date: Thu, 19 Dec 2024 17:24:26 +0800 Subject: [PATCH] Add parameters for ask_chat and fix bugs in list_sessions (#4119) ### What problem does this PR solve? Add parameters for ask_chat and fix bugs in list_sessions #4105 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) - [x] New Feature (non-breaking change which adds functionality) Co-authored-by: liuhua <10215101452@stu.ecun.edu.cn> --- api/apps/sdk/session.py | 43 +++++++++++------------ sdk/python/ragflow_sdk/modules/session.py | 10 +++--- 2 files changed, 27 insertions(+), 26 deletions(-) diff --git a/api/apps/sdk/session.py b/api/apps/sdk/session.py index e074462b9..0e50eeddb 100644 --- a/api/apps/sdk/session.py +++ b/api/apps/sdk/session.py @@ -65,20 +65,24 @@ def create(tenant_id, chat_id): @manager.route('/agents//sessions', methods=['POST']) # noqa: F821 @token_required def create_agent_session(tenant_id, agent_id): + req = request.json e, cvs = UserCanvasService.get_by_id(agent_id) if not e: return get_error_data_result("Agent not found.") + if not UserCanvasService.query(user_id=tenant_id,id=agent_id): + return get_error_data_result("You cannot access the agent.") + if not isinstance(cvs.dsl, str): cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False) canvas = Canvas(cvs.dsl, tenant_id) if canvas.get_preset_param(): - return get_error_data_result("The agent can't create a session directly") + return get_error_data_result("The agent cannot create a session directly") conv = { "id": get_uuid(), "dialog_id": cvs.id, - "user_id": tenant_id, + "user_id": req.get("usr_id","") if isinstance(req, dict) else "", "message": [{"role": "assistant", "content": canvas.get_prologue()}], "source": "agent", "dsl": json.loads(cvs.dsl) @@ -199,17 +203,15 @@ def list_session(tenant_id, chat_id): chunks = conv["reference"][chunk_num]["chunks"] for chunk in chunks: new_chunk = { - "id": chunk["chunk_id"], - "content": chunk["content_with_weight"], - "document_id": chunk["doc_id"], - "document_name": chunk["docnm_kwd"], - "dataset_id": chunk["kb_id"], - "image_id": chunk.get("image_id", ""), - "similarity": chunk["similarity"], - "vector_similarity": chunk["vector_similarity"], - "term_similarity": chunk["term_similarity"], - "positions": chunk["positions"], + "id": chunk.get("chunk_id", chunk.get("id")), + "content": chunk.get("content_with_weight", chunk.get("content")), + "document_id": chunk.get("doc_id", chunk.get("document_id")), + "document_name": chunk.get("docnm_kwd", chunk.get("document_name")), + "dataset_id": chunk.get("kb_id", chunk.get("dataset_id")), + "image_id": chunk.get("image_id", chunk.get("img_id")), + "positions": chunk.get("positions", chunk.get("position_int")), } + chunk_list.append(new_chunk) chunk_num += 1 messages[message_num]["reference"] = chunk_list @@ -254,16 +256,13 @@ def list_agent_session(tenant_id, agent_id): chunks = conv["reference"][chunk_num]["chunks"] for chunk in chunks: new_chunk = { - "id": chunk["chunk_id"], - "content": chunk["content"], - "document_id": chunk["doc_id"], - "document_name": chunk["docnm_kwd"], - "dataset_id": chunk["kb_id"], - "image_id": chunk.get("image_id", ""), - "similarity": chunk["similarity"], - "vector_similarity": chunk["vector_similarity"], - "term_similarity": chunk["term_similarity"], - "positions": chunk["positions"], + "id": chunk.get("chunk_id", chunk.get("id")), + "content": chunk.get("content_with_weight", chunk.get("content")), + "document_id": chunk.get("doc_id", chunk.get("document_id")), + "document_name": chunk.get("docnm_kwd", chunk.get("document_name")), + "dataset_id": chunk.get("kb_id", chunk.get("dataset_id")), + "image_id": chunk.get("image_id", chunk.get("img_id")), + "positions": chunk.get("positions", chunk.get("position_int")), } chunk_list.append(new_chunk) chunk_num += 1 diff --git a/sdk/python/ragflow_sdk/modules/session.py b/sdk/python/ragflow_sdk/modules/session.py index 19345bf51..15b1e5e3e 100644 --- a/sdk/python/ragflow_sdk/modules/session.py +++ b/sdk/python/ragflow_sdk/modules/session.py @@ -17,11 +17,11 @@ class Session(Base): self.__session_type = "agent" super().__init__(rag, res_dict) - def ask(self, question,stream=True): + def ask(self, question,stream=True,**kwargs): if self.__session_type == "agent": res=self._ask_agent(question,stream) elif self.__session_type == "chat": - res=self._ask_chat(question,stream) + res=self._ask_chat(question,stream,**kwargs) for line in res.iter_lines(): line = line.decode("utf-8") if line.startswith("{"): @@ -45,9 +45,11 @@ class Session(Base): yield message - def _ask_chat(self, question: str, stream: bool): + def _ask_chat(self, question: str, stream: bool,**kwargs): + json_data={"question": question, "stream": True,"session_id":self.id} + json_data.update(kwargs) res = self.post(f"/chats/{self.chat_id}/completions", - {"question": question, "stream": True,"session_id":self.id}, stream=stream) + json_data, stream=stream) return res def _ask_agent(self,question:str,stream:bool): res = self.post(f"/agents/{self.agent_id}/completions",