fix(workflow): missing parallel event in workflow app

This commit is contained in:
takatost 2024-08-30 20:04:17 +08:00
parent 77e62f7fee
commit 162e9677c7
5 changed files with 44 additions and 8 deletions

View File

@ -19,6 +19,9 @@ from core.app.entities.queue_entities import (
QueueNodeFailedEvent, QueueNodeFailedEvent,
QueueNodeStartedEvent, QueueNodeStartedEvent,
QueueNodeSucceededEvent, QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent,
QueueParallelBranchRunStartedEvent,
QueueParallelBranchRunSucceededEvent,
QueuePingEvent, QueuePingEvent,
QueueStopEvent, QueueStopEvent,
QueueTextChunkEvent, QueueTextChunkEvent,
@ -280,6 +283,24 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if response: if response:
yield response yield response
elif isinstance(event, QueueParallelBranchRunStartedEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
yield self._workflow_parallel_branch_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event
)
elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
yield self._workflow_parallel_branch_finished_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event
)
elif isinstance(event, QueueIterationStartEvent): elif isinstance(event, QueueIterationStartEvent):
if not workflow_run: if not workflow_run:
raise Exception('Workflow run not initialized.') raise Exception('Workflow run not initialized.')

View File

@ -221,6 +221,8 @@ class NodeStartStreamResponse(StreamResponse):
extras: dict = {} extras: dict = {}
parallel_id: Optional[str] = None parallel_id: Optional[str] = None
parallel_start_node_id: Optional[str] = None parallel_start_node_id: Optional[str] = None
parent_parallel_id: Optional[str] = None
parent_parallel_start_node_id: Optional[str] = None
event: StreamEvent = StreamEvent.NODE_STARTED event: StreamEvent = StreamEvent.NODE_STARTED
workflow_run_id: str workflow_run_id: str
@ -243,6 +245,8 @@ class NodeStartStreamResponse(StreamResponse):
"extras": {}, "extras": {},
"parallel_id": self.data.parallel_id, "parallel_id": self.data.parallel_id,
"parallel_start_node_id": self.data.parallel_start_node_id, "parallel_start_node_id": self.data.parallel_start_node_id,
"parent_parallel_id": self.data.parent_parallel_id,
"parent_parallel_start_node_id": self.data.parent_parallel_start_node_id,
} }
} }
@ -274,6 +278,8 @@ class NodeFinishStreamResponse(StreamResponse):
files: Optional[list[dict]] = [] files: Optional[list[dict]] = []
parallel_id: Optional[str] = None parallel_id: Optional[str] = None
parallel_start_node_id: Optional[str] = None parallel_start_node_id: Optional[str] = None
parent_parallel_id: Optional[str] = None
parent_parallel_start_node_id: Optional[str] = None
event: StreamEvent = StreamEvent.NODE_FINISHED event: StreamEvent = StreamEvent.NODE_FINISHED
workflow_run_id: str workflow_run_id: str
@ -303,6 +309,8 @@ class NodeFinishStreamResponse(StreamResponse):
"files": [], "files": [],
"parallel_id": self.data.parallel_id, "parallel_id": self.data.parallel_id,
"parallel_start_node_id": self.data.parallel_start_node_id, "parallel_start_node_id": self.data.parallel_start_node_id,
"parent_parallel_id": self.data.parent_parallel_id,
"parent_parallel_start_node_id": self.data.parent_parallel_start_node_id,
} }
} }

View File

@ -392,6 +392,8 @@ class WorkflowCycleManage:
created_at=int(workflow_node_execution.created_at.timestamp()), created_at=int(workflow_node_execution.created_at.timestamp()),
parallel_id=event.parallel_id, parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id, parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
), ),
) )
@ -444,6 +446,8 @@ class WorkflowCycleManage:
files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict or {}), files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict or {}),
parallel_id=event.parallel_id, parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id, parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
), ),
) )

View File

@ -516,22 +516,21 @@ class Graph(BaseModel):
duplicate_end_node_ids[(node_id, node_id2)] = branch_node_ids duplicate_end_node_ids[(node_id, node_id2)] = branch_node_ids
for (node_id, node_id2), branch_node_ids in duplicate_end_node_ids.items(): for (node_id, node_id2), branch_node_ids in duplicate_end_node_ids.items():
if node_id not in merge_branch_node_ids or node_id2 not in branch_node_ids:
continue
# check which node is after # check which node is after
if cls._is_node2_after_node1( if cls._is_node2_after_node1(
node1_id=node_id, node1_id=node_id,
node2_id=node_id2, node2_id=node_id2,
edge_mapping=edge_mapping edge_mapping=edge_mapping
): ):
del merge_branch_node_ids[node_id] if node_id in merge_branch_node_ids:
del merge_branch_node_ids[node_id]
elif cls._is_node2_after_node1( elif cls._is_node2_after_node1(
node1_id=node_id2, node1_id=node_id2,
node2_id=node_id, node2_id=node_id,
edge_mapping=edge_mapping edge_mapping=edge_mapping
): ):
del merge_branch_node_ids[node_id2] if node_id2 in merge_branch_node_ids:
del merge_branch_node_ids[node_id2]
branches_merge_node_ids: dict[str, str] = {} branches_merge_node_ids: dict[str, str] = {}
for node_id, branch_node_ids in merge_branch_node_ids.items(): for node_id, branch_node_ids in merge_branch_node_ids.items():

View File

@ -245,9 +245,13 @@ def test_parallels_graph():
assert graph.root_node_id == "start" assert graph.root_node_id == "start"
for i in range(3): for i in range(3):
assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i+1}" start_edges = graph.edge_mapping.get("start")
assert graph.edge_mapping.get(f"llm{i+1}") is not None assert start_edges is not None
assert graph.edge_mapping.get(f"llm{i+1}")[0].target_node_id == "answer" assert start_edges[i].target_node_id == f"llm{i+1}"
llm_edges = graph.edge_mapping.get(f"llm{i+1}")
assert llm_edges is not None
assert llm_edges[0].target_node_id == "answer"
assert len(graph.parallel_mapping) == 1 assert len(graph.parallel_mapping) == 1
assert len(graph.node_parallel_mapping) == 3 assert len(graph.node_parallel_mapping) == 3