diff --git a/api/core/rag/rerank/rerank_type.py b/api/core/rag/rerank/rerank_type.py index d4894e3cc6..d71eb2daa8 100644 --- a/api/core/rag/rerank/rerank_type.py +++ b/api/core/rag/rerank/rerank_type.py @@ -1,6 +1,6 @@ from enum import Enum -class RerankMode(Enum): +class RerankMode(str, Enum): RERANKING_MODEL = "reranking_model" WEIGHTED_SCORE = "weighted_score" diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 3455cdc3c4..7a5bf39fa6 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -22,6 +22,7 @@ from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaK from core.rag.datasource.retrieval_service import RetrievalService from core.rag.entities.context_entities import DocumentContext from core.rag.models.document import Document +from core.rag.rerank.rerank_type import RerankMode from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.rag.retrieval.router.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter @@ -361,10 +362,39 @@ class DatasetRetrieval: reranking_enable: bool = True, message_id: Optional[str] = None, ): + if not available_datasets: + return [] threads = [] all_documents = [] dataset_ids = [dataset.id for dataset in available_datasets] - index_type = None + index_type_check = all( + item.indexing_technique == available_datasets[0].indexing_technique for item in available_datasets + ) + if not index_type_check and (not reranking_enable or reranking_mode != RerankMode.RERANKING_MODEL): + raise ValueError( + "The configured knowledge base list have different indexing technique, please set reranking model." + ) + index_type = available_datasets[0].indexing_technique + if index_type == "high_quality": + embedding_model_check = all( + item.embedding_model == available_datasets[0].embedding_model for item in available_datasets + ) + embedding_model_provider_check = all( + item.embedding_model_provider == available_datasets[0].embedding_model_provider + for item in available_datasets + ) + if ( + reranking_enable + and reranking_mode == "weighted_score" + and (not embedding_model_check or not embedding_model_provider_check) + ): + raise ValueError( + "The configured knowledge base list have different embedding model, please set reranking model." + ) + if reranking_enable and reranking_mode == RerankMode.WEIGHTED_SCORE: + weights["vector_setting"]["embedding_provider_name"] = available_datasets[0].embedding_model_provider + weights["vector_setting"]["embedding_model_name"] = available_datasets[0].embedding_model + for dataset in available_datasets: index_type = dataset.indexing_technique retrieval_thread = threading.Thread(