mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-15 15:36:00 +08:00
fix(workflow): parallel execution after if-else that only one branch runs
This commit is contained in:
parent
cd52633b0e
commit
b0a81c654b
@ -19,14 +19,13 @@ class RunConditionHandler(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def check(self,
|
def check(self,
|
||||||
graph_runtime_state: GraphRuntimeState,
|
graph_runtime_state: GraphRuntimeState,
|
||||||
previous_route_node_state: RouteNodeState,
|
previous_route_node_state: RouteNodeState
|
||||||
target_node_id: str) -> bool:
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if the condition can be executed
|
Check if the condition can be executed
|
||||||
|
|
||||||
:param graph_runtime_state: graph runtime state
|
:param graph_runtime_state: graph runtime state
|
||||||
:param previous_route_node_state: previous route node state
|
:param previous_route_node_state: previous route node state
|
||||||
:param target_node_id: target node id
|
|
||||||
:return: bool
|
:return: bool
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
@ -7,14 +7,12 @@ class BranchIdentifyRunConditionHandler(RunConditionHandler):
|
|||||||
|
|
||||||
def check(self,
|
def check(self,
|
||||||
graph_runtime_state: GraphRuntimeState,
|
graph_runtime_state: GraphRuntimeState,
|
||||||
previous_route_node_state: RouteNodeState,
|
previous_route_node_state: RouteNodeState) -> bool:
|
||||||
target_node_id: str) -> bool:
|
|
||||||
"""
|
"""
|
||||||
Check if the condition can be executed
|
Check if the condition can be executed
|
||||||
|
|
||||||
:param graph_runtime_state: graph runtime state
|
:param graph_runtime_state: graph runtime state
|
||||||
:param previous_route_node_state: previous route node state
|
:param previous_route_node_state: previous route node state
|
||||||
:param target_node_id: target node id
|
|
||||||
:return: bool
|
:return: bool
|
||||||
"""
|
"""
|
||||||
if not self.condition.branch_identify:
|
if not self.condition.branch_identify:
|
||||||
|
@ -7,14 +7,13 @@ from core.workflow.utils.condition.processor import ConditionProcessor
|
|||||||
class ConditionRunConditionHandlerHandler(RunConditionHandler):
|
class ConditionRunConditionHandlerHandler(RunConditionHandler):
|
||||||
def check(self,
|
def check(self,
|
||||||
graph_runtime_state: GraphRuntimeState,
|
graph_runtime_state: GraphRuntimeState,
|
||||||
previous_route_node_state: RouteNodeState,
|
previous_route_node_state: RouteNodeState
|
||||||
target_node_id: str) -> bool:
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if the condition can be executed
|
Check if the condition can be executed
|
||||||
|
|
||||||
:param graph_runtime_state: graph runtime state
|
:param graph_runtime_state: graph runtime state
|
||||||
:param previous_route_node_state: previous route node state
|
:param previous_route_node_state: previous route node state
|
||||||
:param target_node_id: target node id
|
|
||||||
:return: bool
|
:return: bool
|
||||||
"""
|
"""
|
||||||
if not self.condition.conditions:
|
if not self.condition.conditions:
|
||||||
|
@ -314,8 +314,22 @@ class Graph(BaseModel):
|
|||||||
target_node_edges = edge_mapping.get(start_node_id, [])
|
target_node_edges = edge_mapping.get(start_node_id, [])
|
||||||
if len(target_node_edges) > 1:
|
if len(target_node_edges) > 1:
|
||||||
# fetch all node ids in current parallels
|
# fetch all node ids in current parallels
|
||||||
parallel_node_ids = [graph_edge.target_node_id
|
parallel_node_ids = []
|
||||||
for graph_edge in target_node_edges if graph_edge.run_condition is None]
|
condition_edge_mappings = {}
|
||||||
|
for graph_edge in target_node_edges:
|
||||||
|
if graph_edge.run_condition is None:
|
||||||
|
parallel_node_ids.append(graph_edge.target_node_id)
|
||||||
|
else:
|
||||||
|
condition_hash = graph_edge.run_condition.hash
|
||||||
|
if not condition_hash in condition_edge_mappings:
|
||||||
|
condition_edge_mappings[condition_hash] = []
|
||||||
|
|
||||||
|
condition_edge_mappings[condition_hash].append(graph_edge)
|
||||||
|
|
||||||
|
for _, graph_edges in condition_edge_mappings.items():
|
||||||
|
if len(graph_edges) > 1:
|
||||||
|
for graph_edge in graph_edges:
|
||||||
|
parallel_node_ids.append(graph_edge.target_node_id)
|
||||||
|
|
||||||
# any target node id in node_parallel_mapping
|
# any target node id in node_parallel_mapping
|
||||||
if parallel_node_ids:
|
if parallel_node_ids:
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import hashlib
|
||||||
from typing import Literal, Optional
|
from typing import Literal, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
@ -14,3 +15,7 @@ class RunCondition(BaseModel):
|
|||||||
|
|
||||||
conditions: Optional[list[Condition]] = None
|
conditions: Optional[list[Condition]] = None
|
||||||
"""conditions to run the node, required when type is condition"""
|
"""conditions to run the node, required when type is condition"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def hash(self) -> str:
|
||||||
|
return hashlib.sha256(self.model_dump_json().encode()).hexdigest()
|
@ -32,7 +32,7 @@ from core.workflow.graph_engine.entities.event import (
|
|||||||
ParallelBranchRunStartedEvent,
|
ParallelBranchRunStartedEvent,
|
||||||
ParallelBranchRunSucceededEvent,
|
ParallelBranchRunSucceededEvent,
|
||||||
)
|
)
|
||||||
from core.workflow.graph_engine.entities.graph import Graph
|
from core.workflow.graph_engine.entities.graph import Graph, GraphEdge
|
||||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||||
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
|
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
|
||||||
@ -262,8 +262,7 @@ class GraphEngine:
|
|||||||
run_condition=edge.run_condition,
|
run_condition=edge.run_condition,
|
||||||
).check(
|
).check(
|
||||||
graph_runtime_state=self.graph_runtime_state,
|
graph_runtime_state=self.graph_runtime_state,
|
||||||
previous_route_node_state=previous_route_node_state,
|
previous_route_node_state=previous_route_node_state
|
||||||
target_node_id=edge.target_node_id,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if not result:
|
if not result:
|
||||||
@ -274,90 +273,137 @@ class GraphEngine:
|
|||||||
if any(edge.run_condition for edge in edge_mappings):
|
if any(edge.run_condition for edge in edge_mappings):
|
||||||
# if nodes has run conditions, get node id which branch to take based on the run condition results
|
# if nodes has run conditions, get node id which branch to take based on the run condition results
|
||||||
final_node_id = None
|
final_node_id = None
|
||||||
|
|
||||||
|
condition_edge_mappings = {}
|
||||||
for edge in edge_mappings:
|
for edge in edge_mappings:
|
||||||
if edge.run_condition:
|
if edge.run_condition:
|
||||||
result = ConditionManager.get_condition_handler(
|
run_condition_hash = edge.run_condition.hash
|
||||||
init_params=self.init_params,
|
if run_condition_hash not in condition_edge_mappings:
|
||||||
graph=self.graph,
|
condition_edge_mappings[run_condition_hash] = []
|
||||||
run_condition=edge.run_condition,
|
|
||||||
).check(
|
condition_edge_mappings[run_condition_hash].append(edge)
|
||||||
graph_runtime_state=self.graph_runtime_state,
|
|
||||||
previous_route_node_state=previous_route_node_state,
|
for _, sub_edge_mappings in condition_edge_mappings.items():
|
||||||
target_node_id=edge.target_node_id,
|
if len(sub_edge_mappings) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
edge = sub_edge_mappings[0]
|
||||||
|
|
||||||
|
result = ConditionManager.get_condition_handler(
|
||||||
|
init_params=self.init_params,
|
||||||
|
graph=self.graph,
|
||||||
|
run_condition=edge.run_condition,
|
||||||
|
).check(
|
||||||
|
graph_runtime_state=self.graph_runtime_state,
|
||||||
|
previous_route_node_state=previous_route_node_state,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not result:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if len(sub_edge_mappings) == 1:
|
||||||
|
final_node_id = edge.target_node_id
|
||||||
|
else:
|
||||||
|
final_node_id, parallel_generator = self._run_parallel_branches(
|
||||||
|
edge_mappings=sub_edge_mappings,
|
||||||
|
in_parallel_id=in_parallel_id,
|
||||||
|
parallel_start_node_id=parallel_start_node_id
|
||||||
)
|
)
|
||||||
|
|
||||||
if result:
|
yield from parallel_generator
|
||||||
final_node_id = edge.target_node_id
|
|
||||||
break
|
break
|
||||||
|
|
||||||
if not final_node_id:
|
if not final_node_id:
|
||||||
break
|
break
|
||||||
|
|
||||||
next_node_id = final_node_id
|
next_node_id = final_node_id
|
||||||
else:
|
else:
|
||||||
# if nodes has no run conditions, parallel run all nodes
|
next_node_id, parallel_generator = self._run_parallel_branches(
|
||||||
parallel_id = self.graph.node_parallel_mapping.get(edge_mappings[0].target_node_id)
|
edge_mappings=edge_mappings,
|
||||||
if not parallel_id:
|
in_parallel_id=in_parallel_id,
|
||||||
raise GraphRunFailedError(f'Node {edge_mappings[0].target_node_id} related parallel not found.')
|
parallel_start_node_id=parallel_start_node_id
|
||||||
|
)
|
||||||
|
|
||||||
parallel = self.graph.parallel_mapping.get(parallel_id)
|
yield from parallel_generator
|
||||||
if not parallel:
|
|
||||||
raise GraphRunFailedError(f'Parallel {parallel_id} not found.')
|
|
||||||
|
|
||||||
# run parallel nodes, run in new thread and use queue to get results
|
if not next_node_id:
|
||||||
q: queue.Queue = queue.Queue()
|
|
||||||
|
|
||||||
# Create a list to store the threads
|
|
||||||
threads = []
|
|
||||||
|
|
||||||
# new thread
|
|
||||||
for edge in edge_mappings:
|
|
||||||
thread = threading.Thread(target=self._run_parallel_node, kwargs={
|
|
||||||
'flask_app': current_app._get_current_object(), # type: ignore[attr-defined]
|
|
||||||
'q': q,
|
|
||||||
'parallel_id': parallel_id,
|
|
||||||
'parallel_start_node_id': edge.target_node_id,
|
|
||||||
'parent_parallel_id': in_parallel_id,
|
|
||||||
'parent_parallel_start_node_id': parallel_start_node_id,
|
|
||||||
})
|
|
||||||
|
|
||||||
threads.append(thread)
|
|
||||||
thread.start()
|
|
||||||
|
|
||||||
succeeded_count = 0
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
event = q.get(timeout=1)
|
|
||||||
if event is None:
|
|
||||||
break
|
|
||||||
|
|
||||||
yield event
|
|
||||||
if event.parallel_id == parallel_id:
|
|
||||||
if isinstance(event, ParallelBranchRunSucceededEvent):
|
|
||||||
succeeded_count += 1
|
|
||||||
if succeeded_count == len(threads):
|
|
||||||
q.put(None)
|
|
||||||
|
|
||||||
continue
|
|
||||||
elif isinstance(event, ParallelBranchRunFailedEvent):
|
|
||||||
raise GraphRunFailedError(event.error)
|
|
||||||
except queue.Empty:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Join all threads
|
|
||||||
for thread in threads:
|
|
||||||
thread.join()
|
|
||||||
|
|
||||||
# get final node id
|
|
||||||
final_node_id = parallel.end_to_node_id
|
|
||||||
if not final_node_id:
|
|
||||||
break
|
break
|
||||||
|
|
||||||
next_node_id = final_node_id
|
|
||||||
|
|
||||||
if in_parallel_id and self.graph.node_parallel_mapping.get(next_node_id, '') != in_parallel_id:
|
if in_parallel_id and self.graph.node_parallel_mapping.get(next_node_id, '') != in_parallel_id:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
def _run_parallel_branches(
|
||||||
|
self,
|
||||||
|
edge_mappings: list[GraphEdge],
|
||||||
|
in_parallel_id: Optional[str] = None,
|
||||||
|
parallel_start_node_id: Optional[str] = None,
|
||||||
|
) -> tuple[Optional[str], Generator[GraphEngineEvent, None, None]]:
|
||||||
|
# if nodes has no run conditions, parallel run all nodes
|
||||||
|
parallel_id = self.graph.node_parallel_mapping.get(edge_mappings[0].target_node_id)
|
||||||
|
if not parallel_id:
|
||||||
|
raise GraphRunFailedError(f'Node {edge_mappings[0].target_node_id} related parallel not found.')
|
||||||
|
|
||||||
|
parallel = self.graph.parallel_mapping.get(parallel_id)
|
||||||
|
if not parallel:
|
||||||
|
raise GraphRunFailedError(f'Parallel {parallel_id} not found.')
|
||||||
|
|
||||||
|
# run parallel nodes, run in new thread and use queue to get results
|
||||||
|
q: queue.Queue = queue.Queue()
|
||||||
|
|
||||||
|
# Create a list to store the threads
|
||||||
|
threads = []
|
||||||
|
|
||||||
|
# new thread
|
||||||
|
for edge in edge_mappings:
|
||||||
|
thread = threading.Thread(target=self._run_parallel_node, kwargs={
|
||||||
|
'flask_app': current_app._get_current_object(), # type: ignore[attr-defined]
|
||||||
|
'q': q,
|
||||||
|
'parallel_id': parallel_id,
|
||||||
|
'parallel_start_node_id': edge.target_node_id,
|
||||||
|
'parent_parallel_id': in_parallel_id,
|
||||||
|
'parent_parallel_start_node_id': parallel_start_node_id,
|
||||||
|
})
|
||||||
|
|
||||||
|
threads.append(thread)
|
||||||
|
thread.start()
|
||||||
|
|
||||||
|
def parallel_generator() -> Generator[GraphEngineEvent, None, None]:
|
||||||
|
succeeded_count = 0
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
event = q.get(timeout=1)
|
||||||
|
if event is None:
|
||||||
|
break
|
||||||
|
|
||||||
|
yield event
|
||||||
|
if event.parallel_id == parallel_id:
|
||||||
|
if isinstance(event, ParallelBranchRunSucceededEvent):
|
||||||
|
succeeded_count += 1
|
||||||
|
if succeeded_count == len(threads):
|
||||||
|
q.put(None)
|
||||||
|
|
||||||
|
continue
|
||||||
|
elif isinstance(event, ParallelBranchRunFailedEvent):
|
||||||
|
raise GraphRunFailedError(event.error)
|
||||||
|
except queue.Empty:
|
||||||
|
continue
|
||||||
|
|
||||||
|
generator = parallel_generator()
|
||||||
|
|
||||||
|
# Join all threads
|
||||||
|
for thread in threads:
|
||||||
|
thread.join()
|
||||||
|
|
||||||
|
# get final node id
|
||||||
|
final_node_id = parallel.end_to_node_id
|
||||||
|
if not final_node_id:
|
||||||
|
return None, generator
|
||||||
|
|
||||||
|
next_node_id = final_node_id
|
||||||
|
|
||||||
|
return final_node_id, generator
|
||||||
|
|
||||||
def _run_parallel_node(
|
def _run_parallel_node(
|
||||||
self,
|
self,
|
||||||
flask_app: Flask,
|
flask_app: Flask,
|
||||||
|
@ -99,9 +99,7 @@ class AppGenerateService:
|
|||||||
return max_active_requests
|
return max_active_requests
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def generate_single_iteration(
|
def generate_single_iteration(cls, app_model: App, user: Account, node_id: str, args: Any, streaming: bool = True):
|
||||||
cls, app_model: App, user: Account, node_id: str, args: Any, streaming: bool = True
|
|
||||||
):
|
|
||||||
if app_model.mode == AppMode.ADVANCED_CHAT.value:
|
if app_model.mode == AppMode.ADVANCED_CHAT.value:
|
||||||
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
|
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
|
||||||
return AdvancedChatAppGenerator().single_iteration_generate(
|
return AdvancedChatAppGenerator().single_iteration_generate(
|
||||||
|
@ -234,7 +234,7 @@ class WorkflowService:
|
|||||||
break
|
break
|
||||||
|
|
||||||
if not node_run_result:
|
if not node_run_result:
|
||||||
raise ValueError('Node run failed with no run result')
|
raise ValueError("Node run failed with no run result")
|
||||||
|
|
||||||
run_succeeded = True if node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED else False
|
run_succeeded = True if node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED else False
|
||||||
error = node_run_result.error if not run_succeeded else None
|
error = node_run_result.error if not run_succeeded else None
|
||||||
@ -262,9 +262,15 @@ class WorkflowService:
|
|||||||
if run_succeeded and node_run_result:
|
if run_succeeded and node_run_result:
|
||||||
# create workflow node execution
|
# create workflow node execution
|
||||||
workflow_node_execution.inputs = json.dumps(node_run_result.inputs) if node_run_result.inputs else None
|
workflow_node_execution.inputs = json.dumps(node_run_result.inputs) if node_run_result.inputs else None
|
||||||
workflow_node_execution.process_data = json.dumps(node_run_result.process_data) if node_run_result.process_data else None
|
workflow_node_execution.process_data = (
|
||||||
workflow_node_execution.outputs = json.dumps(jsonable_encoder(node_run_result.outputs)) if node_run_result.outputs else None
|
json.dumps(node_run_result.process_data) if node_run_result.process_data else None
|
||||||
workflow_node_execution.execution_metadata = json.dumps(jsonable_encoder(node_run_result.metadata)) if node_run_result.metadata else None
|
)
|
||||||
|
workflow_node_execution.outputs = (
|
||||||
|
json.dumps(jsonable_encoder(node_run_result.outputs)) if node_run_result.outputs else None
|
||||||
|
)
|
||||||
|
workflow_node_execution.execution_metadata = (
|
||||||
|
json.dumps(jsonable_encoder(node_run_result.metadata)) if node_run_result.metadata else None
|
||||||
|
)
|
||||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
|
workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
|
||||||
else:
|
else:
|
||||||
# create workflow node execution
|
# create workflow node execution
|
||||||
|
Loading…
x
Reference in New Issue
Block a user