Feat: support huggingface re-rank model. (#5684)

### What problem does this PR solve?

#5658

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Kevin Hu 2025-03-06 10:44:04 +08:00 committed by GitHub
parent 5f62f0c9d7
commit b8da2eeb69
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 46 additions and 7 deletions

View File

@ -3284,7 +3284,7 @@
{
"name": "HuggingFace",
"logo": "",
"tags": "TEXT EMBEDDING",
"tags": "TEXT EMBEDDING,TEXT RE-RANK",
"status": "1",
"llm": []
},

View File

@ -107,6 +107,7 @@ from .cv_model import (
YiCV,
HunyuanCV,
)
from .rerank_model import (
LocalAIRerank,
DefaultRerank,
@ -123,7 +124,9 @@ from .rerank_model import (
VoyageRerank,
QWenRerank,
GPUStackRerank,
HuggingfaceRerank,
)
from .sequence2txt_model import (
GPTSeq2txt,
QWenSeq2txt,
@ -132,6 +135,7 @@ from .sequence2txt_model import (
TencentCloudSeq2txt,
GPUStackSeq2txt,
)
from .tts_model import (
FishAudioTTS,
QwenTTS,
@ -255,6 +259,7 @@ RerankModel = {
"Voyage AI": VoyageRerank,
"Tongyi-Qianwen": QWenRerank,
"GPUStack": GPUStackRerank,
"HuggingFace": HuggingfaceRerank,
}
Seq2txtModel = {

View File

@ -31,7 +31,6 @@ from rag.utils import num_tokens_from_string, truncate
import json
def sigmoid(x):
return 1 / (1 + np.exp(-x))
@ -90,7 +89,6 @@ class DefaultRerank(Base):
self._dynamic_batch_size = 8
self._min_batch_size = 1
def torch_empty_cache(self):
try:
import torch
@ -282,6 +280,7 @@ class LocalAIRerank(Base):
return rank, token_count
class NvidiaRerank(Base):
def __init__(
self, key, model_name, base_url="https://ai.api.nvidia.com/v1/retrieval/nvidia/"
@ -513,6 +512,40 @@ class QWenRerank(Base):
else:
raise ValueError(f"Error calling QWenRerank model {self.model_name}: {resp.status_code} - {resp.text}")
class HuggingfaceRerank(DefaultRerank):
@staticmethod
def post(query: str, texts: list, url="127.0.0.1"):
exc = None
scores = [0 for _ in range(len(texts))]
batch_size = 8
for i in range(0, len(texts), batch_size):
try:
res = requests.post(f"http://{url}/rerank", headers={"Content-Type": "application/json"},
json={"query": query, "texts": texts[i: i + batch_size],
"raw_scores": False, "truncate": True})
for o in res.json():
scores[o["index"] + i] = o["score"]
except Exception as e:
exc = e
if exc:
raise exc
return np.array(scores)
def __init__(self, key, model_name="BAAI/bge-reranker-v2-m3", base_url="http://127.0.0.1"):
self.model_name = model_name
self.base_url = base_url
def similarity(self, query: str, texts: list) -> tuple[np.ndarray, int]:
if not texts:
return np.array([]), 0
token_count = 0
for t in texts:
token_count += num_tokens_from_string(t)
return HuggingfaceRerank.post(query, texts, self.base_url), token_count
class GPUStackRerank(Base):
def __init__(
self, key, model_name, base_url
@ -560,5 +593,6 @@ class GPUStackRerank(Base):
)
except httpx.HTTPStatusError as e:
raise ValueError(f"Error calling GPUStackRerank model {self.model_name}: {e.response.status_code} - {e.response.text}")
raise ValueError(
f"Error calling GPUStackRerank model {self.model_name}: {e.response.status_code} - {e.response.text}")