refactor: Apply DI to WorkflowNodeExecutionRepository. (#18794)

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN- 2025-04-25 19:05:36 +09:00 committed by GitHub
parent 0e68e8b40a
commit c104febf63
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 94 additions and 18 deletions

View File

@ -7,6 +7,7 @@ from typing import Any, Literal, Optional, Union, overload
from flask import Flask, current_app from flask import Flask, current_app
from pydantic import ValidationError from pydantic import ValidationError
from sqlalchemy.orm import sessionmaker
import contexts import contexts
from configs import dify_config from configs import dify_config
@ -24,6 +25,8 @@ from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotA
from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.ops.ops_trace_manager import TraceQueueManager from core.ops.ops_trace_manager import TraceQueueManager
from core.prompt.utils.get_thread_messages_length import get_thread_messages_length from core.prompt.utils.get_thread_messages_length import get_thread_messages_length
from core.repository import RepositoryFactory
from core.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from extensions.ext_database import db from extensions.ext_database import db
from factories import file_factory from factories import file_factory
from models.account import Account from models.account import Account
@ -158,11 +161,22 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
contexts.plugin_tool_providers.set({}) contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock()) contexts.plugin_tool_providers_lock.set(threading.Lock())
# Create workflow node execution repository
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
params={
"tenant_id": application_generate_entity.app_config.tenant_id,
"app_id": application_generate_entity.app_config.app_id,
"session_factory": session_factory,
}
)
return self._generate( return self._generate(
workflow=workflow, workflow=workflow,
user=user, user=user,
invoke_from=invoke_from, invoke_from=invoke_from,
application_generate_entity=application_generate_entity, application_generate_entity=application_generate_entity,
workflow_node_execution_repository=workflow_node_execution_repository,
conversation=conversation, conversation=conversation,
stream=streaming, stream=streaming,
) )
@ -215,11 +229,22 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
contexts.plugin_tool_providers.set({}) contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock()) contexts.plugin_tool_providers_lock.set(threading.Lock())
# Create workflow node execution repository
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
params={
"tenant_id": application_generate_entity.app_config.tenant_id,
"app_id": application_generate_entity.app_config.app_id,
"session_factory": session_factory,
}
)
return self._generate( return self._generate(
workflow=workflow, workflow=workflow,
user=user, user=user,
invoke_from=InvokeFrom.DEBUGGER, invoke_from=InvokeFrom.DEBUGGER,
application_generate_entity=application_generate_entity, application_generate_entity=application_generate_entity,
workflow_node_execution_repository=workflow_node_execution_repository,
conversation=None, conversation=None,
stream=streaming, stream=streaming,
) )
@ -270,11 +295,22 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
contexts.plugin_tool_providers.set({}) contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock()) contexts.plugin_tool_providers_lock.set(threading.Lock())
# Create workflow node execution repository
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
params={
"tenant_id": application_generate_entity.app_config.tenant_id,
"app_id": application_generate_entity.app_config.app_id,
"session_factory": session_factory,
}
)
return self._generate( return self._generate(
workflow=workflow, workflow=workflow,
user=user, user=user,
invoke_from=InvokeFrom.DEBUGGER, invoke_from=InvokeFrom.DEBUGGER,
application_generate_entity=application_generate_entity, application_generate_entity=application_generate_entity,
workflow_node_execution_repository=workflow_node_execution_repository,
conversation=None, conversation=None,
stream=streaming, stream=streaming,
) )
@ -286,6 +322,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
user: Union[Account, EndUser], user: Union[Account, EndUser],
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
application_generate_entity: AdvancedChatAppGenerateEntity, application_generate_entity: AdvancedChatAppGenerateEntity,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
conversation: Optional[Conversation] = None, conversation: Optional[Conversation] = None,
stream: bool = True, stream: bool = True,
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]: ) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
@ -296,6 +333,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
:param user: account or end user :param user: account or end user
:param invoke_from: invoke from source :param invoke_from: invoke from source
:param application_generate_entity: application generate entity :param application_generate_entity: application generate entity
:param workflow_node_execution_repository: repository for workflow node execution
:param conversation: conversation :param conversation: conversation
:param stream: is stream :param stream: is stream
""" """
@ -348,6 +386,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
conversation=conversation, conversation=conversation,
message=message, message=message,
user=user, user=user,
workflow_node_execution_repository=workflow_node_execution_repository,
stream=stream, stream=stream,
) )
@ -419,6 +458,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
conversation: Conversation, conversation: Conversation,
message: Message, message: Message,
user: Union[Account, EndUser], user: Union[Account, EndUser],
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
stream: bool = False, stream: bool = False,
) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]: ) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
""" """
@ -430,6 +470,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
:param message: message :param message: message
:param user: account or end user :param user: account or end user
:param stream: is stream :param stream: is stream
:param workflow_node_execution_repository: optional repository for workflow node execution
:return: :return:
""" """
# init generate task pipeline # init generate task pipeline
@ -442,6 +483,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
user=user, user=user,
stream=stream, stream=stream,
dialogue_count=self._dialogue_count, dialogue_count=self._dialogue_count,
workflow_node_execution_repository=workflow_node_execution_repository,
) )
try: try:

