mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-12 20:39:01 +08:00
Hotfix/fix documents index mismatch error in rerank (#1662)
Co-authored-by: baomi.wbm <baomi.wbm@dtwave-inc.com>
This commit is contained in:
parent
0423775687
commit
22bc9ddc73
@ -1,14 +1,15 @@
|
||||
import logging
|
||||
from typing import Optional, List
|
||||
from typing import List, Optional
|
||||
|
||||
import cohere
|
||||
import openai
|
||||
from langchain.schema import Document
|
||||
|
||||
from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
|
||||
LLMRateLimitError, LLMAuthorizationError
|
||||
from core.model_providers.error import (LLMAPIConnectionError,
|
||||
LLMAPIUnavailableError,
|
||||
LLMAuthorizationError,
|
||||
LLMBadRequestError, LLMRateLimitError)
|
||||
from core.model_providers.models.reranking.base import BaseReranking
|
||||
from core.model_providers.providers.base import BaseModelProvider
|
||||
from langchain.schema import Document
|
||||
|
||||
|
||||
class CohereReranking(BaseReranking):
|
||||
@ -26,10 +27,14 @@ class CohereReranking(BaseReranking):
|
||||
def rerank(self, query: str, documents: List[Document], score_threshold: Optional[float], top_k: Optional[int]) -> Optional[List[Document]]:
|
||||
docs = []
|
||||
doc_id = []
|
||||
unique_documents = []
|
||||
for document in documents:
|
||||
if document.metadata['doc_id'] not in doc_id:
|
||||
doc_id.append(document.metadata['doc_id'])
|
||||
docs.append(document.page_content)
|
||||
unique_documents.append(document)
|
||||
documents = unique_documents
|
||||
|
||||
results = self.client.rerank(query=query, documents=docs, model=self.name, top_n=top_k)
|
||||
rerank_documents = []
|
||||
|
||||
|
@ -23,11 +23,14 @@ class XinferenceReranking(BaseReranking):
|
||||
def rerank(self, query: str, documents: List[Document], score_threshold: Optional[float], top_k: Optional[int]) -> Optional[List[Document]]:
|
||||
docs = []
|
||||
doc_id = []
|
||||
unique_documents = []
|
||||
for document in documents:
|
||||
if document.metadata['doc_id'] not in doc_id:
|
||||
doc_id.append(document.metadata['doc_id'])
|
||||
docs.append(document.page_content)
|
||||
|
||||
unique_documents.append(document)
|
||||
documents = unique_documents
|
||||
|
||||
model = self.client.get_model(self.credentials['model_uid'])
|
||||
response = model.rerank(query=query, documents=docs, top_n=top_k)
|
||||
rerank_documents = []
|
||||
|
Loading…
x
Reference in New Issue
Block a user