fix(workflow): fix generate issues in workflow

This commit is contained in:
takatost 2024-08-15 20:45:23 +08:00
parent 1da5862a96
commit 702df31db7
8 changed files with 94 additions and 31 deletions

View File

@ -384,28 +384,23 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
yield self._error_to_stream_response(self._handle_error(err_event, self._message)) yield self._error_to_stream_response(self._handle_error(err_event, self._message))
break break
elif isinstance(event, QueueStopEvent): elif isinstance(event, QueueStopEvent):
if not workflow_run: if workflow_run and graph_runtime_state:
raise Exception('Workflow run not initialized.') workflow_run = self._handle_workflow_run_failed(
workflow_run=workflow_run,
if not graph_runtime_state: start_at=graph_runtime_state.start_at,
raise Exception('Graph runtime state not initialized.') total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
workflow_run = self._handle_workflow_run_failed( status=WorkflowRunStatus.STOPPED,
workflow_run=workflow_run, error=event.get_stop_reason(),
start_at=graph_runtime_state.start_at, conversation_id=self._conversation.id,
total_tokens=graph_runtime_state.total_tokens, trace_manager=trace_manager,
total_steps=graph_runtime_state.node_run_steps, )
status=WorkflowRunStatus.STOPPED,
error='Workflow stopped.',
conversation_id=self._conversation.id,
trace_manager=trace_manager,
)
yield self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run
)
yield self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run
)
# Save message # Save message
self._save_message(graph_runtime_state=graph_runtime_state) self._save_message(graph_runtime_state=graph_runtime_state)
@ -471,7 +466,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
if self._conversation_name_generate_thread: if self._conversation_name_generate_thread:
self._conversation_name_generate_thread.join() self._conversation_name_generate_thread.join()
def _save_message(self, graph_runtime_state: GraphRuntimeState) -> None: def _save_message(self, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None:
""" """
Save message. Save message.
:return: :return:
@ -483,7 +478,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \ self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \
if self._task_state.metadata else None if self._task_state.metadata else None
if graph_runtime_state.llm_usage: if graph_runtime_state and graph_runtime_state.llm_usage:
usage = graph_runtime_state.llm_usage usage = graph_runtime_state.llm_usage
self._message.message_tokens = usage.prompt_tokens self._message.message_tokens = usage.prompt_tokens
self._message.message_unit_price = usage.prompt_unit_price self._message.message_unit_price = usage.prompt_unit_price
@ -511,7 +506,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
""" """
extras = {} extras = {}
if self._task_state.metadata: if self._task_state.metadata:
extras['metadata'] = self._task_state.metadata extras['metadata'] = self._task_state.metadata.copy()
if 'annotation_reply' in extras['metadata']:
del extras['metadata']['annotation_reply']
return MessageEndStreamResponse( return MessageEndStreamResponse(
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,

View File

@ -50,6 +50,7 @@ from models.workflow import (
WorkflowAppLog, WorkflowAppLog,
WorkflowAppLogCreatedFrom, WorkflowAppLogCreatedFrom,
WorkflowRun, WorkflowRun,
WorkflowRunStatus,
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -304,7 +305,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
workflow_run=workflow_run, workflow_run=workflow_run,
event=event event=event
) )
elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent): elif isinstance(event, QueueWorkflowSucceededEvent):
if not workflow_run: if not workflow_run:
raise Exception('Workflow run not initialized.') raise Exception('Workflow run not initialized.')
@ -324,6 +325,31 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
# save workflow app log # save workflow app log
self._save_workflow_app_log(workflow_run) self._save_workflow_app_log(workflow_run)
yield self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run
)
elif isinstance(event, QueueWorkflowFailedEvent | QueueStopEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
if not graph_runtime_state:
raise Exception('Graph runtime state not initialized.')
workflow_run = self._handle_workflow_run_failed(
workflow_run=workflow_run,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
status=WorkflowRunStatus.FAILED if isinstance(event, QueueWorkflowFailedEvent) else WorkflowRunStatus.STOPPED,
error=event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(),
conversation_id=None,
trace_manager=trace_manager,
)
# save workflow app log
self._save_workflow_app_log(workflow_run)
yield self._workflow_finish_to_stream_response( yield self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run workflow_run=workflow_run

View File

@ -1,4 +1,5 @@
from typing import Any, Mapping, Optional, cast from collections.abc import Mapping
from typing import Any, Optional, cast
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.base_app_runner import AppRunner from core.app.apps.base_app_runner import AppRunner
@ -198,7 +199,7 @@ class WorkflowBasedAppRunner(AppRunner):
parallel_id=event.parallel_id, parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id, parallel_start_node_id=event.parallel_start_node_id,
start_at=event.route_node_state.start_at, start_at=event.route_node_state.start_at,
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps, node_run_index=event.route_node_state.index,
predecessor_node_id=event.predecessor_node_id predecessor_node_id=event.predecessor_node_id
) )
) )

