mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-18 07:05:52 +08:00
compatible xinference reranker server (#6927)
This commit is contained in:
parent
26e46d365c
commit
5e634a59a2
@ -51,17 +51,22 @@ class XinferenceRerankModel(RerankModel):
|
|||||||
server_url = server_url[:-1]
|
server_url = server_url[:-1]
|
||||||
auth_headers = {'Authorization': f'Bearer {api_key}'} if api_key else {}
|
auth_headers = {'Authorization': f'Bearer {api_key}'} if api_key else {}
|
||||||
|
|
||||||
|
params = {
|
||||||
|
'documents': docs,
|
||||||
|
'query': query,
|
||||||
|
'top_n': top_n,
|
||||||
|
'return_documents': True
|
||||||
|
}
|
||||||
try:
|
try:
|
||||||
handle = RESTfulRerankModelHandle(model_uid, server_url, auth_headers)
|
handle = RESTfulRerankModelHandle(model_uid, server_url, auth_headers)
|
||||||
response = handle.rerank(
|
response = handle.rerank(**params)
|
||||||
documents=docs,
|
|
||||||
query=query,
|
|
||||||
top_n=top_n,
|
|
||||||
return_documents=True
|
|
||||||
)
|
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
raise InvokeServerUnavailableError(str(e))
|
if "rerank hasn't support extra parameter" not in str(e):
|
||||||
|
raise InvokeServerUnavailableError(str(e))
|
||||||
|
|
||||||
|
# compatible xinference server between v0.10.1 - v0.12.1, not support 'return_len'
|
||||||
|
handle = RESTfulRerankModelHandleWithoutExtraParameter(model_uid, server_url, auth_headers)
|
||||||
|
response = handle.rerank(**params)
|
||||||
|
|
||||||
rerank_documents = []
|
rerank_documents = []
|
||||||
for idx, result in enumerate(response['results']):
|
for idx, result in enumerate(response['results']):
|
||||||
@ -167,8 +172,40 @@ class XinferenceRerankModel(RerankModel):
|
|||||||
),
|
),
|
||||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||||
model_type=ModelType.RERANK,
|
model_type=ModelType.RERANK,
|
||||||
model_properties={ },
|
model_properties={},
|
||||||
parameter_rules=[]
|
parameter_rules=[]
|
||||||
)
|
)
|
||||||
|
|
||||||
return entity
|
return entity
|
||||||
|
|
||||||
|
|
||||||
|
class RESTfulRerankModelHandleWithoutExtraParameter(RESTfulRerankModelHandle):
|
||||||
|
|
||||||
|
def rerank(
|
||||||
|
self,
|
||||||
|
documents: list[str],
|
||||||
|
query: str,
|
||||||
|
top_n: Optional[int] = None,
|
||||||
|
max_chunks_per_doc: Optional[int] = None,
|
||||||
|
return_documents: Optional[bool] = None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
url = f"{self._base_url}/v1/rerank"
|
||||||
|
request_body = {
|
||||||
|
"model": self._model_uid,
|
||||||
|
"documents": documents,
|
||||||
|
"query": query,
|
||||||
|
"top_n": top_n,
|
||||||
|
"max_chunks_per_doc": max_chunks_per_doc,
|
||||||
|
"return_documents": return_documents,
|
||||||
|
}
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
response = requests.post(url, json=request_body, headers=self.auth_headers)
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise InvokeServerUnavailableError(
|
||||||
|
f"Failed to rerank documents, detail: {response.json()['detail']}"
|
||||||
|
)
|
||||||
|
response_data = response.json()
|
||||||
|
return response_data
|
||||||
|
Loading…
x
Reference in New Issue
Block a user