From cc4785f094d122dc7b2be67471fc27886eaede38 Mon Sep 17 00:00:00 2001 From: Weaxs <459312872@qq.com> Date: Thu, 1 Aug 2024 19:57:53 +0800 Subject: [PATCH] fix: xinference reranker return_documents (#6888) --- .../model_runtime/model_providers/xinference/rerank/rerank.py | 3 ++- api/tests/integration_tests/model_runtime/__mock/xinference.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/api/core/model_runtime/model_providers/xinference/rerank/rerank.py b/api/core/model_runtime/model_providers/xinference/rerank/rerank.py index 649898f47a..b361806bcd 100644 --- a/api/core/model_runtime/model_providers/xinference/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/xinference/rerank/rerank.py @@ -57,6 +57,7 @@ class XinferenceRerankModel(RerankModel): documents=docs, query=query, top_n=top_n, + return_documents=True ) except RuntimeError as e: raise InvokeServerUnavailableError(str(e)) @@ -66,7 +67,7 @@ class XinferenceRerankModel(RerankModel): for idx, result in enumerate(response['results']): # format document index = result['index'] - page_content = result['document'] + page_content = result['document'] if isinstance(result['document'], str) else result['document']['text'] rerank_document = RerankDocument( index=index, text=page_content, diff --git a/api/tests/integration_tests/model_runtime/__mock/xinference.py b/api/tests/integration_tests/model_runtime/__mock/xinference.py index ddb18fe919..7cb0a1318e 100644 --- a/api/tests/integration_tests/model_runtime/__mock/xinference.py +++ b/api/tests/integration_tests/model_runtime/__mock/xinference.py @@ -106,7 +106,7 @@ class MockXinferenceClass: def _check_cluster_authenticated(self): self._cluster_authed = True - def rerank(self: RESTfulRerankModelHandle, documents: list[str], query: str, top_n: int) -> dict: + def rerank(self: RESTfulRerankModelHandle, documents: list[str], query: str, top_n: int, return_documents: bool) -> dict: # check if self._model_uid is a valid uuid if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', self._model_uid) and \ self._model_uid != 'rerank':