diff --git a/api/core/model_runtime/model_providers/xinference/rerank/rerank.py b/api/core/model_runtime/model_providers/xinference/rerank/rerank.py index b361806bcd..4e7543fd99 100644 --- a/api/core/model_runtime/model_providers/xinference/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/xinference/rerank/rerank.py @@ -51,17 +51,22 @@ class XinferenceRerankModel(RerankModel): server_url = server_url[:-1] auth_headers = {'Authorization': f'Bearer {api_key}'} if api_key else {} + params = { + 'documents': docs, + 'query': query, + 'top_n': top_n, + 'return_documents': True + } try: handle = RESTfulRerankModelHandle(model_uid, server_url, auth_headers) - response = handle.rerank( - documents=docs, - query=query, - top_n=top_n, - return_documents=True - ) + response = handle.rerank(**params) except RuntimeError as e: - raise InvokeServerUnavailableError(str(e)) + if "rerank hasn't support extra parameter" not in str(e): + raise InvokeServerUnavailableError(str(e)) + # compatible xinference server between v0.10.1 - v0.12.1, not support 'return_len' + handle = RESTfulRerankModelHandleWithoutExtraParameter(model_uid, server_url, auth_headers) + response = handle.rerank(**params) rerank_documents = [] for idx, result in enumerate(response['results']): @@ -167,8 +172,40 @@ class XinferenceRerankModel(RerankModel): ), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.RERANK, - model_properties={ }, + model_properties={}, parameter_rules=[] ) return entity + + +class RESTfulRerankModelHandleWithoutExtraParameter(RESTfulRerankModelHandle): + + def rerank( + self, + documents: list[str], + query: str, + top_n: Optional[int] = None, + max_chunks_per_doc: Optional[int] = None, + return_documents: Optional[bool] = None, + **kwargs + ): + url = f"{self._base_url}/v1/rerank" + request_body = { + "model": self._model_uid, + "documents": documents, + "query": query, + "top_n": top_n, + "max_chunks_per_doc": max_chunks_per_doc, + "return_documents": return_documents, + } + + import requests + + response = requests.post(url, json=request_body, headers=self.auth_headers) + if response.status_code != 200: + raise InvokeServerUnavailableError( + f"Failed to rerank documents, detail: {response.json()['detail']}" + ) + response_data = response.json() + return response_data