mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-14 22:16:02 +08:00
Feat/loop node (#17273)
This commit is contained in:
parent
11e95d2a61
commit
8c77f2dc03
@ -30,6 +30,7 @@ class NodeRunMetadataKey(StrEnum):
|
|||||||
ITERATION_DURATION_MAP = "iteration_duration_map" # single iteration duration if iteration node runs
|
ITERATION_DURATION_MAP = "iteration_duration_map" # single iteration duration if iteration node runs
|
||||||
LOOP_DURATION_MAP = "loop_duration_map" # single loop duration if loop node runs
|
LOOP_DURATION_MAP = "loop_duration_map" # single loop duration if loop node runs
|
||||||
ERROR_STRATEGY = "error_strategy" # node in continue on error mode return the field
|
ERROR_STRATEGY = "error_strategy" # node in continue on error mode return the field
|
||||||
|
LOOP_VARIABLE_MAP = "loop_variable_map" # single loop variable output
|
||||||
|
|
||||||
|
|
||||||
class NodeRunResult(BaseModel):
|
class NodeRunResult(BaseModel):
|
||||||
|
@ -17,6 +17,7 @@ class NodeType(StrEnum):
|
|||||||
LEGACY_VARIABLE_AGGREGATOR = "variable-assigner" # TODO: Merge this into VARIABLE_AGGREGATOR in the database.
|
LEGACY_VARIABLE_AGGREGATOR = "variable-assigner" # TODO: Merge this into VARIABLE_AGGREGATOR in the database.
|
||||||
LOOP = "loop"
|
LOOP = "loop"
|
||||||
LOOP_START = "loop-start"
|
LOOP_START = "loop-start"
|
||||||
|
LOOP_END = "loop-end"
|
||||||
ITERATION = "iteration"
|
ITERATION = "iteration"
|
||||||
ITERATION_START = "iteration-start" # Fake start node for iteration.
|
ITERATION_START = "iteration-start" # Fake start node for iteration.
|
||||||
PARAMETER_EXTRACTOR = "parameter-extractor"
|
PARAMETER_EXTRACTOR = "parameter-extractor"
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
from .entities import LoopNodeData
|
from .entities import LoopNodeData
|
||||||
|
from .loop_end_node import LoopEndNode
|
||||||
from .loop_node import LoopNode
|
from .loop_node import LoopNode
|
||||||
from .loop_start_node import LoopStartNode
|
from .loop_start_node import LoopStartNode
|
||||||
|
|
||||||
__all__ = ["LoopNode", "LoopNodeData", "LoopStartNode"]
|
__all__ = ["LoopEndNode", "LoopNode", "LoopNodeData", "LoopStartNode"]
|
||||||
|
@ -1,11 +1,23 @@
|
|||||||
|
from collections.abc import Mapping
|
||||||
from typing import Any, Literal, Optional
|
from typing import Any, Literal, Optional
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from core.workflow.nodes.base import BaseLoopNodeData, BaseLoopState, BaseNodeData
|
from core.workflow.nodes.base import BaseLoopNodeData, BaseLoopState, BaseNodeData
|
||||||
from core.workflow.utils.condition.entities import Condition
|
from core.workflow.utils.condition.entities import Condition
|
||||||
|
|
||||||
|
|
||||||
|
class LoopVariableData(BaseModel):
|
||||||
|
"""
|
||||||
|
Loop Variable Data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
label: str
|
||||||
|
var_type: Literal["string", "number", "object", "array[string]", "array[number]", "array[object]"]
|
||||||
|
value_type: Literal["variable", "constant"]
|
||||||
|
value: Optional[Any | list[str]] = None
|
||||||
|
|
||||||
|
|
||||||
class LoopNodeData(BaseLoopNodeData):
|
class LoopNodeData(BaseLoopNodeData):
|
||||||
"""
|
"""
|
||||||
Loop Node Data.
|
Loop Node Data.
|
||||||
@ -14,6 +26,8 @@ class LoopNodeData(BaseLoopNodeData):
|
|||||||
loop_count: int # Maximum number of loops
|
loop_count: int # Maximum number of loops
|
||||||
break_conditions: list[Condition] # Conditions to break the loop
|
break_conditions: list[Condition] # Conditions to break the loop
|
||||||
logical_operator: Literal["and", "or"]
|
logical_operator: Literal["and", "or"]
|
||||||
|
loop_variables: Optional[list[LoopVariableData]] = Field(default_factory=list)
|
||||||
|
outputs: Optional[Mapping[str, Any]] = None
|
||||||
|
|
||||||
|
|
||||||
class LoopStartNodeData(BaseNodeData):
|
class LoopStartNodeData(BaseNodeData):
|
||||||
@ -24,6 +38,14 @@ class LoopStartNodeData(BaseNodeData):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class LoopEndNodeData(BaseNodeData):
|
||||||
|
"""
|
||||||
|
Loop End Node Data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class LoopState(BaseLoopState):
|
class LoopState(BaseLoopState):
|
||||||
"""
|
"""
|
||||||
Loop State.
|
Loop State.
|
||||||
|
20
api/core/workflow/nodes/loop/loop_end_node.py
Normal file
20
api/core/workflow/nodes/loop/loop_end_node.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
from core.workflow.entities.node_entities import NodeRunResult
|
||||||
|
from core.workflow.nodes.base import BaseNode
|
||||||
|
from core.workflow.nodes.enums import NodeType
|
||||||
|
from core.workflow.nodes.loop.entities import LoopEndNodeData
|
||||||
|
from models.workflow import WorkflowNodeExecutionStatus
|
||||||
|
|
||||||
|
|
||||||
|
class LoopEndNode(BaseNode[LoopEndNodeData]):
|
||||||
|
"""
|
||||||
|
Loop End Node.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_node_data_cls = LoopEndNodeData
|
||||||
|
_node_type = NodeType.LOOP_END
|
||||||
|
|
||||||
|
def _run(self) -> NodeRunResult:
|
||||||
|
"""
|
||||||
|
Run the node.
|
||||||
|
"""
|
||||||
|
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED)
|
@ -1,10 +1,20 @@
|
|||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
from collections.abc import Generator, Mapping, Sequence
|
from collections.abc import Generator, Mapping, Sequence
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
from typing import Any, cast
|
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.variables import IntegerSegment
|
from core.variables import (
|
||||||
|
ArrayNumberSegment,
|
||||||
|
ArrayObjectSegment,
|
||||||
|
ArrayStringSegment,
|
||||||
|
IntegerSegment,
|
||||||
|
ObjectSegment,
|
||||||
|
Segment,
|
||||||
|
SegmentType,
|
||||||
|
StringSegment,
|
||||||
|
)
|
||||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
|
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
|
||||||
from core.workflow.graph_engine.entities.event import (
|
from core.workflow.graph_engine.entities.event import (
|
||||||
BaseGraphEvent,
|
BaseGraphEvent,
|
||||||
@ -29,6 +39,10 @@ from core.workflow.nodes.loop.entities import LoopNodeData
|
|||||||
from core.workflow.utils.condition.processor import ConditionProcessor
|
from core.workflow.utils.condition.processor import ConditionProcessor
|
||||||
from models.workflow import WorkflowNodeExecutionStatus
|
from models.workflow import WorkflowNodeExecutionStatus
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
|
from core.workflow.graph_engine.graph_engine import GraphEngine
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ -61,6 +75,28 @@ class LoopNode(BaseNode[LoopNodeData]):
|
|||||||
variable_pool = self.graph_runtime_state.variable_pool
|
variable_pool = self.graph_runtime_state.variable_pool
|
||||||
variable_pool.add([self.node_id, "index"], 0)
|
variable_pool.add([self.node_id, "index"], 0)
|
||||||
|
|
||||||
|
# Initialize loop variables
|
||||||
|
loop_variable_selectors = {}
|
||||||
|
if self.node_data.loop_variables:
|
||||||
|
for loop_variable in self.node_data.loop_variables:
|
||||||
|
value_processor = {
|
||||||
|
"constant": lambda var=loop_variable: self._get_segment_for_constant(var.var_type, var.value),
|
||||||
|
"variable": lambda var=loop_variable: variable_pool.get(var.value),
|
||||||
|
}
|
||||||
|
|
||||||
|
if loop_variable.value_type not in value_processor:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid value type '{loop_variable.value_type}' for loop variable {loop_variable.label}"
|
||||||
|
)
|
||||||
|
|
||||||
|
processed_segment = value_processor[loop_variable.value_type]()
|
||||||
|
if not processed_segment:
|
||||||
|
raise ValueError(f"Invalid value for loop variable {loop_variable.label}")
|
||||||
|
variable_selector = [self.node_id, loop_variable.label]
|
||||||
|
variable_pool.add(variable_selector, processed_segment.value)
|
||||||
|
loop_variable_selectors[loop_variable.label] = variable_selector
|
||||||
|
inputs[loop_variable.label] = processed_segment.value
|
||||||
|
|
||||||
from core.workflow.graph_engine.graph_engine import GraphEngine
|
from core.workflow.graph_engine.graph_engine import GraphEngine
|
||||||
|
|
||||||
graph_engine = GraphEngine(
|
graph_engine = GraphEngine(
|
||||||
@ -95,135 +131,51 @@ class LoopNode(BaseNode[LoopNodeData]):
|
|||||||
predecessor_node_id=self.previous_node_id,
|
predecessor_node_id=self.previous_node_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
yield LoopRunNextEvent(
|
# yield LoopRunNextEvent(
|
||||||
loop_id=self.id,
|
# loop_id=self.id,
|
||||||
loop_node_id=self.node_id,
|
# loop_node_id=self.node_id,
|
||||||
loop_node_type=self.node_type,
|
# loop_node_type=self.node_type,
|
||||||
loop_node_data=self.node_data,
|
# loop_node_data=self.node_data,
|
||||||
index=0,
|
# index=0,
|
||||||
pre_loop_output=None,
|
# pre_loop_output=None,
|
||||||
)
|
# )
|
||||||
|
loop_duration_map = {}
|
||||||
|
single_loop_variable_map = {} # single loop variable output
|
||||||
try:
|
try:
|
||||||
check_break_result = False
|
check_break_result = False
|
||||||
for i in range(loop_count):
|
for i in range(loop_count):
|
||||||
# Run workflow
|
loop_start_time = datetime.now(UTC).replace(tzinfo=None)
|
||||||
rst = graph_engine.run()
|
# run single loop
|
||||||
current_index_variable = variable_pool.get([self.node_id, "index"])
|
loop_result = yield from self._run_single_loop(
|
||||||
if not isinstance(current_index_variable, IntegerSegment):
|
graph_engine=graph_engine,
|
||||||
raise ValueError(f"loop {self.node_id} current index not found")
|
loop_graph=loop_graph,
|
||||||
current_index = current_index_variable.value
|
variable_pool=variable_pool,
|
||||||
|
loop_variable_selectors=loop_variable_selectors,
|
||||||
|
break_conditions=break_conditions,
|
||||||
|
logical_operator=logical_operator,
|
||||||
|
condition_processor=condition_processor,
|
||||||
|
current_index=i,
|
||||||
|
start_at=start_at,
|
||||||
|
inputs=inputs,
|
||||||
|
)
|
||||||
|
loop_end_time = datetime.now(UTC).replace(tzinfo=None)
|
||||||
|
|
||||||
check_break_result = False
|
single_loop_variable = {}
|
||||||
|
for key, selector in loop_variable_selectors.items():
|
||||||
for event in rst:
|
item = variable_pool.get(selector)
|
||||||
if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_loop_id:
|
if item:
|
||||||
event.in_loop_id = self.node_id
|
single_loop_variable[key] = item.value
|
||||||
|
|
||||||
if (
|
|
||||||
isinstance(event, BaseNodeEvent)
|
|
||||||
and event.node_type == NodeType.LOOP_START
|
|
||||||
and not isinstance(event, NodeRunStreamChunkEvent)
|
|
||||||
):
|
|
||||||
continue
|
|
||||||
|
|
||||||
if isinstance(event, NodeRunSucceededEvent):
|
|
||||||
yield self._handle_event_metadata(event=event, iter_run_index=current_index)
|
|
||||||
|
|
||||||
# Check if all variables in break conditions exist
|
|
||||||
exists_variable = False
|
|
||||||
for condition in break_conditions:
|
|
||||||
if not self.graph_runtime_state.variable_pool.get(condition.variable_selector):
|
|
||||||
exists_variable = False
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
exists_variable = True
|
|
||||||
if exists_variable:
|
|
||||||
input_conditions, group_result, check_break_result = condition_processor.process_conditions(
|
|
||||||
variable_pool=self.graph_runtime_state.variable_pool,
|
|
||||||
conditions=break_conditions,
|
|
||||||
operator=logical_operator,
|
|
||||||
)
|
|
||||||
if check_break_result:
|
|
||||||
break
|
|
||||||
|
|
||||||
elif isinstance(event, BaseGraphEvent):
|
|
||||||
if isinstance(event, GraphRunFailedEvent):
|
|
||||||
# Loop run failed
|
|
||||||
yield LoopRunFailedEvent(
|
|
||||||
loop_id=self.id,
|
|
||||||
loop_node_id=self.node_id,
|
|
||||||
loop_node_type=self.node_type,
|
|
||||||
loop_node_data=self.node_data,
|
|
||||||
start_at=start_at,
|
|
||||||
inputs=inputs,
|
|
||||||
steps=i,
|
|
||||||
metadata={
|
|
||||||
NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
|
|
||||||
"completed_reason": "error",
|
|
||||||
},
|
|
||||||
error=event.error,
|
|
||||||
)
|
|
||||||
yield RunCompletedEvent(
|
|
||||||
run_result=NodeRunResult(
|
|
||||||
status=WorkflowNodeExecutionStatus.FAILED,
|
|
||||||
error=event.error,
|
|
||||||
metadata={
|
|
||||||
NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return
|
|
||||||
elif isinstance(event, NodeRunFailedEvent):
|
|
||||||
# Loop run failed
|
|
||||||
yield event
|
|
||||||
yield LoopRunFailedEvent(
|
|
||||||
loop_id=self.id,
|
|
||||||
loop_node_id=self.node_id,
|
|
||||||
loop_node_type=self.node_type,
|
|
||||||
loop_node_data=self.node_data,
|
|
||||||
start_at=start_at,
|
|
||||||
inputs=inputs,
|
|
||||||
steps=i,
|
|
||||||
metadata={
|
|
||||||
NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
|
|
||||||
"completed_reason": "error",
|
|
||||||
},
|
|
||||||
error=event.error,
|
|
||||||
)
|
|
||||||
yield RunCompletedEvent(
|
|
||||||
run_result=NodeRunResult(
|
|
||||||
status=WorkflowNodeExecutionStatus.FAILED,
|
|
||||||
error=event.error,
|
|
||||||
metadata={
|
|
||||||
NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return
|
|
||||||
else:
|
else:
|
||||||
yield self._handle_event_metadata(event=cast(InNodeEvent, event), iter_run_index=current_index)
|
single_loop_variable[key] = None
|
||||||
|
|
||||||
# Remove all nodes outputs from variable pool
|
loop_duration_map[str(i)] = (loop_end_time - loop_start_time).total_seconds()
|
||||||
for node_id in loop_graph.node_ids:
|
single_loop_variable_map[str(i)] = single_loop_variable
|
||||||
variable_pool.remove([node_id])
|
|
||||||
|
check_break_result = loop_result.get("check_break_result", False)
|
||||||
|
|
||||||
if check_break_result:
|
if check_break_result:
|
||||||
break
|
break
|
||||||
|
|
||||||
# Move to next loop
|
|
||||||
next_index = current_index + 1
|
|
||||||
variable_pool.add([self.node_id, "index"], next_index)
|
|
||||||
|
|
||||||
yield LoopRunNextEvent(
|
|
||||||
loop_id=self.id,
|
|
||||||
loop_node_id=self.node_id,
|
|
||||||
loop_node_type=self.node_type,
|
|
||||||
loop_node_data=self.node_data,
|
|
||||||
index=next_index,
|
|
||||||
pre_loop_output=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Loop completed successfully
|
# Loop completed successfully
|
||||||
yield LoopRunSucceededEvent(
|
yield LoopRunSucceededEvent(
|
||||||
loop_id=self.id,
|
loop_id=self.id,
|
||||||
@ -232,17 +184,26 @@ class LoopNode(BaseNode[LoopNodeData]):
|
|||||||
loop_node_data=self.node_data,
|
loop_node_data=self.node_data,
|
||||||
start_at=start_at,
|
start_at=start_at,
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
|
outputs=self.node_data.outputs,
|
||||||
steps=loop_count,
|
steps=loop_count,
|
||||||
metadata={
|
metadata={
|
||||||
NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
|
NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
|
||||||
"completed_reason": "loop_break" if check_break_result else "loop_completed",
|
"completed_reason": "loop_break" if check_break_result else "loop_completed",
|
||||||
|
NodeRunMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
|
||||||
|
NodeRunMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
yield RunCompletedEvent(
|
yield RunCompletedEvent(
|
||||||
run_result=NodeRunResult(
|
run_result=NodeRunResult(
|
||||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||||
metadata={NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens},
|
metadata={
|
||||||
|
NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
|
||||||
|
NodeRunMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
|
||||||
|
NodeRunMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
|
||||||
|
},
|
||||||
|
outputs=self.node_data.outputs,
|
||||||
|
inputs=inputs,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -260,6 +221,8 @@ class LoopNode(BaseNode[LoopNodeData]):
|
|||||||
metadata={
|
metadata={
|
||||||
"total_tokens": graph_engine.graph_runtime_state.total_tokens,
|
"total_tokens": graph_engine.graph_runtime_state.total_tokens,
|
||||||
"completed_reason": "error",
|
"completed_reason": "error",
|
||||||
|
NodeRunMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
|
||||||
|
NodeRunMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
|
||||||
},
|
},
|
||||||
error=str(e),
|
error=str(e),
|
||||||
)
|
)
|
||||||
@ -268,7 +231,11 @@ class LoopNode(BaseNode[LoopNodeData]):
|
|||||||
run_result=NodeRunResult(
|
run_result=NodeRunResult(
|
||||||
status=WorkflowNodeExecutionStatus.FAILED,
|
status=WorkflowNodeExecutionStatus.FAILED,
|
||||||
error=str(e),
|
error=str(e),
|
||||||
metadata={NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens},
|
metadata={
|
||||||
|
NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
|
||||||
|
NodeRunMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
|
||||||
|
NodeRunMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -276,6 +243,159 @@ class LoopNode(BaseNode[LoopNodeData]):
|
|||||||
# Clean up
|
# Clean up
|
||||||
variable_pool.remove([self.node_id, "index"])
|
variable_pool.remove([self.node_id, "index"])
|
||||||
|
|
||||||
|
def _run_single_loop(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
graph_engine: "GraphEngine",
|
||||||
|
loop_graph: Graph,
|
||||||
|
variable_pool: "VariablePool",
|
||||||
|
loop_variable_selectors: dict,
|
||||||
|
break_conditions: list,
|
||||||
|
logical_operator: Literal["and", "or"],
|
||||||
|
condition_processor: ConditionProcessor,
|
||||||
|
current_index: int,
|
||||||
|
start_at: datetime,
|
||||||
|
inputs: dict,
|
||||||
|
) -> Generator[NodeEvent | InNodeEvent, None, dict]:
|
||||||
|
"""Run a single loop iteration.
|
||||||
|
Returns:
|
||||||
|
dict: {'check_break_result': bool}
|
||||||
|
"""
|
||||||
|
# Run workflow
|
||||||
|
rst = graph_engine.run()
|
||||||
|
current_index_variable = variable_pool.get([self.node_id, "index"])
|
||||||
|
if not isinstance(current_index_variable, IntegerSegment):
|
||||||
|
raise ValueError(f"loop {self.node_id} current index not found")
|
||||||
|
current_index = current_index_variable.value
|
||||||
|
|
||||||
|
check_break_result = False
|
||||||
|
|
||||||
|
for event in rst:
|
||||||
|
if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_loop_id:
|
||||||
|
event.in_loop_id = self.node_id
|
||||||
|
|
||||||
|
if (
|
||||||
|
isinstance(event, BaseNodeEvent)
|
||||||
|
and event.node_type == NodeType.LOOP_START
|
||||||
|
and not isinstance(event, NodeRunStreamChunkEvent)
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if (
|
||||||
|
isinstance(event, NodeRunSucceededEvent)
|
||||||
|
and event.node_type == NodeType.LOOP_END
|
||||||
|
and not isinstance(event, NodeRunStreamChunkEvent)
|
||||||
|
):
|
||||||
|
check_break_result = True
|
||||||
|
yield self._handle_event_metadata(event=event, iter_run_index=current_index)
|
||||||
|
break
|
||||||
|
|
||||||
|
if isinstance(event, NodeRunSucceededEvent):
|
||||||
|
yield self._handle_event_metadata(event=event, iter_run_index=current_index)
|
||||||
|
|
||||||
|
# Check if all variables in break conditions exist
|
||||||
|
exists_variable = False
|
||||||
|
for condition in break_conditions:
|
||||||
|
if not self.graph_runtime_state.variable_pool.get(condition.variable_selector):
|
||||||
|
exists_variable = False
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
exists_variable = True
|
||||||
|
if exists_variable:
|
||||||
|
input_conditions, group_result, check_break_result = condition_processor.process_conditions(
|
||||||
|
variable_pool=self.graph_runtime_state.variable_pool,
|
||||||
|
conditions=break_conditions,
|
||||||
|
operator=logical_operator,
|
||||||
|
)
|
||||||
|
if check_break_result:
|
||||||
|
break
|
||||||
|
|
||||||
|
elif isinstance(event, BaseGraphEvent):
|
||||||
|
if isinstance(event, GraphRunFailedEvent):
|
||||||
|
# Loop run failed
|
||||||
|
yield LoopRunFailedEvent(
|
||||||
|
loop_id=self.id,
|
||||||
|
loop_node_id=self.node_id,
|
||||||
|
loop_node_type=self.node_type,
|
||||||
|
loop_node_data=self.node_data,
|
||||||
|
start_at=start_at,
|
||||||
|
inputs=inputs,
|
||||||
|
steps=current_index,
|
||||||
|
metadata={
|
||||||
|
NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
|
||||||
|
"completed_reason": "error",
|
||||||
|
},
|
||||||
|
error=event.error,
|
||||||
|
)
|
||||||
|
yield RunCompletedEvent(
|
||||||
|
run_result=NodeRunResult(
|
||||||
|
status=WorkflowNodeExecutionStatus.FAILED,
|
||||||
|
error=event.error,
|
||||||
|
metadata={NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return {"check_break_result": True}
|
||||||
|
elif isinstance(event, NodeRunFailedEvent):
|
||||||
|
# Loop run failed
|
||||||
|
yield event
|
||||||
|
yield LoopRunFailedEvent(
|
||||||
|
loop_id=self.id,
|
||||||
|
loop_node_id=self.node_id,
|
||||||
|
loop_node_type=self.node_type,
|
||||||
|
loop_node_data=self.node_data,
|
||||||
|
start_at=start_at,
|
||||||
|
inputs=inputs,
|
||||||
|
steps=current_index,
|
||||||
|
metadata={
|
||||||
|
NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
|
||||||
|
"completed_reason": "error",
|
||||||
|
},
|
||||||
|
error=event.error,
|
||||||
|
)
|
||||||
|
yield RunCompletedEvent(
|
||||||
|
run_result=NodeRunResult(
|
||||||
|
status=WorkflowNodeExecutionStatus.FAILED,
|
||||||
|
error=event.error,
|
||||||
|
metadata={NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return {"check_break_result": True}
|
||||||
|
else:
|
||||||
|
yield self._handle_event_metadata(event=cast(InNodeEvent, event), iter_run_index=current_index)
|
||||||
|
|
||||||
|
# Remove all nodes outputs from variable pool
|
||||||
|
for node_id in loop_graph.node_ids:
|
||||||
|
variable_pool.remove([node_id])
|
||||||
|
|
||||||
|
_outputs = {}
|
||||||
|
for loop_variable_key, loop_variable_selector in loop_variable_selectors.items():
|
||||||
|
_loop_variable_segment = variable_pool.get(loop_variable_selector)
|
||||||
|
if _loop_variable_segment:
|
||||||
|
_outputs[loop_variable_key] = _loop_variable_segment.value
|
||||||
|
else:
|
||||||
|
_outputs[loop_variable_key] = None
|
||||||
|
|
||||||
|
_outputs["loop_round"] = current_index + 1
|
||||||
|
self.node_data.outputs = _outputs
|
||||||
|
|
||||||
|
if check_break_result:
|
||||||
|
return {"check_break_result": True}
|
||||||
|
|
||||||
|
# Move to next loop
|
||||||
|
next_index = current_index + 1
|
||||||
|
variable_pool.add([self.node_id, "index"], next_index)
|
||||||
|
|
||||||
|
yield LoopRunNextEvent(
|
||||||
|
loop_id=self.id,
|
||||||
|
loop_node_id=self.node_id,
|
||||||
|
loop_node_type=self.node_type,
|
||||||
|
loop_node_data=self.node_data,
|
||||||
|
index=next_index,
|
||||||
|
pre_loop_output=self.node_data.outputs,
|
||||||
|
)
|
||||||
|
|
||||||
|
return {"check_break_result": False}
|
||||||
|
|
||||||
def _handle_event_metadata(
|
def _handle_event_metadata(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
@ -360,3 +480,25 @@ class LoopNode(BaseNode[LoopNodeData]):
|
|||||||
}
|
}
|
||||||
|
|
||||||
return variable_mapping
|
return variable_mapping
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_segment_for_constant(var_type: str, value: Any) -> Segment:
|
||||||
|
"""Get the appropriate segment type for a constant value."""
|
||||||
|
segment_mapping: dict[str, tuple[type[Segment], SegmentType]] = {
|
||||||
|
"string": (StringSegment, SegmentType.STRING),
|
||||||
|
"number": (IntegerSegment, SegmentType.NUMBER),
|
||||||
|
"object": (ObjectSegment, SegmentType.OBJECT),
|
||||||
|
"array[string]": (ArrayStringSegment, SegmentType.ARRAY_STRING),
|
||||||
|
"array[number]": (ArrayNumberSegment, SegmentType.ARRAY_NUMBER),
|
||||||
|
"array[object]": (ArrayObjectSegment, SegmentType.ARRAY_OBJECT),
|
||||||
|
}
|
||||||
|
if var_type in ["array[string]", "array[number]", "array[object]"]:
|
||||||
|
if value:
|
||||||
|
value = json.loads(value)
|
||||||
|
else:
|
||||||
|
value = []
|
||||||
|
segment_info = segment_mapping.get(var_type)
|
||||||
|
if not segment_info:
|
||||||
|
raise ValueError(f"Invalid variable type: {var_type}")
|
||||||
|
segment_class, value_type = segment_info
|
||||||
|
return segment_class(value=value, value_type=value_type)
|
||||||
|
@ -13,7 +13,7 @@ from core.workflow.nodes.iteration import IterationNode, IterationStartNode
|
|||||||
from core.workflow.nodes.knowledge_retrieval import KnowledgeRetrievalNode
|
from core.workflow.nodes.knowledge_retrieval import KnowledgeRetrievalNode
|
||||||
from core.workflow.nodes.list_operator import ListOperatorNode
|
from core.workflow.nodes.list_operator import ListOperatorNode
|
||||||
from core.workflow.nodes.llm import LLMNode
|
from core.workflow.nodes.llm import LLMNode
|
||||||
from core.workflow.nodes.loop import LoopNode, LoopStartNode
|
from core.workflow.nodes.loop import LoopEndNode, LoopNode, LoopStartNode
|
||||||
from core.workflow.nodes.parameter_extractor import ParameterExtractorNode
|
from core.workflow.nodes.parameter_extractor import ParameterExtractorNode
|
||||||
from core.workflow.nodes.question_classifier import QuestionClassifierNode
|
from core.workflow.nodes.question_classifier import QuestionClassifierNode
|
||||||
from core.workflow.nodes.start import StartNode
|
from core.workflow.nodes.start import StartNode
|
||||||
@ -94,6 +94,10 @@ NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[BaseNode]]] = {
|
|||||||
LATEST_VERSION: LoopStartNode,
|
LATEST_VERSION: LoopStartNode,
|
||||||
"1": LoopStartNode,
|
"1": LoopStartNode,
|
||||||
},
|
},
|
||||||
|
NodeType.LOOP_END: {
|
||||||
|
LATEST_VERSION: LoopEndNode,
|
||||||
|
"1": LoopEndNode,
|
||||||
|
},
|
||||||
NodeType.PARAMETER_EXTRACTOR: {
|
NodeType.PARAMETER_EXTRACTOR: {
|
||||||
LATEST_VERSION: ParameterExtractorNode,
|
LATEST_VERSION: ParameterExtractorNode,
|
||||||
"1": ParameterExtractorNode,
|
"1": ParameterExtractorNode,
|
||||||
|
@ -2,6 +2,7 @@ import json
|
|||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from core.variables import SegmentType, Variable
|
from core.variables import SegmentType, Variable
|
||||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
|
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||||
from core.workflow.entities.node_entities import NodeRunResult
|
from core.workflow.entities.node_entities import NodeRunResult
|
||||||
@ -123,13 +124,14 @@ class VariableAssignerNode(BaseNode[VariableAssignerNodeData]):
|
|||||||
if variable.selector[0] == CONVERSATION_VARIABLE_NODE_ID:
|
if variable.selector[0] == CONVERSATION_VARIABLE_NODE_ID:
|
||||||
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", "conversation_id"])
|
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", "conversation_id"])
|
||||||
if not conversation_id:
|
if not conversation_id:
|
||||||
raise ConversationIDNotFoundError
|
if self.invoke_from != InvokeFrom.DEBUGGER:
|
||||||
|
raise ConversationIDNotFoundError
|
||||||
else:
|
else:
|
||||||
conversation_id = conversation_id.value
|
conversation_id = conversation_id.value
|
||||||
common_helpers.update_conversation_variable(
|
common_helpers.update_conversation_variable(
|
||||||
conversation_id=cast(str, conversation_id),
|
conversation_id=cast(str, conversation_id),
|
||||||
variable=variable,
|
variable=variable,
|
||||||
)
|
)
|
||||||
|
|
||||||
return NodeRunResult(
|
return NodeRunResult(
|
||||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user