Feat/loop node (#17273)

This commit is contained in:
Dongyu Li 2025-04-02 13:53:26 +08:00 committed by GitHub
parent 11e95d2a61
commit 8c77f2dc03
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 324 additions and 131 deletions

View File

@ -30,6 +30,7 @@ class NodeRunMetadataKey(StrEnum):
ITERATION_DURATION_MAP = "iteration_duration_map" # single iteration duration if iteration node runs
LOOP_DURATION_MAP = "loop_duration_map" # single loop duration if loop node runs
ERROR_STRATEGY = "error_strategy" # node in continue on error mode return the field
LOOP_VARIABLE_MAP = "loop_variable_map" # single loop variable output
class NodeRunResult(BaseModel):

View File

@ -17,6 +17,7 @@ class NodeType(StrEnum):
LEGACY_VARIABLE_AGGREGATOR = "variable-assigner" # TODO: Merge this into VARIABLE_AGGREGATOR in the database.
LOOP = "loop"
LOOP_START = "loop-start"
LOOP_END = "loop-end"
ITERATION = "iteration"
ITERATION_START = "iteration-start" # Fake start node for iteration.
PARAMETER_EXTRACTOR = "parameter-extractor"

View File

@ -1,5 +1,6 @@
from .entities import LoopNodeData
from .loop_end_node import LoopEndNode
from .loop_node import LoopNode
from .loop_start_node import LoopStartNode
__all__ = ["LoopNode", "LoopNodeData", "LoopStartNode"]
__all__ = ["LoopEndNode", "LoopNode", "LoopNodeData", "LoopStartNode"]

View File

@ -1,11 +1,23 @@
from collections.abc import Mapping
from typing import Any, Literal, Optional
from pydantic import Field
from pydantic import BaseModel, Field
from core.workflow.nodes.base import BaseLoopNodeData, BaseLoopState, BaseNodeData
from core.workflow.utils.condition.entities import Condition
class LoopVariableData(BaseModel):
"""
Loop Variable Data.
"""
label: str
var_type: Literal["string", "number", "object", "array[string]", "array[number]", "array[object]"]
value_type: Literal["variable", "constant"]
value: Optional[Any | list[str]] = None
class LoopNodeData(BaseLoopNodeData):
"""
Loop Node Data.
@ -14,6 +26,8 @@ class LoopNodeData(BaseLoopNodeData):
loop_count: int # Maximum number of loops
break_conditions: list[Condition] # Conditions to break the loop
logical_operator: Literal["and", "or"]
loop_variables: Optional[list[LoopVariableData]] = Field(default_factory=list)
outputs: Optional[Mapping[str, Any]] = None
class LoopStartNodeData(BaseNodeData):
@ -24,6 +38,14 @@ class LoopStartNodeData(BaseNodeData):
pass
class LoopEndNodeData(BaseNodeData):
"""
Loop End Node Data.
"""
pass
class LoopState(BaseLoopState):
"""
Loop State.

View File

@ -0,0 +1,20 @@
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.loop.entities import LoopEndNodeData
from models.workflow import WorkflowNodeExecutionStatus
class LoopEndNode(BaseNode[LoopEndNodeData]):
"""
Loop End Node.
"""
_node_data_cls = LoopEndNodeData
_node_type = NodeType.LOOP_END
def _run(self) -> NodeRunResult:
"""
Run the node.
"""
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED)

View File

