mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-17 04:36:01 +08:00
feat(workflow): integrate parallel into workflow apps
This commit is contained in:
parent
1973f5003b
commit
352c45c8a2
@ -21,6 +21,9 @@ from core.app.entities.queue_entities import (
|
|||||||
QueueNodeFailedEvent,
|
QueueNodeFailedEvent,
|
||||||
QueueNodeStartedEvent,
|
QueueNodeStartedEvent,
|
||||||
QueueNodeSucceededEvent,
|
QueueNodeSucceededEvent,
|
||||||
|
QueueParallelBranchRunFailedEvent,
|
||||||
|
QueueParallelBranchRunStartedEvent,
|
||||||
|
QueueParallelBranchRunSucceededEvent,
|
||||||
QueuePingEvent,
|
QueuePingEvent,
|
||||||
QueueRetrieverResourcesEvent,
|
QueueRetrieverResourcesEvent,
|
||||||
QueueStopEvent,
|
QueueStopEvent,
|
||||||
@ -304,6 +307,24 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
|
|
||||||
if response:
|
if response:
|
||||||
yield response
|
yield response
|
||||||
|
elif isinstance(event, QueueParallelBranchRunStartedEvent):
|
||||||
|
if not workflow_run:
|
||||||
|
raise Exception('Workflow run not initialized.')
|
||||||
|
|
||||||
|
yield self._workflow_parallel_branch_start_to_stream_response(
|
||||||
|
task_id=self._application_generate_entity.task_id,
|
||||||
|
workflow_run=workflow_run,
|
||||||
|
event=event
|
||||||
|
)
|
||||||
|
elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent):
|
||||||
|
if not workflow_run:
|
||||||
|
raise Exception('Workflow run not initialized.')
|
||||||
|
|
||||||
|
yield self._workflow_parallel_branch_finished_to_stream_response(
|
||||||
|
task_id=self._application_generate_entity.task_id,
|
||||||
|
workflow_run=workflow_run,
|
||||||
|
event=event
|
||||||
|
)
|
||||||
elif isinstance(event, QueueIterationStartEvent):
|
elif isinstance(event, QueueIterationStartEvent):
|
||||||
if not workflow_run:
|
if not workflow_run:
|
||||||
raise Exception('Workflow run not initialized.')
|
raise Exception('Workflow run not initialized.')
|
||||||
|
@ -13,6 +13,7 @@ from core.app.entities.queue_entities import (
|
|||||||
QueueNodeSucceededEvent,
|
QueueNodeSucceededEvent,
|
||||||
QueueParallelBranchRunFailedEvent,
|
QueueParallelBranchRunFailedEvent,
|
||||||
QueueParallelBranchRunStartedEvent,
|
QueueParallelBranchRunStartedEvent,
|
||||||
|
QueueParallelBranchRunSucceededEvent,
|
||||||
QueueRetrieverResourcesEvent,
|
QueueRetrieverResourcesEvent,
|
||||||
QueueTextChunkEvent,
|
QueueTextChunkEvent,
|
||||||
QueueWorkflowFailedEvent,
|
QueueWorkflowFailedEvent,
|
||||||
@ -261,14 +262,16 @@ class WorkflowBasedAppRunner(AppRunner):
|
|||||||
self._publish_event(
|
self._publish_event(
|
||||||
QueueParallelBranchRunStartedEvent(
|
QueueParallelBranchRunStartedEvent(
|
||||||
parallel_id=event.parallel_id,
|
parallel_id=event.parallel_id,
|
||||||
parallel_start_node_id=event.parallel_start_node_id
|
parallel_start_node_id=event.parallel_start_node_id,
|
||||||
|
in_iteration_id=event.in_iteration_id
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
elif isinstance(event, ParallelBranchRunSucceededEvent):
|
elif isinstance(event, ParallelBranchRunSucceededEvent):
|
||||||
self._publish_event(
|
self._publish_event(
|
||||||
QueueParallelBranchRunStartedEvent(
|
QueueParallelBranchRunSucceededEvent(
|
||||||
parallel_id=event.parallel_id,
|
parallel_id=event.parallel_id,
|
||||||
parallel_start_node_id=event.parallel_start_node_id
|
parallel_start_node_id=event.parallel_start_node_id,
|
||||||
|
in_iteration_id=event.in_iteration_id
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
elif isinstance(event, ParallelBranchRunFailedEvent):
|
elif isinstance(event, ParallelBranchRunFailedEvent):
|
||||||
@ -276,6 +279,7 @@ class WorkflowBasedAppRunner(AppRunner):
|
|||||||
QueueParallelBranchRunFailedEvent(
|
QueueParallelBranchRunFailedEvent(
|
||||||
parallel_id=event.parallel_id,
|
parallel_id=event.parallel_id,
|
||||||
parallel_start_node_id=event.parallel_start_node_id,
|
parallel_start_node_id=event.parallel_start_node_id,
|
||||||
|
in_iteration_id=event.in_iteration_id,
|
||||||
error=event.error
|
error=event.error
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -15,10 +15,10 @@ from core.workflow.graph_engine.entities.event import (
|
|||||||
NodeRunStartedEvent,
|
NodeRunStartedEvent,
|
||||||
NodeRunStreamChunkEvent,
|
NodeRunStreamChunkEvent,
|
||||||
NodeRunSucceededEvent,
|
NodeRunSucceededEvent,
|
||||||
|
ParallelBranchRunFailedEvent,
|
||||||
|
ParallelBranchRunStartedEvent,
|
||||||
|
ParallelBranchRunSucceededEvent,
|
||||||
)
|
)
|
||||||
from core.workflow.graph_engine.entities.graph import Graph
|
|
||||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
|
||||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
|
||||||
|
|
||||||
_TEXT_COLOR_MAPPING = {
|
_TEXT_COLOR_MAPPING = {
|
||||||
"blue": "36;1",
|
"blue": "36;1",
|
||||||
@ -36,9 +36,6 @@ class WorkflowLoggingCallback(WorkflowCallback):
|
|||||||
|
|
||||||
def on_event(
|
def on_event(
|
||||||
self,
|
self,
|
||||||
graph: Graph,
|
|
||||||
graph_init_params: GraphInitParams,
|
|
||||||
graph_runtime_state: GraphRuntimeState,
|
|
||||||
event: GraphEngineEvent
|
event: GraphEngineEvent
|
||||||
) -> None:
|
) -> None:
|
||||||
if isinstance(event, GraphRunStartedEvent):
|
if isinstance(event, GraphRunStartedEvent):
|
||||||
@ -49,49 +46,38 @@ class WorkflowLoggingCallback(WorkflowCallback):
|
|||||||
self.print_text(f"\n[on_workflow_run_failed] reason: {event.error}", color='red')
|
self.print_text(f"\n[on_workflow_run_failed] reason: {event.error}", color='red')
|
||||||
elif isinstance(event, NodeRunStartedEvent):
|
elif isinstance(event, NodeRunStartedEvent):
|
||||||
self.on_workflow_node_execute_started(
|
self.on_workflow_node_execute_started(
|
||||||
graph=graph,
|
|
||||||
event=event
|
event=event
|
||||||
)
|
)
|
||||||
elif isinstance(event, NodeRunSucceededEvent):
|
elif isinstance(event, NodeRunSucceededEvent):
|
||||||
self.on_workflow_node_execute_succeeded(
|
self.on_workflow_node_execute_succeeded(
|
||||||
graph=graph,
|
|
||||||
graph_init_params=graph_init_params,
|
|
||||||
graph_runtime_state=graph_runtime_state,
|
|
||||||
event=event
|
event=event
|
||||||
)
|
)
|
||||||
elif isinstance(event, NodeRunFailedEvent):
|
elif isinstance(event, NodeRunFailedEvent):
|
||||||
self.on_workflow_node_execute_failed(
|
self.on_workflow_node_execute_failed(
|
||||||
graph=graph,
|
|
||||||
graph_init_params=graph_init_params,
|
|
||||||
graph_runtime_state=graph_runtime_state,
|
|
||||||
event=event
|
event=event
|
||||||
)
|
)
|
||||||
elif isinstance(event, NodeRunStreamChunkEvent):
|
elif isinstance(event, NodeRunStreamChunkEvent):
|
||||||
self.on_node_text_chunk(
|
self.on_node_text_chunk(
|
||||||
graph=graph,
|
event=event
|
||||||
graph_init_params=graph_init_params,
|
)
|
||||||
graph_runtime_state=graph_runtime_state,
|
elif isinstance(event, ParallelBranchRunStartedEvent):
|
||||||
|
self.on_workflow_parallel_started(
|
||||||
|
event=event
|
||||||
|
)
|
||||||
|
elif isinstance(event, ParallelBranchRunSucceededEvent | ParallelBranchRunFailedEvent):
|
||||||
|
self.on_workflow_parallel_completed(
|
||||||
event=event
|
event=event
|
||||||
)
|
)
|
||||||
elif isinstance(event, IterationRunStartedEvent):
|
elif isinstance(event, IterationRunStartedEvent):
|
||||||
self.on_workflow_iteration_started(
|
self.on_workflow_iteration_started(
|
||||||
graph=graph,
|
|
||||||
graph_init_params=graph_init_params,
|
|
||||||
graph_runtime_state=graph_runtime_state,
|
|
||||||
event=event
|
event=event
|
||||||
)
|
)
|
||||||
elif isinstance(event, IterationRunNextEvent):
|
elif isinstance(event, IterationRunNextEvent):
|
||||||
self.on_workflow_iteration_next(
|
self.on_workflow_iteration_next(
|
||||||
graph=graph,
|
|
||||||
graph_init_params=graph_init_params,
|
|
||||||
graph_runtime_state=graph_runtime_state,
|
|
||||||
event=event
|
event=event
|
||||||
)
|
)
|
||||||
elif isinstance(event, IterationRunSucceededEvent | IterationRunFailedEvent):
|
elif isinstance(event, IterationRunSucceededEvent | IterationRunFailedEvent):
|
||||||
self.on_workflow_iteration_completed(
|
self.on_workflow_iteration_completed(
|
||||||
graph=graph,
|
|
||||||
graph_init_params=graph_init_params,
|
|
||||||
graph_runtime_state=graph_runtime_state,
|
|
||||||
event=event
|
event=event
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -99,39 +85,29 @@ class WorkflowLoggingCallback(WorkflowCallback):
|
|||||||
|
|
||||||
def on_workflow_node_execute_started(
|
def on_workflow_node_execute_started(
|
||||||
self,
|
self,
|
||||||
graph: Graph,
|
|
||||||
event: NodeRunStartedEvent
|
event: NodeRunStartedEvent
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Workflow node execute started
|
Workflow node execute started
|
||||||
"""
|
"""
|
||||||
route_node_state = event.route_node_state
|
route_node_state = event.route_node_state
|
||||||
node_config = graph.node_id_config_mapping.get(route_node_state.node_id)
|
node_type = event.node_type.value
|
||||||
node_type = None
|
|
||||||
if node_config:
|
|
||||||
node_type = node_config.get("data", {}).get("type")
|
|
||||||
|
|
||||||
self.print_text("\n[on_workflow_node_execute_started]", color='yellow')
|
self.print_text("\n[NodeRunStartedEvent]", color='yellow')
|
||||||
self.print_text(f"Node ID: {route_node_state.node_id}", color='yellow')
|
self.print_text(f"Node ID: {route_node_state.node_id}", color='yellow')
|
||||||
self.print_text(f"Type: {node_type}", color='yellow')
|
self.print_text(f"Type: {node_type}", color='yellow')
|
||||||
|
|
||||||
def on_workflow_node_execute_succeeded(
|
def on_workflow_node_execute_succeeded(
|
||||||
self,
|
self,
|
||||||
graph: Graph,
|
|
||||||
graph_init_params: GraphInitParams,
|
|
||||||
graph_runtime_state: GraphRuntimeState,
|
|
||||||
event: NodeRunSucceededEvent
|
event: NodeRunSucceededEvent
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Workflow node execute succeeded
|
Workflow node execute succeeded
|
||||||
"""
|
"""
|
||||||
route_node_state = event.route_node_state
|
route_node_state = event.route_node_state
|
||||||
node_config = graph.node_id_config_mapping.get(route_node_state.node_id)
|
node_type = event.node_type.value
|
||||||
node_type = None
|
|
||||||
if node_config:
|
|
||||||
node_type = node_config.get("data", {}).get("type")
|
|
||||||
|
|
||||||
self.print_text("\n[on_workflow_node_execute_succeeded]", color='green')
|
self.print_text("\n[NodeRunSucceededEvent]", color='green')
|
||||||
self.print_text(f"Node ID: {route_node_state.node_id}", color='green')
|
self.print_text(f"Node ID: {route_node_state.node_id}", color='green')
|
||||||
self.print_text(f"Type: {node_type}", color='green')
|
self.print_text(f"Type: {node_type}", color='green')
|
||||||
|
|
||||||
@ -150,21 +126,15 @@ class WorkflowLoggingCallback(WorkflowCallback):
|
|||||||
|
|
||||||
def on_workflow_node_execute_failed(
|
def on_workflow_node_execute_failed(
|
||||||
self,
|
self,
|
||||||
graph: Graph,
|
|
||||||
graph_init_params: GraphInitParams,
|
|
||||||
graph_runtime_state: GraphRuntimeState,
|
|
||||||
event: NodeRunFailedEvent
|
event: NodeRunFailedEvent
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Workflow node execute failed
|
Workflow node execute failed
|
||||||
"""
|
"""
|
||||||
route_node_state = event.route_node_state
|
route_node_state = event.route_node_state
|
||||||
node_config = graph.node_id_config_mapping.get(route_node_state.node_id)
|
node_type = event.node_type.value
|
||||||
node_type = None
|
|
||||||
if node_config:
|
|
||||||
node_type = node_config.get("data", {}).get("type")
|
|
||||||
|
|
||||||
self.print_text("\n[on_workflow_node_execute_failed]", color='red')
|
self.print_text("\n[NodeRunFailedEvent]", color='red')
|
||||||
self.print_text(f"Node ID: {route_node_state.node_id}", color='red')
|
self.print_text(f"Node ID: {route_node_state.node_id}", color='red')
|
||||||
self.print_text(f"Type: {node_type}", color='red')
|
self.print_text(f"Type: {node_type}", color='red')
|
||||||
|
|
||||||
@ -181,9 +151,6 @@ class WorkflowLoggingCallback(WorkflowCallback):
|
|||||||
|
|
||||||
def on_node_text_chunk(
|
def on_node_text_chunk(
|
||||||
self,
|
self,
|
||||||
graph: Graph,
|
|
||||||
graph_init_params: GraphInitParams,
|
|
||||||
graph_runtime_state: GraphRuntimeState,
|
|
||||||
event: NodeRunStreamChunkEvent
|
event: NodeRunStreamChunkEvent
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
@ -192,7 +159,7 @@ class WorkflowLoggingCallback(WorkflowCallback):
|
|||||||
route_node_state = event.route_node_state
|
route_node_state = event.route_node_state
|
||||||
if not self.current_node_id or self.current_node_id != route_node_state.node_id:
|
if not self.current_node_id or self.current_node_id != route_node_state.node_id:
|
||||||
self.current_node_id = route_node_state.node_id
|
self.current_node_id = route_node_state.node_id
|
||||||
self.print_text('\n[on_node_text_chunk]')
|
self.print_text('\n[NodeRunStreamChunkEvent]')
|
||||||
self.print_text(f"Node ID: {route_node_state.node_id}")
|
self.print_text(f"Node ID: {route_node_state.node_id}")
|
||||||
|
|
||||||
node_run_result = route_node_state.node_run_result
|
node_run_result = route_node_state.node_run_result
|
||||||
@ -202,43 +169,69 @@ class WorkflowLoggingCallback(WorkflowCallback):
|
|||||||
|
|
||||||
self.print_text(event.chunk_content, color="pink", end="")
|
self.print_text(event.chunk_content, color="pink", end="")
|
||||||
|
|
||||||
|
def on_workflow_parallel_started(
|
||||||
|
self,
|
||||||
|
event: ParallelBranchRunStartedEvent
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Publish parallel started
|
||||||
|
"""
|
||||||
|
self.print_text("\n[ParallelBranchRunStartedEvent]", color='blue')
|
||||||
|
self.print_text(f"Parallel ID: {event.parallel_id}", color='blue')
|
||||||
|
self.print_text(f"Branch ID: {event.parallel_start_node_id}", color='blue')
|
||||||
|
if event.in_iteration_id:
|
||||||
|
self.print_text(f"Iteration ID: {event.in_iteration_id}", color='blue')
|
||||||
|
|
||||||
|
def on_workflow_parallel_completed(
|
||||||
|
self,
|
||||||
|
event: ParallelBranchRunSucceededEvent | ParallelBranchRunFailedEvent
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Publish parallel completed
|
||||||
|
"""
|
||||||
|
if isinstance(event, ParallelBranchRunSucceededEvent):
|
||||||
|
color = 'blue'
|
||||||
|
elif isinstance(event, ParallelBranchRunFailedEvent):
|
||||||
|
color = 'red'
|
||||||
|
|
||||||
|
self.print_text("\n[ParallelBranchRunSucceededEvent]" if isinstance(event, ParallelBranchRunSucceededEvent) else "\n[ParallelBranchRunFailedEvent]", color=color)
|
||||||
|
self.print_text(f"Parallel ID: {event.parallel_id}", color=color)
|
||||||
|
self.print_text(f"Branch ID: {event.parallel_start_node_id}", color=color)
|
||||||
|
if event.in_iteration_id:
|
||||||
|
self.print_text(f"Iteration ID: {event.in_iteration_id}", color=color)
|
||||||
|
|
||||||
|
if isinstance(event, ParallelBranchRunFailedEvent):
|
||||||
|
self.print_text(f"Error: {event.error}", color=color)
|
||||||
|
|
||||||
def on_workflow_iteration_started(
|
def on_workflow_iteration_started(
|
||||||
self,
|
self,
|
||||||
graph: Graph,
|
|
||||||
graph_init_params: GraphInitParams,
|
|
||||||
graph_runtime_state: GraphRuntimeState,
|
|
||||||
event: IterationRunStartedEvent
|
event: IterationRunStartedEvent
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Publish iteration started
|
Publish iteration started
|
||||||
"""
|
"""
|
||||||
self.print_text("\n[on_workflow_iteration_started]", color='blue')
|
self.print_text("\n[IterationRunStartedEvent]", color='blue')
|
||||||
self.print_text(f"Node ID: {event.iteration_id}", color='blue')
|
self.print_text(f"Iteration Node ID: {event.iteration_id}", color='blue')
|
||||||
|
|
||||||
def on_workflow_iteration_next(
|
def on_workflow_iteration_next(
|
||||||
self,
|
self,
|
||||||
graph: Graph,
|
|
||||||
graph_init_params: GraphInitParams,
|
|
||||||
graph_runtime_state: GraphRuntimeState,
|
|
||||||
event: IterationRunNextEvent
|
event: IterationRunNextEvent
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Publish iteration next
|
Publish iteration next
|
||||||
"""
|
"""
|
||||||
self.print_text("\n[on_workflow_iteration_next]", color='blue')
|
self.print_text("\n[IterationRunNextEvent]", color='blue')
|
||||||
self.print_text(f"Node ID: {event.iteration_id}", color='blue')
|
self.print_text(f"Iteration Node ID: {event.iteration_id}", color='blue')
|
||||||
|
self.print_text(f"Iteration Index: {event.index}", color='blue')
|
||||||
|
|
||||||
def on_workflow_iteration_completed(
|
def on_workflow_iteration_completed(
|
||||||
self,
|
self,
|
||||||
graph: Graph,
|
|
||||||
graph_init_params: GraphInitParams,
|
|
||||||
graph_runtime_state: GraphRuntimeState,
|
|
||||||
event: IterationRunSucceededEvent | IterationRunFailedEvent
|
event: IterationRunSucceededEvent | IterationRunFailedEvent
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Publish iteration completed
|
Publish iteration completed
|
||||||
"""
|
"""
|
||||||
self.print_text("\n[on_workflow_iteration_completed]", color='blue')
|
self.print_text("\n[IterationRunSucceededEvent]" if isinstance(event, IterationRunSucceededEvent) else "\n[IterationRunFailedEvent]", color='blue')
|
||||||
self.print_text(f"Node ID: {event.iteration_id}", color='blue')
|
self.print_text(f"Node ID: {event.iteration_id}", color='blue')
|
||||||
|
|
||||||
def print_text(
|
def print_text(
|
||||||
|
@ -334,7 +334,7 @@ class QueueStopEvent(AppQueueEvent):
|
|||||||
QueueStopEvent.StopBy.OUTPUT_MODERATION: 'Stopped by output moderation.',
|
QueueStopEvent.StopBy.OUTPUT_MODERATION: 'Stopped by output moderation.',
|
||||||
QueueStopEvent.StopBy.INPUT_MODERATION: 'Stopped by input moderation.'
|
QueueStopEvent.StopBy.INPUT_MODERATION: 'Stopped by input moderation.'
|
||||||
}
|
}
|
||||||
|
|
||||||
return reason_mapping.get(self.stopped_by, 'Stopped by unknown reason.')
|
return reason_mapping.get(self.stopped_by, 'Stopped by unknown reason.')
|
||||||
|
|
||||||
|
|
||||||
@ -370,6 +370,8 @@ class QueueParallelBranchRunStartedEvent(AppQueueEvent):
|
|||||||
|
|
||||||
parallel_id: str
|
parallel_id: str
|
||||||
parallel_start_node_id: str
|
parallel_start_node_id: str
|
||||||
|
in_iteration_id: Optional[str] = None
|
||||||
|
"""iteration id if node is in iteration"""
|
||||||
|
|
||||||
|
|
||||||
class QueueParallelBranchRunSucceededEvent(AppQueueEvent):
|
class QueueParallelBranchRunSucceededEvent(AppQueueEvent):
|
||||||
@ -380,6 +382,8 @@ class QueueParallelBranchRunSucceededEvent(AppQueueEvent):
|
|||||||
|
|
||||||
parallel_id: str
|
parallel_id: str
|
||||||
parallel_start_node_id: str
|
parallel_start_node_id: str
|
||||||
|
in_iteration_id: Optional[str] = None
|
||||||
|
"""iteration id if node is in iteration"""
|
||||||
|
|
||||||
|
|
||||||
class QueueParallelBranchRunFailedEvent(AppQueueEvent):
|
class QueueParallelBranchRunFailedEvent(AppQueueEvent):
|
||||||
@ -390,4 +394,6 @@ class QueueParallelBranchRunFailedEvent(AppQueueEvent):
|
|||||||
|
|
||||||
parallel_id: str
|
parallel_id: str
|
||||||
parallel_start_node_id: str
|
parallel_start_node_id: str
|
||||||
|
in_iteration_id: Optional[str] = None
|
||||||
|
"""iteration id if node is in iteration"""
|
||||||
error: str
|
error: str
|
||||||
|
@ -47,6 +47,8 @@ class StreamEvent(Enum):
|
|||||||
WORKFLOW_FINISHED = "workflow_finished"
|
WORKFLOW_FINISHED = "workflow_finished"
|
||||||
NODE_STARTED = "node_started"
|
NODE_STARTED = "node_started"
|
||||||
NODE_FINISHED = "node_finished"
|
NODE_FINISHED = "node_finished"
|
||||||
|
PARALLEL_BRANCH_STARTED = "parallel_branch_started"
|
||||||
|
PARALLEL_BRANCH_FINISHED = "parallel_branch_finished"
|
||||||
ITERATION_STARTED = "iteration_started"
|
ITERATION_STARTED = "iteration_started"
|
||||||
ITERATION_NEXT = "iteration_next"
|
ITERATION_NEXT = "iteration_next"
|
||||||
ITERATION_COMPLETED = "iteration_completed"
|
ITERATION_COMPLETED = "iteration_completed"
|
||||||
@ -295,6 +297,46 @@ class NodeFinishStreamResponse(StreamResponse):
|
|||||||
"files": []
|
"files": []
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ParallelBranchStartStreamResponse(StreamResponse):
|
||||||
|
"""
|
||||||
|
ParallelBranchStartStreamResponse entity
|
||||||
|
"""
|
||||||
|
|
||||||
|
class Data(BaseModel):
|
||||||
|
"""
|
||||||
|
Data entity
|
||||||
|
"""
|
||||||
|
parallel_id: str
|
||||||
|
parallel_branch_id: str
|
||||||
|
iteration_id: Optional[str] = None
|
||||||
|
created_at: int
|
||||||
|
|
||||||
|
event: StreamEvent = StreamEvent.PARALLEL_BRANCH_STARTED
|
||||||
|
workflow_run_id: str
|
||||||
|
data: Data
|
||||||
|
|
||||||
|
|
||||||
|
class ParallelBranchFinishedStreamResponse(StreamResponse):
|
||||||
|
"""
|
||||||
|
ParallelBranchFinishedStreamResponse entity
|
||||||
|
"""
|
||||||
|
|
||||||
|
class Data(BaseModel):
|
||||||
|
"""
|
||||||
|
Data entity
|
||||||
|
"""
|
||||||
|
parallel_id: str
|
||||||
|
parallel_branch_id: str
|
||||||
|
iteration_id: Optional[str] = None
|
||||||
|
status: str
|
||||||
|
error: Optional[str] = None
|
||||||
|
created_at: int
|
||||||
|
|
||||||
|
event: StreamEvent = StreamEvent.PARALLEL_BRANCH_FINISHED
|
||||||
|
workflow_run_id: str
|
||||||
|
data: Data
|
||||||
|
|
||||||
|
|
||||||
class IterationNodeStartStreamResponse(StreamResponse):
|
class IterationNodeStartStreamResponse(StreamResponse):
|
||||||
|
@ -11,6 +11,9 @@ from core.app.entities.queue_entities import (
|
|||||||
QueueNodeFailedEvent,
|
QueueNodeFailedEvent,
|
||||||
QueueNodeStartedEvent,
|
QueueNodeStartedEvent,
|
||||||
QueueNodeSucceededEvent,
|
QueueNodeSucceededEvent,
|
||||||
|
QueueParallelBranchRunFailedEvent,
|
||||||
|
QueueParallelBranchRunStartedEvent,
|
||||||
|
QueueParallelBranchRunSucceededEvent,
|
||||||
)
|
)
|
||||||
from core.app.entities.task_entities import (
|
from core.app.entities.task_entities import (
|
||||||
IterationNodeCompletedStreamResponse,
|
IterationNodeCompletedStreamResponse,
|
||||||
@ -18,6 +21,8 @@ from core.app.entities.task_entities import (
|
|||||||
IterationNodeStartStreamResponse,
|
IterationNodeStartStreamResponse,
|
||||||
NodeFinishStreamResponse,
|
NodeFinishStreamResponse,
|
||||||
NodeStartStreamResponse,
|
NodeStartStreamResponse,
|
||||||
|
ParallelBranchFinishedStreamResponse,
|
||||||
|
ParallelBranchStartStreamResponse,
|
||||||
WorkflowFinishStreamResponse,
|
WorkflowFinishStreamResponse,
|
||||||
WorkflowStartStreamResponse,
|
WorkflowStartStreamResponse,
|
||||||
WorkflowTaskState,
|
WorkflowTaskState,
|
||||||
@ -433,6 +438,56 @@ class WorkflowCycleManage:
|
|||||||
files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict or {}),
|
files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict or {}),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _workflow_parallel_branch_start_to_stream_response(
|
||||||
|
self,
|
||||||
|
task_id: str,
|
||||||
|
workflow_run: WorkflowRun,
|
||||||
|
event: QueueParallelBranchRunStartedEvent
|
||||||
|
) -> ParallelBranchStartStreamResponse:
|
||||||
|
"""
|
||||||
|
Workflow parallel branch start to stream response
|
||||||
|
:param task_id: task id
|
||||||
|
:param workflow_run: workflow run
|
||||||
|
:param event: parallel branch run started event
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
return ParallelBranchStartStreamResponse(
|
||||||
|
task_id=task_id,
|
||||||
|
workflow_run_id=workflow_run.id,
|
||||||
|
data=ParallelBranchStartStreamResponse.Data(
|
||||||
|
parallel_id=event.parallel_id,
|
||||||
|
parallel_branch_id=event.parallel_start_node_id,
|
||||||
|
iteration_id=event.in_iteration_id,
|
||||||
|
created_at=int(time.time()),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def _workflow_parallel_branch_finished_to_stream_response(
|
||||||
|
self,
|
||||||
|
task_id: str,
|
||||||
|
workflow_run: WorkflowRun,
|
||||||
|
event: QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent
|
||||||
|
) -> ParallelBranchFinishedStreamResponse:
|
||||||
|
"""
|
||||||
|
Workflow parallel branch finished to stream response
|
||||||
|
:param task_id: task id
|
||||||
|
:param workflow_run: workflow run
|
||||||
|
:param event: parallel branch run succeeded or failed event
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
return ParallelBranchFinishedStreamResponse(
|
||||||
|
task_id=task_id,
|
||||||
|
workflow_run_id=workflow_run.id,
|
||||||
|
data=ParallelBranchFinishedStreamResponse.Data(
|
||||||
|
parallel_id=event.parallel_id,
|
||||||
|
parallel_branch_id=event.parallel_start_node_id,
|
||||||
|
iteration_id=event.in_iteration_id,
|
||||||
|
status='succeeded' if isinstance(event, QueueParallelBranchRunSucceededEvent) else 'failed',
|
||||||
|
error=event.error if isinstance(event, QueueParallelBranchRunFailedEvent) else None,
|
||||||
|
created_at=int(time.time()),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def _workflow_iteration_start_to_stream_response(
|
def _workflow_iteration_start_to_stream_response(
|
||||||
self,
|
self,
|
||||||
|
@ -1,18 +1,12 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
from core.workflow.graph_engine.entities.event import GraphEngineEvent
|
from core.workflow.graph_engine.entities.event import GraphEngineEvent
|
||||||
from core.workflow.graph_engine.entities.graph import Graph
|
|
||||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
|
||||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
|
||||||
|
|
||||||
|
|
||||||
class WorkflowCallback(ABC):
|
class WorkflowCallback(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def on_event(
|
def on_event(
|
||||||
self,
|
self,
|
||||||
graph: Graph,
|
|
||||||
graph_init_params: GraphInitParams,
|
|
||||||
graph_runtime_state: GraphRuntimeState,
|
|
||||||
event: GraphEngineEvent
|
event: GraphEngineEvent
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
|
@ -56,6 +56,8 @@ class NodeRunMetadataKey(Enum):
|
|||||||
TOOL_INFO = 'tool_info'
|
TOOL_INFO = 'tool_info'
|
||||||
ITERATION_ID = 'iteration_id'
|
ITERATION_ID = 'iteration_id'
|
||||||
ITERATION_INDEX = 'iteration_index'
|
ITERATION_INDEX = 'iteration_index'
|
||||||
|
PARALLEL_ID = 'parallel_id'
|
||||||
|
PARALLEL_START_NODE_ID = 'parallel_start_node_id'
|
||||||
|
|
||||||
|
|
||||||
class NodeRunResult(BaseModel):
|
class NodeRunResult(BaseModel):
|
||||||
|
@ -99,6 +99,10 @@ class Graph(BaseModel):
|
|||||||
if target_node_id not in reverse_edge_mapping:
|
if target_node_id not in reverse_edge_mapping:
|
||||||
reverse_edge_mapping[target_node_id] = []
|
reverse_edge_mapping[target_node_id] = []
|
||||||
|
|
||||||
|
# is target node id in source node id edge mapping
|
||||||
|
if any(graph_edge.target_node_id == target_node_id for graph_edge in edge_mapping[source_node_id]):
|
||||||
|
continue
|
||||||
|
|
||||||
target_edge_ids.add(target_node_id)
|
target_edge_ids.add(target_node_id)
|
||||||
|
|
||||||
# parse run condition
|
# parse run condition
|
||||||
|
@ -244,6 +244,7 @@ class GraphEngine:
|
|||||||
|
|
||||||
if len(edge_mappings) == 1:
|
if len(edge_mappings) == 1:
|
||||||
edge = edge_mappings[0]
|
edge = edge_mappings[0]
|
||||||
|
|
||||||
if edge.run_condition:
|
if edge.run_condition:
|
||||||
result = ConditionManager.get_condition_handler(
|
result = ConditionManager.get_condition_handler(
|
||||||
init_params=self.init_params,
|
init_params=self.init_params,
|
||||||
@ -296,14 +297,20 @@ class GraphEngine:
|
|||||||
# run parallel nodes, run in new thread and use queue to get results
|
# run parallel nodes, run in new thread and use queue to get results
|
||||||
q: queue.Queue = queue.Queue()
|
q: queue.Queue = queue.Queue()
|
||||||
|
|
||||||
|
# Create a list to store the threads
|
||||||
|
threads = []
|
||||||
|
|
||||||
# new thread
|
# new thread
|
||||||
for edge in edge_mappings:
|
for edge in edge_mappings:
|
||||||
threading.Thread(target=self._run_parallel_node, kwargs={
|
thread = threading.Thread(target=self._run_parallel_node, kwargs={
|
||||||
'flask_app': current_app._get_current_object(), # type: ignore[attr-defined]
|
'flask_app': current_app._get_current_object(), # type: ignore[attr-defined]
|
||||||
'parallel_id': parallel_id,
|
'parallel_id': parallel_id,
|
||||||
'parallel_start_node_id': edge.target_node_id,
|
'parallel_start_node_id': edge.target_node_id,
|
||||||
'q': q
|
'q': q
|
||||||
}).start()
|
})
|
||||||
|
|
||||||
|
threads.append(thread)
|
||||||
|
thread.start()
|
||||||
|
|
||||||
succeeded_count = 0
|
succeeded_count = 0
|
||||||
while True:
|
while True:
|
||||||
@ -315,8 +322,8 @@ class GraphEngine:
|
|||||||
yield event
|
yield event
|
||||||
if isinstance(event, ParallelBranchRunSucceededEvent):
|
if isinstance(event, ParallelBranchRunSucceededEvent):
|
||||||
succeeded_count += 1
|
succeeded_count += 1
|
||||||
if succeeded_count == len(edge_mappings):
|
if succeeded_count == len(threads):
|
||||||
break
|
q.put(None)
|
||||||
|
|
||||||
continue
|
continue
|
||||||
elif isinstance(event, ParallelBranchRunFailedEvent):
|
elif isinstance(event, ParallelBranchRunFailedEvent):
|
||||||
@ -324,6 +331,10 @@ class GraphEngine:
|
|||||||
except queue.Empty:
|
except queue.Empty:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# Join all threads
|
||||||
|
for thread in threads:
|
||||||
|
thread.join()
|
||||||
|
|
||||||
# get final node id
|
# get final node id
|
||||||
final_node_id = parallel.end_to_node_id
|
final_node_id = parallel.end_to_node_id
|
||||||
if not final_node_id:
|
if not final_node_id:
|
||||||
@ -331,8 +342,8 @@ class GraphEngine:
|
|||||||
|
|
||||||
next_node_id = final_node_id
|
next_node_id = final_node_id
|
||||||
|
|
||||||
# if in_parallel_id and self.graph.node_parallel_mapping.get(next_node_id, '') == in_parallel_id:
|
if in_parallel_id and self.graph.node_parallel_mapping.get(next_node_id, '') != in_parallel_id:
|
||||||
# break
|
break
|
||||||
|
|
||||||
def _run_parallel_node(self,
|
def _run_parallel_node(self,
|
||||||
flask_app: Flask,
|
flask_app: Flask,
|
||||||
@ -449,6 +460,14 @@ class GraphEngine:
|
|||||||
variable_value=variable_value
|
variable_value=variable_value
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# add parallel info to run result metadata
|
||||||
|
if parallel_id and parallel_start_node_id:
|
||||||
|
if not run_result.metadata:
|
||||||
|
run_result.metadata = {}
|
||||||
|
|
||||||
|
run_result.metadata[NodeRunMetadataKey.PARALLEL_ID] = parallel_id
|
||||||
|
run_result.metadata[NodeRunMetadataKey.PARALLEL_START_NODE_ID] = parallel_start_node_id
|
||||||
|
|
||||||
yield NodeRunSucceededEvent(
|
yield NodeRunSucceededEvent(
|
||||||
id=node_instance.id,
|
id=node_instance.id,
|
||||||
node_id=node_instance.node_id,
|
node_id=node_instance.node_id,
|
||||||
|
@ -37,6 +37,10 @@ class AnswerStreamProcessor(StreamProcessor):
|
|||||||
|
|
||||||
yield event
|
yield event
|
||||||
elif isinstance(event, NodeRunStreamChunkEvent):
|
elif isinstance(event, NodeRunStreamChunkEvent):
|
||||||
|
if event.in_iteration_id:
|
||||||
|
yield event
|
||||||
|
continue
|
||||||
|
|
||||||
if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids:
|
if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids:
|
||||||
stream_out_answer_node_ids = self.current_stream_chunk_generating_node_ids[
|
stream_out_answer_node_ids = self.current_stream_chunk_generating_node_ids[
|
||||||
event.route_node_state.node_id
|
event.route_node_state.node_id
|
||||||
|
@ -157,14 +157,15 @@ class IterationNode(BaseNode):
|
|||||||
event.in_iteration_id = self.node_id
|
event.in_iteration_id = self.node_id
|
||||||
|
|
||||||
if isinstance(event, NodeRunSucceededEvent):
|
if isinstance(event, NodeRunSucceededEvent):
|
||||||
metadata = event.route_node_state.node_run_result.metadata
|
if event.route_node_state.node_run_result:
|
||||||
if not metadata:
|
metadata = event.route_node_state.node_run_result.metadata
|
||||||
metadata = {}
|
if not metadata:
|
||||||
|
metadata = {}
|
||||||
|
|
||||||
if NodeRunMetadataKey.ITERATION_ID not in metadata:
|
if NodeRunMetadataKey.ITERATION_ID not in metadata:
|
||||||
metadata[NodeRunMetadataKey.ITERATION_ID] = self.node_id
|
metadata[NodeRunMetadataKey.ITERATION_ID] = self.node_id
|
||||||
metadata[NodeRunMetadataKey.ITERATION_INDEX] = variable_pool.get_any([self.node_id, 'index'])
|
metadata[NodeRunMetadataKey.ITERATION_INDEX] = variable_pool.get_any([self.node_id, 'index'])
|
||||||
event.route_node_state.node_run_result.metadata = metadata
|
event.route_node_state.node_run_result.metadata = metadata
|
||||||
|
|
||||||
yield event
|
yield event
|
||||||
|
|
||||||
|
@ -98,9 +98,6 @@ class WorkflowEntry:
|
|||||||
if callbacks:
|
if callbacks:
|
||||||
for callback in callbacks:
|
for callback in callbacks:
|
||||||
callback.on_event(
|
callback.on_event(
|
||||||
graph=self.graph_engine.graph,
|
|
||||||
graph_init_params=graph_engine.init_params,
|
|
||||||
graph_runtime_state=graph_engine.graph_runtime_state,
|
|
||||||
event=event
|
event=event
|
||||||
)
|
)
|
||||||
yield event
|
yield event
|
||||||
@ -111,9 +108,6 @@ class WorkflowEntry:
|
|||||||
if callbacks:
|
if callbacks:
|
||||||
for callback in callbacks:
|
for callback in callbacks:
|
||||||
callback.on_event(
|
callback.on_event(
|
||||||
graph=self.graph_engine.graph,
|
|
||||||
graph_init_params=graph_engine.init_params,
|
|
||||||
graph_runtime_state=graph_engine.graph_runtime_state,
|
|
||||||
event=GraphRunFailedEvent(
|
event=GraphRunFailedEvent(
|
||||||
error=str(e)
|
error=str(e)
|
||||||
)
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user