compatible xinference reranker server (#6927)

This commit is contained in:
Weaxs 2024-08-04 13:49:38 +08:00 committed by GitHub
parent 26e46d365c
commit 5e634a59a2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -51,17 +51,22 @@ class XinferenceRerankModel(RerankModel):
server_url = server_url[:-1] server_url = server_url[:-1]
auth_headers = {'Authorization': f'Bearer {api_key}'} if api_key else {} auth_headers = {'Authorization': f'Bearer {api_key}'} if api_key else {}
params = {
'documents': docs,
'query': query,
'top_n': top_n,
'return_documents': True
}
try: try:
handle = RESTfulRerankModelHandle(model_uid, server_url, auth_headers) handle = RESTfulRerankModelHandle(model_uid, server_url, auth_headers)
response = handle.rerank( response = handle.rerank(**params)
documents=docs,
query=query,
top_n=top_n,
return_documents=True
)
except RuntimeError as e: except RuntimeError as e:
raise InvokeServerUnavailableError(str(e)) if "rerank hasn't support extra parameter" not in str(e):
raise InvokeServerUnavailableError(str(e))
# compatible xinference server between v0.10.1 - v0.12.1, not support 'return_len'
handle = RESTfulRerankModelHandleWithoutExtraParameter(model_uid, server_url, auth_headers)
response = handle.rerank(**params)
rerank_documents = [] rerank_documents = []
for idx, result in enumerate(response['results']): for idx, result in enumerate(response['results']):
@ -167,8 +172,40 @@ class XinferenceRerankModel(RerankModel):
), ),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.RERANK, model_type=ModelType.RERANK,
model_properties={ }, model_properties={},
parameter_rules=[] parameter_rules=[]
) )
return entity return entity
class RESTfulRerankModelHandleWithoutExtraParameter(RESTfulRerankModelHandle):
def rerank(
self,
documents: list[str],
query: str,
top_n: Optional[int] = None,
max_chunks_per_doc: Optional[int] = None,
return_documents: Optional[bool] = None,
**kwargs
):
url = f"{self._base_url}/v1/rerank"
request_body = {
"model": self._model_uid,
"documents": documents,
"query": query,
"top_n": top_n,
"max_chunks_per_doc": max_chunks_per_doc,
"return_documents": return_documents,
}
import requests
response = requests.post(url, json=request_body, headers=self.auth_headers)
if response.status_code != 200:
raise InvokeServerUnavailableError(
f"Failed to rerank documents, detail: {response.json()['detail']}"
)
response_data = response.json()
return response_data