From c104febf6372d00b739ab266f2930c14d9cd9a3a Mon Sep 17 00:00:00 2001 From: -LAN- Date: Fri, 25 Apr 2025 19:05:36 +0900 Subject: [PATCH] refactor: Apply DI to WorkflowNodeExecutionRepository. (#18794) Signed-off-by: -LAN- --- .../app/apps/advanced_chat/app_generator.py | 42 +++++++++++++++++++ .../advanced_chat/generate_task_pipeline.py | 3 ++ api/core/app/apps/workflow/app_generator.py | 42 +++++++++++++++++++ .../apps/workflow/generate_task_pipeline.py | 3 ++ .../task_pipeline/workflow_cycle_manage.py | 22 ++-------- 5 files changed, 94 insertions(+), 18 deletions(-) diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index ef582d28e0..6079b51daa 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -7,6 +7,7 @@ from typing import Any, Literal, Optional, Union, overload from flask import Flask, current_app from pydantic import ValidationError +from sqlalchemy.orm import sessionmaker import contexts from configs import dify_config @@ -24,6 +25,8 @@ from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotA from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.ops.ops_trace_manager import TraceQueueManager from core.prompt.utils.get_thread_messages_length import get_thread_messages_length +from core.repository import RepositoryFactory +from core.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository from extensions.ext_database import db from factories import file_factory from models.account import Account @@ -158,11 +161,22 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): contexts.plugin_tool_providers.set({}) contexts.plugin_tool_providers_lock.set(threading.Lock()) + # Create workflow node execution repository + session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) + workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository( + params={ + "tenant_id": application_generate_entity.app_config.tenant_id, + "app_id": application_generate_entity.app_config.app_id, + "session_factory": session_factory, + } + ) + return self._generate( workflow=workflow, user=user, invoke_from=invoke_from, application_generate_entity=application_generate_entity, + workflow_node_execution_repository=workflow_node_execution_repository, conversation=conversation, stream=streaming, ) @@ -215,11 +229,22 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): contexts.plugin_tool_providers.set({}) contexts.plugin_tool_providers_lock.set(threading.Lock()) + # Create workflow node execution repository + session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) + workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository( + params={ + "tenant_id": application_generate_entity.app_config.tenant_id, + "app_id": application_generate_entity.app_config.app_id, + "session_factory": session_factory, + } + ) + return self._generate( workflow=workflow, user=user, invoke_from=InvokeFrom.DEBUGGER, application_generate_entity=application_generate_entity, + workflow_node_execution_repository=workflow_node_execution_repository, conversation=None, stream=streaming, ) @@ -270,11 +295,22 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): contexts.plugin_tool_providers.set({}) contexts.plugin_tool_providers_lock.set(threading.Lock()) + # Create workflow node execution repository + session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) + workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository( + params={ + "tenant_id": application_generate_entity.app_config.tenant_id, + "app_id": application_generate_entity.app_config.app_id, + "session_factory": session_factory, + } + ) + return self._generate( workflow=workflow, user=user, invoke_from=InvokeFrom.DEBUGGER, application_generate_entity=application_generate_entity, + workflow_node_execution_repository=workflow_node_execution_repository, conversation=None, stream=streaming, ) @@ -286,6 +322,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): user: Union[Account, EndUser], invoke_from: InvokeFrom, application_generate_entity: AdvancedChatAppGenerateEntity, + workflow_node_execution_repository: WorkflowNodeExecutionRepository, conversation: Optional[Conversation] = None, stream: bool = True, ) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]: @@ -296,6 +333,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): :param user: account or end user :param invoke_from: invoke from source :param application_generate_entity: application generate entity + :param workflow_node_execution_repository: repository for workflow node execution :param conversation: conversation :param stream: is stream """ @@ -348,6 +386,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): conversation=conversation, message=message, user=user, + workflow_node_execution_repository=workflow_node_execution_repository, stream=stream, ) @@ -419,6 +458,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): conversation: Conversation, message: Message, user: Union[Account, EndUser], + workflow_node_execution_repository: WorkflowNodeExecutionRepository, stream: bool = False, ) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]: """ @@ -430,6 +470,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): :param message: message :param user: account or end user :param stream: is stream + :param workflow_node_execution_repository: optional repository for workflow node execution :return: """ # init generate task pipeline @@ -442,6 +483,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): user=user, stream=stream, dialogue_count=self._dialogue_count, + workflow_node_execution_repository=workflow_node_execution_repository, ) try: diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index baefca0c3f..43ccaea9c0 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -62,6 +62,7 @@ from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage from core.model_runtime.entities.llm_entities import LLMUsage from core.model_runtime.utils.encoders import jsonable_encoder from core.ops.ops_trace_manager import TraceQueueManager +from core.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository from core.workflow.enums import SystemVariableKey from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.nodes import NodeType @@ -93,6 +94,7 @@ class AdvancedChatAppGenerateTaskPipeline: user: Union[Account, EndUser], stream: bool, dialogue_count: int, + workflow_node_execution_repository: WorkflowNodeExecutionRepository, ) -> None: self._base_task_pipeline = BasedGenerateTaskPipeline( application_generate_entity=application_generate_entity, @@ -123,6 +125,7 @@ class AdvancedChatAppGenerateTaskPipeline: SystemVariableKey.WORKFLOW_ID: workflow.id, SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id, }, + workflow_node_execution_repository=workflow_node_execution_repository, ) self._task_state = WorkflowTaskState() diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index 08986b16f0..6be3a7331d 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -7,6 +7,7 @@ from typing import Any, Literal, Optional, Union, overload from flask import Flask, current_app from pydantic import ValidationError +from sqlalchemy.orm import sessionmaker import contexts from configs import dify_config @@ -22,6 +23,8 @@ from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerat from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.ops.ops_trace_manager import TraceQueueManager +from core.repository import RepositoryFactory +from core.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository from extensions.ext_database import db from factories import file_factory from models import Account, App, EndUser, Workflow @@ -133,12 +136,23 @@ class WorkflowAppGenerator(BaseAppGenerator): contexts.plugin_tool_providers.set({}) contexts.plugin_tool_providers_lock.set(threading.Lock()) + # Create workflow node execution repository + session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) + workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository( + params={ + "tenant_id": application_generate_entity.app_config.tenant_id, + "app_id": application_generate_entity.app_config.app_id, + "session_factory": session_factory, + } + ) + return self._generate( app_model=app_model, workflow=workflow, user=user, application_generate_entity=application_generate_entity, invoke_from=invoke_from, + workflow_node_execution_repository=workflow_node_execution_repository, streaming=streaming, workflow_thread_pool_id=workflow_thread_pool_id, ) @@ -151,6 +165,7 @@ class WorkflowAppGenerator(BaseAppGenerator): user: Union[Account, EndUser], application_generate_entity: WorkflowAppGenerateEntity, invoke_from: InvokeFrom, + workflow_node_execution_repository: WorkflowNodeExecutionRepository, streaming: bool = True, workflow_thread_pool_id: Optional[str] = None, ) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]: @@ -162,6 +177,7 @@ class WorkflowAppGenerator(BaseAppGenerator): :param user: account or end user :param application_generate_entity: application generate entity :param invoke_from: invoke from source + :param workflow_node_execution_repository: repository for workflow node execution :param streaming: is stream :param workflow_thread_pool_id: workflow thread pool id """ @@ -193,6 +209,7 @@ class WorkflowAppGenerator(BaseAppGenerator): workflow=workflow, queue_manager=queue_manager, user=user, + workflow_node_execution_repository=workflow_node_execution_repository, stream=streaming, ) @@ -245,12 +262,23 @@ class WorkflowAppGenerator(BaseAppGenerator): contexts.plugin_tool_providers.set({}) contexts.plugin_tool_providers_lock.set(threading.Lock()) + # Create workflow node execution repository + session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) + workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository( + params={ + "tenant_id": application_generate_entity.app_config.tenant_id, + "app_id": application_generate_entity.app_config.app_id, + "session_factory": session_factory, + } + ) + return self._generate( app_model=app_model, workflow=workflow, user=user, invoke_from=InvokeFrom.DEBUGGER, application_generate_entity=application_generate_entity, + workflow_node_execution_repository=workflow_node_execution_repository, streaming=streaming, ) @@ -299,12 +327,23 @@ class WorkflowAppGenerator(BaseAppGenerator): contexts.plugin_tool_providers.set({}) contexts.plugin_tool_providers_lock.set(threading.Lock()) + # Create workflow node execution repository + session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) + workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository( + params={ + "tenant_id": application_generate_entity.app_config.tenant_id, + "app_id": application_generate_entity.app_config.app_id, + "session_factory": session_factory, + } + ) + return self._generate( app_model=app_model, workflow=workflow, user=user, invoke_from=InvokeFrom.DEBUGGER, application_generate_entity=application_generate_entity, + workflow_node_execution_repository=workflow_node_execution_repository, streaming=streaming, ) @@ -361,6 +400,7 @@ class WorkflowAppGenerator(BaseAppGenerator): workflow: Workflow, queue_manager: AppQueueManager, user: Union[Account, EndUser], + workflow_node_execution_repository: WorkflowNodeExecutionRepository, stream: bool = False, ) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]: """ @@ -370,6 +410,7 @@ class WorkflowAppGenerator(BaseAppGenerator): :param queue_manager: queue manager :param user: account or end user :param stream: is stream + :param workflow_node_execution_repository: optional repository for workflow node execution :return: """ # init generate task pipeline @@ -379,6 +420,7 @@ class WorkflowAppGenerator(BaseAppGenerator): queue_manager=queue_manager, user=user, stream=stream, + workflow_node_execution_repository=workflow_node_execution_repository, ) try: diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index 1f998edb6a..68131a7463 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -54,6 +54,7 @@ from core.app.entities.task_entities import ( from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage from core.ops.ops_trace_manager import TraceQueueManager +from core.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository from core.workflow.enums import SystemVariableKey from extensions.ext_database import db from models.account import Account @@ -82,6 +83,7 @@ class WorkflowAppGenerateTaskPipeline: queue_manager: AppQueueManager, user: Union[Account, EndUser], stream: bool, + workflow_node_execution_repository: WorkflowNodeExecutionRepository, ) -> None: self._base_task_pipeline = BasedGenerateTaskPipeline( application_generate_entity=application_generate_entity, @@ -109,6 +111,7 @@ class WorkflowAppGenerateTaskPipeline: SystemVariableKey.WORKFLOW_ID: workflow.id, SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id, }, + workflow_node_execution_repository=workflow_node_execution_repository, ) self._application_generate_entity = application_generate_entity diff --git a/api/core/app/task_pipeline/workflow_cycle_manage.py b/api/core/app/task_pipeline/workflow_cycle_manage.py index 5ce9f737d1..38e7c9eb12 100644 --- a/api/core/app/task_pipeline/workflow_cycle_manage.py +++ b/api/core/app/task_pipeline/workflow_cycle_manage.py @@ -6,7 +6,7 @@ from typing import Any, Optional, Union, cast from uuid import uuid4 from sqlalchemy import func, select -from sqlalchemy.orm import Session, sessionmaker +from sqlalchemy.orm import Session from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity from core.app.entities.queue_entities import ( @@ -49,14 +49,13 @@ from core.file import FILE_MODEL_IDENTITY, File from core.model_runtime.utils.encoders import jsonable_encoder from core.ops.entities.trace_entity import TraceTaskName from core.ops.ops_trace_manager import TraceQueueManager, TraceTask -from core.repository import RepositoryFactory +from core.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository from core.tools.tool_manager import ToolManager from core.workflow.entities.node_entities import NodeRunMetadataKey from core.workflow.enums import SystemVariableKey from core.workflow.nodes import NodeType from core.workflow.nodes.tool.entities import ToolNodeData from core.workflow.workflow_entry import WorkflowEntry -from extensions.ext_database import db from models.account import Account from models.enums import CreatedByRole, WorkflowRunTriggeredFrom from models.model import EndUser @@ -76,26 +75,13 @@ class WorkflowCycleManage: *, application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity], workflow_system_variables: dict[SystemVariableKey, Any], + workflow_node_execution_repository: WorkflowNodeExecutionRepository, ) -> None: self._workflow_run: WorkflowRun | None = None self._workflow_node_executions: dict[str, WorkflowNodeExecution] = {} self._application_generate_entity = application_generate_entity self._workflow_system_variables = workflow_system_variables - - # Initialize the session factory and repository - # We use the global db engine instead of the session passed to methods - # Disable expire_on_commit to avoid the need for merging objects - self._session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) - self._workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository( - params={ - "tenant_id": self._application_generate_entity.app_config.tenant_id, - "app_id": self._application_generate_entity.app_config.app_id, - "session_factory": self._session_factory, - } - ) - - # We'll still keep the cache for backward compatibility and performance - # but use the repository for database operations + self._workflow_node_execution_repository = workflow_node_execution_repository def _handle_workflow_run_start( self,