From 4991107822144cbea3f5c4896d91a5afb0e7b22f Mon Sep 17 00:00:00 2001 From: 0000sir <0000sir@users.noreply.github.com> Date: Wed, 16 Oct 2024 10:21:08 +0800 Subject: [PATCH] Fix keys of Xinference deployed models, especially has the same model name with public hosted models. (#2832) ### What problem does this PR solve? Fix keys of Xinference deployed models, especially has the same model name with public hosted models. ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --------- Co-authored-by: 0000sir <0000sir@gmail.com> Co-authored-by: Kevin Hu --- api/apps/llm_app.py | 4 ++-- api/apps/sdk/doc.py | 19 +++++++++---------- rag/llm/cv_model.py | 2 +- rag/llm/embedding_model.py | 2 +- rag/llm/rerank_model.py | 3 ++- rag/llm/sequence2txt_model.py | 1 + sdk/python/ragflow/ragflow.py | 6 ++++++ 7 files changed, 22 insertions(+), 15 deletions(-) diff --git a/api/apps/llm_app.py b/api/apps/llm_app.py index 432c21ff6..720978963 100644 --- a/api/apps/llm_app.py +++ b/api/apps/llm_app.py @@ -343,10 +343,10 @@ def list_app(): for m in llms: m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding" or m["fid"] in self_deploied - llm_set = set([m["llm_name"] for m in llms]) + llm_set = set([m["llm_name"]+"@"+m["fid"] for m in llms]) for o in objs: if not o.api_key:continue - if o.llm_name in llm_set:continue + if o.llm_name+"@"+o.llm_factory in llm_set:continue llms.append({"llm_name": o.llm_name, "model_type": o.model_type, "fid": o.llm_factory, "available": True}) res = {} diff --git a/api/apps/sdk/doc.py b/api/apps/sdk/doc.py index 81861a70c..502697e14 100644 --- a/api/apps/sdk/doc.py +++ b/api/apps/sdk/doc.py @@ -494,25 +494,24 @@ def set(tenant_id,dataset_id,document_id,chunk_id): -@manager.route('/retrieval', methods=['GET']) +@manager.route('/retrieval', methods=['POST']) @token_required def retrieval_test(tenant_id): - req = request.args - req_json = request.json - if not req_json.get("datasets"): + req = request.json + if not req.get("datasets"): return get_error_data_result("`datasets` is required.") - for id in req_json.get("datasets"): + kb_id = req["datasets"] + if isinstance(kb_id, str): kb_id = [kb_id] + for id in kb_id: if not KnowledgebaseService.query(id=id,tenant_id=tenant_id): return get_error_data_result(f"You don't own the dataset {id}.") if "question" not in req_json: return get_error_data_result("`question` is required.") page = int(req.get("offset", 1)) size = int(req.get("limit", 30)) - question = req_json["question"] - kb_id = req_json["datasets"] - if isinstance(kb_id, str): kb_id = [kb_id] - doc_ids = req_json.get("documents", []) - similarity_threshold = float(req.get("similarity_threshold", 0.0)) + question = req["question"] + doc_ids = req.get("documents", []) + similarity_threshold = float(req.get("similarity_threshold", 0.2)) vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3)) top = int(req.get("top_k", 1024)) if req.get("highlight")=="False" or req.get("highlight")=="false": diff --git a/rag/llm/cv_model.py b/rag/llm/cv_model.py index 97e02911f..70e9f24ea 100644 --- a/rag/llm/cv_model.py +++ b/rag/llm/cv_model.py @@ -453,7 +453,7 @@ class XinferenceCV(Base): def __init__(self, key, model_name="", lang="Chinese", base_url=""): if base_url.split("/")[-1] != "v1": base_url = os.path.join(base_url, "v1") - self.client = OpenAI(api_key="xxx", base_url=base_url) + self.client = OpenAI(api_key=key, base_url=base_url) self.model_name = model_name self.lang = lang diff --git a/rag/llm/embedding_model.py b/rag/llm/embedding_model.py index c7af5c506..ea994746a 100644 --- a/rag/llm/embedding_model.py +++ b/rag/llm/embedding_model.py @@ -274,7 +274,7 @@ class XinferenceEmbed(Base): def __init__(self, key, model_name="", base_url=""): if base_url.split("/")[-1] != "v1": base_url = os.path.join(base_url, "v1") - self.client = OpenAI(api_key="xxx", base_url=base_url) + self.client = OpenAI(api_key=key, base_url=base_url) self.model_name = model_name def encode(self, texts: list, batch_size=32): diff --git a/rag/llm/rerank_model.py b/rag/llm/rerank_model.py index 67e4fd4a9..6b9cbae7b 100644 --- a/rag/llm/rerank_model.py +++ b/rag/llm/rerank_model.py @@ -162,7 +162,8 @@ class XInferenceRerank(Base): self.base_url = base_url self.headers = { "Content-Type": "application/json", - "accept": "application/json" + "accept": "application/json", + "Authorization": f"Bearer {key}" } def similarity(self, query: str, texts: list): diff --git a/rag/llm/sequence2txt_model.py b/rag/llm/sequence2txt_model.py index e2f76e16d..a2d3ea0ef 100644 --- a/rag/llm/sequence2txt_model.py +++ b/rag/llm/sequence2txt_model.py @@ -90,6 +90,7 @@ class XinferenceSeq2txt(Base): def __init__(self,key,model_name="whisper-small",**kwargs): self.base_url = kwargs.get('base_url', None) self.model_name = model_name + self.key = key def transcription(self, audio, language="zh", prompt=None, response_format="json", temperature=0.7): if isinstance(audio, str): diff --git a/sdk/python/ragflow/ragflow.py b/sdk/python/ragflow/ragflow.py index 962914396..c50c61929 100644 --- a/sdk/python/ragflow/ragflow.py +++ b/sdk/python/ragflow/ragflow.py @@ -74,6 +74,12 @@ class RAGFlow: if res.get("code") != 0: raise Exception(res["message"]) + def get_dataset(self,name: str): + _list = self.list_datasets(name=name) + if len(_list) > 0: + return _list[0] + raise Exception("Dataset %s not found" % name) + def list_datasets(self, page: int = 1, page_size: int = 1024, orderby: str = "create_time", desc: bool = True, id: str = None, name: str = None) -> \ List[DataSet]: