This commit is contained in:
jyong 2025-05-15 17:19:14 +08:00
parent 3c90a7acee
commit 1c179b17bc
11 changed files with 71 additions and 32 deletions

View File

@ -3,7 +3,6 @@ from flask import Blueprint
from libs.external_api import ExternalApi
from .app.app_import import AppImportApi, AppImportCheckDependenciesApi, AppImportConfirmApi
from .datasets.rag_pipeline import data_source
from .explore.audio import ChatAudioApi, ChatTextApi
from .explore.completion import ChatApi, ChatStopApi, CompletionApi, CompletionStopApi
from .explore.conversation import (
@ -84,6 +83,12 @@ from .datasets import (
metadata,
website,
)
from .datasets.rag_pipeline import (
rag_pipeline,
rag_pipeline_datasets,
rag_pipeline_import,
rag_pipeline_workflow,
)
# Import explore controllers
from .explore import (

View File

@ -12,9 +12,8 @@ from core.rag.docstore.dataset_docstore import DatasetDocumentStore
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.models.document import Document
from core.rag.models.document import Document, GeneralStructureChunk
from core.tools.utils.text_processing_utils import remove_leading_symbols
from core.workflow.nodes.knowledge_index.entities import GeneralStructureChunk
from libs import helper
from models.dataset import Dataset, DatasetProcessRule
from services.entities.knowledge_entities.knowledge_entities import Rule

View File

@ -13,8 +13,7 @@ from core.rag.docstore.dataset_docstore import DatasetDocumentStore
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.models.document import ChildDocument, Document
from core.workflow.nodes.knowledge_index.entities import ParentChildStructureChunk
from core.rag.models.document import ChildDocument, Document, ParentChildStructureChunk
from extensions.ext_database import db
from libs import helper
from models.dataset import ChildChunk, Dataset, DocumentSegment

View File

@ -35,6 +35,31 @@ class Document(BaseModel):
children: Optional[list[ChildDocument]] = None
class GeneralStructureChunk(BaseModel):
"""
General Structure Chunk.
"""
general_chunk: list[str]
class ParentChildChunk(BaseModel):
"""
Parent Child Chunk.
"""
parent_content: str
child_contents: list[str]
class ParentChildStructureChunk(BaseModel):
"""
Parent Child Structure Chunk.
"""
parent_child_chunks: list[ParentChildChunk]
class BaseDocumentTransformer(ABC):
"""Abstract base class for document transformation systems.

View File

@ -1,3 +1,3 @@
from .knowledge_index_node import KnowledgeRetrievalNode
from .knowledge_index_node import KnowledgeIndexNode
__all__ = ["KnowledgeRetrievalNode"]
__all__ = ["KnowledgeIndexNode"]

View File

@ -1,7 +1,7 @@
import datetime
import logging
import time
from typing import Any, cast
from typing import Any, cast, Mapping
from flask_login import current_user
@ -106,7 +106,7 @@ class KnowledgeIndexNode(LLMNode):
error_type=type(e).__name__,
)
def _invoke_knowledge_index(self, node_data: KnowledgeIndexNodeData, chunks: list[Any]) -> Any:
def _invoke_knowledge_index(self, node_data: KnowledgeIndexNodeData, chunks: Mapping[str, Any]) -> Any:
dataset = Dataset.query.filter_by(id=node_data.dataset_id).first()
if not dataset:
raise KnowledgeIndexNodeError(f"Dataset {node_data.dataset_id} not found.")

View File

@ -57,8 +57,7 @@ class MultipleRetrievalConfig(BaseModel):
class ModelConfig(BaseModel):
"""
Model Config.
provider: str
name: str

View File

@ -62,10 +62,6 @@ workflow_fields = {
"tool_published": fields.Boolean,
"environment_variables": fields.List(EnvironmentVariableField()),
"conversation_variables": fields.List(fields.Nested(conversation_variable_fields)),
"pipeline_variables": fields.Dict(
keys=fields.String,
values=fields.List(fields.Nested(pipeline_variable_fields)),
),
}
workflow_partial_fields = {

View File

@ -2,12 +2,12 @@ from typing import Optional
from extensions.ext_database import db
from models.dataset import Pipeline, PipelineBuiltInTemplate
from services.app_dsl_service import AppDslService
from services.recommend_app.recommend_app_base import RecommendAppRetrievalBase
from services.recommend_app.recommend_app_type import RecommendAppType
from services.rag_pipeline.pipeline_template.pipeline_template_base import PipelineTemplateRetrievalBase
from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType
#from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
class DatabasePipelineTemplateRetrieval(RecommendAppRetrievalBase):
class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
"""
Retrieval pipeline template from database
"""
@ -21,7 +21,7 @@ class DatabasePipelineTemplateRetrieval(RecommendAppRetrievalBase):
return result
def get_type(self) -> str:
return RecommendAppType.DATABASE
return PipelineTemplateType.DATABASE
@classmethod
def fetch_pipeline_templates_from_db(cls, language: str) -> dict:
@ -61,5 +61,5 @@ class DatabasePipelineTemplateRetrieval(RecommendAppRetrievalBase):
"name": pipeline.name,
"icon": pipeline.icon,
"mode": pipeline.mode,
"export_data": AppDslService.export_dsl(app_model=pipeline),
"export_data": RagPipelineDslService.export_rag_pipeline_dsl(pipeline=pipeline),
}

View File

@ -13,8 +13,7 @@ from sqlalchemy.orm import Session
import contexts
from configs import dify_config
from core.model_runtime.utils.encoders import jsonable_encoder
from core.repository.repository_factory import RepositoryFactory
from core.repository.workflow_node_execution_repository import OrderConfig
from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
from core.variables.variables import Variable
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.errors import WorkflowNodeRunFailedError
@ -24,6 +23,7 @@ from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.event.event import RunCompletedEvent
from core.workflow.nodes.event.types import NodeEvent
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
from core.workflow.repository.workflow_node_execution_repository import OrderConfig
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from libs.infinite_scroll_pagination import InfiniteScrollPagination
@ -650,13 +650,9 @@ class RagPipelineService:
if not workflow_run:
return []
# Use the repository to get the node executions
repository = RepositoryFactory.create_workflow_node_execution_repository(
params={
"tenant_id": pipeline.tenant_id,
"app_id": pipeline.id,
"session_factory": db.session.get_bind(),
}
# Use the repository to get the node execution
repository = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=db.engine, tenant_id=pipeline.tenant_id, app_id=pipeline.id
)
# Use the repository to get the node executions with ordering

View File

@ -25,12 +25,12 @@ from core.workflow.nodes.llm.entities import LLMNodeData
from core.workflow.nodes.parameter_extractor.entities import ParameterExtractorNodeData
from core.workflow.nodes.question_classifier.entities import QuestionClassifierNodeData
from core.workflow.nodes.tool.entities import ToolNodeData
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from factories import variable_factory
from models import Account
from models.dataset import Dataset, Pipeline
from models.dataset import Dataset, DatasetCollectionBinding, Pipeline
from models.workflow import Workflow
from services.dataset_service import DatasetCollectionBindingService
from services.entities.knowledge_entities.rag_pipeline_entities import KnowledgeConfiguration
from services.plugin.dependencies_analysis import DependenciesAnalysisService
from services.rag_pipeline.rag_pipeline import RagPipelineService
@ -306,6 +306,26 @@ class RagPipelineDslService:
knowledge_configuration.index_method.embedding_setting.embedding_provider_name, # type: ignore
knowledge_configuration.index_method.embedding_setting.embedding_model_name, # type: ignore
)
dataset_collection_binding = (
db.session.query(DatasetCollectionBinding)
.filter(
DatasetCollectionBinding.provider_name == knowledge_configuration.index_method.embedding_setting.embedding_provider_name,
DatasetCollectionBinding.model_name == knowledge_configuration.index_method.embedding_setting.embedding_model_name,
DatasetCollectionBinding.type == "dataset",
)
.order_by(DatasetCollectionBinding.created_at)
.first()
)
if not dataset_collection_binding:
dataset_collection_binding = DatasetCollectionBinding(
provider_name=knowledge_configuration.index_method.embedding_setting.embedding_provider_name,
model_name=knowledge_configuration.index_method.embedding_setting.embedding_model_name,
collection_name=Dataset.gen_collection_name_by_id(str(uuid.uuid4())),
type="dataset",
)
db.session.add(dataset_collection_binding)
db.session.commit()
dataset_collection_binding_id = dataset_collection_binding.id
dataset.collection_binding_id = dataset_collection_binding_id
dataset.embedding_model = (