fix(workflow): parallel execution after if-else that only one branch runs

This commit is contained in:
takatost 2024-08-28 15:53:39 +08:00
parent cd52633b0e
commit b0a81c654b
8 changed files with 153 additions and 88 deletions

View File

@ -19,14 +19,13 @@ class RunConditionHandler(ABC):
@abstractmethod
def check(self,
graph_runtime_state: GraphRuntimeState,
previous_route_node_state: RouteNodeState,
target_node_id: str) -> bool:
previous_route_node_state: RouteNodeState
) -> bool:
"""
Check if the condition can be executed
:param graph_runtime_state: graph runtime state
:param previous_route_node_state: previous route node state
:param target_node_id: target node id
:return: bool
"""
raise NotImplementedError

View File

@ -7,14 +7,12 @@ class BranchIdentifyRunConditionHandler(RunConditionHandler):
def check(self,
graph_runtime_state: GraphRuntimeState,
previous_route_node_state: RouteNodeState,
target_node_id: str) -> bool:
previous_route_node_state: RouteNodeState) -> bool:
"""
Check if the condition can be executed
:param graph_runtime_state: graph runtime state
:param previous_route_node_state: previous route node state
:param target_node_id: target node id
:return: bool
"""
if not self.condition.branch_identify:

View File

@ -7,14 +7,13 @@ from core.workflow.utils.condition.processor import ConditionProcessor
class ConditionRunConditionHandlerHandler(RunConditionHandler):
def check(self,
graph_runtime_state: GraphRuntimeState,
previous_route_node_state: RouteNodeState,
target_node_id: str) -> bool:
previous_route_node_state: RouteNodeState
) -> bool:
"""
Check if the condition can be executed
:param graph_runtime_state: graph runtime state
:param previous_route_node_state: previous route node state
:param target_node_id: target node id
:return: bool
"""
if not self.condition.conditions:

View File

@ -314,8 +314,22 @@ class Graph(BaseModel):
target_node_edges = edge_mapping.get(start_node_id, [])
if len(target_node_edges) > 1:
# fetch all node ids in current parallels
parallel_node_ids = [graph_edge.target_node_id
for graph_edge in target_node_edges if graph_edge.run_condition is None]
parallel_node_ids = []
condition_edge_mappings = {}
for graph_edge in target_node_edges:
if graph_edge.run_condition is None:
parallel_node_ids.append(graph_edge.target_node_id)
else:
condition_hash = graph_edge.run_condition.hash
if not condition_hash in condition_edge_mappings:
condition_edge_mappings[condition_hash] = []
condition_edge_mappings[condition_hash].append(graph_edge)
for _, graph_edges in condition_edge_mappings.items():
if len(graph_edges) > 1:
for graph_edge in graph_edges:
parallel_node_ids.append(graph_edge.target_node_id)
# any target node id in node_parallel_mapping
if parallel_node_ids:

View File

@ -1,3 +1,4 @@
import hashlib
from typing import Literal, Optional
from pydantic import BaseModel
@ -14,3 +15,7 @@ class RunCondition(BaseModel):
conditions: Optional[list[Condition]] = None
"""conditions to run the node, required when type is condition"""
@property
def hash(self) -> str:
return hashlib.sha256(self.model_dump_json().encode()).hexdigest()

View File

