feat(workflow): integrate parallel into workflow apps

This commit is contained in:
takatost 2024-08-16 21:33:09 +08:00
parent 1973f5003b
commit 352c45c8a2
13 changed files with 233 additions and 94 deletions

View File

@ -21,6 +21,9 @@ from core.app.entities.queue_entities import (
QueueNodeFailedEvent, QueueNodeFailedEvent,
QueueNodeStartedEvent, QueueNodeStartedEvent,
QueueNodeSucceededEvent, QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent,
QueueParallelBranchRunStartedEvent,
QueueParallelBranchRunSucceededEvent,
QueuePingEvent, QueuePingEvent,
QueueRetrieverResourcesEvent, QueueRetrieverResourcesEvent,
QueueStopEvent, QueueStopEvent,
@ -304,6 +307,24 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
if response: if response:
yield response yield response
elif isinstance(event, QueueParallelBranchRunStartedEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
yield self._workflow_parallel_branch_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event
)
elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
yield self._workflow_parallel_branch_finished_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event
)
elif isinstance(event, QueueIterationStartEvent): elif isinstance(event, QueueIterationStartEvent):
if not workflow_run: if not workflow_run:
raise Exception('Workflow run not initialized.') raise Exception('Workflow run not initialized.')

View File

@ -13,6 +13,7 @@ from core.app.entities.queue_entities import (
QueueNodeSucceededEvent, QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent, QueueParallelBranchRunFailedEvent,
QueueParallelBranchRunStartedEvent, QueueParallelBranchRunStartedEvent,
QueueParallelBranchRunSucceededEvent,
QueueRetrieverResourcesEvent, QueueRetrieverResourcesEvent,
QueueTextChunkEvent, QueueTextChunkEvent,
QueueWorkflowFailedEvent, QueueWorkflowFailedEvent,
@ -261,14 +262,16 @@ class WorkflowBasedAppRunner(AppRunner):
self._publish_event( self._publish_event(
QueueParallelBranchRunStartedEvent( QueueParallelBranchRunStartedEvent(
parallel_id=event.parallel_id, parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id parallel_start_node_id=event.parallel_start_node_id,
in_iteration_id=event.in_iteration_id
) )
) )
elif isinstance(event, ParallelBranchRunSucceededEvent): elif isinstance(event, ParallelBranchRunSucceededEvent):
self._publish_event( self._publish_event(
QueueParallelBranchRunStartedEvent( QueueParallelBranchRunSucceededEvent(
parallel_id=event.parallel_id, parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id parallel_start_node_id=event.parallel_start_node_id,
in_iteration_id=event.in_iteration_id
) )
) )
elif isinstance(event, ParallelBranchRunFailedEvent): elif isinstance(event, ParallelBranchRunFailedEvent):
@ -276,6 +279,7 @@ class WorkflowBasedAppRunner(AppRunner):
QueueParallelBranchRunFailedEvent( QueueParallelBranchRunFailedEvent(
parallel_id=event.parallel_id, parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id, parallel_start_node_id=event.parallel_start_node_id,
in_iteration_id=event.in_iteration_id,
error=event.error error=event.error
) )
) )

View File

