fix(workflow): missing parallel event in workflow app

This commit is contained in:
takatost 2024-08-30 20:04:17 +08:00
parent 77e62f7fee
commit 162e9677c7
5 changed files with 44 additions and 8 deletions

View File

@ -19,6 +19,9 @@ from core.app.entities.queue_entities import (
QueueNodeFailedEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent,
QueueParallelBranchRunStartedEvent,
QueueParallelBranchRunSucceededEvent,
QueuePingEvent,
QueueStopEvent,
QueueTextChunkEvent,
@ -280,6 +283,24 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if response:
yield response
elif isinstance(event, QueueParallelBranchRunStartedEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
yield self._workflow_parallel_branch_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event
)
elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
yield self._workflow_parallel_branch_finished_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event
)
elif isinstance(event, QueueIterationStartEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')

View File

@ -221,6 +221,8 @@ class NodeStartStreamResponse(StreamResponse):
extras: dict = {}
parallel_id: Optional[str] = None
parallel_start_node_id: Optional[str] = None
parent_parallel_id: Optional[str] = None
parent_parallel_start_node_id: Optional[str] = None
event: StreamEvent = StreamEvent.NODE_STARTED
workflow_run_id: str
@ -243,6 +245,8 @@ class NodeStartStreamResponse(StreamResponse):
"extras": {},
"parallel_id": self.data.parallel_id,
"parallel_start_node_id": self.data.parallel_start_node_id,
"parent_parallel_id": self.data.parent_parallel_id,
"parent_parallel_start_node_id": self.data.parent_parallel_start_node_id,
}
}
@ -274,6 +278,8 @@ class NodeFinishStreamResponse(StreamResponse):
files: Optional[list[dict]] = []
parallel_id: Optional[str] = None
parallel_start_node_id: Optional[str] = None
parent_parallel_id: Optional[str] = None
parent_parallel_start_node_id: Optional[str] = None
event: StreamEvent = StreamEvent.NODE_FINISHED
workflow_run_id: str
@ -303,6 +309,8 @@ class NodeFinishStreamResponse(StreamResponse):
"files": [],
"parallel_id": self.data.parallel_id,
"parallel_start_node_id": self.data.parallel_start_node_id,
"parent_parallel_id": self.data.parent_parallel_id,
"parent_parallel_start_node_id": self.data.parent_parallel_start_node_id,
}
}

View File

@ -392,6 +392,8 @@ class WorkflowCycleManage:
created_at=int(workflow_node_execution.created_at.timestamp()),
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
),
)
@ -444,6 +446,8 @@ class WorkflowCycleManage:
files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict or {}),
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
),
)

View File

@ -516,22 +516,21 @@ class Graph(BaseModel):
duplicate_end_node_ids[(node_id, node_id2)] = branch_node_ids
for (node_id, node_id2), branch_node_ids in duplicate_end_node_ids.items():
if node_id not in merge_branch_node_ids or node_id2 not in branch_node_ids:
continue
# check which node is after
if cls._is_node2_after_node1(
node1_id=node_id,
node2_id=node_id2,
edge_mapping=edge_mapping
):
del merge_branch_node_ids[node_id]
if node_id in merge_branch_node_ids:
del merge_branch_node_ids[node_id]
elif cls._is_node2_after_node1(
node1_id=node_id2,
node2_id=node_id,
edge_mapping=edge_mapping
):
del merge_branch_node_ids[node_id2]
if node_id2 in merge_branch_node_ids:
del merge_branch_node_ids[node_id2]
branches_merge_node_ids: dict[str, str] = {}
for node_id, branch_node_ids in merge_branch_node_ids.items():

View File

@ -245,9 +245,13 @@ def test_parallels_graph():
assert graph.root_node_id == "start"
for i in range(3):
assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i+1}"
assert graph.edge_mapping.get(f"llm{i+1}") is not None
assert graph.edge_mapping.get(f"llm{i+1}")[0].target_node_id == "answer"
start_edges = graph.edge_mapping.get("start")
assert start_edges is not None
assert start_edges[i].target_node_id == f"llm{i+1}"
llm_edges = graph.edge_mapping.get(f"llm{i+1}")
assert llm_edges is not None
assert llm_edges[0].target_node_id == "answer"
assert len(graph.parallel_mapping) == 1
assert len(graph.node_parallel_mapping) == 3