@ -32,7 +32,7 @@ from core.workflow.graph_engine.entities.event import (
ParallelBranchRunStartedEvent,
ParallelBranchRunSucceededEvent,
)
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph import Graph, GraphEdge
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
@ -262,8 +262,7 @@ class GraphEngine:
run_condition=edge.run_condition,
).check(
graph_runtime_state=self.graph_runtime_state,
previous_route_node_state=previous_route_node_state,
target_node_id=edge.target_node_id,
previous_route_node_state=previous_route_node_state
)
if not result:
@ -274,90 +273,137 @@ class GraphEngine:
if any(edge.run_condition for edge in edge_mappings):
# if nodes has run conditions, get node id which branch to take based on the run condition results
final_node_id = None
condition_edge_mappings = {}
for edge in edge_mappings:
if edge.run_condition:
result = ConditionManager.get_condition_handler(
init_params=self.init_params,
graph=self.graph,
run_condition=edge.run_condition,
).check(
graph_runtime_state=self.graph_runtime_state,
previous_route_node_state=previous_route_node_state,
target_node_id=edge.target_node_id,
run_condition_hash = edge.run_condition.hash
if run_condition_hash not in condition_edge_mappings:
condition_edge_mappings[run_condition_hash] = []
condition_edge_mappings[run_condition_hash].append(edge)
for _, sub_edge_mappings in condition_edge_mappings.items():
if len(sub_edge_mappings) == 0:
continue
edge = sub_edge_mappings[0]
result = ConditionManager.get_condition_handler(
init_params=self.init_params,
graph=self.graph,
run_condition=edge.run_condition,
).check(
graph_runtime_state=self.graph_runtime_state,
previous_route_node_state=previous_route_node_state,
)
if not result:
continue
if len(sub_edge_mappings) == 1:
final_node_id = edge.target_node_id
else:
final_node_id, parallel_generator = self._run_parallel_branches(
edge_mappings=sub_edge_mappings,
in_parallel_id=in_parallel_id,
parallel_start_node_id=parallel_start_node_id
)
if result:
final_node_id = edge.target_node_id
break
yield from parallel_generator
break
if not final_node_id:
break
next_node_id = final_node_id
else:
# if nodes has no run conditions, parallel run all nodes
parallel_id = self.graph.node_parallel_mapping.get(edge_mappings[0].target_node_id)
if not parallel_id:
raise GraphRunFailedError(f'Node {edge_mappings[0].target_node_id} related parallel not found.')
next_node_id, parallel_generator = self._run_parallel_branches(
edge_mappings=edge_mappings,
in_parallel_id=in_parallel_id,
parallel_start_node_id=parallel_start_node_id
)
parallel = self.graph.parallel_mapping.get(parallel_id)
if not parallel:
raise GraphRunFailedError(f'Parallel {parallel_id} not found.')
yield from parallel_generator
# run parallel nodes, run in new thread and use queue to get results
q: queue.Queue = queue.Queue()
# Create a list to store the threads
threads = []
# new thread
for edge in edge_mappings:
thread = threading.Thread(target=self._run_parallel_node, kwargs={
'flask_app': current_app._get_current_object(), # type: ignore[attr-defined]
'q': q,
'parallel_id': parallel_id,
'parallel_start_node_id': edge.target_node_id,
'parent_parallel_id': in_parallel_id,
'parent_parallel_start_node_id': parallel_start_node_id,
})
threads.append(thread)
thread.start()
succeeded_count = 0
while True:
try:
event = q.get(timeout=1)
if event is None:
break
yield event
if event.parallel_id == parallel_id:
if isinstance(event, ParallelBranchRunSucceededEvent):
succeeded_count += 1
if succeeded_count == len(threads):
q.put(None)
continue
elif isinstance(event, ParallelBranchRunFailedEvent):
raise GraphRunFailedError(event.error)
except queue.Empty:
continue
# Join all threads
for thread in threads:
thread.join()
# get final node id
final_node_id = parallel.end_to_node_id
if not final_node_id:
if not next_node_id:
break
next_node_id = final_node_id
if in_parallel_id and self.graph.node_parallel_mapping.get(next_node_id, '') != in_parallel_id:
break
def _run_parallel_branches(
self,
edge_mappings: list[GraphEdge],
in_parallel_id: Optional[str] = None,
parallel_start_node_id: Optional[str] = None,
) -> tuple[Optional[str], Generator[GraphEngineEvent, None, None]]:
# if nodes has no run conditions, parallel run all nodes
parallel_id = self.graph.node_parallel_mapping.get(edge_mappings[0].target_node_id)
if not parallel_id:
raise GraphRunFailedError(f'Node {edge_mappings[0].target_node_id} related parallel not found.')
parallel = self.graph.parallel_mapping.get(parallel_id)
if not parallel:
raise GraphRunFailedError(f'Parallel {parallel_id} not found.')
# run parallel nodes, run in new thread and use queue to get results
q: queue.Queue = queue.Queue()
# Create a list to store the threads
threads = []
# new thread
for edge in edge_mappings:
thread = threading.Thread(target=self._run_parallel_node, kwargs={
'flask_app': current_app._get_current_object(), # type: ignore[attr-defined]
'q': q,
'parallel_id': parallel_id,
'parallel_start_node_id': edge.target_node_id,
'parent_parallel_id': in_parallel_id,
'parent_parallel_start_node_id': parallel_start_node_id,
})
threads.append(thread)
thread.start()
def parallel_generator() -> Generator[GraphEngineEvent, None, None]:
succeeded_count = 0
while True:
try:
event = q.get(timeout=1)
if event is None:
break
yield event
if event.parallel_id == parallel_id:
if isinstance(event, ParallelBranchRunSucceededEvent):
succeeded_count += 1
if succeeded_count == len(threads):
q.put(None)
continue
elif isinstance(event, ParallelBranchRunFailedEvent):
raise GraphRunFailedError(event.error)
except queue.Empty:
continue
generator = parallel_generator()
# Join all threads
for thread in threads:
thread.join()
# get final node id
final_node_id = parallel.end_to_node_id
if not final_node_id:
return None, generator
next_node_id = final_node_id
return final_node_id, generator
def _run_parallel_node(
self,
flask_app: Flask,

View File

@ -99,9 +99,7 @@ class AppGenerateService:
return max_active_requests
@classmethod
def generate_single_iteration(
cls, app_model: App, user: Account, node_id: str, args: Any, streaming: bool = True
):
def generate_single_iteration(cls, app_model: App, user: Account, node_id: str, args: Any, streaming: bool = True):
if app_model.mode == AppMode.ADVANCED_CHAT.value:
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
return AdvancedChatAppGenerator().single_iteration_generate(

View File

@ -234,7 +234,7 @@ class WorkflowService:
break
if not node_run_result:
raise ValueError('Node run failed with no run result')
raise ValueError("Node run failed with no run result")
run_succeeded = True if node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED else False
error = node_run_result.error if not run_succeeded else None
@ -262,9 +262,15 @@ class WorkflowService:
if run_succeeded and node_run_result:
# create workflow node execution
workflow_node_execution.inputs = json.dumps(node_run_result.inputs) if node_run_result.inputs else None
workflow_node_execution.process_data = json.dumps(node_run_result.process_data) if node_run_result.process_data else None
workflow_node_execution.outputs = json.dumps(jsonable_encoder(node_run_result.outputs)) if node_run_result.outputs else None
workflow_node_execution.execution_metadata = json.dumps(jsonable_encoder(node_run_result.metadata)) if node_run_result.metadata else None
workflow_node_execution.process_data = (
json.dumps(node_run_result.process_data) if node_run_result.process_data else None
)
workflow_node_execution.outputs = (
json.dumps(jsonable_encoder(node_run_result.outputs)) if node_run_result.outputs else None
)
workflow_node_execution.execution_metadata = (
json.dumps(jsonable_encoder(node_run_result.metadata)) if node_run_result.metadata else None
)
workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
else:
# create workflow node execution