@ -15,10 +15,10 @@ from core.workflow.graph_engine.entities.event import (
NodeRunStartedEvent, NodeRunStartedEvent,
NodeRunStreamChunkEvent, NodeRunStreamChunkEvent,
NodeRunSucceededEvent, NodeRunSucceededEvent,
ParallelBranchRunFailedEvent,
ParallelBranchRunStartedEvent,
ParallelBranchRunSucceededEvent,
) )
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
_TEXT_COLOR_MAPPING = { _TEXT_COLOR_MAPPING = {
"blue": "36;1", "blue": "36;1",
@ -36,9 +36,6 @@ class WorkflowLoggingCallback(WorkflowCallback):
def on_event( def on_event(
self, self,
graph: Graph,
graph_init_params: GraphInitParams,
graph_runtime_state: GraphRuntimeState,
event: GraphEngineEvent event: GraphEngineEvent
) -> None: ) -> None:
if isinstance(event, GraphRunStartedEvent): if isinstance(event, GraphRunStartedEvent):
@ -49,49 +46,38 @@ class WorkflowLoggingCallback(WorkflowCallback):
self.print_text(f"\n[on_workflow_run_failed] reason: {event.error}", color='red') self.print_text(f"\n[on_workflow_run_failed] reason: {event.error}", color='red')
elif isinstance(event, NodeRunStartedEvent): elif isinstance(event, NodeRunStartedEvent):
self.on_workflow_node_execute_started( self.on_workflow_node_execute_started(
graph=graph,
event=event event=event
) )
elif isinstance(event, NodeRunSucceededEvent): elif isinstance(event, NodeRunSucceededEvent):
self.on_workflow_node_execute_succeeded( self.on_workflow_node_execute_succeeded(
graph=graph,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
event=event event=event
) )
elif isinstance(event, NodeRunFailedEvent): elif isinstance(event, NodeRunFailedEvent):
self.on_workflow_node_execute_failed( self.on_workflow_node_execute_failed(
graph=graph,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
event=event event=event
) )
elif isinstance(event, NodeRunStreamChunkEvent): elif isinstance(event, NodeRunStreamChunkEvent):
self.on_node_text_chunk( self.on_node_text_chunk(
graph=graph, event=event
graph_init_params=graph_init_params, )
graph_runtime_state=graph_runtime_state, elif isinstance(event, ParallelBranchRunStartedEvent):
self.on_workflow_parallel_started(
event=event
)
elif isinstance(event, ParallelBranchRunSucceededEvent | ParallelBranchRunFailedEvent):
self.on_workflow_parallel_completed(
event=event event=event
) )
elif isinstance(event, IterationRunStartedEvent): elif isinstance(event, IterationRunStartedEvent):
self.on_workflow_iteration_started( self.on_workflow_iteration_started(
graph=graph,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
event=event event=event
) )
elif isinstance(event, IterationRunNextEvent): elif isinstance(event, IterationRunNextEvent):
self.on_workflow_iteration_next( self.on_workflow_iteration_next(
graph=graph,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
event=event event=event
) )
elif isinstance(event, IterationRunSucceededEvent | IterationRunFailedEvent): elif isinstance(event, IterationRunSucceededEvent | IterationRunFailedEvent):
self.on_workflow_iteration_completed( self.on_workflow_iteration_completed(
graph=graph,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
event=event event=event
) )
else: else:
@ -99,39 +85,29 @@ class WorkflowLoggingCallback(WorkflowCallback):
def on_workflow_node_execute_started( def on_workflow_node_execute_started(
self, self,
graph: Graph,
event: NodeRunStartedEvent event: NodeRunStartedEvent
) -> None: ) -> None:
""" """
Workflow node execute started Workflow node execute started
""" """
route_node_state = event.route_node_state route_node_state = event.route_node_state
node_config = graph.node_id_config_mapping.get(route_node_state.node_id) node_type = event.node_type.value
node_type = None
if node_config:
node_type = node_config.get("data", {}).get("type")
self.print_text("\n[on_workflow_node_execute_started]", color='yellow') self.print_text("\n[NodeRunStartedEvent]", color='yellow')
self.print_text(f"Node ID: {route_node_state.node_id}", color='yellow') self.print_text(f"Node ID: {route_node_state.node_id}", color='yellow')
self.print_text(f"Type: {node_type}", color='yellow') self.print_text(f"Type: {node_type}", color='yellow')
def on_workflow_node_execute_succeeded( def on_workflow_node_execute_succeeded(
self, self,
graph: Graph,
graph_init_params: GraphInitParams,
graph_runtime_state: GraphRuntimeState,
event: NodeRunSucceededEvent event: NodeRunSucceededEvent
) -> None: ) -> None:
""" """
Workflow node execute succeeded Workflow node execute succeeded
""" """
route_node_state = event.route_node_state route_node_state = event.route_node_state
node_config = graph.node_id_config_mapping.get(route_node_state.node_id) node_type = event.node_type.value
node_type = None
if node_config:
node_type = node_config.get("data", {}).get("type")
self.print_text("\n[on_workflow_node_execute_succeeded]", color='green') self.print_text("\n[NodeRunSucceededEvent]", color='green')
self.print_text(f"Node ID: {route_node_state.node_id}", color='green') self.print_text(f"Node ID: {route_node_state.node_id}", color='green')
self.print_text(f"Type: {node_type}", color='green') self.print_text(f"Type: {node_type}", color='green')
@ -150,21 +126,15 @@ class WorkflowLoggingCallback(WorkflowCallback):
def on_workflow_node_execute_failed( def on_workflow_node_execute_failed(
self, self,
graph: Graph,
graph_init_params: GraphInitParams,
graph_runtime_state: GraphRuntimeState,
event: NodeRunFailedEvent event: NodeRunFailedEvent
) -> None: ) -> None:
""" """
Workflow node execute failed Workflow node execute failed
""" """
route_node_state = event.route_node_state route_node_state = event.route_node_state
node_config = graph.node_id_config_mapping.get(route_node_state.node_id) node_type = event.node_type.value
node_type = None
if node_config:
node_type = node_config.get("data", {}).get("type")
self.print_text("\n[on_workflow_node_execute_failed]", color='red') self.print_text("\n[NodeRunFailedEvent]", color='red')
self.print_text(f"Node ID: {route_node_state.node_id}", color='red') self.print_text(f"Node ID: {route_node_state.node_id}", color='red')
self.print_text(f"Type: {node_type}", color='red') self.print_text(f"Type: {node_type}", color='red')
@ -181,9 +151,6 @@ class WorkflowLoggingCallback(WorkflowCallback):
def on_node_text_chunk( def on_node_text_chunk(
self, self,
graph: Graph,
graph_init_params: GraphInitParams,
graph_runtime_state: GraphRuntimeState,
event: NodeRunStreamChunkEvent event: NodeRunStreamChunkEvent
) -> None: ) -> None:
""" """
@ -192,7 +159,7 @@ class WorkflowLoggingCallback(WorkflowCallback):
route_node_state = event.route_node_state route_node_state = event.route_node_state
if not self.current_node_id or self.current_node_id != route_node_state.node_id: if not self.current_node_id or self.current_node_id != route_node_state.node_id:
self.current_node_id = route_node_state.node_id self.current_node_id = route_node_state.node_id
self.print_text('\n[on_node_text_chunk]') self.print_text('\n[NodeRunStreamChunkEvent]')
self.print_text(f"Node ID: {route_node_state.node_id}") self.print_text(f"Node ID: {route_node_state.node_id}")
node_run_result = route_node_state.node_run_result node_run_result = route_node_state.node_run_result
@ -202,43 +169,69 @@ class WorkflowLoggingCallback(WorkflowCallback):
self.print_text(event.chunk_content, color="pink", end="") self.print_text(event.chunk_content, color="pink", end="")
def on_workflow_parallel_started(
self,
event: ParallelBranchRunStartedEvent
) -> None:
"""
Publish parallel started
"""
self.print_text("\n[ParallelBranchRunStartedEvent]", color='blue')
self.print_text(f"Parallel ID: {event.parallel_id}", color='blue')
self.print_text(f"Branch ID: {event.parallel_start_node_id}", color='blue')
if event.in_iteration_id:
self.print_text(f"Iteration ID: {event.in_iteration_id}", color='blue')
def on_workflow_parallel_completed(
self,
event: ParallelBranchRunSucceededEvent | ParallelBranchRunFailedEvent
) -> None:
"""
Publish parallel completed
"""
if isinstance(event, ParallelBranchRunSucceededEvent):
color = 'blue'
elif isinstance(event, ParallelBranchRunFailedEvent):
color = 'red'
self.print_text("\n[ParallelBranchRunSucceededEvent]" if isinstance(event, ParallelBranchRunSucceededEvent) else "\n[ParallelBranchRunFailedEvent]", color=color)
self.print_text(f"Parallel ID: {event.parallel_id}", color=color)
self.print_text(f"Branch ID: {event.parallel_start_node_id}", color=color)
if event.in_iteration_id:
self.print_text(f"Iteration ID: {event.in_iteration_id}", color=color)
if isinstance(event, ParallelBranchRunFailedEvent):
self.print_text(f"Error: {event.error}", color=color)
def on_workflow_iteration_started( def on_workflow_iteration_started(
self, self,
graph: Graph,
graph_init_params: GraphInitParams,
graph_runtime_state: GraphRuntimeState,
event: IterationRunStartedEvent event: IterationRunStartedEvent
) -> None: ) -> None:
""" """
Publish iteration started Publish iteration started
""" """
self.print_text("\n[on_workflow_iteration_started]", color='blue') self.print_text("\n[IterationRunStartedEvent]", color='blue')
self.print_text(f"Node ID: {event.iteration_id}", color='blue') self.print_text(f"Iteration Node ID: {event.iteration_id}", color='blue')
def on_workflow_iteration_next( def on_workflow_iteration_next(
self, self,
graph: Graph,
graph_init_params: GraphInitParams,
graph_runtime_state: GraphRuntimeState,
event: IterationRunNextEvent event: IterationRunNextEvent
) -> None: ) -> None:
""" """
Publish iteration next Publish iteration next
""" """
self.print_text("\n[on_workflow_iteration_next]", color='blue') self.print_text("\n[IterationRunNextEvent]", color='blue')
self.print_text(f"Node ID: {event.iteration_id}", color='blue') self.print_text(f"Iteration Node ID: {event.iteration_id}", color='blue')
self.print_text(f"Iteration Index: {event.index}", color='blue')
def on_workflow_iteration_completed( def on_workflow_iteration_completed(
self, self,
graph: Graph,
graph_init_params: GraphInitParams,
graph_runtime_state: GraphRuntimeState,
event: IterationRunSucceededEvent | IterationRunFailedEvent event: IterationRunSucceededEvent | IterationRunFailedEvent
) -> None: ) -> None:
""" """
Publish iteration completed Publish iteration completed
""" """
self.print_text("\n[on_workflow_iteration_completed]", color='blue') self.print_text("\n[IterationRunSucceededEvent]" if isinstance(event, IterationRunSucceededEvent) else "\n[IterationRunFailedEvent]", color='blue')
self.print_text(f"Node ID: {event.iteration_id}", color='blue') self.print_text(f"Node ID: {event.iteration_id}", color='blue')
def print_text( def print_text(

View File

@ -334,7 +334,7 @@ class QueueStopEvent(AppQueueEvent):
QueueStopEvent.StopBy.OUTPUT_MODERATION: 'Stopped by output moderation.', QueueStopEvent.StopBy.OUTPUT_MODERATION: 'Stopped by output moderation.',
QueueStopEvent.StopBy.INPUT_MODERATION: 'Stopped by input moderation.' QueueStopEvent.StopBy.INPUT_MODERATION: 'Stopped by input moderation.'
} }
return reason_mapping.get(self.stopped_by, 'Stopped by unknown reason.') return reason_mapping.get(self.stopped_by, 'Stopped by unknown reason.')
@ -370,6 +370,8 @@ class QueueParallelBranchRunStartedEvent(AppQueueEvent):
parallel_id: str parallel_id: str
parallel_start_node_id: str parallel_start_node_id: str
in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration"""
class QueueParallelBranchRunSucceededEvent(AppQueueEvent): class QueueParallelBranchRunSucceededEvent(AppQueueEvent):
@ -380,6 +382,8 @@ class QueueParallelBranchRunSucceededEvent(AppQueueEvent):
parallel_id: str parallel_id: str
parallel_start_node_id: str parallel_start_node_id: str
in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration"""
class QueueParallelBranchRunFailedEvent(AppQueueEvent): class QueueParallelBranchRunFailedEvent(AppQueueEvent):
@ -390,4 +394,6 @@ class QueueParallelBranchRunFailedEvent(AppQueueEvent):
parallel_id: str parallel_id: str
parallel_start_node_id: str parallel_start_node_id: str
in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration"""
error: str error: str

View File

@ -47,6 +47,8 @@ class StreamEvent(Enum):
WORKFLOW_FINISHED = "workflow_finished" WORKFLOW_FINISHED = "workflow_finished"
NODE_STARTED = "node_started" NODE_STARTED = "node_started"
NODE_FINISHED = "node_finished" NODE_FINISHED = "node_finished"
PARALLEL_BRANCH_STARTED = "parallel_branch_started"
PARALLEL_BRANCH_FINISHED = "parallel_branch_finished"
ITERATION_STARTED = "iteration_started" ITERATION_STARTED = "iteration_started"
ITERATION_NEXT = "iteration_next" ITERATION_NEXT = "iteration_next"
ITERATION_COMPLETED = "iteration_completed" ITERATION_COMPLETED = "iteration_completed"
@ -295,6 +297,46 @@ class NodeFinishStreamResponse(StreamResponse):
"files": [] "files": []
} }
} }
class ParallelBranchStartStreamResponse(StreamResponse):
"""
ParallelBranchStartStreamResponse entity
"""
class Data(BaseModel):
"""
Data entity
"""
parallel_id: str
parallel_branch_id: str
iteration_id: Optional[str] = None
created_at: int
event: StreamEvent = StreamEvent.PARALLEL_BRANCH_STARTED
workflow_run_id: str
data: Data
class ParallelBranchFinishedStreamResponse(StreamResponse):
"""
ParallelBranchFinishedStreamResponse entity
"""
class Data(BaseModel):
"""
Data entity
"""
parallel_id: str
parallel_branch_id: str
iteration_id: Optional[str] = None
status: str
error: Optional[str] = None
created_at: int
event: StreamEvent = StreamEvent.PARALLEL_BRANCH_FINISHED
workflow_run_id: str
data: Data
class IterationNodeStartStreamResponse(StreamResponse): class IterationNodeStartStreamResponse(StreamResponse):

View File

@ -11,6 +11,9 @@ from core.app.entities.queue_entities import (
QueueNodeFailedEvent, QueueNodeFailedEvent,
QueueNodeStartedEvent, QueueNodeStartedEvent,
QueueNodeSucceededEvent, QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent,
QueueParallelBranchRunStartedEvent,
QueueParallelBranchRunSucceededEvent,
) )
from core.app.entities.task_entities import ( from core.app.entities.task_entities import (
IterationNodeCompletedStreamResponse, IterationNodeCompletedStreamResponse,
@ -18,6 +21,8 @@ from core.app.entities.task_entities import (
IterationNodeStartStreamResponse, IterationNodeStartStreamResponse,
NodeFinishStreamResponse, NodeFinishStreamResponse,
NodeStartStreamResponse, NodeStartStreamResponse,
ParallelBranchFinishedStreamResponse,
ParallelBranchStartStreamResponse,
WorkflowFinishStreamResponse, WorkflowFinishStreamResponse,
WorkflowStartStreamResponse, WorkflowStartStreamResponse,
WorkflowTaskState, WorkflowTaskState,
@ -433,6 +438,56 @@ class WorkflowCycleManage:
files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict or {}), files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict or {}),
), ),
) )
def _workflow_parallel_branch_start_to_stream_response(
self,
task_id: str,
workflow_run: WorkflowRun,
event: QueueParallelBranchRunStartedEvent
) -> ParallelBranchStartStreamResponse:
"""
Workflow parallel branch start to stream response
:param task_id: task id
:param workflow_run: workflow run
:param event: parallel branch run started event
:return:
"""
return ParallelBranchStartStreamResponse(
task_id=task_id,
workflow_run_id=workflow_run.id,
data=ParallelBranchStartStreamResponse.Data(
parallel_id=event.parallel_id,
parallel_branch_id=event.parallel_start_node_id,
iteration_id=event.in_iteration_id,
created_at=int(time.time()),
)
)
def _workflow_parallel_branch_finished_to_stream_response(
self,
task_id: str,
workflow_run: WorkflowRun,
event: QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent
) -> ParallelBranchFinishedStreamResponse:
"""
Workflow parallel branch finished to stream response
:param task_id: task id
:param workflow_run: workflow run
:param event: parallel branch run succeeded or failed event
:return:
"""
return ParallelBranchFinishedStreamResponse(
task_id=task_id,
workflow_run_id=workflow_run.id,
data=ParallelBranchFinishedStreamResponse.Data(
parallel_id=event.parallel_id,
parallel_branch_id=event.parallel_start_node_id,
iteration_id=event.in_iteration_id,
status='succeeded' if isinstance(event, QueueParallelBranchRunSucceededEvent) else 'failed',
error=event.error if isinstance(event, QueueParallelBranchRunFailedEvent) else None,
created_at=int(time.time()),
)
)
def _workflow_iteration_start_to_stream_response( def _workflow_iteration_start_to_stream_response(
self, self,

View File

@ -1,18 +1,12 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from core.workflow.graph_engine.entities.event import GraphEngineEvent from core.workflow.graph_engine.entities.event import GraphEngineEvent
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
class WorkflowCallback(ABC): class WorkflowCallback(ABC):
@abstractmethod @abstractmethod
def on_event( def on_event(
self, self,
graph: Graph,
graph_init_params: GraphInitParams,
graph_runtime_state: GraphRuntimeState,
event: GraphEngineEvent event: GraphEngineEvent
) -> None: ) -> None:
""" """

View File

@ -56,6 +56,8 @@ class NodeRunMetadataKey(Enum):
TOOL_INFO = 'tool_info' TOOL_INFO = 'tool_info'
ITERATION_ID = 'iteration_id' ITERATION_ID = 'iteration_id'
ITERATION_INDEX = 'iteration_index' ITERATION_INDEX = 'iteration_index'
PARALLEL_ID = 'parallel_id'
PARALLEL_START_NODE_ID = 'parallel_start_node_id'
class NodeRunResult(BaseModel): class NodeRunResult(BaseModel):

View File

@ -99,6 +99,10 @@ class Graph(BaseModel):
if target_node_id not in reverse_edge_mapping: if target_node_id not in reverse_edge_mapping:
reverse_edge_mapping[target_node_id] = [] reverse_edge_mapping[target_node_id] = []
# is target node id in source node id edge mapping
if any(graph_edge.target_node_id == target_node_id for graph_edge in edge_mapping[source_node_id]):
continue
target_edge_ids.add(target_node_id) target_edge_ids.add(target_node_id)
# parse run condition # parse run condition

View File

@ -244,6 +244,7 @@ class GraphEngine:
if len(edge_mappings) == 1: if len(edge_mappings) == 1:
edge = edge_mappings[0] edge = edge_mappings[0]
if edge.run_condition: if edge.run_condition:
result = ConditionManager.get_condition_handler( result = ConditionManager.get_condition_handler(
init_params=self.init_params, init_params=self.init_params,
@ -296,14 +297,20 @@ class GraphEngine:
# run parallel nodes, run in new thread and use queue to get results # run parallel nodes, run in new thread and use queue to get results
q: queue.Queue = queue.Queue() q: queue.Queue = queue.Queue()
# Create a list to store the threads
threads = []
# new thread # new thread
for edge in edge_mappings: for edge in edge_mappings:
threading.Thread(target=self._run_parallel_node, kwargs={ thread = threading.Thread(target=self._run_parallel_node, kwargs={
'flask_app': current_app._get_current_object(), # type: ignore[attr-defined] 'flask_app': current_app._get_current_object(), # type: ignore[attr-defined]
'parallel_id': parallel_id, 'parallel_id': parallel_id,
'parallel_start_node_id': edge.target_node_id, 'parallel_start_node_id': edge.target_node_id,
'q': q 'q': q
}).start() })
threads.append(thread)
thread.start()
succeeded_count = 0 succeeded_count = 0
while True: while True:
@ -315,8 +322,8 @@ class GraphEngine:
yield event yield event
if isinstance(event, ParallelBranchRunSucceededEvent): if isinstance(event, ParallelBranchRunSucceededEvent):
succeeded_count += 1 succeeded_count += 1
if succeeded_count == len(edge_mappings): if succeeded_count == len(threads):
break q.put(None)
continue continue
elif isinstance(event, ParallelBranchRunFailedEvent): elif isinstance(event, ParallelBranchRunFailedEvent):
@ -324,6 +331,10 @@ class GraphEngine:
except queue.Empty: except queue.Empty:
continue continue
# Join all threads
for thread in threads:
thread.join()
# get final node id # get final node id
final_node_id = parallel.end_to_node_id final_node_id = parallel.end_to_node_id
if not final_node_id: if not final_node_id:
@ -331,8 +342,8 @@ class GraphEngine:
next_node_id = final_node_id next_node_id = final_node_id
# if in_parallel_id and self.graph.node_parallel_mapping.get(next_node_id, '') == in_parallel_id: if in_parallel_id and self.graph.node_parallel_mapping.get(next_node_id, '') != in_parallel_id:
# break break
def _run_parallel_node(self, def _run_parallel_node(self,
flask_app: Flask, flask_app: Flask,
@ -449,6 +460,14 @@ class GraphEngine:
variable_value=variable_value variable_value=variable_value
) )
# add parallel info to run result metadata
if parallel_id and parallel_start_node_id:
if not run_result.metadata:
run_result.metadata = {}
run_result.metadata[NodeRunMetadataKey.PARALLEL_ID] = parallel_id
run_result.metadata[NodeRunMetadataKey.PARALLEL_START_NODE_ID] = parallel_start_node_id
yield NodeRunSucceededEvent( yield NodeRunSucceededEvent(
id=node_instance.id, id=node_instance.id,
node_id=node_instance.node_id, node_id=node_instance.node_id,

View File

@ -37,6 +37,10 @@ class AnswerStreamProcessor(StreamProcessor):
yield event yield event
elif isinstance(event, NodeRunStreamChunkEvent): elif isinstance(event, NodeRunStreamChunkEvent):
if event.in_iteration_id:
yield event
continue
if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids: if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids:
stream_out_answer_node_ids = self.current_stream_chunk_generating_node_ids[ stream_out_answer_node_ids = self.current_stream_chunk_generating_node_ids[
event.route_node_state.node_id event.route_node_state.node_id

View File

@ -157,14 +157,15 @@ class IterationNode(BaseNode):
event.in_iteration_id = self.node_id event.in_iteration_id = self.node_id
if isinstance(event, NodeRunSucceededEvent): if isinstance(event, NodeRunSucceededEvent):
metadata = event.route_node_state.node_run_result.metadata if event.route_node_state.node_run_result:
if not metadata: metadata = event.route_node_state.node_run_result.metadata
metadata = {} if not metadata:
metadata = {}
if NodeRunMetadataKey.ITERATION_ID not in metadata: if NodeRunMetadataKey.ITERATION_ID not in metadata:
metadata[NodeRunMetadataKey.ITERATION_ID] = self.node_id metadata[NodeRunMetadataKey.ITERATION_ID] = self.node_id
metadata[NodeRunMetadataKey.ITERATION_INDEX] = variable_pool.get_any([self.node_id, 'index']) metadata[NodeRunMetadataKey.ITERATION_INDEX] = variable_pool.get_any([self.node_id, 'index'])
event.route_node_state.node_run_result.metadata = metadata event.route_node_state.node_run_result.metadata = metadata
yield event yield event

View File

@ -98,9 +98,6 @@ class WorkflowEntry:
if callbacks: if callbacks:
for callback in callbacks: for callback in callbacks:
callback.on_event( callback.on_event(
graph=self.graph_engine.graph,
graph_init_params=graph_engine.init_params,
graph_runtime_state=graph_engine.graph_runtime_state,
event=event event=event
) )
yield event yield event
@ -111,9 +108,6 @@ class WorkflowEntry:
if callbacks: if callbacks:
for callback in callbacks: for callback in callbacks:
callback.on_event( callback.on_event(
graph=self.graph_engine.graph,
graph_init_params=graph_engine.init_params,
graph_runtime_state=graph_engine.graph_runtime_state,
event=GraphRunFailedEvent( event=GraphRunFailedEvent(
error=str(e) error=str(e)
) )