mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-13 03:19:01 +08:00
Hotfix/fix documents index mismatch error in rerank (#1662)
Co-authored-by: baomi.wbm <baomi.wbm@dtwave-inc.com>
This commit is contained in:
parent
0423775687
commit
22bc9ddc73
@ -1,14 +1,15 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Optional, List
|
from typing import List, Optional
|
||||||
|
|
||||||
import cohere
|
import cohere
|
||||||
import openai
|
import openai
|
||||||
from langchain.schema import Document
|
from core.model_providers.error import (LLMAPIConnectionError,
|
||||||
|
LLMAPIUnavailableError,
|
||||||
from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
|
LLMAuthorizationError,
|
||||||
LLMRateLimitError, LLMAuthorizationError
|
LLMBadRequestError, LLMRateLimitError)
|
||||||
from core.model_providers.models.reranking.base import BaseReranking
|
from core.model_providers.models.reranking.base import BaseReranking
|
||||||
from core.model_providers.providers.base import BaseModelProvider
|
from core.model_providers.providers.base import BaseModelProvider
|
||||||
|
from langchain.schema import Document
|
||||||
|
|
||||||
|
|
||||||
class CohereReranking(BaseReranking):
|
class CohereReranking(BaseReranking):
|
||||||
@ -26,10 +27,14 @@ class CohereReranking(BaseReranking):
|
|||||||
def rerank(self, query: str, documents: List[Document], score_threshold: Optional[float], top_k: Optional[int]) -> Optional[List[Document]]:
|
def rerank(self, query: str, documents: List[Document], score_threshold: Optional[float], top_k: Optional[int]) -> Optional[List[Document]]:
|
||||||
docs = []
|
docs = []
|
||||||
doc_id = []
|
doc_id = []
|
||||||
|
unique_documents = []
|
||||||
for document in documents:
|
for document in documents:
|
||||||
if document.metadata['doc_id'] not in doc_id:
|
if document.metadata['doc_id'] not in doc_id:
|
||||||
doc_id.append(document.metadata['doc_id'])
|
doc_id.append(document.metadata['doc_id'])
|
||||||
docs.append(document.page_content)
|
docs.append(document.page_content)
|
||||||
|
unique_documents.append(document)
|
||||||
|
documents = unique_documents
|
||||||
|
|
||||||
results = self.client.rerank(query=query, documents=docs, model=self.name, top_n=top_k)
|
results = self.client.rerank(query=query, documents=docs, model=self.name, top_n=top_k)
|
||||||
rerank_documents = []
|
rerank_documents = []
|
||||||
|
|
||||||
|
@ -23,11 +23,14 @@ class XinferenceReranking(BaseReranking):
|
|||||||
def rerank(self, query: str, documents: List[Document], score_threshold: Optional[float], top_k: Optional[int]) -> Optional[List[Document]]:
|
def rerank(self, query: str, documents: List[Document], score_threshold: Optional[float], top_k: Optional[int]) -> Optional[List[Document]]:
|
||||||
docs = []
|
docs = []
|
||||||
doc_id = []
|
doc_id = []
|
||||||
|
unique_documents = []
|
||||||
for document in documents:
|
for document in documents:
|
||||||
if document.metadata['doc_id'] not in doc_id:
|
if document.metadata['doc_id'] not in doc_id:
|
||||||
doc_id.append(document.metadata['doc_id'])
|
doc_id.append(document.metadata['doc_id'])
|
||||||
docs.append(document.page_content)
|
docs.append(document.page_content)
|
||||||
|
unique_documents.append(document)
|
||||||
|
documents = unique_documents
|
||||||
|
|
||||||
model = self.client.get_model(self.credentials['model_uid'])
|
model = self.client.get_model(self.credentials['model_uid'])
|
||||||
response = model.rerank(query=query, documents=docs, top_n=top_k)
|
response = model.rerank(query=query, documents=docs, top_n=top_k)
|
||||||
rerank_documents = []
|
rerank_documents = []
|
||||||
|
Loading…
x
Reference in New Issue
Block a user