mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-15 23:05:52 +08:00
fix(workflow): fix generate issues in workflow
This commit is contained in:
parent
1da5862a96
commit
702df31db7
@ -384,28 +384,23 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
yield self._error_to_stream_response(self._handle_error(err_event, self._message))
|
yield self._error_to_stream_response(self._handle_error(err_event, self._message))
|
||||||
break
|
break
|
||||||
elif isinstance(event, QueueStopEvent):
|
elif isinstance(event, QueueStopEvent):
|
||||||
if not workflow_run:
|
if workflow_run and graph_runtime_state:
|
||||||
raise Exception('Workflow run not initialized.')
|
workflow_run = self._handle_workflow_run_failed(
|
||||||
|
workflow_run=workflow_run,
|
||||||
if not graph_runtime_state:
|
start_at=graph_runtime_state.start_at,
|
||||||
raise Exception('Graph runtime state not initialized.')
|
total_tokens=graph_runtime_state.total_tokens,
|
||||||
|
total_steps=graph_runtime_state.node_run_steps,
|
||||||
workflow_run = self._handle_workflow_run_failed(
|
status=WorkflowRunStatus.STOPPED,
|
||||||
workflow_run=workflow_run,
|
error=event.get_stop_reason(),
|
||||||
start_at=graph_runtime_state.start_at,
|
conversation_id=self._conversation.id,
|
||||||
total_tokens=graph_runtime_state.total_tokens,
|
trace_manager=trace_manager,
|
||||||
total_steps=graph_runtime_state.node_run_steps,
|
)
|
||||||
status=WorkflowRunStatus.STOPPED,
|
|
||||||
error='Workflow stopped.',
|
|
||||||
conversation_id=self._conversation.id,
|
|
||||||
trace_manager=trace_manager,
|
|
||||||
)
|
|
||||||
|
|
||||||
yield self._workflow_finish_to_stream_response(
|
|
||||||
task_id=self._application_generate_entity.task_id,
|
|
||||||
workflow_run=workflow_run
|
|
||||||
)
|
|
||||||
|
|
||||||
|
yield self._workflow_finish_to_stream_response(
|
||||||
|
task_id=self._application_generate_entity.task_id,
|
||||||
|
workflow_run=workflow_run
|
||||||
|
)
|
||||||
|
|
||||||
# Save message
|
# Save message
|
||||||
self._save_message(graph_runtime_state=graph_runtime_state)
|
self._save_message(graph_runtime_state=graph_runtime_state)
|
||||||
|
|
||||||
@ -471,7 +466,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
if self._conversation_name_generate_thread:
|
if self._conversation_name_generate_thread:
|
||||||
self._conversation_name_generate_thread.join()
|
self._conversation_name_generate_thread.join()
|
||||||
|
|
||||||
def _save_message(self, graph_runtime_state: GraphRuntimeState) -> None:
|
def _save_message(self, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None:
|
||||||
"""
|
"""
|
||||||
Save message.
|
Save message.
|
||||||
:return:
|
:return:
|
||||||
@ -483,7 +478,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \
|
self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \
|
||||||
if self._task_state.metadata else None
|
if self._task_state.metadata else None
|
||||||
|
|
||||||
if graph_runtime_state.llm_usage:
|
if graph_runtime_state and graph_runtime_state.llm_usage:
|
||||||
usage = graph_runtime_state.llm_usage
|
usage = graph_runtime_state.llm_usage
|
||||||
self._message.message_tokens = usage.prompt_tokens
|
self._message.message_tokens = usage.prompt_tokens
|
||||||
self._message.message_unit_price = usage.prompt_unit_price
|
self._message.message_unit_price = usage.prompt_unit_price
|
||||||
@ -511,7 +506,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
"""
|
"""
|
||||||
extras = {}
|
extras = {}
|
||||||
if self._task_state.metadata:
|
if self._task_state.metadata:
|
||||||
extras['metadata'] = self._task_state.metadata
|
extras['metadata'] = self._task_state.metadata.copy()
|
||||||
|
|
||||||
|
if 'annotation_reply' in extras['metadata']:
|
||||||
|
del extras['metadata']['annotation_reply']
|
||||||
|
|
||||||
return MessageEndStreamResponse(
|
return MessageEndStreamResponse(
|
||||||
task_id=self._application_generate_entity.task_id,
|
task_id=self._application_generate_entity.task_id,
|
||||||
|
@ -50,6 +50,7 @@ from models.workflow import (
|
|||||||
WorkflowAppLog,
|
WorkflowAppLog,
|
||||||
WorkflowAppLogCreatedFrom,
|
WorkflowAppLogCreatedFrom,
|
||||||
WorkflowRun,
|
WorkflowRun,
|
||||||
|
WorkflowRunStatus,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -304,7 +305,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||||||
workflow_run=workflow_run,
|
workflow_run=workflow_run,
|
||||||
event=event
|
event=event
|
||||||
)
|
)
|
||||||
elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent):
|
elif isinstance(event, QueueWorkflowSucceededEvent):
|
||||||
if not workflow_run:
|
if not workflow_run:
|
||||||
raise Exception('Workflow run not initialized.')
|
raise Exception('Workflow run not initialized.')
|
||||||
|
|
||||||
@ -324,6 +325,31 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||||||
# save workflow app log
|
# save workflow app log
|
||||||
self._save_workflow_app_log(workflow_run)
|
self._save_workflow_app_log(workflow_run)
|
||||||
|
|
||||||
|
yield self._workflow_finish_to_stream_response(
|
||||||
|
task_id=self._application_generate_entity.task_id,
|
||||||
|
workflow_run=workflow_run
|
||||||
|
)
|
||||||
|
elif isinstance(event, QueueWorkflowFailedEvent | QueueStopEvent):
|
||||||
|
if not workflow_run:
|
||||||
|
raise Exception('Workflow run not initialized.')
|
||||||
|
|
||||||
|
if not graph_runtime_state:
|
||||||
|
raise Exception('Graph runtime state not initialized.')
|
||||||
|
|
||||||
|
workflow_run = self._handle_workflow_run_failed(
|
||||||
|
workflow_run=workflow_run,
|
||||||
|
start_at=graph_runtime_state.start_at,
|
||||||
|
total_tokens=graph_runtime_state.total_tokens,
|
||||||
|
total_steps=graph_runtime_state.node_run_steps,
|
||||||
|
status=WorkflowRunStatus.FAILED if isinstance(event, QueueWorkflowFailedEvent) else WorkflowRunStatus.STOPPED,
|
||||||
|
error=event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(),
|
||||||
|
conversation_id=None,
|
||||||
|
trace_manager=trace_manager,
|
||||||
|
)
|
||||||
|
|
||||||
|
# save workflow app log
|
||||||
|
self._save_workflow_app_log(workflow_run)
|
||||||
|
|
||||||
yield self._workflow_finish_to_stream_response(
|
yield self._workflow_finish_to_stream_response(
|
||||||
task_id=self._application_generate_entity.task_id,
|
task_id=self._application_generate_entity.task_id,
|
||||||
workflow_run=workflow_run
|
workflow_run=workflow_run
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
from typing import Any, Mapping, Optional, cast
|
from collections.abc import Mapping
|
||||||
|
from typing import Any, Optional, cast
|
||||||
|
|
||||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||||
from core.app.apps.base_app_runner import AppRunner
|
from core.app.apps.base_app_runner import AppRunner
|
||||||
@ -198,7 +199,7 @@ class WorkflowBasedAppRunner(AppRunner):
|
|||||||
parallel_id=event.parallel_id,
|
parallel_id=event.parallel_id,
|
||||||
parallel_start_node_id=event.parallel_start_node_id,
|
parallel_start_node_id=event.parallel_start_node_id,
|
||||||
start_at=event.route_node_state.start_at,
|
start_at=event.route_node_state.start_at,
|
||||||
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
node_run_index=event.route_node_state.index,
|
||||||
predecessor_node_id=event.predecessor_node_id
|
predecessor_node_id=event.predecessor_node_id
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -324,6 +324,19 @@ class QueueStopEvent(AppQueueEvent):
|
|||||||
event: QueueEvent = QueueEvent.STOP
|
event: QueueEvent = QueueEvent.STOP
|
||||||
stopped_by: StopBy
|
stopped_by: StopBy
|
||||||
|
|
||||||
|
def get_stop_reason(self) -> str:
|
||||||
|
"""
|
||||||
|
To stop reason
|
||||||
|
"""
|
||||||
|
reason_mapping = {
|
||||||
|
QueueStopEvent.StopBy.USER_MANUAL: 'Stopped by user.',
|
||||||
|
QueueStopEvent.StopBy.ANNOTATION_REPLY: 'Stopped by annotation reply.',
|
||||||
|
QueueStopEvent.StopBy.OUTPUT_MODERATION: 'Stopped by output moderation.',
|
||||||
|
QueueStopEvent.StopBy.INPUT_MODERATION: 'Stopped by input moderation.'
|
||||||
|
}
|
||||||
|
|
||||||
|
return reason_mapping.get(self.stopped_by, 'Stopped by unknown reason.')
|
||||||
|
|
||||||
|
|
||||||
class QueueMessage(BaseModel):
|
class QueueMessage(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
@ -8,6 +8,7 @@ from core.app.entities.app_invoke_entities import (
|
|||||||
AgentChatAppGenerateEntity,
|
AgentChatAppGenerateEntity,
|
||||||
ChatAppGenerateEntity,
|
ChatAppGenerateEntity,
|
||||||
CompletionAppGenerateEntity,
|
CompletionAppGenerateEntity,
|
||||||
|
InvokeFrom,
|
||||||
)
|
)
|
||||||
from core.app.entities.queue_entities import (
|
from core.app.entities.queue_entities import (
|
||||||
QueueAnnotationReplyEvent,
|
QueueAnnotationReplyEvent,
|
||||||
|
@ -178,6 +178,23 @@ class WorkflowCycleManage:
|
|||||||
workflow_run.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
workflow_run.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||||
|
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
|
running_workflow_node_executions = db.session.query(WorkflowNodeExecution).filter(
|
||||||
|
WorkflowNodeExecution.tenant_id == workflow_run.tenant_id,
|
||||||
|
WorkflowNodeExecution.app_id == workflow_run.app_id,
|
||||||
|
WorkflowNodeExecution.workflow_id == workflow_run.workflow_id,
|
||||||
|
WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
|
||||||
|
WorkflowNodeExecution.workflow_run_id == workflow_run.id,
|
||||||
|
WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value
|
||||||
|
).all()
|
||||||
|
|
||||||
|
for workflow_node_execution in running_workflow_node_executions:
|
||||||
|
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
|
||||||
|
workflow_node_execution.error = error
|
||||||
|
workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||||
|
workflow_node_execution.elapsed_time = (workflow_node_execution.finished_at - workflow_node_execution.created_at).total_seconds()
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
db.session.refresh(workflow_run)
|
db.session.refresh(workflow_run)
|
||||||
db.session.close()
|
db.session.close()
|
||||||
|
|
||||||
|
@ -43,6 +43,8 @@ class RouteNodeState(BaseModel):
|
|||||||
paused_by: Optional[str] = None
|
paused_by: Optional[str] = None
|
||||||
"""paused by"""
|
"""paused by"""
|
||||||
|
|
||||||
|
index: int = 1
|
||||||
|
|
||||||
def set_finished(self, run_result: NodeRunResult) -> None:
|
def set_finished(self, run_result: NodeRunResult) -> None:
|
||||||
"""
|
"""
|
||||||
Node finished
|
Node finished
|
||||||
|
@ -39,7 +39,7 @@ from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeSta
|
|||||||
from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor
|
from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor
|
||||||
from core.workflow.nodes.base_node import BaseNode
|
from core.workflow.nodes.base_node import BaseNode
|
||||||
from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor
|
from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor
|
||||||
from core.workflow.nodes.event import RunCompletedEvent, RunEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
|
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
|
||||||
from core.workflow.nodes.node_mapping import node_classes
|
from core.workflow.nodes.node_mapping import node_classes
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
|
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
|
||||||
@ -193,13 +193,20 @@ class GraphEngine:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# run node
|
# run node
|
||||||
yield from self._run_node(
|
generator = self._run_node(
|
||||||
node_instance=node_instance,
|
node_instance=node_instance,
|
||||||
route_node_state=route_node_state,
|
route_node_state=route_node_state,
|
||||||
parallel_id=in_parallel_id,
|
parallel_id=in_parallel_id,
|
||||||
parallel_start_node_id=parallel_start_node_id
|
parallel_start_node_id=parallel_start_node_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
for item in generator:
|
||||||
|
if isinstance(item, NodeRunStartedEvent):
|
||||||
|
self.graph_runtime_state.node_run_steps += 1
|
||||||
|
item.route_node_state.index = self.graph_runtime_state.node_run_steps
|
||||||
|
|
||||||
|
yield item
|
||||||
|
|
||||||
self.graph_runtime_state.node_run_state.node_state_mapping[route_node_state.id] = route_node_state
|
self.graph_runtime_state.node_run_state.node_state_mapping[route_node_state.id] = route_node_state
|
||||||
|
|
||||||
# append route
|
# append route
|
||||||
@ -394,8 +401,6 @@ class GraphEngine:
|
|||||||
|
|
||||||
db.session.close()
|
db.session.close()
|
||||||
|
|
||||||
self.graph_runtime_state.node_run_steps += 1
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# run node
|
# run node
|
||||||
generator = node_instance.run()
|
generator = node_instance.run()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user