mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-13 23:35:53 +08:00
fix(workflow): parallel execution after if-else that only one branch runs
This commit is contained in:
parent
cd52633b0e
commit
b0a81c654b
@ -19,14 +19,13 @@ class RunConditionHandler(ABC):
|
||||
@abstractmethod
|
||||
def check(self,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
previous_route_node_state: RouteNodeState,
|
||||
target_node_id: str) -> bool:
|
||||
previous_route_node_state: RouteNodeState
|
||||
) -> bool:
|
||||
"""
|
||||
Check if the condition can be executed
|
||||
|
||||
:param graph_runtime_state: graph runtime state
|
||||
:param previous_route_node_state: previous route node state
|
||||
:param target_node_id: target node id
|
||||
:return: bool
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
@ -7,14 +7,12 @@ class BranchIdentifyRunConditionHandler(RunConditionHandler):
|
||||
|
||||
def check(self,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
previous_route_node_state: RouteNodeState,
|
||||
target_node_id: str) -> bool:
|
||||
previous_route_node_state: RouteNodeState) -> bool:
|
||||
"""
|
||||
Check if the condition can be executed
|
||||
|
||||
:param graph_runtime_state: graph runtime state
|
||||
:param previous_route_node_state: previous route node state
|
||||
:param target_node_id: target node id
|
||||
:return: bool
|
||||
"""
|
||||
if not self.condition.branch_identify:
|
||||
|
@ -7,14 +7,13 @@ from core.workflow.utils.condition.processor import ConditionProcessor
|
||||
class ConditionRunConditionHandlerHandler(RunConditionHandler):
|
||||
def check(self,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
previous_route_node_state: RouteNodeState,
|
||||
target_node_id: str) -> bool:
|
||||
previous_route_node_state: RouteNodeState
|
||||
) -> bool:
|
||||
"""
|
||||
Check if the condition can be executed
|
||||
|
||||
:param graph_runtime_state: graph runtime state
|
||||
:param previous_route_node_state: previous route node state
|
||||
:param target_node_id: target node id
|
||||
:return: bool
|
||||
"""
|
||||
if not self.condition.conditions:
|
||||
|
@ -314,8 +314,22 @@ class Graph(BaseModel):
|
||||
target_node_edges = edge_mapping.get(start_node_id, [])
|
||||
if len(target_node_edges) > 1:
|
||||
# fetch all node ids in current parallels
|
||||
parallel_node_ids = [graph_edge.target_node_id
|
||||
for graph_edge in target_node_edges if graph_edge.run_condition is None]
|
||||
parallel_node_ids = []
|
||||
condition_edge_mappings = {}
|
||||
for graph_edge in target_node_edges:
|
||||
if graph_edge.run_condition is None:
|
||||
parallel_node_ids.append(graph_edge.target_node_id)
|
||||
else:
|
||||
condition_hash = graph_edge.run_condition.hash
|
||||
if not condition_hash in condition_edge_mappings:
|
||||
condition_edge_mappings[condition_hash] = []
|
||||
|
||||
condition_edge_mappings[condition_hash].append(graph_edge)
|
||||
|
||||
for _, graph_edges in condition_edge_mappings.items():
|
||||
if len(graph_edges) > 1:
|
||||
for graph_edge in graph_edges:
|
||||
parallel_node_ids.append(graph_edge.target_node_id)
|
||||
|
||||
# any target node id in node_parallel_mapping
|
||||
if parallel_node_ids:
|
||||
|
@ -1,3 +1,4 @@
|
||||
import hashlib
|
||||
from typing import Literal, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
@ -14,3 +15,7 @@ class RunCondition(BaseModel):
|
||||
|
||||
conditions: Optional[list[Condition]] = None
|
||||
"""conditions to run the node, required when type is condition"""
|
||||
|
||||
@property
|
||||
def hash(self) -> str:
|
||||
return hashlib.sha256(self.model_dump_json().encode()).hexdigest()
|
@ -32,7 +32,7 @@ from core.workflow.graph_engine.entities.event import (
|
||||
ParallelBranchRunStartedEvent,
|
||||
ParallelBranchRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph import Graph, GraphEdge
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
|
||||
@ -262,8 +262,7 @@ class GraphEngine:
|
||||
run_condition=edge.run_condition,
|
||||
).check(
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
previous_route_node_state=previous_route_node_state,
|
||||
target_node_id=edge.target_node_id,
|
||||
previous_route_node_state=previous_route_node_state
|
||||
)
|
||||
|
||||
if not result:
|
||||
@ -274,90 +273,137 @@ class GraphEngine:
|
||||
if any(edge.run_condition for edge in edge_mappings):
|
||||
# if nodes has run conditions, get node id which branch to take based on the run condition results
|
||||
final_node_id = None
|
||||
|
||||
condition_edge_mappings = {}
|
||||
for edge in edge_mappings:
|
||||
if edge.run_condition:
|
||||
result = ConditionManager.get_condition_handler(
|
||||
init_params=self.init_params,
|
||||
graph=self.graph,
|
||||
run_condition=edge.run_condition,
|
||||
).check(
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
previous_route_node_state=previous_route_node_state,
|
||||
target_node_id=edge.target_node_id,
|
||||
run_condition_hash = edge.run_condition.hash
|
||||
if run_condition_hash not in condition_edge_mappings:
|
||||
condition_edge_mappings[run_condition_hash] = []
|
||||
|
||||
condition_edge_mappings[run_condition_hash].append(edge)
|
||||
|
||||
for _, sub_edge_mappings in condition_edge_mappings.items():
|
||||
if len(sub_edge_mappings) == 0:
|
||||
continue
|
||||
|
||||
edge = sub_edge_mappings[0]
|
||||
|
||||
result = ConditionManager.get_condition_handler(
|
||||
init_params=self.init_params,
|
||||
graph=self.graph,
|
||||
run_condition=edge.run_condition,
|
||||
).check(
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
previous_route_node_state=previous_route_node_state,
|
||||
)
|
||||
|
||||
if not result:
|
||||
continue
|
||||
|
||||
if len(sub_edge_mappings) == 1:
|
||||
final_node_id = edge.target_node_id
|
||||
else:
|
||||
final_node_id, parallel_generator = self._run_parallel_branches(
|
||||
edge_mappings=sub_edge_mappings,
|
||||
in_parallel_id=in_parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id
|
||||
)
|
||||
|
||||
if result:
|
||||
final_node_id = edge.target_node_id
|
||||
break
|
||||
yield from parallel_generator
|
||||
|
||||
break
|
||||
|
||||
if not final_node_id:
|
||||
break
|
||||
|
||||
next_node_id = final_node_id
|
||||
else:
|
||||
# if nodes has no run conditions, parallel run all nodes
|
||||
parallel_id = self.graph.node_parallel_mapping.get(edge_mappings[0].target_node_id)
|
||||
if not parallel_id:
|
||||
raise GraphRunFailedError(f'Node {edge_mappings[0].target_node_id} related parallel not found.')
|
||||
next_node_id, parallel_generator = self._run_parallel_branches(
|
||||
edge_mappings=edge_mappings,
|
||||
in_parallel_id=in_parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id
|
||||
)
|
||||
|
||||
parallel = self.graph.parallel_mapping.get(parallel_id)
|
||||
if not parallel:
|
||||
raise GraphRunFailedError(f'Parallel {parallel_id} not found.')
|
||||
yield from parallel_generator
|
||||
|
||||
# run parallel nodes, run in new thread and use queue to get results
|
||||
q: queue.Queue = queue.Queue()
|
||||
|
||||
# Create a list to store the threads
|
||||
threads = []
|
||||
|
||||
# new thread
|
||||
for edge in edge_mappings:
|
||||
thread = threading.Thread(target=self._run_parallel_node, kwargs={
|
||||
'flask_app': current_app._get_current_object(), # type: ignore[attr-defined]
|
||||
'q': q,
|
||||
'parallel_id': parallel_id,
|
||||
'parallel_start_node_id': edge.target_node_id,
|
||||
'parent_parallel_id': in_parallel_id,
|
||||
'parent_parallel_start_node_id': parallel_start_node_id,
|
||||
})
|
||||
|
||||
threads.append(thread)
|
||||
thread.start()
|
||||
|
||||
succeeded_count = 0
|
||||
while True:
|
||||
try:
|
||||
event = q.get(timeout=1)
|
||||
if event is None:
|
||||
break
|
||||
|
||||
yield event
|
||||
if event.parallel_id == parallel_id:
|
||||
if isinstance(event, ParallelBranchRunSucceededEvent):
|
||||
succeeded_count += 1
|
||||
if succeeded_count == len(threads):
|
||||
q.put(None)
|
||||
|
||||
continue
|
||||
elif isinstance(event, ParallelBranchRunFailedEvent):
|
||||
raise GraphRunFailedError(event.error)
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
# Join all threads
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
# get final node id
|
||||
final_node_id = parallel.end_to_node_id
|
||||
if not final_node_id:
|
||||
if not next_node_id:
|
||||
break
|
||||
|
||||
next_node_id = final_node_id
|
||||
|
||||
if in_parallel_id and self.graph.node_parallel_mapping.get(next_node_id, '') != in_parallel_id:
|
||||
break
|
||||
|
||||
def _run_parallel_branches(
|
||||
self,
|
||||
edge_mappings: list[GraphEdge],
|
||||
in_parallel_id: Optional[str] = None,
|
||||
parallel_start_node_id: Optional[str] = None,
|
||||
) -> tuple[Optional[str], Generator[GraphEngineEvent, None, None]]:
|
||||
# if nodes has no run conditions, parallel run all nodes
|
||||
parallel_id = self.graph.node_parallel_mapping.get(edge_mappings[0].target_node_id)
|
||||
if not parallel_id:
|
||||
raise GraphRunFailedError(f'Node {edge_mappings[0].target_node_id} related parallel not found.')
|
||||
|
||||
parallel = self.graph.parallel_mapping.get(parallel_id)
|
||||
if not parallel:
|
||||
raise GraphRunFailedError(f'Parallel {parallel_id} not found.')
|
||||
|
||||
# run parallel nodes, run in new thread and use queue to get results
|
||||
q: queue.Queue = queue.Queue()
|
||||
|
||||
# Create a list to store the threads
|
||||
threads = []
|
||||
|
||||
# new thread
|
||||
for edge in edge_mappings:
|
||||
thread = threading.Thread(target=self._run_parallel_node, kwargs={
|
||||
'flask_app': current_app._get_current_object(), # type: ignore[attr-defined]
|
||||
'q': q,
|
||||
'parallel_id': parallel_id,
|
||||
'parallel_start_node_id': edge.target_node_id,
|
||||
'parent_parallel_id': in_parallel_id,
|
||||
'parent_parallel_start_node_id': parallel_start_node_id,
|
||||
})
|
||||
|
||||
threads.append(thread)
|
||||
thread.start()
|
||||
|
||||
def parallel_generator() -> Generator[GraphEngineEvent, None, None]:
|
||||
succeeded_count = 0
|
||||
while True:
|
||||
try:
|
||||
event = q.get(timeout=1)
|
||||
if event is None:
|
||||
break
|
||||
|
||||
yield event
|
||||
if event.parallel_id == parallel_id:
|
||||
if isinstance(event, ParallelBranchRunSucceededEvent):
|
||||
succeeded_count += 1
|
||||
if succeeded_count == len(threads):
|
||||
q.put(None)
|
||||
|
||||
continue
|
||||
elif isinstance(event, ParallelBranchRunFailedEvent):
|
||||
raise GraphRunFailedError(event.error)
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
generator = parallel_generator()
|
||||
|
||||
# Join all threads
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
# get final node id
|
||||
final_node_id = parallel.end_to_node_id
|
||||
if not final_node_id:
|
||||
return None, generator
|
||||
|
||||
next_node_id = final_node_id
|
||||
|
||||
return final_node_id, generator
|
||||
|
||||
def _run_parallel_node(
|
||||
self,
|
||||
flask_app: Flask,
|
||||
|
@ -99,9 +99,7 @@ class AppGenerateService:
|
||||
return max_active_requests
|
||||
|
||||
@classmethod
|
||||
def generate_single_iteration(
|
||||
cls, app_model: App, user: Account, node_id: str, args: Any, streaming: bool = True
|
||||
):
|
||||
def generate_single_iteration(cls, app_model: App, user: Account, node_id: str, args: Any, streaming: bool = True):
|
||||
if app_model.mode == AppMode.ADVANCED_CHAT.value:
|
||||
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
|
||||
return AdvancedChatAppGenerator().single_iteration_generate(
|
||||
|
@ -234,7 +234,7 @@ class WorkflowService:
|
||||
break
|
||||
|
||||
if not node_run_result:
|
||||
raise ValueError('Node run failed with no run result')
|
||||
raise ValueError("Node run failed with no run result")
|
||||
|
||||
run_succeeded = True if node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED else False
|
||||
error = node_run_result.error if not run_succeeded else None
|
||||
@ -262,9 +262,15 @@ class WorkflowService:
|
||||
if run_succeeded and node_run_result:
|
||||
# create workflow node execution
|
||||
workflow_node_execution.inputs = json.dumps(node_run_result.inputs) if node_run_result.inputs else None
|
||||
workflow_node_execution.process_data = json.dumps(node_run_result.process_data) if node_run_result.process_data else None
|
||||
workflow_node_execution.outputs = json.dumps(jsonable_encoder(node_run_result.outputs)) if node_run_result.outputs else None
|
||||
workflow_node_execution.execution_metadata = json.dumps(jsonable_encoder(node_run_result.metadata)) if node_run_result.metadata else None
|
||||
workflow_node_execution.process_data = (
|
||||
json.dumps(node_run_result.process_data) if node_run_result.process_data else None
|
||||
)
|
||||
workflow_node_execution.outputs = (
|
||||
json.dumps(jsonable_encoder(node_run_result.outputs)) if node_run_result.outputs else None
|
||||
)
|
||||
workflow_node_execution.execution_metadata = (
|
||||
json.dumps(jsonable_encoder(node_run_result.metadata)) if node_run_result.metadata else None
|
||||
)
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
|
||||
else:
|
||||
# create workflow node execution
|
||||
|
Loading…
x
Reference in New Issue
Block a user