mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-18 02:35:56 +08:00
fix(workflow): parallel not yield
This commit is contained in:
parent
8ba5673606
commit
c2bb11405f
@ -304,13 +304,17 @@ class GraphEngine:
|
||||
if len(sub_edge_mappings) == 1:
|
||||
final_node_id = edge.target_node_id
|
||||
else:
|
||||
final_node_id, parallel_generator = self._run_parallel_branches(
|
||||
parallel_generator = self._run_parallel_branches(
|
||||
edge_mappings=sub_edge_mappings,
|
||||
in_parallel_id=in_parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id
|
||||
)
|
||||
|
||||
yield from parallel_generator
|
||||
for item in parallel_generator:
|
||||
if isinstance(item, str):
|
||||
final_node_id = item
|
||||
else:
|
||||
yield item
|
||||
|
||||
break
|
||||
|
||||
@ -319,13 +323,17 @@ class GraphEngine:
|
||||
|
||||
next_node_id = final_node_id
|
||||
else:
|
||||
next_node_id, parallel_generator = self._run_parallel_branches(
|
||||
parallel_generator = self._run_parallel_branches(
|
||||
edge_mappings=edge_mappings,
|
||||
in_parallel_id=in_parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id
|
||||
)
|
||||
|
||||
yield from parallel_generator
|
||||
for item in parallel_generator:
|
||||
if isinstance(item, str):
|
||||
next_node_id = item
|
||||
else:
|
||||
yield item
|
||||
|
||||
if not next_node_id:
|
||||
break
|
||||
@ -338,7 +346,7 @@ class GraphEngine:
|
||||
edge_mappings: list[GraphEdge],
|
||||
in_parallel_id: Optional[str] = None,
|
||||
parallel_start_node_id: Optional[str] = None,
|
||||
) -> tuple[Optional[str], Generator[GraphEngineEvent, None, None]]:
|
||||
) -> Generator[GraphEngineEvent | str, None, None]:
|
||||
# if nodes has no run conditions, parallel run all nodes
|
||||
parallel_id = self.graph.node_parallel_mapping.get(edge_mappings[0].target_node_id)
|
||||
if not parallel_id:
|
||||
@ -368,41 +376,34 @@ class GraphEngine:
|
||||
threads.append(thread)
|
||||
thread.start()
|
||||
|
||||
def parallel_generator() -> Generator[GraphEngineEvent, None, None]:
|
||||
succeeded_count = 0
|
||||
while True:
|
||||
try:
|
||||
event = q.get(timeout=1)
|
||||
if event is None:
|
||||
break
|
||||
succeeded_count = 0
|
||||
while True:
|
||||
try:
|
||||
event = q.get(timeout=1)
|
||||
if event is None:
|
||||
break
|
||||
|
||||
yield event
|
||||
if event.parallel_id == parallel_id:
|
||||
if isinstance(event, ParallelBranchRunSucceededEvent):
|
||||
succeeded_count += 1
|
||||
if succeeded_count == len(threads):
|
||||
q.put(None)
|
||||
yield event
|
||||
if event.parallel_id == parallel_id:
|
||||
if isinstance(event, ParallelBranchRunSucceededEvent):
|
||||
succeeded_count += 1
|
||||
if succeeded_count == len(threads):
|
||||
q.put(None)
|
||||
|
||||
continue
|
||||
elif isinstance(event, ParallelBranchRunFailedEvent):
|
||||
raise GraphRunFailedError(event.error)
|
||||
except queue.Empty:
|
||||
continue
|
||||
continue
|
||||
elif isinstance(event, ParallelBranchRunFailedEvent):
|
||||
raise GraphRunFailedError(event.error)
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
generator = parallel_generator()
|
||||
|
||||
# Join all threads
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
# get final node id
|
||||
final_node_id = parallel.end_to_node_id
|
||||
if not final_node_id:
|
||||
return None, generator
|
||||
|
||||
next_node_id = final_node_id
|
||||
|
||||
return final_node_id, generator
|
||||
if final_node_id:
|
||||
yield final_node_id
|
||||
|
||||
def _run_parallel_node(
|
||||
self,
|
||||
|
Loading…
x
Reference in New Issue
Block a user