View File

@ -324,6 +324,19 @@ class QueueStopEvent(AppQueueEvent):
event: QueueEvent = QueueEvent.STOP event: QueueEvent = QueueEvent.STOP
stopped_by: StopBy stopped_by: StopBy
def get_stop_reason(self) -> str:
"""
To stop reason
"""
reason_mapping = {
QueueStopEvent.StopBy.USER_MANUAL: 'Stopped by user.',
QueueStopEvent.StopBy.ANNOTATION_REPLY: 'Stopped by annotation reply.',
QueueStopEvent.StopBy.OUTPUT_MODERATION: 'Stopped by output moderation.',
QueueStopEvent.StopBy.INPUT_MODERATION: 'Stopped by input moderation.'
}
return reason_mapping.get(self.stopped_by, 'Stopped by unknown reason.')
class QueueMessage(BaseModel): class QueueMessage(BaseModel):
""" """

View File

@ -8,6 +8,7 @@ from core.app.entities.app_invoke_entities import (
AgentChatAppGenerateEntity, AgentChatAppGenerateEntity,
ChatAppGenerateEntity, ChatAppGenerateEntity,
CompletionAppGenerateEntity, CompletionAppGenerateEntity,
InvokeFrom,
) )
from core.app.entities.queue_entities import ( from core.app.entities.queue_entities import (
QueueAnnotationReplyEvent, QueueAnnotationReplyEvent,

View File

@ -178,6 +178,23 @@ class WorkflowCycleManage:
workflow_run.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) workflow_run.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
db.session.commit() db.session.commit()
running_workflow_node_executions = db.session.query(WorkflowNodeExecution).filter(
WorkflowNodeExecution.tenant_id == workflow_run.tenant_id,
WorkflowNodeExecution.app_id == workflow_run.app_id,
WorkflowNodeExecution.workflow_id == workflow_run.workflow_id,
WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
WorkflowNodeExecution.workflow_run_id == workflow_run.id,
WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value
).all()
for workflow_node_execution in running_workflow_node_executions:
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
workflow_node_execution.error = error
workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
workflow_node_execution.elapsed_time = (workflow_node_execution.finished_at - workflow_node_execution.created_at).total_seconds()
db.session.commit()
db.session.refresh(workflow_run) db.session.refresh(workflow_run)
db.session.close() db.session.close()

View File

@ -43,6 +43,8 @@ class RouteNodeState(BaseModel):
paused_by: Optional[str] = None paused_by: Optional[str] = None
"""paused by""" """paused by"""
index: int = 1
def set_finished(self, run_result: NodeRunResult) -> None: def set_finished(self, run_result: NodeRunResult) -> None:
""" """
Node finished Node finished

View File

@ -39,7 +39,7 @@ from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeSta
from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor
from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor
from core.workflow.nodes.event import RunCompletedEvent, RunEvent, RunRetrieverResourceEvent, RunStreamChunkEvent from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
from core.workflow.nodes.node_mapping import node_classes from core.workflow.nodes.node_mapping import node_classes
from extensions.ext_database import db from extensions.ext_database import db
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
@ -193,13 +193,20 @@ class GraphEngine:
try: try:
# run node # run node
yield from self._run_node( generator = self._run_node(
node_instance=node_instance, node_instance=node_instance,
route_node_state=route_node_state, route_node_state=route_node_state,
parallel_id=in_parallel_id, parallel_id=in_parallel_id,
parallel_start_node_id=parallel_start_node_id parallel_start_node_id=parallel_start_node_id
) )
for item in generator:
if isinstance(item, NodeRunStartedEvent):
self.graph_runtime_state.node_run_steps += 1
item.route_node_state.index = self.graph_runtime_state.node_run_steps
yield item
self.graph_runtime_state.node_run_state.node_state_mapping[route_node_state.id] = route_node_state self.graph_runtime_state.node_run_state.node_state_mapping[route_node_state.id] = route_node_state
# append route # append route
@ -394,8 +401,6 @@ class GraphEngine:
db.session.close() db.session.close()
self.graph_runtime_state.node_run_steps += 1
try: try:
# run node # run node
generator = node_instance.run() generator = node_instance.run()