@ -1,10 +1,20 @@
import json
import logging
from collections.abc import Generator, Mapping, Sequence
from datetime import UTC, datetime
from typing import Any, cast
from typing import TYPE_CHECKING, Any, Literal, cast
from configs import dify_config
from core.variables import IntegerSegment
from core.variables import (
ArrayNumberSegment,
ArrayObjectSegment,
ArrayStringSegment,
IntegerSegment,
ObjectSegment,
Segment,
SegmentType,
StringSegment,
)
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
from core.workflow.graph_engine.entities.event import (
BaseGraphEvent,
@ -29,6 +39,10 @@ from core.workflow.nodes.loop.entities import LoopNodeData
from core.workflow.utils.condition.processor import ConditionProcessor
from models.workflow import WorkflowNodeExecutionStatus
if TYPE_CHECKING:
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.graph_engine import GraphEngine
logger = logging.getLogger(__name__)
@ -61,6 +75,28 @@ class LoopNode(BaseNode[LoopNodeData]):
variable_pool = self.graph_runtime_state.variable_pool
variable_pool.add([self.node_id, "index"], 0)
# Initialize loop variables
loop_variable_selectors = {}
if self.node_data.loop_variables:
for loop_variable in self.node_data.loop_variables:
value_processor = {
"constant": lambda var=loop_variable: self._get_segment_for_constant(var.var_type, var.value),
"variable": lambda var=loop_variable: variable_pool.get(var.value),
}
if loop_variable.value_type not in value_processor:
raise ValueError(
f"Invalid value type '{loop_variable.value_type}' for loop variable {loop_variable.label}"
)
processed_segment = value_processor[loop_variable.value_type]()
if not processed_segment:
raise ValueError(f"Invalid value for loop variable {loop_variable.label}")
variable_selector = [self.node_id, loop_variable.label]
variable_pool.add(variable_selector, processed_segment.value)
loop_variable_selectors[loop_variable.label] = variable_selector
inputs[loop_variable.label] = processed_segment.value
from core.workflow.graph_engine.graph_engine import GraphEngine
graph_engine = GraphEngine(
@ -95,18 +131,136 @@ class LoopNode(BaseNode[LoopNodeData]):
predecessor_node_id=self.previous_node_id,
)
yield LoopRunNextEvent(
# yield LoopRunNextEvent(
# loop_id=self.id,
# loop_node_id=self.node_id,
# loop_node_type=self.node_type,
# loop_node_data=self.node_data,
# index=0,
# pre_loop_output=None,
# )
loop_duration_map = {}
single_loop_variable_map = {} # single loop variable output
try:
check_break_result = False
for i in range(loop_count):
loop_start_time = datetime.now(UTC).replace(tzinfo=None)
# run single loop
loop_result = yield from self._run_single_loop(
graph_engine=graph_engine,
loop_graph=loop_graph,
variable_pool=variable_pool,
loop_variable_selectors=loop_variable_selectors,
break_conditions=break_conditions,
logical_operator=logical_operator,
condition_processor=condition_processor,
current_index=i,
start_at=start_at,
inputs=inputs,
)
loop_end_time = datetime.now(UTC).replace(tzinfo=None)
single_loop_variable = {}
for key, selector in loop_variable_selectors.items():
item = variable_pool.get(selector)
if item:
single_loop_variable[key] = item.value
else:
single_loop_variable[key] = None
loop_duration_map[str(i)] = (loop_end_time - loop_start_time).total_seconds()
single_loop_variable_map[str(i)] = single_loop_variable
check_break_result = loop_result.get("check_break_result", False)
if check_break_result:
break
# Loop completed successfully
yield LoopRunSucceededEvent(
loop_id=self.id,
loop_node_id=self.node_id,
loop_node_type=self.node_type,
loop_node_data=self.node_data,
index=0,
pre_loop_output=None,
start_at=start_at,
inputs=inputs,
outputs=self.node_data.outputs,
steps=loop_count,
metadata={
NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
"completed_reason": "loop_break" if check_break_result else "loop_completed",
NodeRunMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
NodeRunMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
},
)
try:
check_break_result = False
for i in range(loop_count):
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
metadata={
NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
NodeRunMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
NodeRunMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
},
outputs=self.node_data.outputs,
inputs=inputs,
)
)
except Exception as e:
# Loop failed
logger.exception("Loop run failed")
yield LoopRunFailedEvent(
loop_id=self.id,
loop_node_id=self.node_id,
loop_node_type=self.node_type,
loop_node_data=self.node_data,
start_at=start_at,
inputs=inputs,
steps=loop_count,
metadata={
"total_tokens": graph_engine.graph_runtime_state.total_tokens,
"completed_reason": "error",
NodeRunMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
NodeRunMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
},
error=str(e),
)
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
metadata={
NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
NodeRunMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
NodeRunMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
},
)
)
finally:
# Clean up
variable_pool.remove([self.node_id, "index"])
def _run_single_loop(
self,
*,
graph_engine: "GraphEngine",
loop_graph: Graph,
variable_pool: "VariablePool",
loop_variable_selectors: dict,
break_conditions: list,
logical_operator: Literal["and", "or"],
condition_processor: ConditionProcessor,
current_index: int,
start_at: datetime,
inputs: dict,
) -> Generator[NodeEvent | InNodeEvent, None, dict]:
"""Run a single loop iteration.
Returns:
dict: {'check_break_result': bool}
"""
# Run workflow
rst = graph_engine.run()
current_index_variable = variable_pool.get([self.node_id, "index"])
@ -127,6 +281,15 @@ class LoopNode(BaseNode[LoopNodeData]):
):
continue
if (
isinstance(event, NodeRunSucceededEvent)
and event.node_type == NodeType.LOOP_END
and not isinstance(event, NodeRunStreamChunkEvent)
):
check_break_result = True
yield self._handle_event_metadata(event=event, iter_run_index=current_index)
break
if isinstance(event, NodeRunSucceededEvent):
yield self._handle_event_metadata(event=event, iter_run_index=current_index)
@ -157,7 +320,7 @@ class LoopNode(BaseNode[LoopNodeData]):
loop_node_data=self.node_data,
start_at=start_at,
inputs=inputs,
steps=i,
steps=current_index,
metadata={
NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
"completed_reason": "error",
@ -168,12 +331,10 @@ class LoopNode(BaseNode[LoopNodeData]):
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=event.error,
metadata={
NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens
},
metadata={NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens},
)
)
return
return {"check_break_result": True}
elif isinstance(event, NodeRunFailedEvent):
# Loop run failed
yield event
@ -184,7 +345,7 @@ class LoopNode(BaseNode[LoopNodeData]):
loop_node_data=self.node_data,
start_at=start_at,
inputs=inputs,
steps=i,
steps=current_index,
metadata={
NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
"completed_reason": "error",
@ -195,12 +356,10 @@ class LoopNode(BaseNode[LoopNodeData]):
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=event.error,
metadata={
NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens
},
metadata={NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens},
)
)
return
return {"check_break_result": True}
else:
yield self._handle_event_metadata(event=cast(InNodeEvent, event), iter_run_index=current_index)
@ -208,8 +367,19 @@ class LoopNode(BaseNode[LoopNodeData]):
for node_id in loop_graph.node_ids:
variable_pool.remove([node_id])
_outputs = {}
for loop_variable_key, loop_variable_selector in loop_variable_selectors.items():
_loop_variable_segment = variable_pool.get(loop_variable_selector)
if _loop_variable_segment:
_outputs[loop_variable_key] = _loop_variable_segment.value
else:
_outputs[loop_variable_key] = None
_outputs["loop_round"] = current_index + 1
self.node_data.outputs = _outputs
if check_break_result:
break
return {"check_break_result": True}
# Move to next loop
next_index = current_index + 1
@ -221,60 +391,10 @@ class LoopNode(BaseNode[LoopNodeData]):
loop_node_type=self.node_type,
loop_node_data=self.node_data,
index=next_index,
pre_loop_output=None,
pre_loop_output=self.node_data.outputs,
)
# Loop completed successfully
yield LoopRunSucceededEvent(
loop_id=self.id,
loop_node_id=self.node_id,
loop_node_type=self.node_type,
loop_node_data=self.node_data,
start_at=start_at,
inputs=inputs,
steps=loop_count,
metadata={
NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
"completed_reason": "loop_break" if check_break_result else "loop_completed",
},
)
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
metadata={NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens},
)
)
except Exception as e:
# Loop failed
logger.exception("Loop run failed")
yield LoopRunFailedEvent(
loop_id=self.id,
loop_node_id=self.node_id,
loop_node_type=self.node_type,
loop_node_data=self.node_data,
start_at=start_at,
inputs=inputs,
steps=loop_count,
metadata={
"total_tokens": graph_engine.graph_runtime_state.total_tokens,
"completed_reason": "error",
},
error=str(e),
)
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
metadata={NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens},
)
)
finally:
# Clean up
variable_pool.remove([self.node_id, "index"])
return {"check_break_result": False}
def _handle_event_metadata(
self,
@ -360,3 +480,25 @@ class LoopNode(BaseNode[LoopNodeData]):
}
return variable_mapping
@staticmethod
def _get_segment_for_constant(var_type: str, value: Any) -> Segment:
"""Get the appropriate segment type for a constant value."""
segment_mapping: dict[str, tuple[type[Segment], SegmentType]] = {
"string": (StringSegment, SegmentType.STRING),
"number": (IntegerSegment, SegmentType.NUMBER),
"object": (ObjectSegment, SegmentType.OBJECT),
"array[string]": (ArrayStringSegment, SegmentType.ARRAY_STRING),
"array[number]": (ArrayNumberSegment, SegmentType.ARRAY_NUMBER),
"array[object]": (ArrayObjectSegment, SegmentType.ARRAY_OBJECT),
}
if var_type in ["array[string]", "array[number]", "array[object]"]:
if value:
value = json.loads(value)
else:
value = []
segment_info = segment_mapping.get(var_type)
if not segment_info:
raise ValueError(f"Invalid variable type: {var_type}")
segment_class, value_type = segment_info
return segment_class(value=value, value_type=value_type)