View File

@ -62,6 +62,7 @@ from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage
from core.model_runtime.entities.llm_entities import LLMUsage from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.ops.ops_trace_manager import TraceQueueManager from core.ops.ops_trace_manager import TraceQueueManager
from core.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.enums import SystemVariableKey from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes import NodeType from core.workflow.nodes import NodeType
@ -93,6 +94,7 @@ class AdvancedChatAppGenerateTaskPipeline:
user: Union[Account, EndUser], user: Union[Account, EndUser],
stream: bool, stream: bool,
dialogue_count: int, dialogue_count: int,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
) -> None: ) -> None:
self._base_task_pipeline = BasedGenerateTaskPipeline( self._base_task_pipeline = BasedGenerateTaskPipeline(
application_generate_entity=application_generate_entity, application_generate_entity=application_generate_entity,
@ -123,6 +125,7 @@ class AdvancedChatAppGenerateTaskPipeline:
SystemVariableKey.WORKFLOW_ID: workflow.id, SystemVariableKey.WORKFLOW_ID: workflow.id,
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id, SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
}, },
workflow_node_execution_repository=workflow_node_execution_repository,
) )
self._task_state = WorkflowTaskState() self._task_state = WorkflowTaskState()

View File

@ -7,6 +7,7 @@ from typing import Any, Literal, Optional, Union, overload
from flask import Flask, current_app from flask import Flask, current_app
from pydantic import ValidationError from pydantic import ValidationError
from sqlalchemy.orm import sessionmaker
import contexts import contexts
from configs import dify_config from configs import dify_config
@ -22,6 +23,8 @@ from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerat
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.ops.ops_trace_manager import TraceQueueManager from core.ops.ops_trace_manager import TraceQueueManager
from core.repository import RepositoryFactory
from core.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from extensions.ext_database import db from extensions.ext_database import db
from factories import file_factory from factories import file_factory
from models import Account, App, EndUser, Workflow from models import Account, App, EndUser, Workflow
@ -133,12 +136,23 @@ class WorkflowAppGenerator(BaseAppGenerator):
contexts.plugin_tool_providers.set({}) contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock()) contexts.plugin_tool_providers_lock.set(threading.Lock())
# Create workflow node execution repository
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
params={
"tenant_id": application_generate_entity.app_config.tenant_id,
"app_id": application_generate_entity.app_config.app_id,
"session_factory": session_factory,
}
)
return self._generate( return self._generate(
app_model=app_model, app_model=app_model,
workflow=workflow, workflow=workflow,
user=user, user=user,
application_generate_entity=application_generate_entity, application_generate_entity=application_generate_entity,
invoke_from=invoke_from, invoke_from=invoke_from,
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming, streaming=streaming,
workflow_thread_pool_id=workflow_thread_pool_id, workflow_thread_pool_id=workflow_thread_pool_id,
) )
@ -151,6 +165,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
user: Union[Account, EndUser], user: Union[Account, EndUser],
application_generate_entity: WorkflowAppGenerateEntity, application_generate_entity: WorkflowAppGenerateEntity,
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
streaming: bool = True, streaming: bool = True,
workflow_thread_pool_id: Optional[str] = None, workflow_thread_pool_id: Optional[str] = None,
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]: ) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
@ -162,6 +177,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
:param user: account or end user :param user: account or end user
:param application_generate_entity: application generate entity :param application_generate_entity: application generate entity
:param invoke_from: invoke from source :param invoke_from: invoke from source
:param workflow_node_execution_repository: repository for workflow node execution
:param streaming: is stream :param streaming: is stream
:param workflow_thread_pool_id: workflow thread pool id :param workflow_thread_pool_id: workflow thread pool id
""" """
@ -193,6 +209,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow=workflow, workflow=workflow,
queue_manager=queue_manager, queue_manager=queue_manager,
user=user, user=user,
workflow_node_execution_repository=workflow_node_execution_repository,
stream=streaming, stream=streaming,
) )
@ -245,12 +262,23 @@ class WorkflowAppGenerator(BaseAppGenerator):
contexts.plugin_tool_providers.set({}) contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock()) contexts.plugin_tool_providers_lock.set(threading.Lock())
# Create workflow node execution repository
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
params={
"tenant_id": application_generate_entity.app_config.tenant_id,
"app_id": application_generate_entity.app_config.app_id,
"session_factory": session_factory,
}
)
return self._generate( return self._generate(
app_model=app_model, app_model=app_model,
workflow=workflow, workflow=workflow,
user=user, user=user,
invoke_from=InvokeFrom.DEBUGGER, invoke_from=InvokeFrom.DEBUGGER,
application_generate_entity=application_generate_entity, application_generate_entity=application_generate_entity,
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming, streaming=streaming,
) )
@ -299,12 +327,23 @@ class WorkflowAppGenerator(BaseAppGenerator):
contexts.plugin_tool_providers.set({}) contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock()) contexts.plugin_tool_providers_lock.set(threading.Lock())
# Create workflow node execution repository
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
params={
"tenant_id": application_generate_entity.app_config.tenant_id,
"app_id": application_generate_entity.app_config.app_id,
"session_factory": session_factory,
}
)
return self._generate( return self._generate(
app_model=app_model, app_model=app_model,
workflow=workflow, workflow=workflow,
user=user, user=user,
invoke_from=InvokeFrom.DEBUGGER, invoke_from=InvokeFrom.DEBUGGER,
application_generate_entity=application_generate_entity, application_generate_entity=application_generate_entity,
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming, streaming=streaming,
) )
@ -361,6 +400,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow: Workflow, workflow: Workflow,
queue_manager: AppQueueManager, queue_manager: AppQueueManager,
user: Union[Account, EndUser], user: Union[Account, EndUser],
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
stream: bool = False, stream: bool = False,
) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]: ) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
""" """
@ -370,6 +410,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
:param queue_manager: queue manager :param queue_manager: queue manager
:param user: account or end user :param user: account or end user
:param stream: is stream :param stream: is stream
:param workflow_node_execution_repository: optional repository for workflow node execution
:return: :return:
""" """
# init generate task pipeline # init generate task pipeline
@ -379,6 +420,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
queue_manager=queue_manager, queue_manager=queue_manager,
user=user, user=user,
stream=stream, stream=stream,
workflow_node_execution_repository=workflow_node_execution_repository,
) )
try: try:

View File

@ -54,6 +54,7 @@ from core.app.entities.task_entities import (
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage
from core.ops.ops_trace_manager import TraceQueueManager from core.ops.ops_trace_manager import TraceQueueManager
from core.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.enums import SystemVariableKey from core.workflow.enums import SystemVariableKey
from extensions.ext_database import db from extensions.ext_database import db
from models.account import Account from models.account import Account
@ -82,6 +83,7 @@ class WorkflowAppGenerateTaskPipeline:
queue_manager: AppQueueManager, queue_manager: AppQueueManager,
user: Union[Account, EndUser], user: Union[Account, EndUser],
stream: bool, stream: bool,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
) -> None: ) -> None:
self._base_task_pipeline = BasedGenerateTaskPipeline( self._base_task_pipeline = BasedGenerateTaskPipeline(
application_generate_entity=application_generate_entity, application_generate_entity=application_generate_entity,
@ -109,6 +111,7 @@ class WorkflowAppGenerateTaskPipeline:
SystemVariableKey.WORKFLOW_ID: workflow.id, SystemVariableKey.WORKFLOW_ID: workflow.id,
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id, SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
}, },
workflow_node_execution_repository=workflow_node_execution_repository,
) )
self._application_generate_entity = application_generate_entity self._application_generate_entity = application_generate_entity

