fix(workflow): fix generate issues in workflow

This commit is contained in:
takatost 2024-08-15 20:45:23 +08:00
parent 1da5862a96
commit 702df31db7
8 changed files with 94 additions and 31 deletions

View File

@ -384,28 +384,23 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
yield self._error_to_stream_response(self._handle_error(err_event, self._message))
break
elif isinstance(event, QueueStopEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
if not graph_runtime_state:
raise Exception('Graph runtime state not initialized.')
workflow_run = self._handle_workflow_run_failed(
workflow_run=workflow_run,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
status=WorkflowRunStatus.STOPPED,
error='Workflow stopped.',
conversation_id=self._conversation.id,
trace_manager=trace_manager,
)
yield self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run
)
if workflow_run and graph_runtime_state:
workflow_run = self._handle_workflow_run_failed(
workflow_run=workflow_run,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
status=WorkflowRunStatus.STOPPED,
error=event.get_stop_reason(),
conversation_id=self._conversation.id,
trace_manager=trace_manager,
)
yield self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run
)
# Save message
self._save_message(graph_runtime_state=graph_runtime_state)
@ -471,7 +466,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
if self._conversation_name_generate_thread:
self._conversation_name_generate_thread.join()
def _save_message(self, graph_runtime_state: GraphRuntimeState) -> None:
def _save_message(self, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None:
"""
Save message.
:return:
@ -483,7 +478,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \
if self._task_state.metadata else None
if graph_runtime_state.llm_usage:
if graph_runtime_state and graph_runtime_state.llm_usage:
usage = graph_runtime_state.llm_usage
self._message.message_tokens = usage.prompt_tokens
self._message.message_unit_price = usage.prompt_unit_price
@ -511,7 +506,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
"""
extras = {}
if self._task_state.metadata:
extras['metadata'] = self._task_state.metadata
extras['metadata'] = self._task_state.metadata.copy()
if 'annotation_reply' in extras['metadata']:
del extras['metadata']['annotation_reply']
return MessageEndStreamResponse(
task_id=self._application_generate_entity.task_id,

View File

@ -50,6 +50,7 @@ from models.workflow import (
WorkflowAppLog,
WorkflowAppLogCreatedFrom,
WorkflowRun,
WorkflowRunStatus,
)
logger = logging.getLogger(__name__)
@ -304,7 +305,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
workflow_run=workflow_run,
event=event
)
elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent):
elif isinstance(event, QueueWorkflowSucceededEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
@ -324,6 +325,31 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
# save workflow app log
self._save_workflow_app_log(workflow_run)
yield self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run
)
elif isinstance(event, QueueWorkflowFailedEvent | QueueStopEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
if not graph_runtime_state:
raise Exception('Graph runtime state not initialized.')
workflow_run = self._handle_workflow_run_failed(
workflow_run=workflow_run,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
status=WorkflowRunStatus.FAILED if isinstance(event, QueueWorkflowFailedEvent) else WorkflowRunStatus.STOPPED,
error=event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(),
conversation_id=None,
trace_manager=trace_manager,
)
# save workflow app log
self._save_workflow_app_log(workflow_run)
yield self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run

View File

@ -1,4 +1,5 @@
from typing import Any, Mapping, Optional, cast
from collections.abc import Mapping
from typing import Any, Optional, cast
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.base_app_runner import AppRunner
@ -198,7 +199,7 @@ class WorkflowBasedAppRunner(AppRunner):
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
start_at=event.route_node_state.start_at,
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
node_run_index=event.route_node_state.index,
predecessor_node_id=event.predecessor_node_id
)
)

View File

@ -324,6 +324,19 @@ class QueueStopEvent(AppQueueEvent):
event: QueueEvent = QueueEvent.STOP
stopped_by: StopBy
def get_stop_reason(self) -> str:
"""
To stop reason
"""
reason_mapping = {
QueueStopEvent.StopBy.USER_MANUAL: 'Stopped by user.',
QueueStopEvent.StopBy.ANNOTATION_REPLY: 'Stopped by annotation reply.',
QueueStopEvent.StopBy.OUTPUT_MODERATION: 'Stopped by output moderation.',
QueueStopEvent.StopBy.INPUT_MODERATION: 'Stopped by input moderation.'
}
return reason_mapping.get(self.stopped_by, 'Stopped by unknown reason.')
class QueueMessage(BaseModel):
"""

View File

@ -8,6 +8,7 @@ from core.app.entities.app_invoke_entities import (
AgentChatAppGenerateEntity,
ChatAppGenerateEntity,
CompletionAppGenerateEntity,
InvokeFrom,
)
from core.app.entities.queue_entities import (
QueueAnnotationReplyEvent,

View File

@ -178,6 +178,23 @@ class WorkflowCycleManage:
workflow_run.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
db.session.commit()
running_workflow_node_executions = db.session.query(WorkflowNodeExecution).filter(
WorkflowNodeExecution.tenant_id == workflow_run.tenant_id,
WorkflowNodeExecution.app_id == workflow_run.app_id,
WorkflowNodeExecution.workflow_id == workflow_run.workflow_id,
WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
WorkflowNodeExecution.workflow_run_id == workflow_run.id,
WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value
).all()
for workflow_node_execution in running_workflow_node_executions:
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
workflow_node_execution.error = error
workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
workflow_node_execution.elapsed_time = (workflow_node_execution.finished_at - workflow_node_execution.created_at).total_seconds()
db.session.commit()
db.session.refresh(workflow_run)
db.session.close()

View File

@ -43,6 +43,8 @@ class RouteNodeState(BaseModel):
paused_by: Optional[str] = None
"""paused by"""
index: int = 1
def set_finished(self, run_result: NodeRunResult) -> None:
"""
Node finished

View File

@ -39,7 +39,7 @@ from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeSta
from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor
from core.workflow.nodes.event import RunCompletedEvent, RunEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
from core.workflow.nodes.node_mapping import node_classes
from extensions.ext_database import db
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
@ -193,13 +193,20 @@ class GraphEngine:
try:
# run node
yield from self._run_node(
generator = self._run_node(
node_instance=node_instance,
route_node_state=route_node_state,
parallel_id=in_parallel_id,
parallel_start_node_id=parallel_start_node_id
)
for item in generator:
if isinstance(item, NodeRunStartedEvent):
self.graph_runtime_state.node_run_steps += 1
item.route_node_state.index = self.graph_runtime_state.node_run_steps
yield item
self.graph_runtime_state.node_run_state.node_state_mapping[route_node_state.id] = route_node_state
# append route
@ -394,8 +401,6 @@ class GraphEngine:
db.session.close()
self.graph_runtime_state.node_run_steps += 1
try:
# run node
generator = node_instance.run()