mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-13 21:26:03 +08:00
add rerank check when doing mutil-retrieval (#9998)
This commit is contained in:
parent
5ad5d0cff4
commit
9ebd453b87
@ -1,6 +1,6 @@
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
class RerankMode(Enum):
|
class RerankMode(str, Enum):
|
||||||
RERANKING_MODEL = "reranking_model"
|
RERANKING_MODEL = "reranking_model"
|
||||||
WEIGHTED_SCORE = "weighted_score"
|
WEIGHTED_SCORE = "weighted_score"
|
||||||
|
@ -22,6 +22,7 @@ from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaK
|
|||||||
from core.rag.datasource.retrieval_service import RetrievalService
|
from core.rag.datasource.retrieval_service import RetrievalService
|
||||||
from core.rag.entities.context_entities import DocumentContext
|
from core.rag.entities.context_entities import DocumentContext
|
||||||
from core.rag.models.document import Document
|
from core.rag.models.document import Document
|
||||||
|
from core.rag.rerank.rerank_type import RerankMode
|
||||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||||
from core.rag.retrieval.router.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter
|
from core.rag.retrieval.router.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter
|
||||||
from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter
|
from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter
|
||||||
@ -361,10 +362,39 @@ class DatasetRetrieval:
|
|||||||
reranking_enable: bool = True,
|
reranking_enable: bool = True,
|
||||||
message_id: Optional[str] = None,
|
message_id: Optional[str] = None,
|
||||||
):
|
):
|
||||||
|
if not available_datasets:
|
||||||
|
return []
|
||||||
threads = []
|
threads = []
|
||||||
all_documents = []
|
all_documents = []
|
||||||
dataset_ids = [dataset.id for dataset in available_datasets]
|
dataset_ids = [dataset.id for dataset in available_datasets]
|
||||||
index_type = None
|
index_type_check = all(
|
||||||
|
item.indexing_technique == available_datasets[0].indexing_technique for item in available_datasets
|
||||||
|
)
|
||||||
|
if not index_type_check and (not reranking_enable or reranking_mode != RerankMode.RERANKING_MODEL):
|
||||||
|
raise ValueError(
|
||||||
|
"The configured knowledge base list have different indexing technique, please set reranking model."
|
||||||
|
)
|
||||||
|
index_type = available_datasets[0].indexing_technique
|
||||||
|
if index_type == "high_quality":
|
||||||
|
embedding_model_check = all(
|
||||||
|
item.embedding_model == available_datasets[0].embedding_model for item in available_datasets
|
||||||
|
)
|
||||||
|
embedding_model_provider_check = all(
|
||||||
|
item.embedding_model_provider == available_datasets[0].embedding_model_provider
|
||||||
|
for item in available_datasets
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
reranking_enable
|
||||||
|
and reranking_mode == "weighted_score"
|
||||||
|
and (not embedding_model_check or not embedding_model_provider_check)
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"The configured knowledge base list have different embedding model, please set reranking model."
|
||||||
|
)
|
||||||
|
if reranking_enable and reranking_mode == RerankMode.WEIGHTED_SCORE:
|
||||||
|
weights["vector_setting"]["embedding_provider_name"] = available_datasets[0].embedding_model_provider
|
||||||
|
weights["vector_setting"]["embedding_model_name"] = available_datasets[0].embedding_model
|
||||||
|
|
||||||
for dataset in available_datasets:
|
for dataset in available_datasets:
|
||||||
index_type = dataset.indexing_technique
|
index_type = dataset.indexing_technique
|
||||||
retrieval_thread = threading.Thread(
|
retrieval_thread = threading.Thread(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user