View File

@ -6,7 +6,7 @@ from typing import Any, Optional, Union, cast
from uuid import uuid4 from uuid import uuid4
from sqlalchemy import func, select from sqlalchemy import func, select
from sqlalchemy.orm import Session, sessionmaker from sqlalchemy.orm import Session
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.queue_entities import ( from core.app.entities.queue_entities import (
@ -49,14 +49,13 @@ from core.file import FILE_MODEL_IDENTITY, File
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.ops.entities.trace_entity import TraceTaskName from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.repository import RepositoryFactory from core.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.tools.tool_manager import ToolManager from core.tools.tool_manager import ToolManager
from core.workflow.entities.node_entities import NodeRunMetadataKey from core.workflow.entities.node_entities import NodeRunMetadataKey
from core.workflow.enums import SystemVariableKey from core.workflow.enums import SystemVariableKey
from core.workflow.nodes import NodeType from core.workflow.nodes import NodeType
from core.workflow.nodes.tool.entities import ToolNodeData from core.workflow.nodes.tool.entities import ToolNodeData
from core.workflow.workflow_entry import WorkflowEntry from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from models.account import Account from models.account import Account
from models.enums import CreatedByRole, WorkflowRunTriggeredFrom from models.enums import CreatedByRole, WorkflowRunTriggeredFrom
from models.model import EndUser from models.model import EndUser
@ -76,26 +75,13 @@ class WorkflowCycleManage:
*, *,
application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity], application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity],
workflow_system_variables: dict[SystemVariableKey, Any], workflow_system_variables: dict[SystemVariableKey, Any],
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
) -> None: ) -> None:
self._workflow_run: WorkflowRun | None = None self._workflow_run: WorkflowRun | None = None
self._workflow_node_executions: dict[str, WorkflowNodeExecution] = {} self._workflow_node_executions: dict[str, WorkflowNodeExecution] = {}
self._application_generate_entity = application_generate_entity self._application_generate_entity = application_generate_entity
self._workflow_system_variables = workflow_system_variables self._workflow_system_variables = workflow_system_variables
self._workflow_node_execution_repository = workflow_node_execution_repository
# Initialize the session factory and repository
# We use the global db engine instead of the session passed to methods
# Disable expire_on_commit to avoid the need for merging objects
self._session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
self._workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
params={
"tenant_id": self._application_generate_entity.app_config.tenant_id,
"app_id": self._application_generate_entity.app_config.app_id,
"session_factory": self._session_factory,
}
)
# We'll still keep the cache for backward compatibility and performance
# but use the repository for database operations
def _handle_workflow_run_start( def _handle_workflow_run_start(
self, self,