From 878dca26bb9ff2d435f8f0207399819208ffc0fa Mon Sep 17 00:00:00 2001 From: LiuHua <10215101452@stu.ecnu.edu.cn> Date: Thu, 5 Sep 2024 15:08:02 +0800 Subject: [PATCH] SDK for Assistant (#2266) ### What problem does this PR solve? SDK for Assistant #1102 ### Type of change - [x] New Feature (non-breaking change which adds functionality) Co-authored-by: Feiue <10215101452@stu.ecun.edu.cn> --- api/apps/sdk/assistant.py | 293 +++++++++++++++++++ sdk/python/ragflow/__init__.py | 3 +- sdk/python/ragflow/modules/chat_assistant.py | 56 ++++ sdk/python/ragflow/ragflow.py | 65 ++++ sdk/python/test/common.py | 2 +- sdk/python/test/t_assistant.py | 66 +++++ 6 files changed, 483 insertions(+), 2 deletions(-) create mode 100644 api/apps/sdk/assistant.py create mode 100644 sdk/python/ragflow/modules/chat_assistant.py create mode 100644 sdk/python/test/t_assistant.py diff --git a/api/apps/sdk/assistant.py b/api/apps/sdk/assistant.py new file mode 100644 index 000000000..c71b1e670 --- /dev/null +++ b/api/apps/sdk/assistant.py @@ -0,0 +1,293 @@ +# +# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from flask import request + +from api.db import StatusEnum +from api.db.services.dialog_service import DialogService +from api.db.services.document_service import DocumentService +from api.db.services.knowledgebase_service import KnowledgebaseService +from api.db.services.user_service import TenantService +from api.settings import RetCode +from api.utils import get_uuid +from api.utils.api_utils import get_data_error_result, token_required +from api.utils.api_utils import get_json_result + + +@manager.route('/save', methods=['POST']) +@token_required +def save(tenant_id): + req = request.json + id = req.get("id") + # dataset + if req.get("knowledgebases") == []: + return get_data_error_result(retmsg="knowledgebases can not be empty list") + kb_list = [] + if req.get("knowledgebases"): + for kb in req.get("knowledgebases"): + if not kb["id"]: + return get_data_error_result(retmsg="knowledgebase needs id") + if not KnowledgebaseService.query(id=kb["id"], tenant_id=tenant_id): + return get_data_error_result(retmsg="you do not own the knowledgebase") + if not DocumentService.query(kb_id=kb["id"]): + return get_data_error_result(retmsg="There is a invalid knowledgebase") + kb_list.append(kb["id"]) + req["kb_ids"] = kb_list + # llm + llm = req.get("llm") + if llm: + if "model_name" in llm: + req["llm_id"] = llm.pop("model_name") + req["llm_setting"] = req.pop("llm") + e, tenant = TenantService.get_by_id(tenant_id) + if not e: + return get_data_error_result(retmsg="Tenant not found!") + # prompt + prompt = req.get("prompt") + key_mapping = {"parameters": "variables", + "prologue": "opener", + "quote": "show_quote", + "system": "prompt", + "rerank_id": "rerank_model", + "vector_similarity_weight": "keywords_similarity_weight"} + key_list = ["similarity_threshold", "vector_similarity_weight", "top_n", "rerank_id"] + if prompt: + for new_key, old_key in key_mapping.items(): + if old_key in prompt: + prompt[new_key] = prompt.pop(old_key) + for key in key_list: + if key in prompt: + req[key] = prompt.pop(key) + req["prompt_config"] = req.pop("prompt") + # create + if not id: + # dataset + if not kb_list: + return get_data_error_result(retmsg="knowledgebase is required!") + # init + req["id"] = get_uuid() + req["description"] = req.get("description", "A helpful Assistant") + req["icon"] = req.get("avatar", "") + req["top_n"] = req.get("top_n", 6) + req["top_k"] = req.get("top_k", 1024) + req["rerank_id"] = req.get("rerank_id", "") + req["llm_id"] = req.get("llm_id", tenant.llm_id) + if not req.get("name"): + return get_data_error_result(retmsg="name is required.") + if DialogService.query(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value): + return get_data_error_result(retmsg="Duplicated assistant name in creating dataset.") + # tenant_id + if req.get("tenant_id"): + return get_data_error_result(retmsg="tenant_id must not be provided.") + req["tenant_id"] = tenant_id + # prompt more parameter + default_prompt = { + "system": """你是一个智能助手,请总结知识库的内容来回答问题,请列举知识库中的数据详细回答。当所有知识库内容都与问题无关时,你的回答必须包括“知识库中未找到您要的答案!”这句话。回答需要考虑聊天历史。 + 以下是知识库: + {knowledge} + 以上是知识库。""", + "prologue": "您好,我是您的助手小樱,长得可爱又善良,can I help you?", + "parameters": [ + {"key": "knowledge", "optional": False} + ], + "empty_response": "Sorry! 知识库中未找到相关内容!" + } + key_list_2 = ["system", "prologue", "parameters", "empty_response"] + if "prompt_config" not in req: + req['prompt_config'] = {} + for key in key_list_2: + temp = req['prompt_config'].get(key) + if not temp: + req['prompt_config'][key] = default_prompt[key] + for p in req['prompt_config']["parameters"]: + if p["optional"]: + continue + if req['prompt_config']["system"].find("{%s}" % p["key"]) < 0: + return get_data_error_result( + retmsg="Parameter '{}' is not used".format(p["key"])) + # save + if not DialogService.save(**req): + return get_data_error_result(retmsg="Fail to new an assistant!") + # response + e, res = DialogService.get_by_id(req["id"]) + if not e: + return get_data_error_result(retmsg="Fail to new an assistant!") + res = res.to_json() + renamed_dict = {} + for key, value in res["prompt_config"].items(): + new_key = key_mapping.get(key, key) + renamed_dict[new_key] = value + res["prompt"] = renamed_dict + del res["prompt_config"] + new_dict = {"similarity_threshold": res["similarity_threshold"], + "keywords_similarity_weight": res["vector_similarity_weight"], + "top_n": res["top_n"], + "rerank_model": res['rerank_id']} + res["prompt"].update(new_dict) + for key in key_list: + del res[key] + res["llm"] = res.pop("llm_setting") + res["llm"]["model_name"] = res.pop("llm_id") + del res["kb_ids"] + res["knowledgebases"] = req["knowledgebases"] + res["avatar"] = res.pop("icon") + return get_json_result(data=res) + else: + # authorization + if not DialogService.query(tenant_id=tenant_id, id=req["id"], status=StatusEnum.VALID.value): + return get_json_result(data=False, retmsg='You do not own the assistant', retcode=RetCode.OPERATING_ERROR) + # prompt + e, res = DialogService.get_by_id(req["id"]) + res = res.to_json() + if "name" in req: + if not req.get("name"): + return get_data_error_result(retmsg="name is not empty.") + if req["name"].lower() != res["name"].lower() \ + and len(DialogService.query(name=req["name"], tenant_id=tenant_id,status=StatusEnum.VALID.value)) > 0: + return get_data_error_result(retmsg="Duplicated knowledgebase name in updating dataset.") + if "prompt_config" in req: + res["prompt_config"].update(req["prompt_config"]) + for p in res["prompt_config"]["parameters"]: + if p["optional"]: + continue + if res["prompt_config"]["system"].find("{%s}" % p["key"]) < 0: + return get_data_error_result(retmsg="Parameter '{}' is not used".format(p["key"])) + if "llm_setting" in req: + res["llm_setting"].update(req["llm_setting"]) + req["prompt_config"] = res["prompt_config"] + req["llm_setting"] = res["llm_setting"] + # avatar + if "avatar" in req: + req["icon"] = req.pop("avatar") + assistant_id = req.pop("id") + if "knowledgebases" in req: + req.pop("knowledgebases") + if not DialogService.update_by_id(assistant_id, req): + return get_data_error_result(retmsg="Assistant not found!") + return get_json_result(data=True) + + +@manager.route('/delete', methods=['DELETE']) +@token_required +def delete(tenant_id): + req = request.args + if "id" not in req: + return get_data_error_result(retmsg="id is required") + id = req['id'] + if not DialogService.query(tenant_id=tenant_id, id=id,status=StatusEnum.VALID.value): + return get_json_result(data=False, retmsg='you do not own the assistant.', retcode=RetCode.OPERATING_ERROR) + + temp_dict = {"status": StatusEnum.INVALID.value} + DialogService.update_by_id(req["id"], temp_dict) + return get_json_result(data=True) + + +@manager.route('/get', methods=['GET']) +@token_required +def get(tenant_id): + req = request.args + if "id" in req: + id = req["id"] + ass = DialogService.query(tenant_id=tenant_id, id=id,status=StatusEnum.VALID.value) + if not ass: + return get_json_result(data=False, retmsg='You do not own the assistant.', retcode=RetCode.OPERATING_ERROR) + if "name" in req: + name = req["name"] + if ass[0].name != name: + return get_json_result(data=False, retmsg='name does not match id.', retcode=RetCode.OPERATING_ERROR) + res=ass[0].to_json() + else: + if "name" in req: + name = req["name"] + ass = DialogService.query(name=name, tenant_id=tenant_id,status=StatusEnum.VALID.value) + if not ass: + return get_json_result(data=False, retmsg='You do not own the dataset.',retcode=RetCode.OPERATING_ERROR) + res=ass[0].to_json() + else: + return get_data_error_result(retmsg="At least one of `id` or `name` must be provided.") + renamed_dict = {} + key_mapping = {"parameters": "variables", + "prologue": "opener", + "quote": "show_quote", + "system": "prompt", + "rerank_id": "rerank_model", + "vector_similarity_weight": "keywords_similarity_weight"} + key_list = ["similarity_threshold", "vector_similarity_weight", "top_n", "rerank_id"] + for key, value in res["prompt_config"].items(): + new_key = key_mapping.get(key, key) + renamed_dict[new_key] = value + res["prompt"] = renamed_dict + del res["prompt_config"] + new_dict = {"similarity_threshold": res["similarity_threshold"], + "keywords_similarity_weight": res["vector_similarity_weight"], + "top_n": res["top_n"], + "rerank_model": res['rerank_id']} + res["prompt"].update(new_dict) + for key in key_list: + del res[key] + res["llm"] = res.pop("llm_setting") + res["llm"]["model_name"] = res.pop("llm_id") + kb_list = [] + for kb_id in res["kb_ids"]: + kb = KnowledgebaseService.query(id=kb_id) + kb_list.append(kb[0].to_json()) + del res["kb_ids"] + res["knowledgebases"] = kb_list + res["avatar"] = res.pop("icon") + return get_json_result(data=res) + + +@manager.route('/list', methods=['GET']) +@token_required +def list_assistants(tenant_id): + assts = DialogService.query( + tenant_id=tenant_id, + status=StatusEnum.VALID.value, + reverse=True, + order_by=DialogService.model.create_time) + assts = [d.to_dict() for d in assts] + list_assts=[] + renamed_dict = {} + key_mapping = {"parameters": "variables", + "prologue": "opener", + "quote": "show_quote", + "system": "prompt", + "rerank_id": "rerank_model", + "vector_similarity_weight": "keywords_similarity_weight"} + key_list = ["similarity_threshold", "vector_similarity_weight", "top_n", "rerank_id"] + for res in assts: + for key, value in res["prompt_config"].items(): + new_key = key_mapping.get(key, key) + renamed_dict[new_key] = value + res["prompt"] = renamed_dict + del res["prompt_config"] + new_dict = {"similarity_threshold": res["similarity_threshold"], + "keywords_similarity_weight": res["vector_similarity_weight"], + "top_n": res["top_n"], + "rerank_model": res['rerank_id']} + res["prompt"].update(new_dict) + for key in key_list: + del res[key] + res["llm"] = res.pop("llm_setting") + res["llm"]["model_name"] = res.pop("llm_id") + kb_list = [] + for kb_id in res["kb_ids"]: + kb = KnowledgebaseService.query(id=kb_id) + kb_list.append(kb[0].to_json()) + del res["kb_ids"] + res["knowledgebases"] = kb_list + res["avatar"] = res.pop("icon") + list_assts.append(res) + return get_json_result(data=list_assts) diff --git a/sdk/python/ragflow/__init__.py b/sdk/python/ragflow/__init__.py index fbdb1bcea..fd76c7559 100644 --- a/sdk/python/ragflow/__init__.py +++ b/sdk/python/ragflow/__init__.py @@ -3,4 +3,5 @@ import importlib.metadata __version__ = importlib.metadata.version("ragflow") from .ragflow import RAGFlow -from .modules.dataset import DataSet \ No newline at end of file +from .modules.dataset import DataSet +from .modules.chat_assistant import Assistant \ No newline at end of file diff --git a/sdk/python/ragflow/modules/chat_assistant.py b/sdk/python/ragflow/modules/chat_assistant.py new file mode 100644 index 000000000..d5ec05bdf --- /dev/null +++ b/sdk/python/ragflow/modules/chat_assistant.py @@ -0,0 +1,56 @@ +from .base import Base + + +class Assistant(Base): + def __init__(self, rag, res_dict): + self.id="" + self.name = "assistant" + self.avatar = "path/to/avatar" + self.knowledgebases = ["kb1"] + self.llm = Assistant.LLM(rag, {}) + self.prompt = Assistant.Prompt(rag, {}) + super().__init__(rag, res_dict) + + class LLM(Base): + def __init__(self, rag, res_dict): + self.model_name = "deepseek-chat" + self.temperature = 0.1 + self.top_p = 0.3 + self.presence_penalty = 0.4 + self.frequency_penalty = 0.7 + self.max_tokens = 512 + super().__init__(rag, res_dict) + + class Prompt(Base): + def __init__(self, rag, res_dict): + self.similarity_threshold = 0.2 + self.keywords_similarity_weight = 0.7 + self.top_n = 8 + self.variables = [{"key": "knowledge", "optional": True}] + self.rerank_model = None + self.empty_response = None + self.opener = "Hi! I'm your assistant, what can I do for you?" + self.show_quote = True + self.prompt = ( + "You are an intelligent assistant. Please summarize the content of the knowledge base to answer the question. " + "Please list the data in the knowledge base and answer in detail. When all knowledge base content is irrelevant to the question, " + "your answer must include the sentence 'The answer you are looking for is not found in the knowledge base!' " + "Answers need to consider chat history.\nHere is the knowledge base:\n{knowledge}\nThe above is the knowledge base." + ) + super().__init__(rag, res_dict) + + def save(self) -> bool: + res = self.post('/assistant/save', + {"id": self.id, "name": self.name, "avatar": self.avatar, "knowledgebases":self.knowledgebases, + "llm":self.llm.to_json(),"prompt":self.prompt.to_json() + }) + res = res.json() + if res.get("retmsg") == "success": return True + raise Exception(res["retmsg"]) + + def delete(self) -> bool: + res = self.rm('/assistant/delete', + {"id": self.id}) + res = res.json() + if res.get("retmsg") == "success": return True + raise Exception(res["retmsg"]) diff --git a/sdk/python/ragflow/ragflow.py b/sdk/python/ragflow/ragflow.py index 7a114861b..68713f03a 100644 --- a/sdk/python/ragflow/ragflow.py +++ b/sdk/python/ragflow/ragflow.py @@ -17,6 +17,8 @@ from typing import List import requests + +from .modules.chat_assistant import Assistant from .modules.dataset import DataSet @@ -78,3 +80,66 @@ class RAGFlow: if res.get("retmsg") == "success": return DataSet(self, res['data']) raise Exception(res["retmsg"]) + + def create_assistant(self, name: str = "assistant", avatar: str = "path", knowledgebases: List[DataSet] = [], + llm: Assistant.LLM = None, prompt: Assistant.Prompt = None) -> Assistant: + datasets = [] + for dataset in knowledgebases: + datasets.append(dataset.to_json()) + + if llm is None: + llm = Assistant.LLM(self, {"model_name": "deepseek-chat", + "temperature": 0.1, + "top_p": 0.3, + "presence_penalty": 0.4, + "frequency_penalty": 0.7, + "max_tokens": 512, }) + if prompt is None: + prompt = Assistant.Prompt(self, {"similarity_threshold": 0.2, + "keywords_similarity_weight": 0.7, + "top_n": 8, + "variables": [{ + "key": "knowledge", + "optional": True + }], "rerank_model": "", + "empty_response": None, + "opener": None, + "show_quote": True, + "prompt": None}) + if prompt.opener is None: + prompt.opener = "Hi! I'm your assistant, what can I do for you?" + if prompt.prompt is None: + prompt.prompt = ( + "You are an intelligent assistant. Please summarize the content of the knowledge base to answer the question. " + "Please list the data in the knowledge base and answer in detail. When all knowledge base content is irrelevant to the question, " + "your answer must include the sentence 'The answer you are looking for is not found in the knowledge base!' " + "Answers need to consider chat history.\nHere is the knowledge base:\n{knowledge}\nThe above is the knowledge base." + ) + + temp_dict = {"name": name, + "avatar": avatar, + "knowledgebases": datasets, + "llm": llm.to_json(), + "prompt": prompt.to_json()} + res = self.post("/assistant/save", temp_dict) + res = res.json() + if res.get("retmsg") == "success": + return Assistant(self, res["data"]) + raise Exception(res["retmsg"]) + + def get_assistant(self, id: str = None, name: str = None) -> Assistant: + res = self.get("/assistant/get", {"id": id, "name": name}) + res = res.json() + if res.get("retmsg") == "success": + return Assistant(self, res['data']) + raise Exception(res["retmsg"]) + + def list_assistants(self) -> List[Assistant]: + res = self.get("/assistant/list") + res = res.json() + result_list = [] + if res.get("retmsg") == "success": + for data in res['data']: + result_list.append(Assistant(self, data)) + return result_list + raise Exception(res["retmsg"]) \ No newline at end of file diff --git a/sdk/python/test/common.py b/sdk/python/test/common.py index 5feca4777..c92e34dec 100644 --- a/sdk/python/test/common.py +++ b/sdk/python/test/common.py @@ -1,4 +1,4 @@ -API_KEY = 'ragflow-k0N2I1MzQwNjNhMzExZWY5ODg1MDI0Mm' +API_KEY = 'ragflow-k0YzUxMGY4NjY5YTExZWY5MjI5MDI0Mm' HOST_ADDRESS = 'http://127.0.0.1:9380' \ No newline at end of file diff --git a/sdk/python/test/t_assistant.py b/sdk/python/test/t_assistant.py new file mode 100644 index 000000000..7d70a337b --- /dev/null +++ b/sdk/python/test/t_assistant.py @@ -0,0 +1,66 @@ +from ragflow import RAGFlow, Assistant + +from common import API_KEY, HOST_ADDRESS +from test_sdkbase import TestSdk + + +class TestAssistant(TestSdk): + def test_create_assistant_with_success(self): + """ + Test creating an assistant with success + """ + rag = RAGFlow(API_KEY, HOST_ADDRESS) + kb = rag.get_dataset(name="God") + assistant = rag.create_assistant("God",knowledgebases=[kb]) + if isinstance(assistant, Assistant): + assert assistant.name == "God", "Name does not match." + else: + assert False, f"Failed to create assistant, error: {assistant}" + + def test_update_assistant_with_success(self): + """ + Test updating an assistant with success. + """ + rag = RAGFlow(API_KEY, HOST_ADDRESS) + kb = rag.get_dataset(name="God") + assistant = rag.create_assistant("ABC",knowledgebases=[kb]) + if isinstance(assistant, Assistant): + assert assistant.name == "ABC", "Name does not match." + assistant.name = 'DEF' + res = assistant.save() + assert res is True, f"Failed to update assistant, error: {res}" + else: + assert False, f"Failed to create assistant, error: {assistant}" + + def test_delete_assistant_with_success(self): + """ + Test deleting an assistant with success + """ + rag = RAGFlow(API_KEY, HOST_ADDRESS) + kb = rag.get_dataset(name="God") + assistant = rag.create_assistant("MA",knowledgebases=[kb]) + if isinstance(assistant, Assistant): + assert assistant.name == "MA", "Name does not match." + res = assistant.delete() + assert res is True, f"Failed to delete assistant, error: {res}" + else: + assert False, f"Failed to create assistant, error: {assistant}" + + def test_list_assistants_with_success(self): + """ + Test listing assistants with success + """ + rag = RAGFlow(API_KEY, HOST_ADDRESS) + list_assistants = rag.list_assistants() + assert len(list_assistants) > 0, "Do not exist any assistant" + for assistant in list_assistants: + assert isinstance(assistant, Assistant), "Existence type is not assistant." + + def test_get_detail_assistant_with_success(self): + """ + Test getting an assistant's detail with success + """ + rag = RAGFlow(API_KEY, HOST_ADDRESS) + assistant = rag.get_assistant(name="God") + assert isinstance(assistant, Assistant), f"Failed to get assistant, error: {assistant}." + assert assistant.name == "God", "Name does not match"