From d83911b632aee4db1a1314ba73f19cd2ca8e3a16 Mon Sep 17 00:00:00 2001 From: Kevin Hu Date: Fri, 21 Mar 2025 12:43:32 +0800 Subject: [PATCH] Fix: huggingface rerank model issue. (#6385) ### What problem does this PR solve? #6348 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --- rag/llm/rerank_model.py | 2 +- rag/svr/task_executor.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/rag/llm/rerank_model.py b/rag/llm/rerank_model.py index 474b39a81..372ba9e71 100644 --- a/rag/llm/rerank_model.py +++ b/rag/llm/rerank_model.py @@ -537,7 +537,7 @@ class HuggingfaceRerank(DefaultRerank): return np.array(scores) def __init__(self, key, model_name="BAAI/bge-reranker-v2-m3", base_url="http://127.0.0.1"): - self.model_name = model_name + self.model_name = model_name.split("___")[0] self.base_url = base_url def similarity(self, query: str, texts: list) -> tuple[np.ndarray, int]: diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index e8ccb5be9..4d0e860be 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -57,7 +57,7 @@ from rag.app import laws, paper, presentation, manual, qa, table, book, resume, from rag.nlp import search, rag_tokenizer from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor from rag.settings import DOC_MAXIMUM_SIZE, SVR_CONSUMER_GROUP_NAME, get_svr_queue_name, get_svr_queue_names, print_rag_settings, TAG_FLD, PAGERANK_FLD -from rag.utils import num_tokens_from_string +from rag.utils import num_tokens_from_string, truncate from rag.utils.redis_conn import REDIS_CONN from rag.utils.storage_factory import STORAGE_IMPL from graphrag.utils import chat_limiter @@ -404,7 +404,7 @@ async def embedding(docs, mdl, parser_config=None, callback=None): cnts_ = np.array([]) for i in range(0, len(cnts), batch_size): - vts, c = await trio.to_thread.run_sync(lambda: mdl.encode(cnts[i: i + batch_size])) + vts, c = await trio.to_thread.run_sync(lambda: mdl.encode([truncate(c, mdl.max_length-10) for c in cnts[i: i + batch_size]])) if len(cnts_) == 0: cnts_ = vts else: