From 1c179b17bccb76a9c92a32d0ca9028c8de139322 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Thu, 15 May 2025 17:19:14 +0800 Subject: [PATCH] r2 --- api/controllers/console/__init__.py | 7 +++++- .../processor/paragraph_index_processor.py | 3 +-- .../processor/parent_child_index_processor.py | 3 +-- api/core/rag/models/document.py | 25 +++++++++++++++++++ .../nodes/knowledge_index/__init__.py | 4 +-- .../knowledge_index/knowledge_index_node.py | 4 +-- .../nodes/knowledge_retrieval/entities.py | 3 +-- api/fields/workflow_fields.py | 4 --- .../database/database_retrieval.py | 12 ++++----- api/services/rag_pipeline/rag_pipeline.py | 14 ++++------- .../rag_pipeline/rag_pipeline_dsl_service.py | 24 ++++++++++++++++-- 11 files changed, 71 insertions(+), 32 deletions(-) diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index 998ec2e3bf..c55d3fbb66 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -3,7 +3,6 @@ from flask import Blueprint from libs.external_api import ExternalApi from .app.app_import import AppImportApi, AppImportCheckDependenciesApi, AppImportConfirmApi -from .datasets.rag_pipeline import data_source from .explore.audio import ChatAudioApi, ChatTextApi from .explore.completion import ChatApi, ChatStopApi, CompletionApi, CompletionStopApi from .explore.conversation import ( @@ -84,6 +83,12 @@ from .datasets import ( metadata, website, ) +from .datasets.rag_pipeline import ( + rag_pipeline, + rag_pipeline_datasets, + rag_pipeline_import, + rag_pipeline_workflow, +) # Import explore controllers from .explore import ( diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py index 43d201af73..155aae61d4 100644 --- a/api/core/rag/index_processor/processor/paragraph_index_processor.py +++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py @@ -12,9 +12,8 @@ from core.rag.docstore.dataset_docstore import DatasetDocumentStore from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.extract_processor import ExtractProcessor from core.rag.index_processor.index_processor_base import BaseIndexProcessor -from core.rag.models.document import Document +from core.rag.models.document import Document, GeneralStructureChunk from core.tools.utils.text_processing_utils import remove_leading_symbols -from core.workflow.nodes.knowledge_index.entities import GeneralStructureChunk from libs import helper from models.dataset import Dataset, DatasetProcessRule from services.entities.knowledge_entities.knowledge_entities import Rule diff --git a/api/core/rag/index_processor/processor/parent_child_index_processor.py b/api/core/rag/index_processor/processor/parent_child_index_processor.py index ce64bb2a54..5279864441 100644 --- a/api/core/rag/index_processor/processor/parent_child_index_processor.py +++ b/api/core/rag/index_processor/processor/parent_child_index_processor.py @@ -13,8 +13,7 @@ from core.rag.docstore.dataset_docstore import DatasetDocumentStore from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.extract_processor import ExtractProcessor from core.rag.index_processor.index_processor_base import BaseIndexProcessor -from core.rag.models.document import ChildDocument, Document -from core.workflow.nodes.knowledge_index.entities import ParentChildStructureChunk +from core.rag.models.document import ChildDocument, Document, ParentChildStructureChunk from extensions.ext_database import db from libs import helper from models.dataset import ChildChunk, Dataset, DocumentSegment diff --git a/api/core/rag/models/document.py b/api/core/rag/models/document.py index 421cdc05df..52795bbadf 100644 --- a/api/core/rag/models/document.py +++ b/api/core/rag/models/document.py @@ -35,6 +35,31 @@ class Document(BaseModel): children: Optional[list[ChildDocument]] = None +class GeneralStructureChunk(BaseModel): + """ + General Structure Chunk. + """ + + general_chunk: list[str] + + +class ParentChildChunk(BaseModel): + """ + Parent Child Chunk. + """ + + parent_content: str + child_contents: list[str] + + +class ParentChildStructureChunk(BaseModel): + """ + Parent Child Structure Chunk. + """ + + parent_child_chunks: list[ParentChildChunk] + + class BaseDocumentTransformer(ABC): """Abstract base class for document transformation systems. diff --git a/api/core/workflow/nodes/knowledge_index/__init__.py b/api/core/workflow/nodes/knowledge_index/__init__.py index 01d59b87b2..23897a1e42 100644 --- a/api/core/workflow/nodes/knowledge_index/__init__.py +++ b/api/core/workflow/nodes/knowledge_index/__init__.py @@ -1,3 +1,3 @@ -from .knowledge_index_node import KnowledgeRetrievalNode +from .knowledge_index_node import KnowledgeIndexNode -__all__ = ["KnowledgeRetrievalNode"] +__all__ = ["KnowledgeIndexNode"] diff --git a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py index b8901e5cce..f039b233a5 100644 --- a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py +++ b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py @@ -1,7 +1,7 @@ import datetime import logging import time -from typing import Any, cast +from typing import Any, cast, Mapping from flask_login import current_user @@ -106,7 +106,7 @@ class KnowledgeIndexNode(LLMNode): error_type=type(e).__name__, ) - def _invoke_knowledge_index(self, node_data: KnowledgeIndexNodeData, chunks: list[Any]) -> Any: + def _invoke_knowledge_index(self, node_data: KnowledgeIndexNodeData, chunks: Mapping[str, Any]) -> Any: dataset = Dataset.query.filter_by(id=node_data.dataset_id).first() if not dataset: raise KnowledgeIndexNodeError(f"Dataset {node_data.dataset_id} not found.") diff --git a/api/core/workflow/nodes/knowledge_retrieval/entities.py b/api/core/workflow/nodes/knowledge_retrieval/entities.py index 17b3308a06..8c702b74ee 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/entities.py +++ b/api/core/workflow/nodes/knowledge_retrieval/entities.py @@ -57,8 +57,7 @@ class MultipleRetrievalConfig(BaseModel): class ModelConfig(BaseModel): - """ - Model Config. + provider: str name: str diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py index 45112d42f9..a37ae7856d 100644 --- a/api/fields/workflow_fields.py +++ b/api/fields/workflow_fields.py @@ -62,10 +62,6 @@ workflow_fields = { "tool_published": fields.Boolean, "environment_variables": fields.List(EnvironmentVariableField()), "conversation_variables": fields.List(fields.Nested(conversation_variable_fields)), - "pipeline_variables": fields.Dict( - keys=fields.String, - values=fields.List(fields.Nested(pipeline_variable_fields)), - ), } workflow_partial_fields = { diff --git a/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py b/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py index 10dd044493..f6ab5c9064 100644 --- a/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py @@ -2,12 +2,12 @@ from typing import Optional from extensions.ext_database import db from models.dataset import Pipeline, PipelineBuiltInTemplate -from services.app_dsl_service import AppDslService -from services.recommend_app.recommend_app_base import RecommendAppRetrievalBase -from services.recommend_app.recommend_app_type import RecommendAppType +from services.rag_pipeline.pipeline_template.pipeline_template_base import PipelineTemplateRetrievalBase +from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType +#from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService -class DatabasePipelineTemplateRetrieval(RecommendAppRetrievalBase): +class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): """ Retrieval pipeline template from database """ @@ -21,7 +21,7 @@ class DatabasePipelineTemplateRetrieval(RecommendAppRetrievalBase): return result def get_type(self) -> str: - return RecommendAppType.DATABASE + return PipelineTemplateType.DATABASE @classmethod def fetch_pipeline_templates_from_db(cls, language: str) -> dict: @@ -61,5 +61,5 @@ class DatabasePipelineTemplateRetrieval(RecommendAppRetrievalBase): "name": pipeline.name, "icon": pipeline.icon, "mode": pipeline.mode, - "export_data": AppDslService.export_dsl(app_model=pipeline), + "export_data": RagPipelineDslService.export_rag_pipeline_dsl(pipeline=pipeline), } diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 1e6447d80f..2275c32f63 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -13,8 +13,7 @@ from sqlalchemy.orm import Session import contexts from configs import dify_config from core.model_runtime.utils.encoders import jsonable_encoder -from core.repository.repository_factory import RepositoryFactory -from core.repository.workflow_node_execution_repository import OrderConfig +from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository from core.variables.variables import Variable from core.workflow.entities.node_entities import NodeRunResult from core.workflow.errors import WorkflowNodeRunFailedError @@ -24,6 +23,7 @@ from core.workflow.nodes.enums import ErrorStrategy, NodeType from core.workflow.nodes.event.event import RunCompletedEvent from core.workflow.nodes.event.types import NodeEvent from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING +from core.workflow.repository.workflow_node_execution_repository import OrderConfig from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination @@ -650,13 +650,9 @@ class RagPipelineService: if not workflow_run: return [] - # Use the repository to get the node executions - repository = RepositoryFactory.create_workflow_node_execution_repository( - params={ - "tenant_id": pipeline.tenant_id, - "app_id": pipeline.id, - "session_factory": db.session.get_bind(), - } + # Use the repository to get the node execution + repository = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=db.engine, tenant_id=pipeline.tenant_id, app_id=pipeline.id ) # Use the repository to get the node executions with ordering diff --git a/api/services/rag_pipeline/rag_pipeline_dsl_service.py b/api/services/rag_pipeline/rag_pipeline_dsl_service.py index 80e7c6af0b..e50caa9756 100644 --- a/api/services/rag_pipeline/rag_pipeline_dsl_service.py +++ b/api/services/rag_pipeline/rag_pipeline_dsl_service.py @@ -25,12 +25,12 @@ from core.workflow.nodes.llm.entities import LLMNodeData from core.workflow.nodes.parameter_extractor.entities import ParameterExtractorNodeData from core.workflow.nodes.question_classifier.entities import QuestionClassifierNodeData from core.workflow.nodes.tool.entities import ToolNodeData +from extensions.ext_database import db from extensions.ext_redis import redis_client from factories import variable_factory from models import Account -from models.dataset import Dataset, Pipeline +from models.dataset import Dataset, DatasetCollectionBinding, Pipeline from models.workflow import Workflow -from services.dataset_service import DatasetCollectionBindingService from services.entities.knowledge_entities.rag_pipeline_entities import KnowledgeConfiguration from services.plugin.dependencies_analysis import DependenciesAnalysisService from services.rag_pipeline.rag_pipeline import RagPipelineService @@ -306,6 +306,26 @@ class RagPipelineDslService: knowledge_configuration.index_method.embedding_setting.embedding_provider_name, # type: ignore knowledge_configuration.index_method.embedding_setting.embedding_model_name, # type: ignore ) + dataset_collection_binding = ( + db.session.query(DatasetCollectionBinding) + .filter( + DatasetCollectionBinding.provider_name == knowledge_configuration.index_method.embedding_setting.embedding_provider_name, + DatasetCollectionBinding.model_name == knowledge_configuration.index_method.embedding_setting.embedding_model_name, + DatasetCollectionBinding.type == "dataset", + ) + .order_by(DatasetCollectionBinding.created_at) + .first() + ) + + if not dataset_collection_binding: + dataset_collection_binding = DatasetCollectionBinding( + provider_name=knowledge_configuration.index_method.embedding_setting.embedding_provider_name, + model_name=knowledge_configuration.index_method.embedding_setting.embedding_model_name, + collection_name=Dataset.gen_collection_name_by_id(str(uuid.uuid4())), + type="dataset", + ) + db.session.add(dataset_collection_binding) + db.session.commit() dataset_collection_binding_id = dataset_collection_binding.id dataset.collection_binding_id = dataset_collection_binding_id dataset.embedding_model = (