diff --git a/api/db/init_data.py b/api/db/init_data.py index c649e0e7b..ad3c16d98 100644 --- a/api/db/init_data.py +++ b/api/db/init_data.py @@ -109,7 +109,7 @@ factory_infos = [{ "name": "Ollama", "logo": "", "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION", - "status": "1", + "status": "1", }, { "name": "Moonshot", "logo": "", @@ -123,8 +123,8 @@ factory_infos = [{ }, { "name": "Xinference", "logo": "", - "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION", - "status": "1", + "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION,TEXT RE-RANK", + "status": "1", },{ "name": "Youdao", "logo": "", diff --git a/rag/llm/__init__.py b/rag/llm/__init__.py index 513c1c44e..b07d98064 100644 --- a/rag/llm/__init__.py +++ b/rag/llm/__init__.py @@ -68,4 +68,5 @@ RerankModel = { "BAAI": DefaultRerank, "Jina": JinaRerank, "Youdao": YoudaoRerank, + "Xinference": XInferenceRerank } diff --git a/rag/llm/rerank_model.py b/rag/llm/rerank_model.py index c56a3ccea..fabf11ec5 100644 --- a/rag/llm/rerank_model.py +++ b/rag/llm/rerank_model.py @@ -136,4 +136,22 @@ class YoudaoRerank(DefaultRerank): else: res.extend(scores) return np.array(res), token_count +class XInferenceRerank(Base): + def __init__(self,model_name="",base_url=""): + self.model_name=model_name + self.base_url=base_url + self.headers = { + "Content-Type": "application/json", + "accept": "application/json" + } + def similarity(self, query: str, texts: list): + data = { + "model":self.model_name, + "query":query, + "return_documents": "true", + "return_len": "true", + "documents":texts + } + res = requests.post(self.base_url, headers=self.headers, json=data).json() + return np.array([d["relevance_score"] for d in res["results"]]),res["tokens"]["input_tokens"]+res["tokens"]["output_tokens"] diff --git a/web/src/pages/user-setting/setting-model/ollama-modal/index.tsx b/web/src/pages/user-setting/setting-model/ollama-modal/index.tsx index cded758e6..fe3d29967 100644 --- a/web/src/pages/user-setting/setting-model/ollama-modal/index.tsx +++ b/web/src/pages/user-setting/setting-model/ollama-modal/index.tsx @@ -74,6 +74,7 @@ const OllamaModal = ({