From 593ffc406747eaa8056dad8a468b8789f308b3e2 Mon Sep 17 00:00:00 2001 From: Kevin Hu Date: Thu, 5 Dec 2024 13:28:42 +0800 Subject: [PATCH] Fix HuggingFace model error. (#3870) ### What problem does this PR solve? #3865 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --- rag/llm/chat_model.py | 4 ++-- rag/llm/embedding_model.py | 3 ++- rag/llm/rerank_model.py | 2 +- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index c65078d2d..26ce3e148 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -22,7 +22,7 @@ from abc import ABC from openai import OpenAI import openai from ollama import Client -from rag.nlp import is_chinese +from rag.nlp import is_chinese, is_english from rag.utils import num_tokens_from_string from groq import Groq import os @@ -123,7 +123,7 @@ class HuggingFaceChat(Base): raise ValueError("Local llm url cannot be None") if base_url.split("/")[-1] != "v1": base_url = os.path.join(base_url, "v1") - super().__init__(key, model_name, base_url) + super().__init__(key, model_name.split("___")[0], base_url) class DeepSeekChat(Base): diff --git a/rag/llm/embedding_model.py b/rag/llm/embedding_model.py index bdebdc3a3..d4fdc0956 100644 --- a/rag/llm/embedding_model.py +++ b/rag/llm/embedding_model.py @@ -745,7 +745,7 @@ class HuggingFaceEmbed(Base): if not model_name: raise ValueError("Model name cannot be None") self.key = key - self.model_name = model_name + self.model_name = model_name.split("___")[0] self.base_url = base_url or "http://127.0.0.1:8080" def encode(self, texts: list): @@ -775,6 +775,7 @@ class HuggingFaceEmbed(Base): else: raise Exception(f"Error: {response.status_code} - {response.text}") + class VolcEngineEmbed(OpenAIEmbed): def __init__(self, key, model_name, base_url="https://ark.cn-beijing.volces.com/api/v3"): if not base_url: diff --git a/rag/llm/rerank_model.py b/rag/llm/rerank_model.py index 271f41ef1..925e08aa2 100644 --- a/rag/llm/rerank_model.py +++ b/rag/llm/rerank_model.py @@ -194,7 +194,7 @@ class LocalAIRerank(Base): "Content-Type": "application/json", "Authorization": f"Bearer {key}" } - self.model_name = model_name.replace("___LocalAI","") + self.model_name = model_name.split("___")[0] def similarity(self, query: str, texts: list): # noway to config Ragflow , use fix setting