diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 6bc29a8643..076f3cd44d 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -302,6 +302,8 @@ class DatasetInitApi(Resource): "doc_language", type=str, default="English", required=False, nullable=False, location="json" ) parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json") + parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json") + parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json") args = parser.parse_args() # The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator @@ -309,6 +311,8 @@ class DatasetInitApi(Resource): raise Forbidden() if args["indexing_technique"] == "high_quality": + if args["embedding_model"] is None or args["embedding_model_provider"] is None: + raise ValueError("embedding model and embedding model provider are required for high quality indexing.") try: model_manager = ModelManager() model_manager.get_default_model_instance( diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 4a11de281c..cce0874cf4 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -1057,12 +1057,8 @@ class DocumentService: dataset_collection_binding_id = None retrieval_model = None if document_data["indexing_technique"] == "high_quality": - model_manager = ModelManager() - embedding_model = model_manager.get_default_model_instance( - tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING - ) dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( - embedding_model.provider, embedding_model.model + document_data["embedding_model_provider"], document_data["embedding_model"] ) dataset_collection_binding_id = dataset_collection_binding.id if document_data.get("retrieval_model"):