From 22bc9ddc73b3f0d04fe72f0d8a1d94f4181b2804 Mon Sep 17 00:00:00 2001 From: WangBooth Date: Thu, 30 Nov 2023 22:03:20 +0800 Subject: [PATCH] Hotfix/fix documents index mismatch error in rerank (#1662) Co-authored-by: baomi.wbm --- .../models/reranking/cohere_reranking.py | 15 ++++++++++----- .../models/reranking/xinference_reranking.py | 5 ++++- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/api/core/model_providers/models/reranking/cohere_reranking.py b/api/core/model_providers/models/reranking/cohere_reranking.py index 3119caeae1..e5b2fe61d1 100644 --- a/api/core/model_providers/models/reranking/cohere_reranking.py +++ b/api/core/model_providers/models/reranking/cohere_reranking.py @@ -1,14 +1,15 @@ import logging -from typing import Optional, List +from typing import List, Optional import cohere import openai -from langchain.schema import Document - -from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \ - LLMRateLimitError, LLMAuthorizationError +from core.model_providers.error import (LLMAPIConnectionError, + LLMAPIUnavailableError, + LLMAuthorizationError, + LLMBadRequestError, LLMRateLimitError) from core.model_providers.models.reranking.base import BaseReranking from core.model_providers.providers.base import BaseModelProvider +from langchain.schema import Document class CohereReranking(BaseReranking): @@ -26,10 +27,14 @@ class CohereReranking(BaseReranking): def rerank(self, query: str, documents: List[Document], score_threshold: Optional[float], top_k: Optional[int]) -> Optional[List[Document]]: docs = [] doc_id = [] + unique_documents = [] for document in documents: if document.metadata['doc_id'] not in doc_id: doc_id.append(document.metadata['doc_id']) docs.append(document.page_content) + unique_documents.append(document) + documents = unique_documents + results = self.client.rerank(query=query, documents=docs, model=self.name, top_n=top_k) rerank_documents = [] diff --git a/api/core/model_providers/models/reranking/xinference_reranking.py b/api/core/model_providers/models/reranking/xinference_reranking.py index 47c1c6cd01..0ae9eaec6e 100644 --- a/api/core/model_providers/models/reranking/xinference_reranking.py +++ b/api/core/model_providers/models/reranking/xinference_reranking.py @@ -23,11 +23,14 @@ class XinferenceReranking(BaseReranking): def rerank(self, query: str, documents: List[Document], score_threshold: Optional[float], top_k: Optional[int]) -> Optional[List[Document]]: docs = [] doc_id = [] + unique_documents = [] for document in documents: if document.metadata['doc_id'] not in doc_id: doc_id.append(document.metadata['doc_id']) docs.append(document.page_content) - + unique_documents.append(document) + documents = unique_documents + model = self.client.get_model(self.credentials['model_uid']) response = model.rerank(query=query, documents=docs, top_n=top_k) rerank_documents = []