diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index b74100bb19..4b021aa0b3 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -26,10 +26,13 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.ops.ops_trace_manager import TraceQueueManager from core.prompt.utils.get_thread_messages_length import get_thread_messages_length from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository +from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository +from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository from extensions.ext_database import db from factories import file_factory from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom +from models.enums import WorkflowRunTriggeredFrom from services.conversation_service import ConversationService from services.errors.message import MessageNotExistsError @@ -159,8 +162,22 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): contexts.plugin_tool_providers.set({}) contexts.plugin_tool_providers_lock.set(threading.Lock()) - # Create workflow node execution repository + # Create repositories + # + # Create session factory session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) + # Create workflow execution(aka workflow run) repository + if invoke_from == InvokeFrom.DEBUGGER: + workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING + else: + workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN + workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository( + session_factory=session_factory, + user=user, + app_id=application_generate_entity.app_config.app_id, + triggered_from=workflow_triggered_from, + ) + # Create workflow node execution repository workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( session_factory=session_factory, user=user, @@ -173,6 +190,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): user=user, invoke_from=invoke_from, application_generate_entity=application_generate_entity, + workflow_execution_repository=workflow_execution_repository, workflow_node_execution_repository=workflow_node_execution_repository, conversation=conversation, stream=streaming, @@ -226,8 +244,18 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): contexts.plugin_tool_providers.set({}) contexts.plugin_tool_providers_lock.set(threading.Lock()) - # Create workflow node execution repository + # Create repositories + # + # Create session factory session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) + # Create workflow execution(aka workflow run) repository + workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository( + session_factory=session_factory, + user=user, + app_id=application_generate_entity.app_config.app_id, + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, + ) + # Create workflow node execution repository workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( session_factory=session_factory, user=user, @@ -240,6 +268,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): user=user, invoke_from=InvokeFrom.DEBUGGER, application_generate_entity=application_generate_entity, + workflow_execution_repository=workflow_execution_repository, workflow_node_execution_repository=workflow_node_execution_repository, conversation=None, stream=streaming, @@ -291,8 +320,18 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): contexts.plugin_tool_providers.set({}) contexts.plugin_tool_providers_lock.set(threading.Lock()) - # Create workflow node execution repository + # Create repositories + # + # Create session factory session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) + # Create workflow execution(aka workflow run) repository + workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository( + session_factory=session_factory, + user=user, + app_id=application_generate_entity.app_config.app_id, + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, + ) + # Create workflow node execution repository workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( session_factory=session_factory, user=user, @@ -305,6 +344,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): user=user, invoke_from=InvokeFrom.DEBUGGER, application_generate_entity=application_generate_entity, + workflow_execution_repository=workflow_execution_repository, workflow_node_execution_repository=workflow_node_execution_repository, conversation=None, stream=streaming, @@ -317,6 +357,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): user: Union[Account, EndUser], invoke_from: InvokeFrom, application_generate_entity: AdvancedChatAppGenerateEntity, + workflow_execution_repository: WorkflowExecutionRepository, workflow_node_execution_repository: WorkflowNodeExecutionRepository, conversation: Optional[Conversation] = None, stream: bool = True, @@ -381,6 +422,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): conversation=conversation, message=message, user=user, + workflow_execution_repository=workflow_execution_repository, workflow_node_execution_repository=workflow_node_execution_repository, stream=stream, ) @@ -453,6 +495,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): conversation: Conversation, message: Message, user: Union[Account, EndUser], + workflow_execution_repository: WorkflowExecutionRepository, workflow_node_execution_repository: WorkflowNodeExecutionRepository, stream: bool = False, ) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]: @@ -476,9 +519,10 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): conversation=conversation, message=message, user=user, - stream=stream, dialogue_count=self._dialogue_count, + workflow_execution_repository=workflow_execution_repository, workflow_node_execution_repository=workflow_node_execution_repository, + stream=stream, ) try: diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 735b2a9709..cd764d56a5 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -64,6 +64,7 @@ from core.ops.ops_trace_manager import TraceQueueManager from core.workflow.enums import SystemVariableKey from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.nodes import NodeType +from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository from core.workflow.workflow_cycle_manager import WorkflowCycleManager from events.message_event import message_was_created @@ -94,6 +95,7 @@ class AdvancedChatAppGenerateTaskPipeline: user: Union[Account, EndUser], stream: bool, dialogue_count: int, + workflow_execution_repository: WorkflowExecutionRepository, workflow_node_execution_repository: WorkflowNodeExecutionRepository, ) -> None: self._base_task_pipeline = BasedGenerateTaskPipeline( @@ -125,6 +127,7 @@ class AdvancedChatAppGenerateTaskPipeline: SystemVariableKey.WORKFLOW_ID: workflow.id, SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id, }, + workflow_execution_repository=workflow_execution_repository, workflow_node_execution_repository=workflow_node_execution_repository, ) @@ -294,21 +297,19 @@ class AdvancedChatAppGenerateTaskPipeline: with Session(db.engine, expire_on_commit=False) as session: # init workflow run - workflow_run = self._workflow_cycle_manager._handle_workflow_run_start( + workflow_execution = self._workflow_cycle_manager.handle_workflow_run_start( session=session, workflow_id=self._workflow_id, - user_id=self._user_id, - created_by_role=self._created_by_role, ) - self._workflow_run_id = workflow_run.id + self._workflow_run_id = workflow_execution.id message = self._get_message(session=session) if not message: raise ValueError(f"Message not found: {self._message_id}") - message.workflow_run_id = workflow_run.id - workflow_start_resp = self._workflow_cycle_manager._workflow_start_to_stream_response( - session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run + message.workflow_run_id = workflow_execution.id + workflow_start_resp = self._workflow_cycle_manager.workflow_start_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_execution=workflow_execution, ) - session.commit() yield workflow_start_resp elif isinstance( @@ -319,13 +320,10 @@ class AdvancedChatAppGenerateTaskPipeline: raise ValueError("workflow run not initialized.") with Session(db.engine, expire_on_commit=False) as session: - workflow_run = self._workflow_cycle_manager._get_workflow_run( - session=session, workflow_run_id=self._workflow_run_id + workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_retried( + workflow_execution_id=self._workflow_run_id, event=event ) - workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_retried( - workflow_run=workflow_run, event=event - ) - node_retry_resp = self._workflow_cycle_manager._workflow_node_retry_to_stream_response( + node_retry_resp = self._workflow_cycle_manager.workflow_node_retry_to_stream_response( event=event, task_id=self._application_generate_entity.task_id, workflow_node_execution=workflow_node_execution, @@ -338,20 +336,15 @@ class AdvancedChatAppGenerateTaskPipeline: if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - with Session(db.engine, expire_on_commit=False) as session: - workflow_run = self._workflow_cycle_manager._get_workflow_run( - session=session, workflow_run_id=self._workflow_run_id - ) - workflow_node_execution = self._workflow_cycle_manager._handle_node_execution_start( - workflow_run=workflow_run, event=event - ) + workflow_node_execution = self._workflow_cycle_manager.handle_node_execution_start( + workflow_execution_id=self._workflow_run_id, event=event + ) - node_start_resp = self._workflow_cycle_manager._workflow_node_start_to_stream_response( - event=event, - task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution, - ) - session.commit() + node_start_resp = self._workflow_cycle_manager.workflow_node_start_to_stream_response( + event=event, + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution, + ) if node_start_resp: yield node_start_resp @@ -359,15 +352,15 @@ class AdvancedChatAppGenerateTaskPipeline: # Record files if it's an answer node or end node if event.node_type in [NodeType.ANSWER, NodeType.END]: self._recorded_files.extend( - self._workflow_cycle_manager._fetch_files_from_node_outputs(event.outputs or {}) + self._workflow_cycle_manager.fetch_files_from_node_outputs(event.outputs or {}) ) with Session(db.engine, expire_on_commit=False) as session: - workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success( + workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_success( event=event ) - node_finish_resp = self._workflow_cycle_manager._workflow_node_finish_to_stream_response( + node_finish_resp = self._workflow_cycle_manager.workflow_node_finish_to_stream_response( event=event, task_id=self._application_generate_entity.task_id, workflow_node_execution=workflow_node_execution, @@ -383,11 +376,11 @@ class AdvancedChatAppGenerateTaskPipeline: | QueueNodeInLoopFailedEvent | QueueNodeExceptionEvent, ): - workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed( + workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_failed( event=event ) - node_finish_resp = self._workflow_cycle_manager._workflow_node_finish_to_stream_response( + node_finish_resp = self._workflow_cycle_manager.workflow_node_finish_to_stream_response( event=event, task_id=self._application_generate_entity.task_id, workflow_node_execution=workflow_node_execution, @@ -399,132 +392,90 @@ class AdvancedChatAppGenerateTaskPipeline: if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - with Session(db.engine, expire_on_commit=False) as session: - workflow_run = self._workflow_cycle_manager._get_workflow_run( - session=session, workflow_run_id=self._workflow_run_id - ) - parallel_start_resp = ( - self._workflow_cycle_manager._workflow_parallel_branch_start_to_stream_response( - session=session, - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run, - event=event, - ) - ) + parallel_start_resp = self._workflow_cycle_manager.workflow_parallel_branch_start_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_execution_id=self._workflow_run_id, + event=event, + ) yield parallel_start_resp elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent): if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - with Session(db.engine, expire_on_commit=False) as session: - workflow_run = self._workflow_cycle_manager._get_workflow_run( - session=session, workflow_run_id=self._workflow_run_id - ) - parallel_finish_resp = ( - self._workflow_cycle_manager._workflow_parallel_branch_finished_to_stream_response( - session=session, - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run, - event=event, - ) + parallel_finish_resp = ( + self._workflow_cycle_manager.workflow_parallel_branch_finished_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_execution_id=self._workflow_run_id, + event=event, ) + ) yield parallel_finish_resp elif isinstance(event, QueueIterationStartEvent): if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - with Session(db.engine, expire_on_commit=False) as session: - workflow_run = self._workflow_cycle_manager._get_workflow_run( - session=session, workflow_run_id=self._workflow_run_id - ) - iter_start_resp = self._workflow_cycle_manager._workflow_iteration_start_to_stream_response( - session=session, - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run, - event=event, - ) + iter_start_resp = self._workflow_cycle_manager.workflow_iteration_start_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_execution_id=self._workflow_run_id, + event=event, + ) yield iter_start_resp elif isinstance(event, QueueIterationNextEvent): if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - with Session(db.engine, expire_on_commit=False) as session: - workflow_run = self._workflow_cycle_manager._get_workflow_run( - session=session, workflow_run_id=self._workflow_run_id - ) - iter_next_resp = self._workflow_cycle_manager._workflow_iteration_next_to_stream_response( - session=session, - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run, - event=event, - ) + iter_next_resp = self._workflow_cycle_manager.workflow_iteration_next_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_execution_id=self._workflow_run_id, + event=event, + ) yield iter_next_resp elif isinstance(event, QueueIterationCompletedEvent): if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - with Session(db.engine, expire_on_commit=False) as session: - workflow_run = self._workflow_cycle_manager._get_workflow_run( - session=session, workflow_run_id=self._workflow_run_id - ) - iter_finish_resp = self._workflow_cycle_manager._workflow_iteration_completed_to_stream_response( - session=session, - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run, - event=event, - ) + iter_finish_resp = self._workflow_cycle_manager.workflow_iteration_completed_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_execution_id=self._workflow_run_id, + event=event, + ) yield iter_finish_resp elif isinstance(event, QueueLoopStartEvent): if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - with Session(db.engine, expire_on_commit=False) as session: - workflow_run = self._workflow_cycle_manager._get_workflow_run( - session=session, workflow_run_id=self._workflow_run_id - ) - loop_start_resp = self._workflow_cycle_manager._workflow_loop_start_to_stream_response( - session=session, - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run, - event=event, - ) + loop_start_resp = self._workflow_cycle_manager.workflow_loop_start_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_execution_id=self._workflow_run_id, + event=event, + ) yield loop_start_resp elif isinstance(event, QueueLoopNextEvent): if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - with Session(db.engine, expire_on_commit=False) as session: - workflow_run = self._workflow_cycle_manager._get_workflow_run( - session=session, workflow_run_id=self._workflow_run_id - ) - loop_next_resp = self._workflow_cycle_manager._workflow_loop_next_to_stream_response( - session=session, - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run, - event=event, - ) + loop_next_resp = self._workflow_cycle_manager.workflow_loop_next_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_execution_id=self._workflow_run_id, + event=event, + ) yield loop_next_resp elif isinstance(event, QueueLoopCompletedEvent): if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - with Session(db.engine, expire_on_commit=False) as session: - workflow_run = self._workflow_cycle_manager._get_workflow_run( - session=session, workflow_run_id=self._workflow_run_id - ) - loop_finish_resp = self._workflow_cycle_manager._workflow_loop_completed_to_stream_response( - session=session, - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run, - event=event, - ) + loop_finish_resp = self._workflow_cycle_manager.workflow_loop_completed_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_execution_id=self._workflow_run_id, + event=event, + ) yield loop_finish_resp elif isinstance(event, QueueWorkflowSucceededEvent): @@ -535,10 +486,8 @@ class AdvancedChatAppGenerateTaskPipeline: raise ValueError("workflow run not initialized.") with Session(db.engine, expire_on_commit=False) as session: - workflow_run = self._workflow_cycle_manager._handle_workflow_run_success( - session=session, + workflow_execution = self._workflow_cycle_manager.handle_workflow_run_success( workflow_run_id=self._workflow_run_id, - start_at=graph_runtime_state.start_at, total_tokens=graph_runtime_state.total_tokens, total_steps=graph_runtime_state.node_run_steps, outputs=event.outputs, @@ -546,10 +495,11 @@ class AdvancedChatAppGenerateTaskPipeline: trace_manager=trace_manager, ) - workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response( - session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run + workflow_finish_resp = self._workflow_cycle_manager.workflow_finish_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_execution=workflow_execution, ) - session.commit() yield workflow_finish_resp self._base_task_pipeline._queue_manager.publish( @@ -562,10 +512,8 @@ class AdvancedChatAppGenerateTaskPipeline: raise ValueError("graph runtime state not initialized.") with Session(db.engine, expire_on_commit=False) as session: - workflow_run = self._workflow_cycle_manager._handle_workflow_run_partial_success( - session=session, + workflow_execution = self._workflow_cycle_manager.handle_workflow_run_partial_success( workflow_run_id=self._workflow_run_id, - start_at=graph_runtime_state.start_at, total_tokens=graph_runtime_state.total_tokens, total_steps=graph_runtime_state.node_run_steps, outputs=event.outputs, @@ -573,10 +521,11 @@ class AdvancedChatAppGenerateTaskPipeline: conversation_id=None, trace_manager=trace_manager, ) - workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response( - session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run + workflow_finish_resp = self._workflow_cycle_manager.workflow_finish_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_execution=workflow_execution, ) - session.commit() yield workflow_finish_resp self._base_task_pipeline._queue_manager.publish( @@ -589,26 +538,25 @@ class AdvancedChatAppGenerateTaskPipeline: raise ValueError("graph runtime state not initialized.") with Session(db.engine, expire_on_commit=False) as session: - workflow_run = self._workflow_cycle_manager._handle_workflow_run_failed( - session=session, + workflow_execution = self._workflow_cycle_manager.handle_workflow_run_failed( workflow_run_id=self._workflow_run_id, - start_at=graph_runtime_state.start_at, total_tokens=graph_runtime_state.total_tokens, total_steps=graph_runtime_state.node_run_steps, status=WorkflowRunStatus.FAILED, - error=event.error, + error_message=event.error, conversation_id=self._conversation_id, trace_manager=trace_manager, exceptions_count=event.exceptions_count, ) - workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response( - session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run + workflow_finish_resp = self._workflow_cycle_manager.workflow_finish_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_execution=workflow_execution, ) - err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_run.error}")) + err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_execution.error_message}")) err = self._base_task_pipeline._handle_error( event=err_event, session=session, message_id=self._message_id ) - session.commit() yield workflow_finish_resp yield self._base_task_pipeline._error_to_stream_response(err) @@ -616,21 +564,19 @@ class AdvancedChatAppGenerateTaskPipeline: elif isinstance(event, QueueStopEvent): if self._workflow_run_id and graph_runtime_state: with Session(db.engine, expire_on_commit=False) as session: - workflow_run = self._workflow_cycle_manager._handle_workflow_run_failed( - session=session, + workflow_execution = self._workflow_cycle_manager.handle_workflow_run_failed( workflow_run_id=self._workflow_run_id, - start_at=graph_runtime_state.start_at, total_tokens=graph_runtime_state.total_tokens, total_steps=graph_runtime_state.node_run_steps, status=WorkflowRunStatus.STOPPED, - error=event.get_stop_reason(), + error_message=event.get_stop_reason(), conversation_id=self._conversation_id, trace_manager=trace_manager, ) - workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response( + workflow_finish_resp = self._workflow_cycle_manager.workflow_finish_to_stream_response( session=session, task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run, + workflow_execution=workflow_execution, ) # Save message self._save_message(session=session, graph_runtime_state=graph_runtime_state) @@ -711,7 +657,7 @@ class AdvancedChatAppGenerateTaskPipeline: yield self._message_end_to_stream_response() elif isinstance(event, QueueAgentLogEvent): - yield self._workflow_cycle_manager._handle_agent_log( + yield self._workflow_cycle_manager.handle_agent_log( task_id=self._application_generate_entity.task_id, event=event ) else: diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index d49ff682b9..01d35c57ce 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -18,16 +18,19 @@ from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.app.apps.workflow.app_queue_manager import WorkflowAppQueueManager from core.app.apps.workflow.app_runner import WorkflowAppRunner from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter +from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.ops.ops_trace_manager import TraceQueueManager from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository +from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository +from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository -from core.workflow.workflow_app_generate_task_pipeline import WorkflowAppGenerateTaskPipeline from extensions.ext_database import db from factories import file_factory from models import Account, App, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom +from models.enums import WorkflowRunTriggeredFrom logger = logging.getLogger(__name__) @@ -136,9 +139,22 @@ class WorkflowAppGenerator(BaseAppGenerator): contexts.plugin_tool_providers.set({}) contexts.plugin_tool_providers_lock.set(threading.Lock()) - # Create workflow node execution repository + # Create repositories + # + # Create session factory session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) - + # Create workflow execution(aka workflow run) repository + if invoke_from == InvokeFrom.DEBUGGER: + workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING + else: + workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN + workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository( + session_factory=session_factory, + user=user, + app_id=application_generate_entity.app_config.app_id, + triggered_from=workflow_triggered_from, + ) + # Create workflow node execution repository workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( session_factory=session_factory, user=user, @@ -152,6 +168,7 @@ class WorkflowAppGenerator(BaseAppGenerator): user=user, application_generate_entity=application_generate_entity, invoke_from=invoke_from, + workflow_execution_repository=workflow_execution_repository, workflow_node_execution_repository=workflow_node_execution_repository, streaming=streaming, workflow_thread_pool_id=workflow_thread_pool_id, @@ -165,6 +182,7 @@ class WorkflowAppGenerator(BaseAppGenerator): user: Union[Account, EndUser], application_generate_entity: WorkflowAppGenerateEntity, invoke_from: InvokeFrom, + workflow_execution_repository: WorkflowExecutionRepository, workflow_node_execution_repository: WorkflowNodeExecutionRepository, streaming: bool = True, workflow_thread_pool_id: Optional[str] = None, @@ -209,6 +227,7 @@ class WorkflowAppGenerator(BaseAppGenerator): workflow=workflow, queue_manager=queue_manager, user=user, + workflow_execution_repository=workflow_execution_repository, workflow_node_execution_repository=workflow_node_execution_repository, stream=streaming, ) @@ -262,6 +281,17 @@ class WorkflowAppGenerator(BaseAppGenerator): contexts.plugin_tool_providers.set({}) contexts.plugin_tool_providers_lock.set(threading.Lock()) + # Create repositories + # + # Create session factory + session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) + # Create workflow execution(aka workflow run) repository + workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository( + session_factory=session_factory, + user=user, + app_id=application_generate_entity.app_config.app_id, + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, + ) # Create workflow node execution repository session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) @@ -278,6 +308,7 @@ class WorkflowAppGenerator(BaseAppGenerator): user=user, invoke_from=InvokeFrom.DEBUGGER, application_generate_entity=application_generate_entity, + workflow_execution_repository=workflow_execution_repository, workflow_node_execution_repository=workflow_node_execution_repository, streaming=streaming, ) @@ -327,6 +358,17 @@ class WorkflowAppGenerator(BaseAppGenerator): contexts.plugin_tool_providers.set({}) contexts.plugin_tool_providers_lock.set(threading.Lock()) + # Create repositories + # + # Create session factory + session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) + # Create workflow execution(aka workflow run) repository + workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository( + session_factory=session_factory, + user=user, + app_id=application_generate_entity.app_config.app_id, + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, + ) # Create workflow node execution repository session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) @@ -343,6 +385,7 @@ class WorkflowAppGenerator(BaseAppGenerator): user=user, invoke_from=InvokeFrom.DEBUGGER, application_generate_entity=application_generate_entity, + workflow_execution_repository=workflow_execution_repository, workflow_node_execution_repository=workflow_node_execution_repository, streaming=streaming, ) @@ -400,6 +443,7 @@ class WorkflowAppGenerator(BaseAppGenerator): workflow: Workflow, queue_manager: AppQueueManager, user: Union[Account, EndUser], + workflow_execution_repository: WorkflowExecutionRepository, workflow_node_execution_repository: WorkflowNodeExecutionRepository, stream: bool = False, ) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]: @@ -419,8 +463,9 @@ class WorkflowAppGenerator(BaseAppGenerator): workflow=workflow, queue_manager=queue_manager, user=user, - stream=stream, + workflow_execution_repository=workflow_execution_repository, workflow_node_execution_repository=workflow_node_execution_repository, + stream=stream, ) try: diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py new file mode 100644 index 0000000000..f2ebd78b36 --- /dev/null +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -0,0 +1,591 @@ +import logging +import time +from collections.abc import Generator +from typing import Optional, Union + +from sqlalchemy import select +from sqlalchemy.orm import Session + +from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME +from core.app.apps.base_app_queue_manager import AppQueueManager +from core.app.entities.app_invoke_entities import ( + InvokeFrom, + WorkflowAppGenerateEntity, +) +from core.app.entities.queue_entities import ( + QueueAgentLogEvent, + QueueErrorEvent, + QueueIterationCompletedEvent, + QueueIterationNextEvent, + QueueIterationStartEvent, + QueueLoopCompletedEvent, + QueueLoopNextEvent, + QueueLoopStartEvent, + QueueNodeExceptionEvent, + QueueNodeFailedEvent, + QueueNodeInIterationFailedEvent, + QueueNodeInLoopFailedEvent, + QueueNodeRetryEvent, + QueueNodeStartedEvent, + QueueNodeSucceededEvent, + QueueParallelBranchRunFailedEvent, + QueueParallelBranchRunStartedEvent, + QueueParallelBranchRunSucceededEvent, + QueuePingEvent, + QueueStopEvent, + QueueTextChunkEvent, + QueueWorkflowFailedEvent, + QueueWorkflowPartialSuccessEvent, + QueueWorkflowStartedEvent, + QueueWorkflowSucceededEvent, +) +from core.app.entities.task_entities import ( + ErrorStreamResponse, + MessageAudioEndStreamResponse, + MessageAudioStreamResponse, + StreamResponse, + TextChunkStreamResponse, + WorkflowAppBlockingResponse, + WorkflowAppStreamResponse, + WorkflowFinishStreamResponse, + WorkflowStartStreamResponse, + WorkflowTaskState, +) +from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline +from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk +from core.ops.ops_trace_manager import TraceQueueManager +from core.workflow.entities.workflow_execution_entities import WorkflowExecution +from core.workflow.enums import SystemVariableKey +from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository +from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from core.workflow.workflow_cycle_manager import WorkflowCycleManager +from extensions.ext_database import db +from models.account import Account +from models.enums import CreatorUserRole +from models.model import EndUser +from models.workflow import ( + Workflow, + WorkflowAppLog, + WorkflowAppLogCreatedFrom, + WorkflowRun, + WorkflowRunStatus, +) + +logger = logging.getLogger(__name__) + + +class WorkflowAppGenerateTaskPipeline: + """ + WorkflowAppGenerateTaskPipeline is a class that generate stream output and state management for Application. + """ + + def __init__( + self, + application_generate_entity: WorkflowAppGenerateEntity, + workflow: Workflow, + queue_manager: AppQueueManager, + user: Union[Account, EndUser], + stream: bool, + workflow_execution_repository: WorkflowExecutionRepository, + workflow_node_execution_repository: WorkflowNodeExecutionRepository, + ) -> None: + self._base_task_pipeline = BasedGenerateTaskPipeline( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + stream=stream, + ) + + if isinstance(user, EndUser): + self._user_id = user.id + user_session_id = user.session_id + self._created_by_role = CreatorUserRole.END_USER + elif isinstance(user, Account): + self._user_id = user.id + user_session_id = user.id + self._created_by_role = CreatorUserRole.ACCOUNT + else: + raise ValueError(f"Invalid user type: {type(user)}") + + self._workflow_cycle_manager = WorkflowCycleManager( + application_generate_entity=application_generate_entity, + workflow_system_variables={ + SystemVariableKey.FILES: application_generate_entity.files, + SystemVariableKey.USER_ID: user_session_id, + SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id, + SystemVariableKey.WORKFLOW_ID: workflow.id, + SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id, + }, + workflow_execution_repository=workflow_execution_repository, + workflow_node_execution_repository=workflow_node_execution_repository, + ) + + self._application_generate_entity = application_generate_entity + self._workflow_id = workflow.id + self._workflow_features_dict = workflow.features_dict + self._task_state = WorkflowTaskState() + self._workflow_run_id = "" + + def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]: + """ + Process generate task pipeline. + :return: + """ + generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager) + if self._base_task_pipeline._stream: + return self._to_stream_response(generator) + else: + return self._to_blocking_response(generator) + + def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) -> WorkflowAppBlockingResponse: + """ + To blocking response. + :return: + """ + for stream_response in generator: + if isinstance(stream_response, ErrorStreamResponse): + raise stream_response.err + elif isinstance(stream_response, WorkflowFinishStreamResponse): + response = WorkflowAppBlockingResponse( + task_id=self._application_generate_entity.task_id, + workflow_run_id=stream_response.data.id, + data=WorkflowAppBlockingResponse.Data( + id=stream_response.data.id, + workflow_id=stream_response.data.workflow_id, + status=stream_response.data.status, + outputs=stream_response.data.outputs, + error=stream_response.data.error, + elapsed_time=stream_response.data.elapsed_time, + total_tokens=stream_response.data.total_tokens, + total_steps=stream_response.data.total_steps, + created_at=int(stream_response.data.created_at), + finished_at=int(stream_response.data.finished_at), + ), + ) + + return response + else: + continue + + raise ValueError("queue listening stopped unexpectedly.") + + def _to_stream_response( + self, generator: Generator[StreamResponse, None, None] + ) -> Generator[WorkflowAppStreamResponse, None, None]: + """ + To stream response. + :return: + """ + workflow_run_id = None + for stream_response in generator: + if isinstance(stream_response, WorkflowStartStreamResponse): + workflow_run_id = stream_response.workflow_run_id + + yield WorkflowAppStreamResponse(workflow_run_id=workflow_run_id, stream_response=stream_response) + + def _listen_audio_msg(self, publisher: AppGeneratorTTSPublisher | None, task_id: str): + if not publisher: + return None + audio_msg = publisher.check_and_get_audio() + if audio_msg and isinstance(audio_msg, AudioTrunk) and audio_msg.status != "finish": + return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id) + return None + + def _wrapper_process_stream_response( + self, trace_manager: Optional[TraceQueueManager] = None + ) -> Generator[StreamResponse, None, None]: + tts_publisher = None + task_id = self._application_generate_entity.task_id + tenant_id = self._application_generate_entity.app_config.tenant_id + features_dict = self._workflow_features_dict + + if ( + features_dict.get("text_to_speech") + and features_dict["text_to_speech"].get("enabled") + and features_dict["text_to_speech"].get("autoPlay") == "enabled" + ): + tts_publisher = AppGeneratorTTSPublisher( + tenant_id, features_dict["text_to_speech"].get("voice"), features_dict["text_to_speech"].get("language") + ) + + for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager): + while True: + audio_response = self._listen_audio_msg(publisher=tts_publisher, task_id=task_id) + if audio_response: + yield audio_response + else: + break + yield response + + start_listener_time = time.time() + while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT: + try: + if not tts_publisher: + break + audio_trunk = tts_publisher.check_and_get_audio() + if audio_trunk is None: + # release cpu + # sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file) + time.sleep(TTS_AUTO_PLAY_YIELD_CPU_TIME) + continue + if audio_trunk.status == "finish": + break + else: + yield MessageAudioStreamResponse(audio=audio_trunk.audio, task_id=task_id) + except Exception: + logger.exception(f"Fails to get audio trunk, task_id: {task_id}") + break + if tts_publisher: + yield MessageAudioEndStreamResponse(audio="", task_id=task_id) + + def _process_stream_response( + self, + tts_publisher: Optional[AppGeneratorTTSPublisher] = None, + trace_manager: Optional[TraceQueueManager] = None, + ) -> Generator[StreamResponse, None, None]: + """ + Process stream response. + :return: + """ + graph_runtime_state = None + + for queue_message in self._base_task_pipeline._queue_manager.listen(): + event = queue_message.event + + if isinstance(event, QueuePingEvent): + yield self._base_task_pipeline._ping_stream_response() + elif isinstance(event, QueueErrorEvent): + err = self._base_task_pipeline._handle_error(event=event) + yield self._base_task_pipeline._error_to_stream_response(err) + break + elif isinstance(event, QueueWorkflowStartedEvent): + # override graph runtime state + graph_runtime_state = event.graph_runtime_state + + with Session(db.engine, expire_on_commit=False) as session: + # init workflow run + workflow_execution = self._workflow_cycle_manager.handle_workflow_run_start( + session=session, + workflow_id=self._workflow_id, + ) + self._workflow_run_id = workflow_execution.id + start_resp = self._workflow_cycle_manager.workflow_start_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_execution=workflow_execution, + ) + + yield start_resp + elif isinstance( + event, + QueueNodeRetryEvent, + ): + if not self._workflow_run_id: + raise ValueError("workflow run not initialized.") + with Session(db.engine, expire_on_commit=False) as session: + workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_retried( + workflow_execution_id=self._workflow_run_id, + event=event, + ) + response = self._workflow_cycle_manager.workflow_node_retry_to_stream_response( + event=event, + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution, + ) + session.commit() + + if response: + yield response + elif isinstance(event, QueueNodeStartedEvent): + if not self._workflow_run_id: + raise ValueError("workflow run not initialized.") + + workflow_node_execution = self._workflow_cycle_manager.handle_node_execution_start( + workflow_execution_id=self._workflow_run_id, event=event + ) + node_start_response = self._workflow_cycle_manager.workflow_node_start_to_stream_response( + event=event, + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution, + ) + + if node_start_response: + yield node_start_response + elif isinstance(event, QueueNodeSucceededEvent): + workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_success( + event=event + ) + node_success_response = self._workflow_cycle_manager.workflow_node_finish_to_stream_response( + event=event, + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution, + ) + + if node_success_response: + yield node_success_response + elif isinstance( + event, + QueueNodeFailedEvent + | QueueNodeInIterationFailedEvent + | QueueNodeInLoopFailedEvent + | QueueNodeExceptionEvent, + ): + workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_failed( + event=event, + ) + node_failed_response = self._workflow_cycle_manager.workflow_node_finish_to_stream_response( + event=event, + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution, + ) + + if node_failed_response: + yield node_failed_response + + elif isinstance(event, QueueParallelBranchRunStartedEvent): + if not self._workflow_run_id: + raise ValueError("workflow run not initialized.") + + parallel_start_resp = self._workflow_cycle_manager.workflow_parallel_branch_start_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_execution_id=self._workflow_run_id, + event=event, + ) + + yield parallel_start_resp + + elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent): + if not self._workflow_run_id: + raise ValueError("workflow run not initialized.") + + parallel_finish_resp = ( + self._workflow_cycle_manager.workflow_parallel_branch_finished_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_execution_id=self._workflow_run_id, + event=event, + ) + ) + + yield parallel_finish_resp + + elif isinstance(event, QueueIterationStartEvent): + if not self._workflow_run_id: + raise ValueError("workflow run not initialized.") + + iter_start_resp = self._workflow_cycle_manager.workflow_iteration_start_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_execution_id=self._workflow_run_id, + event=event, + ) + + yield iter_start_resp + + elif isinstance(event, QueueIterationNextEvent): + if not self._workflow_run_id: + raise ValueError("workflow run not initialized.") + + iter_next_resp = self._workflow_cycle_manager.workflow_iteration_next_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_execution_id=self._workflow_run_id, + event=event, + ) + + yield iter_next_resp + + elif isinstance(event, QueueIterationCompletedEvent): + if not self._workflow_run_id: + raise ValueError("workflow run not initialized.") + + iter_finish_resp = self._workflow_cycle_manager.workflow_iteration_completed_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_execution_id=self._workflow_run_id, + event=event, + ) + + yield iter_finish_resp + + elif isinstance(event, QueueLoopStartEvent): + if not self._workflow_run_id: + raise ValueError("workflow run not initialized.") + + loop_start_resp = self._workflow_cycle_manager.workflow_loop_start_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_execution_id=self._workflow_run_id, + event=event, + ) + + yield loop_start_resp + + elif isinstance(event, QueueLoopNextEvent): + if not self._workflow_run_id: + raise ValueError("workflow run not initialized.") + + loop_next_resp = self._workflow_cycle_manager.workflow_loop_next_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_execution_id=self._workflow_run_id, + event=event, + ) + + yield loop_next_resp + + elif isinstance(event, QueueLoopCompletedEvent): + if not self._workflow_run_id: + raise ValueError("workflow run not initialized.") + + loop_finish_resp = self._workflow_cycle_manager.workflow_loop_completed_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_execution_id=self._workflow_run_id, + event=event, + ) + + yield loop_finish_resp + + elif isinstance(event, QueueWorkflowSucceededEvent): + if not self._workflow_run_id: + raise ValueError("workflow run not initialized.") + if not graph_runtime_state: + raise ValueError("graph runtime state not initialized.") + + with Session(db.engine, expire_on_commit=False) as session: + workflow_execution = self._workflow_cycle_manager.handle_workflow_run_success( + workflow_run_id=self._workflow_run_id, + total_tokens=graph_runtime_state.total_tokens, + total_steps=graph_runtime_state.node_run_steps, + outputs=event.outputs, + conversation_id=None, + trace_manager=trace_manager, + ) + + # save workflow app log + self._save_workflow_app_log(session=session, workflow_execution=workflow_execution) + + workflow_finish_resp = self._workflow_cycle_manager.workflow_finish_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_execution=workflow_execution, + ) + session.commit() + + yield workflow_finish_resp + elif isinstance(event, QueueWorkflowPartialSuccessEvent): + if not self._workflow_run_id: + raise ValueError("workflow run not initialized.") + if not graph_runtime_state: + raise ValueError("graph runtime state not initialized.") + + with Session(db.engine, expire_on_commit=False) as session: + workflow_execution = self._workflow_cycle_manager.handle_workflow_run_partial_success( + workflow_run_id=self._workflow_run_id, + total_tokens=graph_runtime_state.total_tokens, + total_steps=graph_runtime_state.node_run_steps, + outputs=event.outputs, + exceptions_count=event.exceptions_count, + conversation_id=None, + trace_manager=trace_manager, + ) + + # save workflow app log + self._save_workflow_app_log(session=session, workflow_execution=workflow_execution) + + workflow_finish_resp = self._workflow_cycle_manager.workflow_finish_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_execution=workflow_execution, + ) + session.commit() + + yield workflow_finish_resp + elif isinstance(event, QueueWorkflowFailedEvent | QueueStopEvent): + if not self._workflow_run_id: + raise ValueError("workflow run not initialized.") + if not graph_runtime_state: + raise ValueError("graph runtime state not initialized.") + + with Session(db.engine, expire_on_commit=False) as session: + workflow_execution = self._workflow_cycle_manager.handle_workflow_run_failed( + workflow_run_id=self._workflow_run_id, + total_tokens=graph_runtime_state.total_tokens, + total_steps=graph_runtime_state.node_run_steps, + status=WorkflowRunStatus.FAILED + if isinstance(event, QueueWorkflowFailedEvent) + else WorkflowRunStatus.STOPPED, + error_message=event.error + if isinstance(event, QueueWorkflowFailedEvent) + else event.get_stop_reason(), + conversation_id=None, + trace_manager=trace_manager, + exceptions_count=event.exceptions_count if isinstance(event, QueueWorkflowFailedEvent) else 0, + ) + + # save workflow app log + self._save_workflow_app_log(session=session, workflow_execution=workflow_execution) + + workflow_finish_resp = self._workflow_cycle_manager.workflow_finish_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_execution=workflow_execution, + ) + session.commit() + + yield workflow_finish_resp + elif isinstance(event, QueueTextChunkEvent): + delta_text = event.text + if delta_text is None: + continue + + # only publish tts message at text chunk streaming + if tts_publisher: + tts_publisher.publish(queue_message) + + self._task_state.answer += delta_text + yield self._text_chunk_to_stream_response( + delta_text, from_variable_selector=event.from_variable_selector + ) + elif isinstance(event, QueueAgentLogEvent): + yield self._workflow_cycle_manager.handle_agent_log( + task_id=self._application_generate_entity.task_id, event=event + ) + else: + continue + + if tts_publisher: + tts_publisher.publish(None) + + def _save_workflow_app_log(self, *, session: Session, workflow_execution: WorkflowExecution) -> None: + workflow_run = session.scalar(select(WorkflowRun).where(WorkflowRun.id == workflow_execution.id)) + assert workflow_run is not None + invoke_from = self._application_generate_entity.invoke_from + if invoke_from == InvokeFrom.SERVICE_API: + created_from = WorkflowAppLogCreatedFrom.SERVICE_API + elif invoke_from == InvokeFrom.EXPLORE: + created_from = WorkflowAppLogCreatedFrom.INSTALLED_APP + elif invoke_from == InvokeFrom.WEB_APP: + created_from = WorkflowAppLogCreatedFrom.WEB_APP + else: + # not save log for debugging + return + + workflow_app_log = WorkflowAppLog() + workflow_app_log.tenant_id = workflow_run.tenant_id + workflow_app_log.app_id = workflow_run.app_id + workflow_app_log.workflow_id = workflow_run.workflow_id + workflow_app_log.workflow_run_id = workflow_run.id + workflow_app_log.created_from = created_from.value + workflow_app_log.created_by_role = self._created_by_role + workflow_app_log.created_by = self._user_id + + session.add(workflow_app_log) + session.commit() + + def _text_chunk_to_stream_response( + self, text: str, from_variable_selector: Optional[list[str]] = None + ) -> TextChunkStreamResponse: + """ + Handle completed event. + :param text: text + :return: + """ + response = TextChunkStreamResponse( + task_id=self._application_generate_entity.task_id, + data=TextChunkStreamResponse.Data(text=text, from_variable_selector=from_variable_selector), + ) + + return response diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py index 0c2d617f80..9b2bfcbf61 100644 --- a/api/core/app/entities/task_entities.py +++ b/api/core/app/entities/task_entities.py @@ -190,7 +190,7 @@ class WorkflowStartStreamResponse(StreamResponse): id: str workflow_id: str sequence_number: int - inputs: dict + inputs: Mapping[str, Any] created_at: int event: StreamEvent = StreamEvent.WORKFLOW_STARTED @@ -212,7 +212,7 @@ class WorkflowFinishStreamResponse(StreamResponse): workflow_id: str sequence_number: int status: str - outputs: Optional[dict] = None + outputs: Optional[Mapping[str, Any]] = None error: Optional[str] = None elapsed_time: float total_tokens: int @@ -788,7 +788,7 @@ class WorkflowAppBlockingResponse(AppBlockingResponse): id: str workflow_id: str status: str - outputs: Optional[dict] = None + outputs: Optional[Mapping[str, Any]] = None error: Optional[str] = None elapsed_time: float total_tokens: int diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index 84520a5991..47c65a3e50 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -30,6 +30,7 @@ from core.ops.entities.trace_entity import ( WorkflowTraceInfo, ) from core.ops.utils import get_message_data +from core.workflow.entities.workflow_execution_entities import WorkflowExecution from extensions.ext_database import db from extensions.ext_storage import storage from models.model import App, AppModelConfig, Conversation, Message, MessageFile, TraceAppConfig @@ -373,7 +374,7 @@ class TraceTask: self, trace_type: Any, message_id: Optional[str] = None, - workflow_run: Optional[WorkflowRun] = None, + workflow_execution: Optional[WorkflowExecution] = None, conversation_id: Optional[str] = None, user_id: Optional[str] = None, timer: Optional[Any] = None, @@ -381,7 +382,7 @@ class TraceTask: ): self.trace_type = trace_type self.message_id = message_id - self.workflow_run_id = workflow_run.id if workflow_run else None + self.workflow_run_id = workflow_execution.id if workflow_execution else None self.conversation_id = conversation_id self.user_id = user_id self.timer = timer diff --git a/api/core/repositories/sqlalchemy_workflow_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_execution_repository.py new file mode 100644 index 0000000000..c1a71b45d0 --- /dev/null +++ b/api/core/repositories/sqlalchemy_workflow_execution_repository.py @@ -0,0 +1,242 @@ +""" +SQLAlchemy implementation of the WorkflowExecutionRepository. +""" + +import json +import logging +from typing import Optional, Union + +from sqlalchemy import select +from sqlalchemy.engine import Engine +from sqlalchemy.orm import sessionmaker + +from core.workflow.entities.workflow_execution_entities import ( + WorkflowExecution, + WorkflowExecutionStatus, + WorkflowType, +) +from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository +from models import ( + Account, + CreatorUserRole, + EndUser, + WorkflowRun, +) +from models.enums import WorkflowRunTriggeredFrom + +logger = logging.getLogger(__name__) + + +class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository): + """ + SQLAlchemy implementation of the WorkflowExecutionRepository interface. + + This implementation supports multi-tenancy by filtering operations based on tenant_id. + Each method creates its own session, handles the transaction, and commits changes + to the database. This prevents long-running connections in the workflow core. + + This implementation also includes an in-memory cache for workflow executions to improve + performance by reducing database queries. + """ + + def __init__( + self, + session_factory: sessionmaker | Engine, + user: Union[Account, EndUser], + app_id: Optional[str], + triggered_from: Optional[WorkflowRunTriggeredFrom], + ): + """ + Initialize the repository with a SQLAlchemy sessionmaker or engine and context information. + + Args: + session_factory: SQLAlchemy sessionmaker or engine for creating sessions + user: Account or EndUser object containing tenant_id, user ID, and role information + app_id: App ID for filtering by application (can be None) + triggered_from: Source of the execution trigger (DEBUGGING or APP_RUN) + """ + # If an engine is provided, create a sessionmaker from it + if isinstance(session_factory, Engine): + self._session_factory = sessionmaker(bind=session_factory, expire_on_commit=False) + elif isinstance(session_factory, sessionmaker): + self._session_factory = session_factory + else: + raise ValueError( + f"Invalid session_factory type {type(session_factory).__name__}; expected sessionmaker or Engine" + ) + + # Extract tenant_id from user + tenant_id: str | None = user.tenant_id if isinstance(user, EndUser) else user.current_tenant_id + if not tenant_id: + raise ValueError("User must have a tenant_id or current_tenant_id") + self._tenant_id = tenant_id + + # Store app context + self._app_id = app_id + + # Extract user context + self._triggered_from = triggered_from + self._creator_user_id = user.id + + # Determine user role based on user type + self._creator_user_role = CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER + + # Initialize in-memory cache for workflow executions + # Key: execution_id, Value: WorkflowRun (DB model) + self._execution_cache: dict[str, WorkflowRun] = {} + + def _to_domain_model(self, db_model: WorkflowRun) -> WorkflowExecution: + """ + Convert a database model to a domain model. + + Args: + db_model: The database model to convert + + Returns: + The domain model + """ + # Parse JSON fields + inputs = db_model.inputs_dict + outputs = db_model.outputs_dict + graph = db_model.graph_dict + + # Convert status to domain enum + status = WorkflowExecutionStatus(db_model.status) + + return WorkflowExecution( + id=db_model.id, + workflow_id=db_model.workflow_id, + sequence_number=db_model.sequence_number, + type=WorkflowType(db_model.type), + workflow_version=db_model.version, + graph=graph, + inputs=inputs, + outputs=outputs, + status=status, + error_message=db_model.error or "", + total_tokens=db_model.total_tokens, + total_steps=db_model.total_steps, + exceptions_count=db_model.exceptions_count, + started_at=db_model.created_at, + finished_at=db_model.finished_at, + ) + + def _to_db_model(self, domain_model: WorkflowExecution) -> WorkflowRun: + """ + Convert a domain model to a database model. + + Args: + domain_model: The domain model to convert + + Returns: + The database model + """ + # Use values from constructor if provided + if not self._triggered_from: + raise ValueError("triggered_from is required in repository constructor") + if not self._creator_user_id: + raise ValueError("created_by is required in repository constructor") + if not self._creator_user_role: + raise ValueError("created_by_role is required in repository constructor") + + db_model = WorkflowRun() + db_model.id = domain_model.id + db_model.tenant_id = self._tenant_id + if self._app_id is not None: + db_model.app_id = self._app_id + db_model.workflow_id = domain_model.workflow_id + db_model.triggered_from = self._triggered_from + db_model.sequence_number = domain_model.sequence_number + db_model.type = domain_model.type + db_model.version = domain_model.workflow_version + db_model.graph = json.dumps(domain_model.graph) if domain_model.graph else None + db_model.inputs = json.dumps(domain_model.inputs) if domain_model.inputs else None + db_model.outputs = json.dumps(domain_model.outputs) if domain_model.outputs else None + db_model.status = domain_model.status + db_model.error = domain_model.error_message if domain_model.error_message else None + db_model.total_tokens = domain_model.total_tokens + db_model.total_steps = domain_model.total_steps + db_model.exceptions_count = domain_model.exceptions_count + db_model.created_by_role = self._creator_user_role + db_model.created_by = self._creator_user_id + db_model.created_at = domain_model.started_at + db_model.finished_at = domain_model.finished_at + + # Calculate elapsed time if finished_at is available + if domain_model.finished_at: + db_model.elapsed_time = (domain_model.finished_at - domain_model.started_at).total_seconds() + else: + db_model.elapsed_time = 0 + + return db_model + + def save(self, execution: WorkflowExecution) -> None: + """ + Save or update a WorkflowExecution domain entity to the database. + + This method serves as a domain-to-database adapter that: + 1. Converts the domain entity to its database representation + 2. Persists the database model using SQLAlchemy's merge operation + 3. Maintains proper multi-tenancy by including tenant context during conversion + 4. Updates the in-memory cache for faster subsequent lookups + + The method handles both creating new records and updating existing ones through + SQLAlchemy's merge operation. + + Args: + execution: The WorkflowExecution domain entity to persist + """ + # Convert domain model to database model using tenant context and other attributes + db_model = self._to_db_model(execution) + + # Create a new database session + with self._session_factory() as session: + # SQLAlchemy merge intelligently handles both insert and update operations + # based on the presence of the primary key + session.merge(db_model) + session.commit() + + # Update the in-memory cache for faster subsequent lookups + logger.debug(f"Updating cache for execution_id: {db_model.id}") + self._execution_cache[db_model.id] = db_model + + def get(self, execution_id: str) -> Optional[WorkflowExecution]: + """ + Retrieve a WorkflowExecution by its ID. + + First checks the in-memory cache, and if not found, queries the database. + If found in the database, adds it to the cache for future lookups. + + Args: + execution_id: The workflow execution ID + + Returns: + The WorkflowExecution instance if found, None otherwise + """ + # First check the cache + if execution_id in self._execution_cache: + logger.debug(f"Cache hit for execution_id: {execution_id}") + # Convert cached DB model to domain model + cached_db_model = self._execution_cache[execution_id] + return self._to_domain_model(cached_db_model) + + # If not in cache, query the database + logger.debug(f"Cache miss for execution_id: {execution_id}, querying database") + with self._session_factory() as session: + stmt = select(WorkflowRun).where( + WorkflowRun.id == execution_id, + WorkflowRun.tenant_id == self._tenant_id, + ) + + if self._app_id: + stmt = stmt.where(WorkflowRun.app_id == self._app_id) + + db_model = session.scalar(stmt) + if db_model: + # Add DB model to cache + self._execution_cache[execution_id] = db_model + + # Convert to domain model and return + return self._to_domain_model(db_model) + + return None diff --git a/api/core/workflow/entities/workflow_execution_entities.py b/api/core/workflow/entities/workflow_execution_entities.py new file mode 100644 index 0000000000..200d4697b5 --- /dev/null +++ b/api/core/workflow/entities/workflow_execution_entities.py @@ -0,0 +1,91 @@ +""" +Domain entities for workflow execution. + +Models are independent of the storage mechanism and don't contain +implementation details like tenant_id, app_id, etc. +""" + +from collections.abc import Mapping +from datetime import UTC, datetime +from enum import StrEnum +from typing import Any, Optional + +from pydantic import BaseModel, Field + + +class WorkflowType(StrEnum): + """ + Workflow Type Enum for domain layer + """ + + WORKFLOW = "workflow" + CHAT = "chat" + + +class WorkflowExecutionStatus(StrEnum): + RUNNING = "running" + SUCCEEDED = "succeeded" + FAILED = "failed" + STOPPED = "stopped" + PARTIAL_SUCCEEDED = "partial-succeeded" + + +class WorkflowExecution(BaseModel): + """ + Domain model for workflow execution based on WorkflowRun but without + user, tenant, and app attributes. + """ + + id: str = Field(...) + workflow_id: str = Field(...) + workflow_version: str = Field(...) + sequence_number: int = Field(...) + + type: WorkflowType = Field(...) + graph: Mapping[str, Any] = Field(...) + + inputs: Mapping[str, Any] = Field(...) + outputs: Optional[Mapping[str, Any]] = None + + status: WorkflowExecutionStatus = WorkflowExecutionStatus.RUNNING + error_message: str = Field(default="") + total_tokens: int = Field(default=0) + total_steps: int = Field(default=0) + exceptions_count: int = Field(default=0) + + started_at: datetime = Field(...) + finished_at: Optional[datetime] = None + + @property + def elapsed_time(self) -> float: + """ + Calculate elapsed time in seconds. + If workflow is not finished, use current time. + """ + end_time = self.finished_at or datetime.now(UTC).replace(tzinfo=None) + return (end_time - self.started_at).total_seconds() + + @classmethod + def new( + cls, + *, + id: str, + workflow_id: str, + sequence_number: int, + type: WorkflowType, + workflow_version: str, + graph: Mapping[str, Any], + inputs: Mapping[str, Any], + started_at: datetime, + ) -> "WorkflowExecution": + return WorkflowExecution( + id=id, + workflow_id=workflow_id, + sequence_number=sequence_number, + type=type, + workflow_version=workflow_version, + graph=graph, + inputs=inputs, + status=WorkflowExecutionStatus.RUNNING, + started_at=started_at, + ) diff --git a/api/core/workflow/repository/workflow_execution_repository.py b/api/core/workflow/repository/workflow_execution_repository.py new file mode 100644 index 0000000000..a39a98ee33 --- /dev/null +++ b/api/core/workflow/repository/workflow_execution_repository.py @@ -0,0 +1,42 @@ +from typing import Optional, Protocol + +from core.workflow.entities.workflow_execution_entities import WorkflowExecution + + +class WorkflowExecutionRepository(Protocol): + """ + Repository interface for WorkflowExecution. + + This interface defines the contract for accessing and manipulating + WorkflowExecution data, regardless of the underlying storage mechanism. + + Note: Domain-specific concepts like multi-tenancy (tenant_id), application context (app_id), + and other implementation details should be handled at the implementation level, not in + the core interface. This keeps the core domain model clean and independent of specific + application domains or deployment scenarios. + """ + + def save(self, execution: WorkflowExecution) -> None: + """ + Save or update a WorkflowExecution instance. + + This method handles both creating new records and updating existing ones. + The implementation should determine whether to create or update based on + the execution's ID or other identifying fields. + + Args: + execution: The WorkflowExecution instance to save or update + """ + ... + + def get(self, execution_id: str) -> Optional[WorkflowExecution]: + """ + Retrieve a WorkflowExecution by its ID. + + Args: + execution_id: The workflow execution ID + + Returns: + The WorkflowExecution instance if found, None otherwise + """ + ... diff --git a/api/core/workflow/workflow_app_generate_task_pipeline.py b/api/core/workflow/workflow_app_generate_task_pipeline.py index 0396fa8157..f2ebd78b36 100644 --- a/api/core/workflow/workflow_app_generate_task_pipeline.py +++ b/api/core/workflow/workflow_app_generate_task_pipeline.py @@ -3,6 +3,7 @@ import time from collections.abc import Generator from typing import Optional, Union +from sqlalchemy import select from sqlalchemy.orm import Session from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME @@ -53,7 +54,9 @@ from core.app.entities.task_entities import ( from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk from core.ops.ops_trace_manager import TraceQueueManager +from core.workflow.entities.workflow_execution_entities import WorkflowExecution from core.workflow.enums import SystemVariableKey +from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository from core.workflow.workflow_cycle_manager import WorkflowCycleManager from extensions.ext_database import db @@ -83,6 +86,7 @@ class WorkflowAppGenerateTaskPipeline: queue_manager: AppQueueManager, user: Union[Account, EndUser], stream: bool, + workflow_execution_repository: WorkflowExecutionRepository, workflow_node_execution_repository: WorkflowNodeExecutionRepository, ) -> None: self._base_task_pipeline = BasedGenerateTaskPipeline( @@ -111,6 +115,7 @@ class WorkflowAppGenerateTaskPipeline: SystemVariableKey.WORKFLOW_ID: workflow.id, SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id, }, + workflow_execution_repository=workflow_execution_repository, workflow_node_execution_repository=workflow_node_execution_repository, ) @@ -258,17 +263,15 @@ class WorkflowAppGenerateTaskPipeline: with Session(db.engine, expire_on_commit=False) as session: # init workflow run - workflow_run = self._workflow_cycle_manager._handle_workflow_run_start( + workflow_execution = self._workflow_cycle_manager.handle_workflow_run_start( session=session, workflow_id=self._workflow_id, - user_id=self._user_id, - created_by_role=self._created_by_role, ) - self._workflow_run_id = workflow_run.id - start_resp = self._workflow_cycle_manager._workflow_start_to_stream_response( - session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run + self._workflow_run_id = workflow_execution.id + start_resp = self._workflow_cycle_manager.workflow_start_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_execution=workflow_execution, ) - session.commit() yield start_resp elif isinstance( @@ -278,13 +281,11 @@ class WorkflowAppGenerateTaskPipeline: if not self._workflow_run_id: raise ValueError("workflow run not initialized.") with Session(db.engine, expire_on_commit=False) as session: - workflow_run = self._workflow_cycle_manager._get_workflow_run( - session=session, workflow_run_id=self._workflow_run_id + workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_retried( + workflow_execution_id=self._workflow_run_id, + event=event, ) - workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_retried( - workflow_run=workflow_run, event=event - ) - response = self._workflow_cycle_manager._workflow_node_retry_to_stream_response( + response = self._workflow_cycle_manager.workflow_node_retry_to_stream_response( event=event, task_id=self._application_generate_entity.task_id, workflow_node_execution=workflow_node_execution, @@ -297,27 +298,22 @@ class WorkflowAppGenerateTaskPipeline: if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - with Session(db.engine, expire_on_commit=False) as session: - workflow_run = self._workflow_cycle_manager._get_workflow_run( - session=session, workflow_run_id=self._workflow_run_id - ) - workflow_node_execution = self._workflow_cycle_manager._handle_node_execution_start( - workflow_run=workflow_run, event=event - ) - node_start_response = self._workflow_cycle_manager._workflow_node_start_to_stream_response( - event=event, - task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution, - ) - session.commit() + workflow_node_execution = self._workflow_cycle_manager.handle_node_execution_start( + workflow_execution_id=self._workflow_run_id, event=event + ) + node_start_response = self._workflow_cycle_manager.workflow_node_start_to_stream_response( + event=event, + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution, + ) if node_start_response: yield node_start_response elif isinstance(event, QueueNodeSucceededEvent): - workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success( + workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_success( event=event ) - node_success_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response( + node_success_response = self._workflow_cycle_manager.workflow_node_finish_to_stream_response( event=event, task_id=self._application_generate_entity.task_id, workflow_node_execution=workflow_node_execution, @@ -332,10 +328,10 @@ class WorkflowAppGenerateTaskPipeline: | QueueNodeInLoopFailedEvent | QueueNodeExceptionEvent, ): - workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed( + workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_failed( event=event, ) - node_failed_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response( + node_failed_response = self._workflow_cycle_manager.workflow_node_finish_to_stream_response( event=event, task_id=self._application_generate_entity.task_id, workflow_node_execution=workflow_node_execution, @@ -348,18 +344,11 @@ class WorkflowAppGenerateTaskPipeline: if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - with Session(db.engine, expire_on_commit=False) as session: - workflow_run = self._workflow_cycle_manager._get_workflow_run( - session=session, workflow_run_id=self._workflow_run_id - ) - parallel_start_resp = ( - self._workflow_cycle_manager._workflow_parallel_branch_start_to_stream_response( - session=session, - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run, - event=event, - ) - ) + parallel_start_resp = self._workflow_cycle_manager.workflow_parallel_branch_start_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_execution_id=self._workflow_run_id, + event=event, + ) yield parallel_start_resp @@ -367,18 +356,13 @@ class WorkflowAppGenerateTaskPipeline: if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - with Session(db.engine, expire_on_commit=False) as session: - workflow_run = self._workflow_cycle_manager._get_workflow_run( - session=session, workflow_run_id=self._workflow_run_id - ) - parallel_finish_resp = ( - self._workflow_cycle_manager._workflow_parallel_branch_finished_to_stream_response( - session=session, - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run, - event=event, - ) + parallel_finish_resp = ( + self._workflow_cycle_manager.workflow_parallel_branch_finished_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_execution_id=self._workflow_run_id, + event=event, ) + ) yield parallel_finish_resp @@ -386,16 +370,11 @@ class WorkflowAppGenerateTaskPipeline: if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - with Session(db.engine, expire_on_commit=False) as session: - workflow_run = self._workflow_cycle_manager._get_workflow_run( - session=session, workflow_run_id=self._workflow_run_id - ) - iter_start_resp = self._workflow_cycle_manager._workflow_iteration_start_to_stream_response( - session=session, - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run, - event=event, - ) + iter_start_resp = self._workflow_cycle_manager.workflow_iteration_start_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_execution_id=self._workflow_run_id, + event=event, + ) yield iter_start_resp @@ -403,16 +382,11 @@ class WorkflowAppGenerateTaskPipeline: if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - with Session(db.engine, expire_on_commit=False) as session: - workflow_run = self._workflow_cycle_manager._get_workflow_run( - session=session, workflow_run_id=self._workflow_run_id - ) - iter_next_resp = self._workflow_cycle_manager._workflow_iteration_next_to_stream_response( - session=session, - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run, - event=event, - ) + iter_next_resp = self._workflow_cycle_manager.workflow_iteration_next_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_execution_id=self._workflow_run_id, + event=event, + ) yield iter_next_resp @@ -420,16 +394,11 @@ class WorkflowAppGenerateTaskPipeline: if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - with Session(db.engine, expire_on_commit=False) as session: - workflow_run = self._workflow_cycle_manager._get_workflow_run( - session=session, workflow_run_id=self._workflow_run_id - ) - iter_finish_resp = self._workflow_cycle_manager._workflow_iteration_completed_to_stream_response( - session=session, - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run, - event=event, - ) + iter_finish_resp = self._workflow_cycle_manager.workflow_iteration_completed_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_execution_id=self._workflow_run_id, + event=event, + ) yield iter_finish_resp @@ -437,16 +406,11 @@ class WorkflowAppGenerateTaskPipeline: if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - with Session(db.engine, expire_on_commit=False) as session: - workflow_run = self._workflow_cycle_manager._get_workflow_run( - session=session, workflow_run_id=self._workflow_run_id - ) - loop_start_resp = self._workflow_cycle_manager._workflow_loop_start_to_stream_response( - session=session, - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run, - event=event, - ) + loop_start_resp = self._workflow_cycle_manager.workflow_loop_start_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_execution_id=self._workflow_run_id, + event=event, + ) yield loop_start_resp @@ -454,16 +418,11 @@ class WorkflowAppGenerateTaskPipeline: if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - with Session(db.engine, expire_on_commit=False) as session: - workflow_run = self._workflow_cycle_manager._get_workflow_run( - session=session, workflow_run_id=self._workflow_run_id - ) - loop_next_resp = self._workflow_cycle_manager._workflow_loop_next_to_stream_response( - session=session, - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run, - event=event, - ) + loop_next_resp = self._workflow_cycle_manager.workflow_loop_next_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_execution_id=self._workflow_run_id, + event=event, + ) yield loop_next_resp @@ -471,16 +430,11 @@ class WorkflowAppGenerateTaskPipeline: if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - with Session(db.engine, expire_on_commit=False) as session: - workflow_run = self._workflow_cycle_manager._get_workflow_run( - session=session, workflow_run_id=self._workflow_run_id - ) - loop_finish_resp = self._workflow_cycle_manager._workflow_loop_completed_to_stream_response( - session=session, - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run, - event=event, - ) + loop_finish_resp = self._workflow_cycle_manager.workflow_loop_completed_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_execution_id=self._workflow_run_id, + event=event, + ) yield loop_finish_resp @@ -491,10 +445,8 @@ class WorkflowAppGenerateTaskPipeline: raise ValueError("graph runtime state not initialized.") with Session(db.engine, expire_on_commit=False) as session: - workflow_run = self._workflow_cycle_manager._handle_workflow_run_success( - session=session, + workflow_execution = self._workflow_cycle_manager.handle_workflow_run_success( workflow_run_id=self._workflow_run_id, - start_at=graph_runtime_state.start_at, total_tokens=graph_runtime_state.total_tokens, total_steps=graph_runtime_state.node_run_steps, outputs=event.outputs, @@ -503,12 +455,12 @@ class WorkflowAppGenerateTaskPipeline: ) # save workflow app log - self._save_workflow_app_log(session=session, workflow_run=workflow_run) + self._save_workflow_app_log(session=session, workflow_execution=workflow_execution) - workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response( + workflow_finish_resp = self._workflow_cycle_manager.workflow_finish_to_stream_response( session=session, task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run, + workflow_execution=workflow_execution, ) session.commit() @@ -520,10 +472,8 @@ class WorkflowAppGenerateTaskPipeline: raise ValueError("graph runtime state not initialized.") with Session(db.engine, expire_on_commit=False) as session: - workflow_run = self._workflow_cycle_manager._handle_workflow_run_partial_success( - session=session, + workflow_execution = self._workflow_cycle_manager.handle_workflow_run_partial_success( workflow_run_id=self._workflow_run_id, - start_at=graph_runtime_state.start_at, total_tokens=graph_runtime_state.total_tokens, total_steps=graph_runtime_state.node_run_steps, outputs=event.outputs, @@ -533,10 +483,12 @@ class WorkflowAppGenerateTaskPipeline: ) # save workflow app log - self._save_workflow_app_log(session=session, workflow_run=workflow_run) + self._save_workflow_app_log(session=session, workflow_execution=workflow_execution) - workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response( - session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run + workflow_finish_resp = self._workflow_cycle_manager.workflow_finish_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_execution=workflow_execution, ) session.commit() @@ -548,26 +500,28 @@ class WorkflowAppGenerateTaskPipeline: raise ValueError("graph runtime state not initialized.") with Session(db.engine, expire_on_commit=False) as session: - workflow_run = self._workflow_cycle_manager._handle_workflow_run_failed( - session=session, + workflow_execution = self._workflow_cycle_manager.handle_workflow_run_failed( workflow_run_id=self._workflow_run_id, - start_at=graph_runtime_state.start_at, total_tokens=graph_runtime_state.total_tokens, total_steps=graph_runtime_state.node_run_steps, status=WorkflowRunStatus.FAILED if isinstance(event, QueueWorkflowFailedEvent) else WorkflowRunStatus.STOPPED, - error=event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(), + error_message=event.error + if isinstance(event, QueueWorkflowFailedEvent) + else event.get_stop_reason(), conversation_id=None, trace_manager=trace_manager, exceptions_count=event.exceptions_count if isinstance(event, QueueWorkflowFailedEvent) else 0, ) # save workflow app log - self._save_workflow_app_log(session=session, workflow_run=workflow_run) + self._save_workflow_app_log(session=session, workflow_execution=workflow_execution) - workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response( - session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run + workflow_finish_resp = self._workflow_cycle_manager.workflow_finish_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_execution=workflow_execution, ) session.commit() @@ -586,7 +540,7 @@ class WorkflowAppGenerateTaskPipeline: delta_text, from_variable_selector=event.from_variable_selector ) elif isinstance(event, QueueAgentLogEvent): - yield self._workflow_cycle_manager._handle_agent_log( + yield self._workflow_cycle_manager.handle_agent_log( task_id=self._application_generate_entity.task_id, event=event ) else: @@ -595,11 +549,9 @@ class WorkflowAppGenerateTaskPipeline: if tts_publisher: tts_publisher.publish(None) - def _save_workflow_app_log(self, *, session: Session, workflow_run: WorkflowRun) -> None: - """ - Save workflow app log. - :return: - """ + def _save_workflow_app_log(self, *, session: Session, workflow_execution: WorkflowExecution) -> None: + workflow_run = session.scalar(select(WorkflowRun).where(WorkflowRun.id == workflow_execution.id)) + assert workflow_run is not None invoke_from = self._application_generate_entity.invoke_from if invoke_from == InvokeFrom.SERVICE_API: created_from = WorkflowAppLogCreatedFrom.SERVICE_API diff --git a/api/core/workflow/workflow_cycle_manager.py b/api/core/workflow/workflow_cycle_manager.py index 6d33d7372c..d4c2b1b6bd 100644 --- a/api/core/workflow/workflow_cycle_manager.py +++ b/api/core/workflow/workflow_cycle_manager.py @@ -1,4 +1,3 @@ -import json import time from collections.abc import Mapping, Sequence from datetime import UTC, datetime @@ -8,7 +7,7 @@ from uuid import uuid4 from sqlalchemy import func, select from sqlalchemy.orm import Session -from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity +from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity from core.app.entities.queue_entities import ( QueueAgentLogEvent, QueueIterationCompletedEvent, @@ -54,9 +53,11 @@ from core.workflow.entities.node_execution_entities import ( NodeExecution, NodeExecutionStatus, ) +from core.workflow.entities.workflow_execution_entities import WorkflowExecution, WorkflowExecutionStatus, WorkflowType from core.workflow.enums import SystemVariableKey from core.workflow.nodes import NodeType from core.workflow.nodes.tool.entities import ToolNodeData +from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository from core.workflow.workflow_entry import WorkflowEntry from models import ( @@ -67,7 +68,6 @@ from models import ( WorkflowNodeExecutionStatus, WorkflowRun, WorkflowRunStatus, - WorkflowRunTriggeredFrom, ) @@ -77,21 +77,20 @@ class WorkflowCycleManager: *, application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity], workflow_system_variables: dict[SystemVariableKey, Any], + workflow_execution_repository: WorkflowExecutionRepository, workflow_node_execution_repository: WorkflowNodeExecutionRepository, ) -> None: - self._workflow_run: WorkflowRun | None = None self._application_generate_entity = application_generate_entity self._workflow_system_variables = workflow_system_variables + self._workflow_execution_repository = workflow_execution_repository self._workflow_node_execution_repository = workflow_node_execution_repository - def _handle_workflow_run_start( + def handle_workflow_run_start( self, *, session: Session, workflow_id: str, - user_id: str, - created_by_role: CreatorUserRole, - ) -> WorkflowRun: + ) -> WorkflowExecution: workflow_stmt = select(Workflow).where(Workflow.id == workflow_id) workflow = session.scalar(workflow_stmt) if not workflow: @@ -110,157 +109,116 @@ class WorkflowCycleManager: continue inputs[f"sys.{key.value}"] = value - triggered_from = ( - WorkflowRunTriggeredFrom.DEBUGGING - if self._application_generate_entity.invoke_from == InvokeFrom.DEBUGGER - else WorkflowRunTriggeredFrom.APP_RUN - ) - # handle special values inputs = dict(WorkflowEntry.handle_special_values(inputs) or {}) # init workflow run # TODO: This workflow_run_id should always not be None, maybe we can use a more elegant way to handle this - workflow_run_id = str(self._workflow_system_variables.get(SystemVariableKey.WORKFLOW_RUN_ID) or uuid4()) + execution_id = str(self._workflow_system_variables.get(SystemVariableKey.WORKFLOW_RUN_ID) or uuid4()) + execution = WorkflowExecution.new( + id=execution_id, + workflow_id=workflow.id, + sequence_number=new_sequence_number, + type=WorkflowType(workflow.type), + workflow_version=workflow.version, + graph=workflow.graph_dict, + inputs=inputs, + started_at=datetime.now(UTC).replace(tzinfo=None), + ) - workflow_run = WorkflowRun() - workflow_run.id = workflow_run_id - workflow_run.tenant_id = workflow.tenant_id - workflow_run.app_id = workflow.app_id - workflow_run.sequence_number = new_sequence_number - workflow_run.workflow_id = workflow.id - workflow_run.type = workflow.type - workflow_run.triggered_from = triggered_from.value - workflow_run.version = workflow.version - workflow_run.graph = workflow.graph - workflow_run.inputs = json.dumps(inputs) - workflow_run.status = WorkflowRunStatus.RUNNING - workflow_run.created_by_role = created_by_role - workflow_run.created_by = user_id - workflow_run.created_at = datetime.now(UTC).replace(tzinfo=None) + self._workflow_execution_repository.save(execution) - session.add(workflow_run) + return execution - return workflow_run - - def _handle_workflow_run_success( + def handle_workflow_run_success( self, *, - session: Session, workflow_run_id: str, - start_at: float, total_tokens: int, total_steps: int, outputs: Mapping[str, Any] | None = None, conversation_id: Optional[str] = None, trace_manager: Optional[TraceQueueManager] = None, - ) -> WorkflowRun: - """ - Workflow run success - :param workflow_run_id: workflow run id - :param start_at: start time - :param total_tokens: total tokens - :param total_steps: total steps - :param outputs: outputs - :param conversation_id: conversation id - :return: - """ - workflow_run = self._get_workflow_run(session=session, workflow_run_id=workflow_run_id) + ) -> WorkflowExecution: + workflow_execution = self._get_workflow_execution_or_raise_error(workflow_run_id) outputs = WorkflowEntry.handle_special_values(outputs) - workflow_run.status = WorkflowRunStatus.SUCCEEDED - workflow_run.outputs = json.dumps(outputs or {}) - workflow_run.elapsed_time = time.perf_counter() - start_at - workflow_run.total_tokens = total_tokens - workflow_run.total_steps = total_steps - workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None) + workflow_execution.status = WorkflowExecutionStatus.SUCCEEDED + workflow_execution.outputs = outputs or {} + workflow_execution.total_tokens = total_tokens + workflow_execution.total_steps = total_steps + workflow_execution.finished_at = datetime.now(UTC).replace(tzinfo=None) if trace_manager: trace_manager.add_trace_task( TraceTask( TraceTaskName.WORKFLOW_TRACE, - workflow_run=workflow_run, + workflow_execution=workflow_execution, conversation_id=conversation_id, user_id=trace_manager.user_id, ) ) - return workflow_run + return workflow_execution - def _handle_workflow_run_partial_success( + def handle_workflow_run_partial_success( self, *, - session: Session, workflow_run_id: str, - start_at: float, total_tokens: int, total_steps: int, outputs: Mapping[str, Any] | None = None, exceptions_count: int = 0, conversation_id: Optional[str] = None, trace_manager: Optional[TraceQueueManager] = None, - ) -> WorkflowRun: - workflow_run = self._get_workflow_run(session=session, workflow_run_id=workflow_run_id) + ) -> WorkflowExecution: + execution = self._get_workflow_execution_or_raise_error(workflow_run_id) outputs = WorkflowEntry.handle_special_values(dict(outputs) if outputs else None) - workflow_run.status = WorkflowRunStatus.PARTIAL_SUCCEEDED.value - workflow_run.outputs = json.dumps(outputs or {}) - workflow_run.elapsed_time = time.perf_counter() - start_at - workflow_run.total_tokens = total_tokens - workflow_run.total_steps = total_steps - workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None) - workflow_run.exceptions_count = exceptions_count + execution.status = WorkflowExecutionStatus.PARTIAL_SUCCEEDED + execution.outputs = outputs or {} + execution.total_tokens = total_tokens + execution.total_steps = total_steps + execution.finished_at = datetime.now(UTC).replace(tzinfo=None) + execution.exceptions_count = exceptions_count if trace_manager: trace_manager.add_trace_task( TraceTask( TraceTaskName.WORKFLOW_TRACE, - workflow_run=workflow_run, + workflow_execution=execution, conversation_id=conversation_id, user_id=trace_manager.user_id, ) ) - return workflow_run + return execution - def _handle_workflow_run_failed( + def handle_workflow_run_failed( self, *, - session: Session, workflow_run_id: str, - start_at: float, total_tokens: int, total_steps: int, status: WorkflowRunStatus, - error: str, + error_message: str, conversation_id: Optional[str] = None, trace_manager: Optional[TraceQueueManager] = None, exceptions_count: int = 0, - ) -> WorkflowRun: - """ - Workflow run failed - :param workflow_run_id: workflow run id - :param start_at: start time - :param total_tokens: total tokens - :param total_steps: total steps - :param status: status - :param error: error message - :return: - """ - workflow_run = self._get_workflow_run(session=session, workflow_run_id=workflow_run_id) + ) -> WorkflowExecution: + execution = self._get_workflow_execution_or_raise_error(workflow_run_id) - workflow_run.status = status.value - workflow_run.error = error - workflow_run.elapsed_time = time.perf_counter() - start_at - workflow_run.total_tokens = total_tokens - workflow_run.total_steps = total_steps - workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None) - workflow_run.exceptions_count = exceptions_count + execution.status = WorkflowExecutionStatus(status.value) + execution.error_message = error_message + execution.total_tokens = total_tokens + execution.total_steps = total_steps + execution.finished_at = datetime.now(UTC).replace(tzinfo=None) + execution.exceptions_count = exceptions_count # Use the instance repository to find running executions for a workflow run running_domain_executions = self._workflow_node_execution_repository.get_running_executions( - workflow_run_id=workflow_run.id + workflow_run_id=execution.id ) # Update the domain models @@ -269,7 +227,7 @@ class WorkflowCycleManager: if domain_execution.node_execution_id: # Update the domain model domain_execution.status = NodeExecutionStatus.FAILED - domain_execution.error = error + domain_execution.error = error_message domain_execution.finished_at = now domain_execution.elapsed_time = (now - domain_execution.created_at).total_seconds() @@ -280,15 +238,22 @@ class WorkflowCycleManager: trace_manager.add_trace_task( TraceTask( TraceTaskName.WORKFLOW_TRACE, - workflow_run=workflow_run, + workflow_execution=execution, conversation_id=conversation_id, user_id=trace_manager.user_id, ) ) - return workflow_run + return execution + + def handle_node_execution_start( + self, + *, + workflow_execution_id: str, + event: QueueNodeStartedEvent, + ) -> NodeExecution: + workflow_execution = self._get_workflow_execution_or_raise_error(workflow_execution_id) - def _handle_node_execution_start(self, *, workflow_run: WorkflowRun, event: QueueNodeStartedEvent) -> NodeExecution: # Create a domain model created_at = datetime.now(UTC).replace(tzinfo=None) metadata = { @@ -299,8 +264,8 @@ class WorkflowCycleManager: domain_execution = NodeExecution( id=str(uuid4()), - workflow_id=workflow_run.workflow_id, - workflow_run_id=workflow_run.id, + workflow_id=workflow_execution.workflow_id, + workflow_run_id=workflow_execution.id, predecessor_node_id=event.predecessor_node_id, index=event.node_run_index, node_execution_id=event.node_execution_id, @@ -317,7 +282,7 @@ class WorkflowCycleManager: return domain_execution - def _handle_workflow_node_execution_success(self, *, event: QueueNodeSucceededEvent) -> NodeExecution: + def handle_workflow_node_execution_success(self, *, event: QueueNodeSucceededEvent) -> NodeExecution: # Get the domain model from repository domain_execution = self._workflow_node_execution_repository.get_by_node_execution_id(event.node_execution_id) if not domain_execution: @@ -350,7 +315,7 @@ class WorkflowCycleManager: return domain_execution - def _handle_workflow_node_execution_failed( + def handle_workflow_node_execution_failed( self, *, event: QueueNodeFailedEvent @@ -400,15 +365,10 @@ class WorkflowCycleManager: return domain_execution - def _handle_workflow_node_execution_retried( - self, *, workflow_run: WorkflowRun, event: QueueNodeRetryEvent + def handle_workflow_node_execution_retried( + self, *, workflow_execution_id: str, event: QueueNodeRetryEvent ) -> NodeExecution: - """ - Workflow node execution failed - :param workflow_run: workflow run - :param event: queue node failed event - :return: - """ + workflow_execution = self._get_workflow_execution_or_raise_error(workflow_execution_id) created_at = event.start_at finished_at = datetime.now(UTC).replace(tzinfo=None) elapsed_time = (finished_at - created_at).total_seconds() @@ -433,8 +393,8 @@ class WorkflowCycleManager: # Create a domain model domain_execution = NodeExecution( id=str(uuid4()), - workflow_id=workflow_run.workflow_id, - workflow_run_id=workflow_run.id, + workflow_id=workflow_execution.workflow_id, + workflow_run_id=workflow_execution.id, predecessor_node_id=event.predecessor_node_id, node_execution_id=event.node_execution_id, node_id=event.node_id, @@ -456,34 +416,34 @@ class WorkflowCycleManager: return domain_execution - def _workflow_start_to_stream_response( + def workflow_start_to_stream_response( self, *, - session: Session, task_id: str, - workflow_run: WorkflowRun, + workflow_execution: WorkflowExecution, ) -> WorkflowStartStreamResponse: - _ = session return WorkflowStartStreamResponse( task_id=task_id, - workflow_run_id=workflow_run.id, + workflow_run_id=workflow_execution.id, data=WorkflowStartStreamResponse.Data( - id=workflow_run.id, - workflow_id=workflow_run.workflow_id, - sequence_number=workflow_run.sequence_number, - inputs=dict(workflow_run.inputs_dict or {}), - created_at=int(workflow_run.created_at.timestamp()), + id=workflow_execution.id, + workflow_id=workflow_execution.workflow_id, + sequence_number=workflow_execution.sequence_number, + inputs=workflow_execution.inputs, + created_at=int(workflow_execution.started_at.timestamp()), ), ) - def _workflow_finish_to_stream_response( + def workflow_finish_to_stream_response( self, *, session: Session, task_id: str, - workflow_run: WorkflowRun, + workflow_execution: WorkflowExecution, ) -> WorkflowFinishStreamResponse: created_by = None + workflow_run = session.scalar(select(WorkflowRun).where(WorkflowRun.id == workflow_execution.id)) + assert workflow_run is not None if workflow_run.created_by_role == CreatorUserRole.ACCOUNT: stmt = select(Account).where(Account.id == workflow_run.created_by) account = session.scalar(stmt) @@ -504,28 +464,35 @@ class WorkflowCycleManager: else: raise NotImplementedError(f"unknown created_by_role: {workflow_run.created_by_role}") + # Handle the case where finished_at is None by using current time as default + finished_at_timestamp = ( + int(workflow_execution.finished_at.timestamp()) + if workflow_execution.finished_at + else int(datetime.now(UTC).timestamp()) + ) + return WorkflowFinishStreamResponse( task_id=task_id, - workflow_run_id=workflow_run.id, + workflow_run_id=workflow_execution.id, data=WorkflowFinishStreamResponse.Data( - id=workflow_run.id, - workflow_id=workflow_run.workflow_id, - sequence_number=workflow_run.sequence_number, - status=workflow_run.status, - outputs=dict(workflow_run.outputs_dict) if workflow_run.outputs_dict else None, - error=workflow_run.error, - elapsed_time=workflow_run.elapsed_time, - total_tokens=workflow_run.total_tokens, - total_steps=workflow_run.total_steps, + id=workflow_execution.id, + workflow_id=workflow_execution.workflow_id, + sequence_number=workflow_execution.sequence_number, + status=workflow_execution.status, + outputs=workflow_execution.outputs, + error=workflow_execution.error_message, + elapsed_time=workflow_execution.elapsed_time, + total_tokens=workflow_execution.total_tokens, + total_steps=workflow_execution.total_steps, created_by=created_by, - created_at=int(workflow_run.created_at.timestamp()), - finished_at=int(workflow_run.finished_at.timestamp()), - files=self._fetch_files_from_node_outputs(dict(workflow_run.outputs_dict)), - exceptions_count=workflow_run.exceptions_count, + created_at=int(workflow_execution.started_at.timestamp()), + finished_at=finished_at_timestamp, + files=self.fetch_files_from_node_outputs(workflow_execution.outputs), + exceptions_count=workflow_execution.exceptions_count, ), ) - def _workflow_node_start_to_stream_response( + def workflow_node_start_to_stream_response( self, *, event: QueueNodeStartedEvent, @@ -571,7 +538,7 @@ class WorkflowCycleManager: return response - def _workflow_node_finish_to_stream_response( + def workflow_node_finish_to_stream_response( self, *, event: QueueNodeSucceededEvent @@ -608,7 +575,7 @@ class WorkflowCycleManager: execution_metadata=workflow_node_execution.metadata, created_at=int(workflow_node_execution.created_at.timestamp()), finished_at=int(workflow_node_execution.finished_at.timestamp()), - files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs or {}), + files=self.fetch_files_from_node_outputs(workflow_node_execution.outputs or {}), parallel_id=event.parallel_id, parallel_start_node_id=event.parallel_start_node_id, parent_parallel_id=event.parent_parallel_id, @@ -618,7 +585,7 @@ class WorkflowCycleManager: ), ) - def _workflow_node_retry_to_stream_response( + def workflow_node_retry_to_stream_response( self, *, event: QueueNodeRetryEvent, @@ -651,7 +618,7 @@ class WorkflowCycleManager: execution_metadata=workflow_node_execution.metadata, created_at=int(workflow_node_execution.created_at.timestamp()), finished_at=int(workflow_node_execution.finished_at.timestamp()), - files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs or {}), + files=self.fetch_files_from_node_outputs(workflow_node_execution.outputs or {}), parallel_id=event.parallel_id, parallel_start_node_id=event.parallel_start_node_id, parent_parallel_id=event.parent_parallel_id, @@ -662,13 +629,16 @@ class WorkflowCycleManager: ), ) - def _workflow_parallel_branch_start_to_stream_response( - self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueParallelBranchRunStartedEvent + def workflow_parallel_branch_start_to_stream_response( + self, + *, + task_id: str, + workflow_execution_id: str, + event: QueueParallelBranchRunStartedEvent, ) -> ParallelBranchStartStreamResponse: - _ = session return ParallelBranchStartStreamResponse( task_id=task_id, - workflow_run_id=workflow_run.id, + workflow_run_id=workflow_execution_id, data=ParallelBranchStartStreamResponse.Data( parallel_id=event.parallel_id, parallel_branch_id=event.parallel_start_node_id, @@ -680,18 +650,16 @@ class WorkflowCycleManager: ), ) - def _workflow_parallel_branch_finished_to_stream_response( + def workflow_parallel_branch_finished_to_stream_response( self, *, - session: Session, task_id: str, - workflow_run: WorkflowRun, + workflow_execution_id: str, event: QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent, ) -> ParallelBranchFinishedStreamResponse: - _ = session return ParallelBranchFinishedStreamResponse( task_id=task_id, - workflow_run_id=workflow_run.id, + workflow_run_id=workflow_execution_id, data=ParallelBranchFinishedStreamResponse.Data( parallel_id=event.parallel_id, parallel_branch_id=event.parallel_start_node_id, @@ -705,13 +673,16 @@ class WorkflowCycleManager: ), ) - def _workflow_iteration_start_to_stream_response( - self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationStartEvent + def workflow_iteration_start_to_stream_response( + self, + *, + task_id: str, + workflow_execution_id: str, + event: QueueIterationStartEvent, ) -> IterationNodeStartStreamResponse: - _ = session return IterationNodeStartStreamResponse( task_id=task_id, - workflow_run_id=workflow_run.id, + workflow_run_id=workflow_execution_id, data=IterationNodeStartStreamResponse.Data( id=event.node_id, node_id=event.node_id, @@ -726,13 +697,16 @@ class WorkflowCycleManager: ), ) - def _workflow_iteration_next_to_stream_response( - self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationNextEvent + def workflow_iteration_next_to_stream_response( + self, + *, + task_id: str, + workflow_execution_id: str, + event: QueueIterationNextEvent, ) -> IterationNodeNextStreamResponse: - _ = session return IterationNodeNextStreamResponse( task_id=task_id, - workflow_run_id=workflow_run.id, + workflow_run_id=workflow_execution_id, data=IterationNodeNextStreamResponse.Data( id=event.node_id, node_id=event.node_id, @@ -749,13 +723,16 @@ class WorkflowCycleManager: ), ) - def _workflow_iteration_completed_to_stream_response( - self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationCompletedEvent + def workflow_iteration_completed_to_stream_response( + self, + *, + task_id: str, + workflow_execution_id: str, + event: QueueIterationCompletedEvent, ) -> IterationNodeCompletedStreamResponse: - _ = session return IterationNodeCompletedStreamResponse( task_id=task_id, - workflow_run_id=workflow_run.id, + workflow_run_id=workflow_execution_id, data=IterationNodeCompletedStreamResponse.Data( id=event.node_id, node_id=event.node_id, @@ -779,13 +756,12 @@ class WorkflowCycleManager: ), ) - def _workflow_loop_start_to_stream_response( - self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueLoopStartEvent + def workflow_loop_start_to_stream_response( + self, *, task_id: str, workflow_execution_id: str, event: QueueLoopStartEvent ) -> LoopNodeStartStreamResponse: - _ = session return LoopNodeStartStreamResponse( task_id=task_id, - workflow_run_id=workflow_run.id, + workflow_run_id=workflow_execution_id, data=LoopNodeStartStreamResponse.Data( id=event.node_id, node_id=event.node_id, @@ -800,13 +776,16 @@ class WorkflowCycleManager: ), ) - def _workflow_loop_next_to_stream_response( - self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueLoopNextEvent + def workflow_loop_next_to_stream_response( + self, + *, + task_id: str, + workflow_execution_id: str, + event: QueueLoopNextEvent, ) -> LoopNodeNextStreamResponse: - _ = session return LoopNodeNextStreamResponse( task_id=task_id, - workflow_run_id=workflow_run.id, + workflow_run_id=workflow_execution_id, data=LoopNodeNextStreamResponse.Data( id=event.node_id, node_id=event.node_id, @@ -823,13 +802,16 @@ class WorkflowCycleManager: ), ) - def _workflow_loop_completed_to_stream_response( - self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueLoopCompletedEvent + def workflow_loop_completed_to_stream_response( + self, + *, + task_id: str, + workflow_execution_id: str, + event: QueueLoopCompletedEvent, ) -> LoopNodeCompletedStreamResponse: - _ = session return LoopNodeCompletedStreamResponse( task_id=task_id, - workflow_run_id=workflow_run.id, + workflow_run_id=workflow_execution_id, data=LoopNodeCompletedStreamResponse.Data( id=event.node_id, node_id=event.node_id, @@ -853,7 +835,7 @@ class WorkflowCycleManager: ), ) - def _fetch_files_from_node_outputs(self, outputs_dict: Mapping[str, Any]) -> Sequence[Mapping[str, Any]]: + def fetch_files_from_node_outputs(self, outputs_dict: Mapping[str, Any] | None) -> Sequence[Mapping[str, Any]]: """ Fetch files from node outputs :param outputs_dict: node outputs dict @@ -910,20 +892,13 @@ class WorkflowCycleManager: return None - def _get_workflow_run(self, *, session: Session, workflow_run_id: str) -> WorkflowRun: - if self._workflow_run and self._workflow_run.id == workflow_run_id: - cached_workflow_run = self._workflow_run - cached_workflow_run = session.merge(cached_workflow_run) - return cached_workflow_run - stmt = select(WorkflowRun).where(WorkflowRun.id == workflow_run_id) - workflow_run = session.scalar(stmt) - if not workflow_run: - raise WorkflowRunNotFoundError(workflow_run_id) - self._workflow_run = workflow_run + def _get_workflow_execution_or_raise_error(self, id: str, /) -> WorkflowExecution: + execution = self._workflow_execution_repository.get(id) + if not execution: + raise WorkflowRunNotFoundError(id) + return execution - return workflow_run - - def _handle_agent_log(self, task_id: str, event: QueueAgentLogEvent) -> AgentLogStreamResponse: + def handle_agent_log(self, task_id: str, event: QueueAgentLogEvent) -> AgentLogStreamResponse: """ Handle agent log :param task_id: task id diff --git a/api/models/workflow.py b/api/models/workflow.py index fc6de8692c..64b0e16577 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -425,14 +425,14 @@ class WorkflowRun(Base): status: Mapped[str] = mapped_column(db.String(255)) # running, succeeded, failed, stopped, partial-succeeded outputs: Mapped[Optional[str]] = mapped_column(sa.Text, default="{}") error: Mapped[Optional[str]] = mapped_column(db.Text) - elapsed_time = db.Column(db.Float, nullable=False, server_default=sa.text("0")) + elapsed_time: Mapped[float] = mapped_column(db.Float, nullable=False, server_default=sa.text("0")) total_tokens: Mapped[int] = mapped_column(sa.BigInteger, server_default=sa.text("0")) - total_steps = db.Column(db.Integer, server_default=db.text("0")) + total_steps: Mapped[int] = mapped_column(db.Integer, server_default=db.text("0")) created_by_role: Mapped[str] = mapped_column(db.String(255)) # account, end_user - created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - finished_at = db.Column(db.DateTime) - exceptions_count = db.Column(db.Integer, server_default=db.text("0")) + created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) + created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + finished_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime) + exceptions_count: Mapped[int] = mapped_column(db.Integer, server_default=db.text("0")) @property def created_by_account(self): @@ -447,7 +447,7 @@ class WorkflowRun(Base): return db.session.get(EndUser, self.created_by) if created_by_role == CreatorUserRole.END_USER else None @property - def graph_dict(self): + def graph_dict(self) -> Mapping[str, Any]: return json.loads(self.graph) if self.graph else {} @property @@ -752,12 +752,12 @@ class WorkflowAppLog(Base): id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID) app_id: Mapped[str] = mapped_column(StringUUID) - workflow_id = db.Column(StringUUID, nullable=False) + workflow_id: Mapped[str] = mapped_column(StringUUID, nullable=False) workflow_run_id: Mapped[str] = mapped_column(StringUUID) - created_from = db.Column(db.String(255), nullable=False) - created_by_role = db.Column(db.String(255), nullable=False) - created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_from: Mapped[str] = mapped_column(db.String(255), nullable=False) + created_by_role: Mapped[str] = mapped_column(db.String(255), nullable=False) + created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) + created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @property def workflow_run(self): @@ -782,9 +782,11 @@ class ConversationVariable(Base): id: Mapped[str] = mapped_column(StringUUID, primary_key=True) conversation_id: Mapped[str] = mapped_column(StringUUID, nullable=False, primary_key=True, index=True) app_id: Mapped[str] = mapped_column(StringUUID, nullable=False, index=True) - data = mapped_column(db.Text, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp(), index=True) - updated_at = mapped_column( + data: Mapped[str] = mapped_column(db.Text, nullable=False) + created_at: Mapped[datetime] = mapped_column( + db.DateTime, nullable=False, server_default=func.current_timestamp(), index=True + ) + updated_at: Mapped[datetime] = mapped_column( db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() ) @@ -832,14 +834,14 @@ class WorkflowDraftVariable(Base): # id is the unique identifier of a draft variable. id: Mapped[str] = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) - created_at = mapped_column( + created_at: Mapped[datetime] = mapped_column( db.DateTime, nullable=False, default=_naive_utc_datetime, server_default=func.current_timestamp(), ) - updated_at = mapped_column( + updated_at: Mapped[datetime] = mapped_column( db.DateTime, nullable=False, default=_naive_utc_datetime, diff --git a/api/tests/unit_tests/core/workflow/test_workflow_cycle_manager.py b/api/tests/unit_tests/core/workflow/test_workflow_cycle_manager.py index 94b9d3e2c6..9c955fc086 100644 --- a/api/tests/unit_tests/core/workflow/test_workflow_cycle_manager.py +++ b/api/tests/unit_tests/core/workflow/test_workflow_cycle_manager.py @@ -1,45 +1,73 @@ import json -import time from datetime import UTC, datetime -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock import pytest from sqlalchemy.orm import Session +from core.app.app_config.entities import AppAdditionalFeatures, WorkflowUIBasedAppConfig from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom from core.app.entities.queue_entities import ( QueueNodeFailedEvent, QueueNodeStartedEvent, QueueNodeSucceededEvent, ) +from core.workflow.entities.node_entities import NodeRunMetadataKey +from core.workflow.entities.node_execution_entities import NodeExecution, NodeExecutionStatus +from core.workflow.entities.workflow_execution_entities import WorkflowExecution, WorkflowExecutionStatus, WorkflowType from core.workflow.enums import SystemVariableKey from core.workflow.nodes import NodeType +from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository from core.workflow.workflow_cycle_manager import WorkflowCycleManager from models.enums import CreatorUserRole +from models.model import AppMode from models.workflow import ( Workflow, - WorkflowNodeExecutionStatus, WorkflowRun, WorkflowRunStatus, ) @pytest.fixture -def mock_app_generate_entity(): - entity = MagicMock(spec=AdvancedChatAppGenerateEntity) - entity.inputs = {"query": "test query"} - entity.invoke_from = InvokeFrom.WEB_APP - # Create app_config as a separate mock - app_config = MagicMock() - app_config.tenant_id = "test-tenant-id" - app_config.app_id = "test-app-id" - entity.app_config = app_config +def real_app_generate_entity(): + additional_features = AppAdditionalFeatures( + file_upload=None, + opening_statement=None, + suggested_questions=[], + suggested_questions_after_answer=False, + show_retrieve_source=False, + more_like_this=False, + speech_to_text=False, + text_to_speech=None, + trace_config=None, + ) + + app_config = WorkflowUIBasedAppConfig( + tenant_id="test-tenant-id", + app_id="test-app-id", + app_mode=AppMode.WORKFLOW, + additional_features=additional_features, + workflow_id="test-workflow-id", + ) + + entity = AdvancedChatAppGenerateEntity( + task_id="test-task-id", + app_config=app_config, + inputs={"query": "test query"}, + files=[], + user_id="test-user-id", + stream=False, + invoke_from=InvokeFrom.WEB_APP, + query="test query", + conversation_id="test-conversation-id", + ) + return entity @pytest.fixture -def mock_workflow_system_variables(): +def real_workflow_system_variables(): return { SystemVariableKey.QUERY: "test query", SystemVariableKey.CONVERSATION_ID: "test-conversation-id", @@ -59,10 +87,23 @@ def mock_node_execution_repository(): @pytest.fixture -def workflow_cycle_manager(mock_app_generate_entity, mock_workflow_system_variables, mock_node_execution_repository): +def mock_workflow_execution_repository(): + repo = MagicMock(spec=WorkflowExecutionRepository) + repo.get.return_value = None + return repo + + +@pytest.fixture +def workflow_cycle_manager( + real_app_generate_entity, + real_workflow_system_variables, + mock_workflow_execution_repository, + mock_node_execution_repository, +): return WorkflowCycleManager( - application_generate_entity=mock_app_generate_entity, - workflow_system_variables=mock_workflow_system_variables, + application_generate_entity=real_app_generate_entity, + workflow_system_variables=real_workflow_system_variables, + workflow_execution_repository=mock_workflow_execution_repository, workflow_node_execution_repository=mock_node_execution_repository, ) @@ -74,121 +115,173 @@ def mock_session(): @pytest.fixture -def mock_workflow(): - workflow = MagicMock(spec=Workflow) +def real_workflow(): + workflow = Workflow() workflow.id = "test-workflow-id" workflow.tenant_id = "test-tenant-id" workflow.app_id = "test-app-id" workflow.type = "chat" workflow.version = "1.0" - workflow.graph = json.dumps({"nodes": [], "edges": []}) + + graph_data = {"nodes": [], "edges": []} + workflow.graph = json.dumps(graph_data) + workflow.features = json.dumps({"file_upload": {"enabled": False}}) + workflow.created_by = "test-user-id" + workflow.created_at = datetime.now(UTC).replace(tzinfo=None) + workflow.updated_at = datetime.now(UTC).replace(tzinfo=None) + workflow._environment_variables = "{}" + workflow._conversation_variables = "{}" + return workflow @pytest.fixture -def mock_workflow_run(): - workflow_run = MagicMock(spec=WorkflowRun) +def real_workflow_run(): + workflow_run = WorkflowRun() workflow_run.id = "test-workflow-run-id" workflow_run.tenant_id = "test-tenant-id" workflow_run.app_id = "test-app-id" workflow_run.workflow_id = "test-workflow-id" + workflow_run.sequence_number = 1 + workflow_run.type = "chat" + workflow_run.triggered_from = "app-run" + workflow_run.version = "1.0" + workflow_run.graph = json.dumps({"nodes": [], "edges": []}) + workflow_run.inputs = json.dumps({"query": "test query"}) workflow_run.status = WorkflowRunStatus.RUNNING + workflow_run.outputs = json.dumps({"answer": "test answer"}) workflow_run.created_by_role = CreatorUserRole.ACCOUNT workflow_run.created_by = "test-user-id" workflow_run.created_at = datetime.now(UTC).replace(tzinfo=None) - workflow_run.inputs_dict = {"query": "test query"} - workflow_run.outputs_dict = {"answer": "test answer"} + return workflow_run def test_init( - workflow_cycle_manager, mock_app_generate_entity, mock_workflow_system_variables, mock_node_execution_repository + workflow_cycle_manager, + real_app_generate_entity, + real_workflow_system_variables, + mock_workflow_execution_repository, + mock_node_execution_repository, ): """Test initialization of WorkflowCycleManager""" - assert workflow_cycle_manager._workflow_run is None - assert workflow_cycle_manager._application_generate_entity == mock_app_generate_entity - assert workflow_cycle_manager._workflow_system_variables == mock_workflow_system_variables + assert workflow_cycle_manager._application_generate_entity == real_app_generate_entity + assert workflow_cycle_manager._workflow_system_variables == real_workflow_system_variables + assert workflow_cycle_manager._workflow_execution_repository == mock_workflow_execution_repository assert workflow_cycle_manager._workflow_node_execution_repository == mock_node_execution_repository -def test_handle_workflow_run_start(workflow_cycle_manager, mock_session, mock_workflow): - """Test _handle_workflow_run_start method""" +def test_handle_workflow_run_start(workflow_cycle_manager, mock_session, real_workflow): + """Test handle_workflow_run_start method""" # Mock session.scalar to return the workflow and max sequence - mock_session.scalar.side_effect = [mock_workflow, 5] + mock_session.scalar.side_effect = [real_workflow, 5] # Call the method - workflow_run = workflow_cycle_manager._handle_workflow_run_start( + workflow_execution = workflow_cycle_manager.handle_workflow_run_start( session=mock_session, workflow_id="test-workflow-id", - user_id="test-user-id", - created_by_role=CreatorUserRole.ACCOUNT, ) # Verify the result - assert workflow_run.tenant_id == mock_workflow.tenant_id - assert workflow_run.app_id == mock_workflow.app_id - assert workflow_run.workflow_id == mock_workflow.id - assert workflow_run.sequence_number == 6 # max_sequence + 1 - assert workflow_run.status == WorkflowRunStatus.RUNNING - assert workflow_run.created_by_role == CreatorUserRole.ACCOUNT - assert workflow_run.created_by == "test-user-id" + assert workflow_execution.workflow_id == real_workflow.id + assert workflow_execution.sequence_number == 6 # max_sequence + 1 - # Verify session.add was called - mock_session.add.assert_called_once_with(workflow_run) + # Verify the workflow_execution_repository.save was called + workflow_cycle_manager._workflow_execution_repository.save.assert_called_once_with(workflow_execution) -def test_handle_workflow_run_success(workflow_cycle_manager, mock_session, mock_workflow_run): - """Test _handle_workflow_run_success method""" - # Mock _get_workflow_run to return the mock_workflow_run - with patch.object(workflow_cycle_manager, "_get_workflow_run", return_value=mock_workflow_run): - # Call the method - result = workflow_cycle_manager._handle_workflow_run_success( - session=mock_session, - workflow_run_id="test-workflow-run-id", - start_at=time.perf_counter() - 10, # 10 seconds ago - total_tokens=100, - total_steps=5, - outputs={"answer": "test answer"}, - ) +def test_handle_workflow_run_success(workflow_cycle_manager, mock_workflow_execution_repository): + """Test handle_workflow_run_success method""" + # Create a real WorkflowExecution - # Verify the result - assert result == mock_workflow_run - assert result.status == WorkflowRunStatus.SUCCEEDED - assert result.outputs == json.dumps({"answer": "test answer"}) - assert result.total_tokens == 100 - assert result.total_steps == 5 - assert result.finished_at is not None + workflow_execution = WorkflowExecution( + id="test-workflow-run-id", + workflow_id="test-workflow-id", + workflow_version="1.0", + sequence_number=1, + type=WorkflowType.CHAT, + graph={"nodes": [], "edges": []}, + inputs={"query": "test query"}, + started_at=datetime.now(UTC).replace(tzinfo=None), + ) + + # Mock _get_workflow_execution_or_raise_error to return the real workflow_execution + workflow_cycle_manager._workflow_execution_repository.get.return_value = workflow_execution + + # Call the method + result = workflow_cycle_manager.handle_workflow_run_success( + workflow_run_id="test-workflow-run-id", + total_tokens=100, + total_steps=5, + outputs={"answer": "test answer"}, + ) + + # Verify the result + assert result == workflow_execution + assert result.status == WorkflowExecutionStatus.SUCCEEDED + assert result.outputs == {"answer": "test answer"} + assert result.total_tokens == 100 + assert result.total_steps == 5 + assert result.finished_at is not None -def test_handle_workflow_run_failed(workflow_cycle_manager, mock_session, mock_workflow_run): - """Test _handle_workflow_run_failed method""" - # Mock _get_workflow_run to return the mock_workflow_run - with patch.object(workflow_cycle_manager, "_get_workflow_run", return_value=mock_workflow_run): - # Mock get_running_executions to return an empty list - workflow_cycle_manager._workflow_node_execution_repository.get_running_executions.return_value = [] +def test_handle_workflow_run_failed(workflow_cycle_manager, mock_workflow_execution_repository): + """Test handle_workflow_run_failed method""" + # Create a real WorkflowExecution - # Call the method - result = workflow_cycle_manager._handle_workflow_run_failed( - session=mock_session, - workflow_run_id="test-workflow-run-id", - start_at=time.perf_counter() - 10, # 10 seconds ago - total_tokens=50, - total_steps=3, - status=WorkflowRunStatus.FAILED, - error="Test error message", - ) + workflow_execution = WorkflowExecution( + id="test-workflow-run-id", + workflow_id="test-workflow-id", + workflow_version="1.0", + sequence_number=1, + type=WorkflowType.CHAT, + graph={"nodes": [], "edges": []}, + inputs={"query": "test query"}, + started_at=datetime.now(UTC).replace(tzinfo=None), + ) - # Verify the result - assert result == mock_workflow_run - assert result.status == WorkflowRunStatus.FAILED.value - assert result.error == "Test error message" - assert result.total_tokens == 50 - assert result.total_steps == 3 - assert result.finished_at is not None + # Mock _get_workflow_execution_or_raise_error to return the real workflow_execution + workflow_cycle_manager._workflow_execution_repository.get.return_value = workflow_execution + + # Mock get_running_executions to return an empty list + workflow_cycle_manager._workflow_node_execution_repository.get_running_executions.return_value = [] + + # Call the method + result = workflow_cycle_manager.handle_workflow_run_failed( + workflow_run_id="test-workflow-run-id", + total_tokens=50, + total_steps=3, + status=WorkflowRunStatus.FAILED, + error_message="Test error message", + ) + + # Verify the result + assert result == workflow_execution + assert result.status == WorkflowExecutionStatus(WorkflowRunStatus.FAILED.value) + assert result.error_message == "Test error message" + assert result.total_tokens == 50 + assert result.total_steps == 3 + assert result.finished_at is not None -def test_handle_node_execution_start(workflow_cycle_manager, mock_workflow_run): - """Test _handle_node_execution_start method""" +def test_handle_node_execution_start(workflow_cycle_manager, mock_workflow_execution_repository): + """Test handle_node_execution_start method""" + # Create a real WorkflowExecution + + workflow_execution = WorkflowExecution( + id="test-workflow-execution-id", + workflow_id="test-workflow-id", + workflow_version="1.0", + sequence_number=1, + type=WorkflowType.CHAT, + graph={"nodes": [], "edges": []}, + inputs={"query": "test query"}, + started_at=datetime.now(UTC).replace(tzinfo=None), + ) + + # Mock _get_workflow_execution_or_raise_error to return the real workflow_execution + workflow_cycle_manager._workflow_execution_repository.get.return_value = workflow_execution + # Create a mock event event = MagicMock(spec=QueueNodeStartedEvent) event.node_execution_id = "test-node-execution-id" @@ -207,129 +300,171 @@ def test_handle_node_execution_start(workflow_cycle_manager, mock_workflow_run): event.in_loop_id = "test-loop-id" # Call the method - result = workflow_cycle_manager._handle_node_execution_start( - workflow_run=mock_workflow_run, + result = workflow_cycle_manager.handle_node_execution_start( + workflow_execution_id=workflow_execution.id, event=event, ) # Verify the result - # NodeExecution doesn't have tenant_id attribute, it's handled at repository level - # assert result.tenant_id == mock_workflow_run.tenant_id - # assert result.app_id == mock_workflow_run.app_id - assert result.workflow_id == mock_workflow_run.workflow_id - assert result.workflow_run_id == mock_workflow_run.id + assert result.workflow_id == workflow_execution.workflow_id + assert result.workflow_run_id == workflow_execution.id assert result.node_execution_id == event.node_execution_id assert result.node_id == event.node_id assert result.node_type == event.node_type assert result.title == event.node_data.title - assert result.status == WorkflowNodeExecutionStatus.RUNNING.value - # NodeExecution doesn't have created_by_role and created_by attributes, they're handled at repository level - # assert result.created_by_role == mock_workflow_run.created_by_role - # assert result.created_by == mock_workflow_run.created_by + assert result.status == NodeExecutionStatus.RUNNING # Verify save was called workflow_cycle_manager._workflow_node_execution_repository.save.assert_called_once_with(result) -def test_get_workflow_run(workflow_cycle_manager, mock_session, mock_workflow_run): - """Test _get_workflow_run method""" - # Mock session.scalar to return the workflow run - mock_session.scalar.return_value = mock_workflow_run +def test_get_workflow_execution_or_raise_error(workflow_cycle_manager, mock_workflow_execution_repository): + """Test _get_workflow_execution_or_raise_error method""" + # Create a real WorkflowExecution - # Call the method - result = workflow_cycle_manager._get_workflow_run( - session=mock_session, - workflow_run_id="test-workflow-run-id", + workflow_execution = WorkflowExecution( + id="test-workflow-run-id", + workflow_id="test-workflow-id", + workflow_version="1.0", + sequence_number=1, + type=WorkflowType.CHAT, + graph={"nodes": [], "edges": []}, + inputs={"query": "test query"}, + started_at=datetime.now(UTC).replace(tzinfo=None), ) + # Mock the repository get method to return the real execution + workflow_cycle_manager._workflow_execution_repository.get.return_value = workflow_execution + + # Call the method + result = workflow_cycle_manager._get_workflow_execution_or_raise_error("test-workflow-run-id") + # Verify the result - assert result == mock_workflow_run - assert workflow_cycle_manager._workflow_run == mock_workflow_run + assert result == workflow_execution + + # Test error case + workflow_cycle_manager._workflow_execution_repository.get.return_value = None + + # Expect an error when execution is not found + with pytest.raises(ValueError): + workflow_cycle_manager._get_workflow_execution_or_raise_error("non-existent-id") def test_handle_workflow_node_execution_success(workflow_cycle_manager): - """Test _handle_workflow_node_execution_success method""" + """Test handle_workflow_node_execution_success method""" # Create a mock event event = MagicMock(spec=QueueNodeSucceededEvent) event.node_execution_id = "test-node-execution-id" event.inputs = {"input": "test input"} event.process_data = {"process": "test process"} event.outputs = {"output": "test output"} - event.execution_metadata = {"metadata": "test metadata"} + event.execution_metadata = {NodeRunMetadataKey.TOTAL_TOKENS: 100} event.start_at = datetime.now(UTC).replace(tzinfo=None) - # Create a mock node execution - node_execution = MagicMock() - node_execution.node_execution_id = "test-node-execution-id" + # Create a real node execution + + node_execution = NodeExecution( + id="test-node-execution-record-id", + node_execution_id="test-node-execution-id", + workflow_id="test-workflow-id", + workflow_run_id="test-workflow-run-id", + index=1, + node_id="test-node-id", + node_type=NodeType.LLM, + title="Test Node", + created_at=datetime.now(UTC).replace(tzinfo=None), + ) # Mock the repository to return the node execution workflow_cycle_manager._workflow_node_execution_repository.get_by_node_execution_id.return_value = node_execution # Call the method - result = workflow_cycle_manager._handle_workflow_node_execution_success( + result = workflow_cycle_manager.handle_workflow_node_execution_success( event=event, ) # Verify the result assert result == node_execution - assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED.value + assert result.status == NodeExecutionStatus.SUCCEEDED # Verify save was called workflow_cycle_manager._workflow_node_execution_repository.save.assert_called_once_with(node_execution) -def test_handle_workflow_run_partial_success(workflow_cycle_manager, mock_session, mock_workflow_run): - """Test _handle_workflow_run_partial_success method""" - # Mock _get_workflow_run to return the mock_workflow_run - with patch.object(workflow_cycle_manager, "_get_workflow_run", return_value=mock_workflow_run): - # Call the method - result = workflow_cycle_manager._handle_workflow_run_partial_success( - session=mock_session, - workflow_run_id="test-workflow-run-id", - start_at=time.perf_counter() - 10, # 10 seconds ago - total_tokens=75, - total_steps=4, - outputs={"partial_answer": "test partial answer"}, - exceptions_count=2, - ) +def test_handle_workflow_run_partial_success(workflow_cycle_manager, mock_workflow_execution_repository): + """Test handle_workflow_run_partial_success method""" + # Create a real WorkflowExecution - # Verify the result - assert result == mock_workflow_run - assert result.status == WorkflowRunStatus.PARTIAL_SUCCEEDED.value - assert result.outputs == json.dumps({"partial_answer": "test partial answer"}) - assert result.total_tokens == 75 - assert result.total_steps == 4 - assert result.exceptions_count == 2 - assert result.finished_at is not None + workflow_execution = WorkflowExecution( + id="test-workflow-run-id", + workflow_id="test-workflow-id", + workflow_version="1.0", + sequence_number=1, + type=WorkflowType.CHAT, + graph={"nodes": [], "edges": []}, + inputs={"query": "test query"}, + started_at=datetime.now(UTC).replace(tzinfo=None), + ) + + # Mock _get_workflow_execution_or_raise_error to return the real workflow_execution + workflow_cycle_manager._workflow_execution_repository.get.return_value = workflow_execution + + # Call the method + result = workflow_cycle_manager.handle_workflow_run_partial_success( + workflow_run_id="test-workflow-run-id", + total_tokens=75, + total_steps=4, + outputs={"partial_answer": "test partial answer"}, + exceptions_count=2, + ) + + # Verify the result + assert result == workflow_execution + assert result.status == WorkflowExecutionStatus.PARTIAL_SUCCEEDED + assert result.outputs == {"partial_answer": "test partial answer"} + assert result.total_tokens == 75 + assert result.total_steps == 4 + assert result.exceptions_count == 2 + assert result.finished_at is not None def test_handle_workflow_node_execution_failed(workflow_cycle_manager): - """Test _handle_workflow_node_execution_failed method""" + """Test handle_workflow_node_execution_failed method""" # Create a mock event event = MagicMock(spec=QueueNodeFailedEvent) event.node_execution_id = "test-node-execution-id" event.inputs = {"input": "test input"} event.process_data = {"process": "test process"} event.outputs = {"output": "test output"} - event.execution_metadata = {"metadata": "test metadata"} + event.execution_metadata = {NodeRunMetadataKey.TOTAL_TOKENS: 100} event.start_at = datetime.now(UTC).replace(tzinfo=None) event.error = "Test error message" - # Create a mock node execution - node_execution = MagicMock() - node_execution.node_execution_id = "test-node-execution-id" + # Create a real node execution + + node_execution = NodeExecution( + id="test-node-execution-record-id", + node_execution_id="test-node-execution-id", + workflow_id="test-workflow-id", + workflow_run_id="test-workflow-run-id", + index=1, + node_id="test-node-id", + node_type=NodeType.LLM, + title="Test Node", + created_at=datetime.now(UTC).replace(tzinfo=None), + ) # Mock the repository to return the node execution workflow_cycle_manager._workflow_node_execution_repository.get_by_node_execution_id.return_value = node_execution # Call the method - result = workflow_cycle_manager._handle_workflow_node_execution_failed( + result = workflow_cycle_manager.handle_workflow_node_execution_failed( event=event, ) # Verify the result assert result == node_execution - assert result.status == WorkflowNodeExecutionStatus.FAILED.value + assert result.status == NodeExecutionStatus.FAILED assert result.error == "Test error message" # Verify save was called