fix: xinference reranker return_documents (#6888)

This commit is contained in:
Weaxs 2024-08-01 19:57:53 +08:00 committed by GitHub
parent 093f902335
commit cc4785f094
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 3 additions and 2 deletions

View File

@ -57,6 +57,7 @@ class XinferenceRerankModel(RerankModel):
documents=docs, documents=docs,
query=query, query=query,
top_n=top_n, top_n=top_n,
return_documents=True
) )
except RuntimeError as e: except RuntimeError as e:
raise InvokeServerUnavailableError(str(e)) raise InvokeServerUnavailableError(str(e))
@ -66,7 +67,7 @@ class XinferenceRerankModel(RerankModel):
for idx, result in enumerate(response['results']): for idx, result in enumerate(response['results']):
# format document # format document
index = result['index'] index = result['index']
page_content = result['document'] page_content = result['document'] if isinstance(result['document'], str) else result['document']['text']
rerank_document = RerankDocument( rerank_document = RerankDocument(
index=index, index=index,
text=page_content, text=page_content,

View File

@ -106,7 +106,7 @@ class MockXinferenceClass:
def _check_cluster_authenticated(self): def _check_cluster_authenticated(self):
self._cluster_authed = True self._cluster_authed = True
def rerank(self: RESTfulRerankModelHandle, documents: list[str], query: str, top_n: int) -> dict: def rerank(self: RESTfulRerankModelHandle, documents: list[str], query: str, top_n: int, return_documents: bool) -> dict:
# check if self._model_uid is a valid uuid # check if self._model_uid is a valid uuid
if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', self._model_uid) and \ if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', self._model_uid) and \
self._model_uid != 'rerank': self._model_uid != 'rerank':