chore(workflow): Optimize the iteration when selecting a variable from a branch in the output variable causes iteration index err (#8440)

This commit is contained in:
takatost 2024-09-14 18:02:43 +08:00 committed by GitHub
parent d882348f39
commit 88c9834ef2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 80 additions and 127 deletions

View File

@ -689,24 +689,12 @@ class Graph(BaseModel):
parallel_start_node_ids[graph_edge.source_node_id].append(branch_node_id)
parallel_start_node_id = None
for p_start_node_id, branch_node_ids in parallel_start_node_ids.items():
for _, branch_node_ids in parallel_start_node_ids.items():
if set(branch_node_ids) == set(routes_node_ids.keys()):
parallel_start_node_id = p_start_node_id
return True
if not parallel_start_node_id:
raise Exception("Parallel start node id not found")
for graph_edge in reverse_edge_mapping[start_node_id]:
if (
graph_edge.source_node_id not in all_routes_node_ids
or graph_edge.source_node_id != parallel_start_node_id
):
return False
return True
@classmethod
def _is_node2_after_node1(cls, node1_id: str, node2_id: str, edge_mapping: dict[str, list[GraphEdge]]) -> bool:
"""

View File

@ -20,11 +20,9 @@ from core.workflow.graph_engine.entities.event import (
NodeRunSucceededEvent,
)
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.run_condition import RunCondition
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.event import RunCompletedEvent, RunEvent
from core.workflow.nodes.iteration.entities import IterationNodeData
from core.workflow.utils.condition.entities import Condition
from models.workflow import WorkflowNodeExecutionStatus
logger = logging.getLogger(__name__)
@ -68,38 +66,6 @@ class IterationNode(BaseNode):
if not iteration_graph:
raise ValueError("iteration graph not found")
leaf_node_ids = iteration_graph.get_leaf_node_ids()
iteration_leaf_node_ids = []
for leaf_node_id in leaf_node_ids:
node_config = iteration_graph.node_id_config_mapping.get(leaf_node_id)
if not node_config:
continue
leaf_node_iteration_id = node_config.get("data", {}).get("iteration_id")
if not leaf_node_iteration_id:
continue
if leaf_node_iteration_id != self.node_id:
continue
iteration_leaf_node_ids.append(leaf_node_id)
# add condition of end nodes to root node
iteration_graph.add_extra_edge(
source_node_id=leaf_node_id,
target_node_id=root_node_id,
run_condition=RunCondition(
type="condition",
conditions=[
Condition(
variable_selector=[self.node_id, "index"],
comparison_operator="<",
value=str(len(iterator_list_value)),
)
],
),
)
variable_pool = self.graph_runtime_state.variable_pool
# append iteration variable (item, index) to variable pool
@ -149,6 +115,7 @@ class IterationNode(BaseNode):
outputs: list[Any] = []
try:
for _ in range(len(iterator_list_value)):
# run workflow
rst = graph_engine.run()
for event in rst:
@ -176,9 +143,33 @@ class IterationNode(BaseNode):
event.route_node_state.node_run_result.metadata = metadata
yield event
elif isinstance(event, BaseGraphEvent):
if isinstance(event, GraphRunFailedEvent):
# iteration run failed
yield IterationRunFailedEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
start_at=start_at,
inputs=inputs,
outputs={"output": jsonable_encoder(outputs)},
steps=len(iterator_list_value),
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
error=event.error,
)
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=event.error,
)
)
return
else:
event = cast(InNodeEvent, event)
yield event
# handle iteration run result
if event.route_node_state.node_id in iteration_leaf_node_ids:
# append to iteration output variable list
current_iteration_output = variable_pool.get_any(self.node_data.output_selector)
outputs.append(current_iteration_output)
@ -208,32 +199,6 @@ class IterationNode(BaseNode):
if current_iteration_output
else None,
)
elif isinstance(event, BaseGraphEvent):
if isinstance(event, GraphRunFailedEvent):
# iteration run failed
yield IterationRunFailedEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
start_at=start_at,
inputs=inputs,
outputs={"output": jsonable_encoder(outputs)},
steps=len(iterator_list_value),
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
error=event.error,
)
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=event.error,
)
)
break
else:
event = cast(InNodeEvent, event)
yield event
yield IterationRunSucceededEvent(
iteration_id=self.id,