mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-04-22 22:20:07 +08:00
Feat: support huggingface re-rank model. (#5684)
### What problem does this PR solve? #5658 ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
parent
5f62f0c9d7
commit
b8da2eeb69
@ -3284,7 +3284,7 @@
|
|||||||
{
|
{
|
||||||
"name": "HuggingFace",
|
"name": "HuggingFace",
|
||||||
"logo": "",
|
"logo": "",
|
||||||
"tags": "TEXT EMBEDDING",
|
"tags": "TEXT EMBEDDING,TEXT RE-RANK",
|
||||||
"status": "1",
|
"status": "1",
|
||||||
"llm": []
|
"llm": []
|
||||||
},
|
},
|
||||||
|
@ -107,6 +107,7 @@ from .cv_model import (
|
|||||||
YiCV,
|
YiCV,
|
||||||
HunyuanCV,
|
HunyuanCV,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .rerank_model import (
|
from .rerank_model import (
|
||||||
LocalAIRerank,
|
LocalAIRerank,
|
||||||
DefaultRerank,
|
DefaultRerank,
|
||||||
@ -123,7 +124,9 @@ from .rerank_model import (
|
|||||||
VoyageRerank,
|
VoyageRerank,
|
||||||
QWenRerank,
|
QWenRerank,
|
||||||
GPUStackRerank,
|
GPUStackRerank,
|
||||||
|
HuggingfaceRerank,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .sequence2txt_model import (
|
from .sequence2txt_model import (
|
||||||
GPTSeq2txt,
|
GPTSeq2txt,
|
||||||
QWenSeq2txt,
|
QWenSeq2txt,
|
||||||
@ -132,6 +135,7 @@ from .sequence2txt_model import (
|
|||||||
TencentCloudSeq2txt,
|
TencentCloudSeq2txt,
|
||||||
GPUStackSeq2txt,
|
GPUStackSeq2txt,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .tts_model import (
|
from .tts_model import (
|
||||||
FishAudioTTS,
|
FishAudioTTS,
|
||||||
QwenTTS,
|
QwenTTS,
|
||||||
@ -255,6 +259,7 @@ RerankModel = {
|
|||||||
"Voyage AI": VoyageRerank,
|
"Voyage AI": VoyageRerank,
|
||||||
"Tongyi-Qianwen": QWenRerank,
|
"Tongyi-Qianwen": QWenRerank,
|
||||||
"GPUStack": GPUStackRerank,
|
"GPUStack": GPUStackRerank,
|
||||||
|
"HuggingFace": HuggingfaceRerank,
|
||||||
}
|
}
|
||||||
|
|
||||||
Seq2txtModel = {
|
Seq2txtModel = {
|
||||||
|
@ -31,7 +31,6 @@ from rag.utils import num_tokens_from_string, truncate
|
|||||||
import json
|
import json
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def sigmoid(x):
|
def sigmoid(x):
|
||||||
return 1 / (1 + np.exp(-x))
|
return 1 / (1 + np.exp(-x))
|
||||||
|
|
||||||
@ -90,7 +89,6 @@ class DefaultRerank(Base):
|
|||||||
self._dynamic_batch_size = 8
|
self._dynamic_batch_size = 8
|
||||||
self._min_batch_size = 1
|
self._min_batch_size = 1
|
||||||
|
|
||||||
|
|
||||||
def torch_empty_cache(self):
|
def torch_empty_cache(self):
|
||||||
try:
|
try:
|
||||||
import torch
|
import torch
|
||||||
@ -282,6 +280,7 @@ class LocalAIRerank(Base):
|
|||||||
|
|
||||||
return rank, token_count
|
return rank, token_count
|
||||||
|
|
||||||
|
|
||||||
class NvidiaRerank(Base):
|
class NvidiaRerank(Base):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, key, model_name, base_url="https://ai.api.nvidia.com/v1/retrieval/nvidia/"
|
self, key, model_name, base_url="https://ai.api.nvidia.com/v1/retrieval/nvidia/"
|
||||||
@ -513,6 +512,40 @@ class QWenRerank(Base):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Error calling QWenRerank model {self.model_name}: {resp.status_code} - {resp.text}")
|
raise ValueError(f"Error calling QWenRerank model {self.model_name}: {resp.status_code} - {resp.text}")
|
||||||
|
|
||||||
|
|
||||||
|
class HuggingfaceRerank(DefaultRerank):
|
||||||
|
@staticmethod
|
||||||
|
def post(query: str, texts: list, url="127.0.0.1"):
|
||||||
|
exc = None
|
||||||
|
scores = [0 for _ in range(len(texts))]
|
||||||
|
batch_size = 8
|
||||||
|
for i in range(0, len(texts), batch_size):
|
||||||
|
try:
|
||||||
|
res = requests.post(f"http://{url}/rerank", headers={"Content-Type": "application/json"},
|
||||||
|
json={"query": query, "texts": texts[i: i + batch_size],
|
||||||
|
"raw_scores": False, "truncate": True})
|
||||||
|
for o in res.json():
|
||||||
|
scores[o["index"] + i] = o["score"]
|
||||||
|
except Exception as e:
|
||||||
|
exc = e
|
||||||
|
|
||||||
|
if exc:
|
||||||
|
raise exc
|
||||||
|
return np.array(scores)
|
||||||
|
|
||||||
|
def __init__(self, key, model_name="BAAI/bge-reranker-v2-m3", base_url="http://127.0.0.1"):
|
||||||
|
self.model_name = model_name
|
||||||
|
self.base_url = base_url
|
||||||
|
|
||||||
|
def similarity(self, query: str, texts: list) -> tuple[np.ndarray, int]:
|
||||||
|
if not texts:
|
||||||
|
return np.array([]), 0
|
||||||
|
token_count = 0
|
||||||
|
for t in texts:
|
||||||
|
token_count += num_tokens_from_string(t)
|
||||||
|
return HuggingfaceRerank.post(query, texts, self.base_url), token_count
|
||||||
|
|
||||||
|
|
||||||
class GPUStackRerank(Base):
|
class GPUStackRerank(Base):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, key, model_name, base_url
|
self, key, model_name, base_url
|
||||||
@ -560,5 +593,6 @@ class GPUStackRerank(Base):
|
|||||||
)
|
)
|
||||||
|
|
||||||
except httpx.HTTPStatusError as e:
|
except httpx.HTTPStatusError as e:
|
||||||
raise ValueError(f"Error calling GPUStackRerank model {self.model_name}: {e.response.status_code} - {e.response.text}")
|
raise ValueError(
|
||||||
|
f"Error calling GPUStackRerank model {self.model_name}: {e.response.status_code} - {e.response.text}")
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user