diff --git a/conf/llm_factories.json b/conf/llm_factories.json index 39d780304..27e66ac6c 100644 --- a/conf/llm_factories.json +++ b/conf/llm_factories.json @@ -3284,7 +3284,7 @@ { "name": "HuggingFace", "logo": "", - "tags": "TEXT EMBEDDING", + "tags": "TEXT EMBEDDING,TEXT RE-RANK", "status": "1", "llm": [] }, diff --git a/rag/llm/__init__.py b/rag/llm/__init__.py index bcea2e558..5db8970b5 100644 --- a/rag/llm/__init__.py +++ b/rag/llm/__init__.py @@ -107,6 +107,7 @@ from .cv_model import ( YiCV, HunyuanCV, ) + from .rerank_model import ( LocalAIRerank, DefaultRerank, @@ -123,7 +124,9 @@ from .rerank_model import ( VoyageRerank, QWenRerank, GPUStackRerank, + HuggingfaceRerank, ) + from .sequence2txt_model import ( GPTSeq2txt, QWenSeq2txt, @@ -132,6 +135,7 @@ from .sequence2txt_model import ( TencentCloudSeq2txt, GPUStackSeq2txt, ) + from .tts_model import ( FishAudioTTS, QwenTTS, @@ -255,6 +259,7 @@ RerankModel = { "Voyage AI": VoyageRerank, "Tongyi-Qianwen": QWenRerank, "GPUStack": GPUStackRerank, + "HuggingFace": HuggingfaceRerank, } Seq2txtModel = { diff --git a/rag/llm/rerank_model.py b/rag/llm/rerank_model.py index 443075153..1ea2faedf 100644 --- a/rag/llm/rerank_model.py +++ b/rag/llm/rerank_model.py @@ -31,7 +31,6 @@ from rag.utils import num_tokens_from_string, truncate import json - def sigmoid(x): return 1 / (1 + np.exp(-x)) @@ -87,10 +86,9 @@ class DefaultRerank(Base): local_dir_use_symlinks=False) DefaultRerank._model = FlagReranker(model_dir, use_fp16=torch.cuda.is_available()) self._model = DefaultRerank._model - self._dynamic_batch_size = 8 + self._dynamic_batch_size = 8 self._min_batch_size = 1 - def torch_empty_cache(self): try: import torch @@ -112,7 +110,7 @@ class DefaultRerank(Base): while retry_count < max_retries: try: # call subclass implemented batch processing calculation - batch_scores = self._compute_batch_scores(pairs[i:i+current_batch]) + batch_scores = self._compute_batch_scores(pairs[i:i + current_batch]) res.extend(batch_scores) i += current_batch self._dynamic_batch_size = min(self._dynamic_batch_size * 2, 8) @@ -282,6 +280,7 @@ class LocalAIRerank(Base): return rank, token_count + class NvidiaRerank(Base): def __init__( self, key, model_name, base_url="https://ai.api.nvidia.com/v1/retrieval/nvidia/" @@ -513,6 +512,40 @@ class QWenRerank(Base): else: raise ValueError(f"Error calling QWenRerank model {self.model_name}: {resp.status_code} - {resp.text}") + +class HuggingfaceRerank(DefaultRerank): + @staticmethod + def post(query: str, texts: list, url="127.0.0.1"): + exc = None + scores = [0 for _ in range(len(texts))] + batch_size = 8 + for i in range(0, len(texts), batch_size): + try: + res = requests.post(f"http://{url}/rerank", headers={"Content-Type": "application/json"}, + json={"query": query, "texts": texts[i: i + batch_size], + "raw_scores": False, "truncate": True}) + for o in res.json(): + scores[o["index"] + i] = o["score"] + except Exception as e: + exc = e + + if exc: + raise exc + return np.array(scores) + + def __init__(self, key, model_name="BAAI/bge-reranker-v2-m3", base_url="http://127.0.0.1"): + self.model_name = model_name + self.base_url = base_url + + def similarity(self, query: str, texts: list) -> tuple[np.ndarray, int]: + if not texts: + return np.array([]), 0 + token_count = 0 + for t in texts: + token_count += num_tokens_from_string(t) + return HuggingfaceRerank.post(query, texts, self.base_url), token_count + + class GPUStackRerank(Base): def __init__( self, key, model_name, base_url @@ -521,7 +554,7 @@ class GPUStackRerank(Base): raise ValueError("url cannot be None") self.model_name = model_name - self.base_url = str(URL(base_url)/ "v1" / "rerank") + self.base_url = str(URL(base_url) / "v1" / "rerank") self.headers = { "accept": "application/json", "content-type": "application/json", @@ -560,5 +593,6 @@ class GPUStackRerank(Base): ) except httpx.HTTPStatusError as e: - raise ValueError(f"Error calling GPUStackRerank model {self.model_name}: {e.response.status_code} - {e.response.text}") + raise ValueError( + f"Error calling GPUStackRerank model {self.model_name}: {e.response.status_code} - {e.response.text}")