mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-08-14 07:35:55 +08:00
fix bugs of rerank model with xinference (#1481)
### What problem does this PR solve? ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
parent
575099df2d
commit
99f7bbaaa2
@ -165,6 +165,17 @@ def add_llm():
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
msg += f"\nFail to access model({llm['llm_name']})." + str(
|
msg += f"\nFail to access model({llm['llm_name']})." + str(
|
||||||
e)
|
e)
|
||||||
|
elif llm["model_type"] == LLMType.RERANK:
|
||||||
|
mdl = RerankModel[factory](
|
||||||
|
key=None, model_name=llm["llm_name"], base_url=llm["api_base"]
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
arr, tc = mdl.similarity("Hello~ Ragflower!", ["Hi, there!"])
|
||||||
|
if len(arr) == 0 or tc == 0:
|
||||||
|
raise Exception("Not known.")
|
||||||
|
except Exception as e:
|
||||||
|
msg += f"\nFail to access model({llm['llm_name']})." + str(
|
||||||
|
e)
|
||||||
else:
|
else:
|
||||||
# TODO: check other type of models
|
# TODO: check other type of models
|
||||||
pass
|
pass
|
||||||
|
@ -136,10 +136,11 @@ class YoudaoRerank(DefaultRerank):
|
|||||||
else: res.extend(scores)
|
else: res.extend(scores)
|
||||||
return np.array(res), token_count
|
return np.array(res), token_count
|
||||||
|
|
||||||
|
|
||||||
class XInferenceRerank(Base):
|
class XInferenceRerank(Base):
|
||||||
def __init__(self,model_name="",base_url=""):
|
def __init__(self, key="xxxxxxx", model_name="", base_url=""):
|
||||||
self.model_name=model_name
|
self.model_name = model_name
|
||||||
self.base_url=base_url
|
self.base_url = base_url
|
||||||
self.headers = {
|
self.headers = {
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
"accept": "application/json"
|
"accept": "application/json"
|
||||||
@ -147,11 +148,12 @@ class XInferenceRerank(Base):
|
|||||||
|
|
||||||
def similarity(self, query: str, texts: list):
|
def similarity(self, query: str, texts: list):
|
||||||
data = {
|
data = {
|
||||||
"model":self.model_name,
|
"model": self.model_name,
|
||||||
"query":query,
|
"query": query,
|
||||||
"return_documents": "true",
|
"return_documents": "true",
|
||||||
"return_len": "true",
|
"return_len": "true",
|
||||||
"documents":texts
|
"documents": texts
|
||||||
}
|
}
|
||||||
res = requests.post(self.base_url, headers=self.headers, json=data).json()
|
res = requests.post(self.base_url, headers=self.headers, json=data).json()
|
||||||
return np.array([d["relevance_score"] for d in res["results"]]),res["tokens"]["input_tokens"]+res["tokens"]["output_tokens"]
|
return np.array([d["relevance_score"] for d in res["results"]]), res["tokens"]["input_tokens"] + res["tokens"][
|
||||||
|
"output_tokens"]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user