View File

@ -13,7 +13,7 @@ from core.workflow.nodes.iteration import IterationNode, IterationStartNode
from core.workflow.nodes.knowledge_retrieval import KnowledgeRetrievalNode
from core.workflow.nodes.list_operator import ListOperatorNode
from core.workflow.nodes.llm import LLMNode
from core.workflow.nodes.loop import LoopNode, LoopStartNode
from core.workflow.nodes.loop import LoopEndNode, LoopNode, LoopStartNode
from core.workflow.nodes.parameter_extractor import ParameterExtractorNode
from core.workflow.nodes.question_classifier import QuestionClassifierNode
from core.workflow.nodes.start import StartNode
@ -94,6 +94,10 @@ NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[BaseNode]]] = {
LATEST_VERSION: LoopStartNode,
"1": LoopStartNode,
},
NodeType.LOOP_END: {
LATEST_VERSION: LoopEndNode,
"1": LoopEndNode,
},
NodeType.PARAMETER_EXTRACTOR: {
LATEST_VERSION: ParameterExtractorNode,
"1": ParameterExtractorNode,

View File

@ -2,6 +2,7 @@ import json
from collections.abc import Sequence
from typing import Any, cast
from core.app.entities.app_invoke_entities import InvokeFrom
from core.variables import SegmentType, Variable
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.entities.node_entities import NodeRunResult
@ -123,6 +124,7 @@ class VariableAssignerNode(BaseNode[VariableAssignerNodeData]):
if variable.selector[0] == CONVERSATION_VARIABLE_NODE_ID:
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", "conversation_id"])
if not conversation_id:
if self.invoke_from != InvokeFrom.DEBUGGER:
raise ConversationIDNotFoundError
else:
conversation_id = conversation_id.value