diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index 1f0db7ff34..00b3b9f57e 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -19,6 +19,9 @@ from core.app.entities.queue_entities import ( QueueNodeFailedEvent, QueueNodeStartedEvent, QueueNodeSucceededEvent, + QueueParallelBranchRunFailedEvent, + QueueParallelBranchRunStartedEvent, + QueueParallelBranchRunSucceededEvent, QueuePingEvent, QueueStopEvent, QueueTextChunkEvent, @@ -280,6 +283,24 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa if response: yield response + elif isinstance(event, QueueParallelBranchRunStartedEvent): + if not workflow_run: + raise Exception('Workflow run not initialized.') + + yield self._workflow_parallel_branch_start_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + event=event + ) + elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent): + if not workflow_run: + raise Exception('Workflow run not initialized.') + + yield self._workflow_parallel_branch_finished_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + event=event + ) elif isinstance(event, QueueIterationStartEvent): if not workflow_run: raise Exception('Workflow run not initialized.') diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py index 41215a931a..b8b17e0896 100644 --- a/api/core/app/entities/task_entities.py +++ b/api/core/app/entities/task_entities.py @@ -221,6 +221,8 @@ class NodeStartStreamResponse(StreamResponse): extras: dict = {} parallel_id: Optional[str] = None parallel_start_node_id: Optional[str] = None + parent_parallel_id: Optional[str] = None + parent_parallel_start_node_id: Optional[str] = None event: StreamEvent = StreamEvent.NODE_STARTED workflow_run_id: str @@ -243,6 +245,8 @@ class NodeStartStreamResponse(StreamResponse): "extras": {}, "parallel_id": self.data.parallel_id, "parallel_start_node_id": self.data.parallel_start_node_id, + "parent_parallel_id": self.data.parent_parallel_id, + "parent_parallel_start_node_id": self.data.parent_parallel_start_node_id, } } @@ -274,6 +278,8 @@ class NodeFinishStreamResponse(StreamResponse): files: Optional[list[dict]] = [] parallel_id: Optional[str] = None parallel_start_node_id: Optional[str] = None + parent_parallel_id: Optional[str] = None + parent_parallel_start_node_id: Optional[str] = None event: StreamEvent = StreamEvent.NODE_FINISHED workflow_run_id: str @@ -303,6 +309,8 @@ class NodeFinishStreamResponse(StreamResponse): "files": [], "parallel_id": self.data.parallel_id, "parallel_start_node_id": self.data.parallel_start_node_id, + "parent_parallel_id": self.data.parent_parallel_id, + "parent_parallel_start_node_id": self.data.parent_parallel_start_node_id, } } diff --git a/api/core/app/task_pipeline/workflow_cycle_manage.py b/api/core/app/task_pipeline/workflow_cycle_manage.py index a7b9872d45..e4c712ccb0 100644 --- a/api/core/app/task_pipeline/workflow_cycle_manage.py +++ b/api/core/app/task_pipeline/workflow_cycle_manage.py @@ -392,6 +392,8 @@ class WorkflowCycleManage: created_at=int(workflow_node_execution.created_at.timestamp()), parallel_id=event.parallel_id, parallel_start_node_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, ), ) @@ -444,6 +446,8 @@ class WorkflowCycleManage: files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict or {}), parallel_id=event.parallel_id, parallel_start_node_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, ), ) diff --git a/api/core/workflow/graph_engine/entities/graph.py b/api/core/workflow/graph_engine/entities/graph.py index 9a79e7e630..6dcd0a52f8 100644 --- a/api/core/workflow/graph_engine/entities/graph.py +++ b/api/core/workflow/graph_engine/entities/graph.py @@ -516,22 +516,21 @@ class Graph(BaseModel): duplicate_end_node_ids[(node_id, node_id2)] = branch_node_ids for (node_id, node_id2), branch_node_ids in duplicate_end_node_ids.items(): - if node_id not in merge_branch_node_ids or node_id2 not in branch_node_ids: - continue - # check which node is after if cls._is_node2_after_node1( node1_id=node_id, node2_id=node_id2, edge_mapping=edge_mapping ): - del merge_branch_node_ids[node_id] + if node_id in merge_branch_node_ids: + del merge_branch_node_ids[node_id] elif cls._is_node2_after_node1( node1_id=node_id2, node2_id=node_id, edge_mapping=edge_mapping ): - del merge_branch_node_ids[node_id2] + if node_id2 in merge_branch_node_ids: + del merge_branch_node_ids[node_id2] branches_merge_node_ids: dict[str, str] = {} for node_id, branch_node_ids in merge_branch_node_ids.items(): diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph.py index e9637839a8..a402dc0845 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_graph.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_graph.py @@ -245,9 +245,13 @@ def test_parallels_graph(): assert graph.root_node_id == "start" for i in range(3): - assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i+1}" - assert graph.edge_mapping.get(f"llm{i+1}") is not None - assert graph.edge_mapping.get(f"llm{i+1}")[0].target_node_id == "answer" + start_edges = graph.edge_mapping.get("start") + assert start_edges is not None + assert start_edges[i].target_node_id == f"llm{i+1}" + + llm_edges = graph.edge_mapping.get(f"llm{i+1}") + assert llm_edges is not None + assert llm_edges[0].target_node_id == "answer" assert len(graph.parallel_mapping) == 1 assert len(graph.node_parallel_mapping) == 3