diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index 99d3b73d33..b348f7a796 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -469,6 +469,7 @@ class DefaultRagPipelineBlockConfigApi(Resource): rag_pipeline_service = RagPipelineService() return rag_pipeline_service.get_default_block_config(node_type=block_type, filters=filters) + class RagPipelineConfigApi(Resource): """Resource for rag pipeline configuration.""" diff --git a/api/core/datasource/entities/datasource_entities.py b/api/core/datasource/entities/datasource_entities.py index 6fc23e88cc..aa31a7f86a 100644 --- a/api/core/datasource/entities/datasource_entities.py +++ b/api/core/datasource/entities/datasource_entities.py @@ -4,7 +4,7 @@ from collections.abc import Mapping from enum import Enum from typing import Any, Optional, Union -from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_serializer, field_validator, model_validator +from pydantic import BaseModel, Field, ValidationInfo, field_serializer, field_validator, model_validator from core.datasource.entities.constants import DATASOURCE_SELECTOR_MODEL_IDENTITY from core.entities.provider_entities import ProviderConfig diff --git a/api/core/plugin/entities/plugin.py b/api/core/plugin/entities/plugin.py index 85d4d130ba..260d4f12db 100644 --- a/api/core/plugin/entities/plugin.py +++ b/api/core/plugin/entities/plugin.py @@ -192,10 +192,12 @@ class ToolProviderID(GenericProviderID): if self.provider_name in ["jina", "siliconflow", "stepfun", "gitee_ai"]: self.plugin_name = f"{self.provider_name}_tool" + class DatasourceProviderID(GenericProviderID): def __init__(self, value: str, is_hardcoded: bool = False) -> None: super().__init__(value, is_hardcoded) + class PluginDependency(BaseModel): class Type(enum.StrEnum): Github = PluginInstallationSource.Github.value diff --git a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py index 5f9ac78097..b8901e5cce 100644 --- a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py +++ b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py @@ -1,7 +1,13 @@ +import datetime import logging import time from typing import Any, cast +from flask_login import current_user + +from core.model_manager import ModelManager +from core.model_runtime.entities.model_entities import ModelType +from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.variables.segments import ObjectSegment from core.workflow.entities.node_entities import NodeRunResult @@ -11,7 +17,7 @@ from extensions.ext_database import db from extensions.ext_redis import redis_client from models.dataset import Dataset, Document, RateLimitLog from models.workflow import WorkflowNodeExecutionStatus -from services.dataset_service import DocumentService +from services.dataset_service import DatasetCollectionBindingService from services.feature_service import FeatureService from .entities import KnowledgeIndexNodeData @@ -109,14 +115,52 @@ class KnowledgeIndexNode(LLMNode): if not document: raise KnowledgeIndexNodeError(f"Document {node_data.document_id} not found.") - DocumentService.invoke_knowledge_index( - dataset=dataset, - document=document, - chunks=chunks, - chunk_structure=node_data.chunk_structure, - index_method=node_data.index_method, - retrieval_setting=node_data.retrieval_setting, - ) + retrieval_setting = node_data.retrieval_setting + index_method = node_data.index_method + if not dataset.indexing_technique: + if node_data.index_method.indexing_technique not in Dataset.INDEXING_TECHNIQUE_LIST: + raise ValueError("Indexing technique is invalid") + + dataset.indexing_technique = index_method.indexing_technique + if index_method.indexing_technique == "high_quality": + model_manager = ModelManager() + if ( + index_method.embedding_setting.embedding_model + and index_method.embedding_setting.embedding_model_provider + ): + dataset_embedding_model = index_method.embedding_setting.embedding_model + dataset_embedding_model_provider = index_method.embedding_setting.embedding_model_provider + else: + embedding_model = model_manager.get_default_model_instance( + tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING + ) + dataset_embedding_model = embedding_model.model + dataset_embedding_model_provider = embedding_model.provider + dataset.embedding_model = dataset_embedding_model + dataset.embedding_model_provider = dataset_embedding_model_provider + dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( + dataset_embedding_model_provider, dataset_embedding_model + ) + dataset.collection_binding_id = dataset_collection_binding.id + if not dataset.retrieval_model: + default_retrieval_model = { + "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, + "reranking_enable": False, + "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, + "top_k": 2, + "score_threshold_enabled": False, + } + + dataset.retrieval_model = ( + retrieval_setting.model_dump() if retrieval_setting else default_retrieval_model + ) # type: ignore + index_processor = IndexProcessorFactory(node_data.chunk_structure).init_index_processor() + index_processor.index(dataset, document, chunks) + + # update document status + document.indexing_status = "completed" + document.completed_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + db.session.commit() return { "dataset_id": dataset.id, diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 02954cdb44..df31e0f7ca 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -6,7 +6,7 @@ import random import time import uuid from collections import Counter -from typing import Any, Literal, Optional +from typing import Any, Optional from flask_login import current_user from sqlalchemy import func, select @@ -20,9 +20,7 @@ from core.model_runtime.entities.model_entities import ModelType from core.plugin.entities.plugin import ModelProviderID from core.rag.index_processor.constant.built_in_field import BuiltInField from core.rag.index_processor.constant.index_type import IndexType -from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.retrieval.retrieval_methods import RetrievalMethod -from core.workflow.nodes.knowledge_index.entities import IndexMethod, RetrievalSetting from events.dataset_event import dataset_was_deleted from events.document_event import document_was_deleted from extensions.ext_database import db @@ -1516,60 +1514,6 @@ class DocumentService: return documents, batch - @staticmethod - def invoke_knowledge_index( - dataset: Dataset, - document: Document, - chunks: list[Any], - index_method: IndexMethod, - retrieval_setting: RetrievalSetting, - chunk_structure: Literal["text_model", "hierarchical_model"], - ): - if not dataset.indexing_technique: - if index_method.indexing_technique not in Dataset.INDEXING_TECHNIQUE_LIST: - raise ValueError("Indexing technique is invalid") - - dataset.indexing_technique = index_method.indexing_technique - if index_method.indexing_technique == "high_quality": - model_manager = ModelManager() - if ( - index_method.embedding_setting.embedding_model - and index_method.embedding_setting.embedding_model_provider - ): - dataset_embedding_model = index_method.embedding_setting.embedding_model - dataset_embedding_model_provider = index_method.embedding_setting.embedding_model_provider - else: - embedding_model = model_manager.get_default_model_instance( - tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING - ) - dataset_embedding_model = embedding_model.model - dataset_embedding_model_provider = embedding_model.provider - dataset.embedding_model = dataset_embedding_model - dataset.embedding_model_provider = dataset_embedding_model_provider - dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( - dataset_embedding_model_provider, dataset_embedding_model - ) - dataset.collection_binding_id = dataset_collection_binding.id - if not dataset.retrieval_model: - default_retrieval_model = { - "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, - "reranking_enable": False, - "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, - "top_k": 2, - "score_threshold_enabled": False, - } - - dataset.retrieval_model = ( - retrieval_setting.model_dump() if retrieval_setting else default_retrieval_model - ) # type: ignore - index_processor = IndexProcessorFactory(chunk_structure).init_index_processor() - index_processor.index(dataset, document, chunks) - - # update document status - document.indexing_status = "completed" - document.completed_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) - db.session.commit() - @staticmethod def check_documents_upload_quota(count: int, features: FeatureModel): can_upload_size = features.documents_upload_quota.limit - features.documents_upload_quota.size