diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 5a3c6f8432..c11beaeee1 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -257,7 +257,8 @@ class DatasetDocumentListApi(Resource): parser.add_argument("original_document_id", type=str, required=False, location="json") parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json") - + parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json") + parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json") parser.add_argument( "doc_language", type=str, default="English", required=False, nullable=False, location="json" ) diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 3be8a38b03..82433b64ff 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -792,13 +792,19 @@ class DocumentService: dataset.indexing_technique = knowledge_config.indexing_technique if knowledge_config.indexing_technique == "high_quality": model_manager = ModelManager() - embedding_model = model_manager.get_default_model_instance( - tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING - ) - dataset.embedding_model = embedding_model.model - dataset.embedding_model_provider = embedding_model.provider + if knowledge_config.embedding_model and knowledge_config.embedding_model_provider: + dataset_embedding_model = knowledge_config.embedding_model + dataset_embedding_model_provider = knowledge_config.embedding_model_provider + else: + embedding_model = model_manager.get_default_model_instance( + tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING + ) + dataset_embedding_model = embedding_model.model + dataset_embedding_model_provider = embedding_model.provider + dataset.embedding_model = dataset_embedding_model + dataset.embedding_model_provider = dataset_embedding_model_provider dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( - embedding_model.provider, embedding_model.model + dataset_embedding_model_provider, dataset_embedding_model ) dataset.collection_binding_id = dataset_collection_binding.id if not dataset.retrieval_model: @@ -810,7 +816,11 @@ class DocumentService: "score_threshold_enabled": False, } - dataset.retrieval_model = knowledge_config.retrieval_model.model_dump() or default_retrieval_model # type: ignore + dataset.retrieval_model = ( + knowledge_config.retrieval_model.model_dump() + if knowledge_config.retrieval_model + else default_retrieval_model + ) # type: ignore documents = [] if knowledge_config.original_document_id: diff --git a/api/tasks/deal_dataset_vector_index_task.py b/api/tasks/deal_dataset_vector_index_task.py index 0efc924a77..a9b5ab91a8 100644 --- a/api/tasks/deal_dataset_vector_index_task.py +++ b/api/tasks/deal_dataset_vector_index_task.py @@ -28,7 +28,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str): if not dataset: raise Exception("Dataset not found") - index_type = dataset.doc_form + index_type = dataset.doc_form or IndexType.PARAGRAPH_INDEX index_processor = IndexProcessorFactory(index_type).init_index_processor() if action == "remove": index_processor.clean(dataset, None, with_keywords=False) @@ -157,6 +157,9 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str): {"indexing_status": "error", "error": str(e)}, synchronize_session=False ) db.session.commit() + else: + # clean collection + index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False) end_at = time.perf_counter() logging.info(