diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 6bf684f8e4..fd63c7787f 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -56,6 +56,7 @@ from models.account import Account from models.model import Conversation, EndUser, Message from models.workflow import ( Workflow, + WorkflowNodeExecution, WorkflowRunStatus, ) @@ -72,6 +73,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc _workflow: Workflow _user: Union[Account, EndUser] _workflow_system_variables: dict[SystemVariableKey, Any] + _wip_workflow_node_executions: dict[str, WorkflowNodeExecution] def __init__( self, @@ -115,6 +117,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc } self._task_state = WorkflowTaskState() + self._wip_workflow_node_executions = {} self._conversation_name_generate_thread = None diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index 3afc505367..7c53556e43 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -52,6 +52,7 @@ from models.workflow import ( Workflow, WorkflowAppLog, WorkflowAppLogCreatedFrom, + WorkflowNodeExecution, WorkflowRun, WorkflowRunStatus, ) @@ -69,6 +70,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa _task_state: WorkflowTaskState _application_generate_entity: WorkflowAppGenerateEntity _workflow_system_variables: dict[SystemVariableKey, Any] + _wip_workflow_node_executions: dict[str, WorkflowNodeExecution] def __init__( self, @@ -103,6 +105,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa } self._task_state = WorkflowTaskState() + self._wip_workflow_node_executions = {} def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]: """ diff --git a/api/core/app/task_pipeline/workflow_cycle_manage.py b/api/core/app/task_pipeline/workflow_cycle_manage.py index 4fc587db77..f48ae9c01e 100644 --- a/api/core/app/task_pipeline/workflow_cycle_manage.py +++ b/api/core/app/task_pipeline/workflow_cycle_manage.py @@ -57,6 +57,7 @@ class WorkflowCycleManage: _user: Union[Account, EndUser] _task_state: WorkflowTaskState _workflow_system_variables: dict[SystemVariableKey, Any] + _wip_workflow_node_executions: dict[str, WorkflowNodeExecution] def _handle_workflow_run_start(self) -> WorkflowRun: max_sequence = ( @@ -251,6 +252,8 @@ class WorkflowCycleManage: db.session.refresh(workflow_node_execution) db.session.close() + self._wip_workflow_node_executions[workflow_node_execution.node_execution_id] = workflow_node_execution + return workflow_node_execution def _handle_workflow_node_execution_success(self, event: QueueNodeSucceededEvent) -> WorkflowNodeExecution: @@ -275,9 +278,10 @@ class WorkflowCycleManage: workflow_node_execution.elapsed_time = (workflow_node_execution.finished_at - event.start_at).total_seconds() db.session.commit() - db.session.refresh(workflow_node_execution) db.session.close() + self._wip_workflow_node_executions.pop(workflow_node_execution.node_execution_id) + return workflow_node_execution def _handle_workflow_node_execution_failed(self, event: QueueNodeFailedEvent) -> WorkflowNodeExecution: @@ -300,9 +304,10 @@ class WorkflowCycleManage: workflow_node_execution.elapsed_time = (workflow_node_execution.finished_at - event.start_at).total_seconds() db.session.commit() - db.session.refresh(workflow_node_execution) db.session.close() + self._wip_workflow_node_executions.pop(workflow_node_execution.node_execution_id) + return workflow_node_execution ################################################# @@ -678,17 +683,7 @@ class WorkflowCycleManage: :param node_execution_id: workflow node execution id :return: """ - workflow_node_execution = ( - db.session.query(WorkflowNodeExecution) - .filter( - WorkflowNodeExecution.tenant_id == self._application_generate_entity.app_config.tenant_id, - WorkflowNodeExecution.app_id == self._application_generate_entity.app_config.app_id, - WorkflowNodeExecution.workflow_id == self._workflow.id, - WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, - WorkflowNodeExecution.node_execution_id == node_execution_id, - ) - .first() - ) + workflow_node_execution = self._wip_workflow_node_executions.get(node_execution_id) if not workflow_node_execution: raise Exception(f"Workflow node execution not found: {node_execution_id}")