From 702df31db76556fff0a00f34657e7e31f95d6a2c Mon Sep 17 00:00:00 2001 From: takatost Date: Thu, 15 Aug 2024 20:45:23 +0800 Subject: [PATCH] fix(workflow): fix generate issues in workflow --- .../advanced_chat/generate_task_pipeline.py | 46 +++++++++---------- .../apps/workflow/generate_task_pipeline.py | 28 ++++++++++- api/core/app/apps/workflow_app_runner.py | 5 +- api/core/app/entities/queue_entities.py | 13 ++++++ .../app/task_pipeline/message_cycle_manage.py | 1 + .../task_pipeline/workflow_cycle_manage.py | 17 +++++++ .../entities/runtime_route_state.py | 2 + .../workflow/graph_engine/graph_engine.py | 13 ++++-- 8 files changed, 94 insertions(+), 31 deletions(-) diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 8c7860f8b4..bf6e576bb2 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -384,28 +384,23 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc yield self._error_to_stream_response(self._handle_error(err_event, self._message)) break elif isinstance(event, QueueStopEvent): - if not workflow_run: - raise Exception('Workflow run not initialized.') - - if not graph_runtime_state: - raise Exception('Graph runtime state not initialized.') - - workflow_run = self._handle_workflow_run_failed( - workflow_run=workflow_run, - start_at=graph_runtime_state.start_at, - total_tokens=graph_runtime_state.total_tokens, - total_steps=graph_runtime_state.node_run_steps, - status=WorkflowRunStatus.STOPPED, - error='Workflow stopped.', - conversation_id=self._conversation.id, - trace_manager=trace_manager, - ) - - yield self._workflow_finish_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run - ) + if workflow_run and graph_runtime_state: + workflow_run = self._handle_workflow_run_failed( + workflow_run=workflow_run, + start_at=graph_runtime_state.start_at, + total_tokens=graph_runtime_state.total_tokens, + total_steps=graph_runtime_state.node_run_steps, + status=WorkflowRunStatus.STOPPED, + error=event.get_stop_reason(), + conversation_id=self._conversation.id, + trace_manager=trace_manager, + ) + yield self._workflow_finish_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run + ) + # Save message self._save_message(graph_runtime_state=graph_runtime_state) @@ -471,7 +466,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc if self._conversation_name_generate_thread: self._conversation_name_generate_thread.join() - def _save_message(self, graph_runtime_state: GraphRuntimeState) -> None: + def _save_message(self, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None: """ Save message. :return: @@ -483,7 +478,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \ if self._task_state.metadata else None - if graph_runtime_state.llm_usage: + if graph_runtime_state and graph_runtime_state.llm_usage: usage = graph_runtime_state.llm_usage self._message.message_tokens = usage.prompt_tokens self._message.message_unit_price = usage.prompt_unit_price @@ -511,7 +506,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc """ extras = {} if self._task_state.metadata: - extras['metadata'] = self._task_state.metadata + extras['metadata'] = self._task_state.metadata.copy() + + if 'annotation_reply' in extras['metadata']: + del extras['metadata']['annotation_reply'] return MessageEndStreamResponse( task_id=self._application_generate_entity.task_id, diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index 08a335ec36..d19d3d3ed0 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -50,6 +50,7 @@ from models.workflow import ( WorkflowAppLog, WorkflowAppLogCreatedFrom, WorkflowRun, + WorkflowRunStatus, ) logger = logging.getLogger(__name__) @@ -304,7 +305,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa workflow_run=workflow_run, event=event ) - elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent): + elif isinstance(event, QueueWorkflowSucceededEvent): if not workflow_run: raise Exception('Workflow run not initialized.') @@ -324,6 +325,31 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa # save workflow app log self._save_workflow_app_log(workflow_run) + yield self._workflow_finish_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run + ) + elif isinstance(event, QueueWorkflowFailedEvent | QueueStopEvent): + if not workflow_run: + raise Exception('Workflow run not initialized.') + + if not graph_runtime_state: + raise Exception('Graph runtime state not initialized.') + + workflow_run = self._handle_workflow_run_failed( + workflow_run=workflow_run, + start_at=graph_runtime_state.start_at, + total_tokens=graph_runtime_state.total_tokens, + total_steps=graph_runtime_state.node_run_steps, + status=WorkflowRunStatus.FAILED if isinstance(event, QueueWorkflowFailedEvent) else WorkflowRunStatus.STOPPED, + error=event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(), + conversation_id=None, + trace_manager=trace_manager, + ) + + # save workflow app log + self._save_workflow_app_log(workflow_run) + yield self._workflow_finish_to_stream_response( task_id=self._application_generate_entity.task_id, workflow_run=workflow_run diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py index 42a2f60582..9355b58e2e 100644 --- a/api/core/app/apps/workflow_app_runner.py +++ b/api/core/app/apps/workflow_app_runner.py @@ -1,4 +1,5 @@ -from typing import Any, Mapping, Optional, cast +from collections.abc import Mapping +from typing import Any, Optional, cast from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.base_app_runner import AppRunner @@ -198,7 +199,7 @@ class WorkflowBasedAppRunner(AppRunner): parallel_id=event.parallel_id, parallel_start_node_id=event.parallel_start_node_id, start_at=event.route_node_state.start_at, - node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps, + node_run_index=event.route_node_state.index, predecessor_node_id=event.predecessor_node_id ) ) diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index c4882ff669..4aa806addc 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -324,6 +324,19 @@ class QueueStopEvent(AppQueueEvent): event: QueueEvent = QueueEvent.STOP stopped_by: StopBy + def get_stop_reason(self) -> str: + """ + To stop reason + """ + reason_mapping = { + QueueStopEvent.StopBy.USER_MANUAL: 'Stopped by user.', + QueueStopEvent.StopBy.ANNOTATION_REPLY: 'Stopped by annotation reply.', + QueueStopEvent.StopBy.OUTPUT_MODERATION: 'Stopped by output moderation.', + QueueStopEvent.StopBy.INPUT_MODERATION: 'Stopped by input moderation.' + } + + return reason_mapping.get(self.stopped_by, 'Stopped by unknown reason.') + class QueueMessage(BaseModel): """ diff --git a/api/core/app/task_pipeline/message_cycle_manage.py b/api/core/app/task_pipeline/message_cycle_manage.py index 8ff50dd174..88f57d9f3f 100644 --- a/api/core/app/task_pipeline/message_cycle_manage.py +++ b/api/core/app/task_pipeline/message_cycle_manage.py @@ -8,6 +8,7 @@ from core.app.entities.app_invoke_entities import ( AgentChatAppGenerateEntity, ChatAppGenerateEntity, CompletionAppGenerateEntity, + InvokeFrom, ) from core.app.entities.queue_entities import ( QueueAnnotationReplyEvent, diff --git a/api/core/app/task_pipeline/workflow_cycle_manage.py b/api/core/app/task_pipeline/workflow_cycle_manage.py index 15a9833a66..ca332348ee 100644 --- a/api/core/app/task_pipeline/workflow_cycle_manage.py +++ b/api/core/app/task_pipeline/workflow_cycle_manage.py @@ -178,6 +178,23 @@ class WorkflowCycleManage: workflow_run.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) db.session.commit() + + running_workflow_node_executions = db.session.query(WorkflowNodeExecution).filter( + WorkflowNodeExecution.tenant_id == workflow_run.tenant_id, + WorkflowNodeExecution.app_id == workflow_run.app_id, + WorkflowNodeExecution.workflow_id == workflow_run.workflow_id, + WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, + WorkflowNodeExecution.workflow_run_id == workflow_run.id, + WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value + ).all() + + for workflow_node_execution in running_workflow_node_executions: + workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value + workflow_node_execution.error = error + workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) + workflow_node_execution.elapsed_time = (workflow_node_execution.finished_at - workflow_node_execution.created_at).total_seconds() + db.session.commit() + db.session.refresh(workflow_run) db.session.close() diff --git a/api/core/workflow/graph_engine/entities/runtime_route_state.py b/api/core/workflow/graph_engine/entities/runtime_route_state.py index 90c918e370..b5d6e4c09d 100644 --- a/api/core/workflow/graph_engine/entities/runtime_route_state.py +++ b/api/core/workflow/graph_engine/entities/runtime_route_state.py @@ -43,6 +43,8 @@ class RouteNodeState(BaseModel): paused_by: Optional[str] = None """paused by""" + index: int = 1 + def set_finished(self, run_result: NodeRunResult) -> None: """ Node finished diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index b9e3e78e5c..cf9e615b32 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -39,7 +39,7 @@ from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeSta from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor -from core.workflow.nodes.event import RunCompletedEvent, RunEvent, RunRetrieverResourceEvent, RunStreamChunkEvent +from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent from core.workflow.nodes.node_mapping import node_classes from extensions.ext_database import db from models.workflow import WorkflowNodeExecutionStatus, WorkflowType @@ -193,13 +193,20 @@ class GraphEngine: try: # run node - yield from self._run_node( + generator = self._run_node( node_instance=node_instance, route_node_state=route_node_state, parallel_id=in_parallel_id, parallel_start_node_id=parallel_start_node_id ) + for item in generator: + if isinstance(item, NodeRunStartedEvent): + self.graph_runtime_state.node_run_steps += 1 + item.route_node_state.index = self.graph_runtime_state.node_run_steps + + yield item + self.graph_runtime_state.node_run_state.node_state_mapping[route_node_state.id] = route_node_state # append route @@ -394,8 +401,6 @@ class GraphEngine: db.session.close() - self.graph_runtime_state.node_run_steps += 1 - try: # run node generator = node_instance.run()