diff --git a/api/core/model_providers/models/reranking/cohere_reranking.py b/api/core/model_providers/models/reranking/cohere_reranking.py index e5b2fe61d1..fa2748734d 100644 --- a/api/core/model_providers/models/reranking/cohere_reranking.py +++ b/api/core/model_providers/models/reranking/cohere_reranking.py @@ -24,7 +24,10 @@ class CohereReranking(BaseReranking): super().__init__(model_provider, client, name) - def rerank(self, query: str, documents: List[Document], score_threshold: Optional[float], top_k: Optional[int]) -> Optional[List[Document]]: + def rerank(self, query: str, documents: List[Document], score_threshold: Optional[float], top_k: Optional[int]) -> \ + Optional[List[Document]]: + if not documents: + return [] docs = [] doc_id = [] unique_documents = [] @@ -34,7 +37,7 @@ class CohereReranking(BaseReranking): docs.append(document.page_content) unique_documents.append(document) documents = unique_documents - + results = self.client.rerank(query=query, documents=docs, model=self.name, top_n=top_k) rerank_documents = [] diff --git a/api/core/model_providers/models/reranking/xinference_reranking.py b/api/core/model_providers/models/reranking/xinference_reranking.py index 0ae9eaec6e..fda2772c70 100644 --- a/api/core/model_providers/models/reranking/xinference_reranking.py +++ b/api/core/model_providers/models/reranking/xinference_reranking.py @@ -21,6 +21,8 @@ class XinferenceReranking(BaseReranking): super().__init__(model_provider, client, name) def rerank(self, query: str, documents: List[Document], score_threshold: Optional[float], top_k: Optional[int]) -> Optional[List[Document]]: + if not documents: + return [] docs = [] doc_id = [] unique_documents = []