mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-13 19:56:04 +08:00
fix(workflow): missing parallel event in workflow app
This commit is contained in:
parent
77e62f7fee
commit
162e9677c7
@ -19,6 +19,9 @@ from core.app.entities.queue_entities import (
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueParallelBranchRunFailedEvent,
|
||||
QueueParallelBranchRunStartedEvent,
|
||||
QueueParallelBranchRunSucceededEvent,
|
||||
QueuePingEvent,
|
||||
QueueStopEvent,
|
||||
QueueTextChunkEvent,
|
||||
@ -280,6 +283,24 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
|
||||
if response:
|
||||
yield response
|
||||
elif isinstance(event, QueueParallelBranchRunStartedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
|
||||
yield self._workflow_parallel_branch_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event
|
||||
)
|
||||
elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
|
||||
yield self._workflow_parallel_branch_finished_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event
|
||||
)
|
||||
elif isinstance(event, QueueIterationStartEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
|
@ -221,6 +221,8 @@ class NodeStartStreamResponse(StreamResponse):
|
||||
extras: dict = {}
|
||||
parallel_id: Optional[str] = None
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
parent_parallel_id: Optional[str] = None
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
|
||||
event: StreamEvent = StreamEvent.NODE_STARTED
|
||||
workflow_run_id: str
|
||||
@ -243,6 +245,8 @@ class NodeStartStreamResponse(StreamResponse):
|
||||
"extras": {},
|
||||
"parallel_id": self.data.parallel_id,
|
||||
"parallel_start_node_id": self.data.parallel_start_node_id,
|
||||
"parent_parallel_id": self.data.parent_parallel_id,
|
||||
"parent_parallel_start_node_id": self.data.parent_parallel_start_node_id,
|
||||
}
|
||||
}
|
||||
|
||||
@ -274,6 +278,8 @@ class NodeFinishStreamResponse(StreamResponse):
|
||||
files: Optional[list[dict]] = []
|
||||
parallel_id: Optional[str] = None
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
parent_parallel_id: Optional[str] = None
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
|
||||
event: StreamEvent = StreamEvent.NODE_FINISHED
|
||||
workflow_run_id: str
|
||||
@ -303,6 +309,8 @@ class NodeFinishStreamResponse(StreamResponse):
|
||||
"files": [],
|
||||
"parallel_id": self.data.parallel_id,
|
||||
"parallel_start_node_id": self.data.parallel_start_node_id,
|
||||
"parent_parallel_id": self.data.parent_parallel_id,
|
||||
"parent_parallel_start_node_id": self.data.parent_parallel_start_node_id,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -392,6 +392,8 @@ class WorkflowCycleManage:
|
||||
created_at=int(workflow_node_execution.created_at.timestamp()),
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
),
|
||||
)
|
||||
|
||||
@ -444,6 +446,8 @@ class WorkflowCycleManage:
|
||||
files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict or {}),
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
),
|
||||
)
|
||||
|
||||
|
@ -516,22 +516,21 @@ class Graph(BaseModel):
|
||||
duplicate_end_node_ids[(node_id, node_id2)] = branch_node_ids
|
||||
|
||||
for (node_id, node_id2), branch_node_ids in duplicate_end_node_ids.items():
|
||||
if node_id not in merge_branch_node_ids or node_id2 not in branch_node_ids:
|
||||
continue
|
||||
|
||||
# check which node is after
|
||||
if cls._is_node2_after_node1(
|
||||
node1_id=node_id,
|
||||
node2_id=node_id2,
|
||||
edge_mapping=edge_mapping
|
||||
):
|
||||
del merge_branch_node_ids[node_id]
|
||||
if node_id in merge_branch_node_ids:
|
||||
del merge_branch_node_ids[node_id]
|
||||
elif cls._is_node2_after_node1(
|
||||
node1_id=node_id2,
|
||||
node2_id=node_id,
|
||||
edge_mapping=edge_mapping
|
||||
):
|
||||
del merge_branch_node_ids[node_id2]
|
||||
if node_id2 in merge_branch_node_ids:
|
||||
del merge_branch_node_ids[node_id2]
|
||||
|
||||
branches_merge_node_ids: dict[str, str] = {}
|
||||
for node_id, branch_node_ids in merge_branch_node_ids.items():
|
||||
|
@ -245,9 +245,13 @@ def test_parallels_graph():
|
||||
|
||||
assert graph.root_node_id == "start"
|
||||
for i in range(3):
|
||||
assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i+1}"
|
||||
assert graph.edge_mapping.get(f"llm{i+1}") is not None
|
||||
assert graph.edge_mapping.get(f"llm{i+1}")[0].target_node_id == "answer"
|
||||
start_edges = graph.edge_mapping.get("start")
|
||||
assert start_edges is not None
|
||||
assert start_edges[i].target_node_id == f"llm{i+1}"
|
||||
|
||||
llm_edges = graph.edge_mapping.get(f"llm{i+1}")
|
||||
assert llm_edges is not None
|
||||
assert llm_edges[0].target_node_id == "answer"
|
||||
|
||||
assert len(graph.parallel_mapping) == 1
|
||||
assert len(graph.node_parallel_mapping) == 3
|
||||
|
Loading…
x
Reference in New Issue
Block a user