diff --git a/api/core/workflow/entities/node_entities.py b/api/core/workflow/entities/node_entities.py index 70d40d87e9..82fd6cdc30 100644 --- a/api/core/workflow/entities/node_entities.py +++ b/api/core/workflow/entities/node_entities.py @@ -30,6 +30,7 @@ class NodeRunMetadataKey(StrEnum): ITERATION_DURATION_MAP = "iteration_duration_map" # single iteration duration if iteration node runs LOOP_DURATION_MAP = "loop_duration_map" # single loop duration if loop node runs ERROR_STRATEGY = "error_strategy" # node in continue on error mode return the field + LOOP_VARIABLE_MAP = "loop_variable_map" # single loop variable output class NodeRunResult(BaseModel): diff --git a/api/core/workflow/nodes/enums.py b/api/core/workflow/nodes/enums.py index d9a2c2d8a8..73b43eeaf7 100644 --- a/api/core/workflow/nodes/enums.py +++ b/api/core/workflow/nodes/enums.py @@ -17,6 +17,7 @@ class NodeType(StrEnum): LEGACY_VARIABLE_AGGREGATOR = "variable-assigner" # TODO: Merge this into VARIABLE_AGGREGATOR in the database. LOOP = "loop" LOOP_START = "loop-start" + LOOP_END = "loop-end" ITERATION = "iteration" ITERATION_START = "iteration-start" # Fake start node for iteration. PARAMETER_EXTRACTOR = "parameter-extractor" diff --git a/api/core/workflow/nodes/loop/__init__.py b/api/core/workflow/nodes/loop/__init__.py index 9dd33be0ad..9fe695607b 100644 --- a/api/core/workflow/nodes/loop/__init__.py +++ b/api/core/workflow/nodes/loop/__init__.py @@ -1,5 +1,6 @@ from .entities import LoopNodeData +from .loop_end_node import LoopEndNode from .loop_node import LoopNode from .loop_start_node import LoopStartNode -__all__ = ["LoopNode", "LoopNodeData", "LoopStartNode"] +__all__ = ["LoopEndNode", "LoopNode", "LoopNodeData", "LoopStartNode"] diff --git a/api/core/workflow/nodes/loop/entities.py b/api/core/workflow/nodes/loop/entities.py index 4f9c149bdf..16802311dc 100644 --- a/api/core/workflow/nodes/loop/entities.py +++ b/api/core/workflow/nodes/loop/entities.py @@ -1,11 +1,23 @@ +from collections.abc import Mapping from typing import Any, Literal, Optional -from pydantic import Field +from pydantic import BaseModel, Field from core.workflow.nodes.base import BaseLoopNodeData, BaseLoopState, BaseNodeData from core.workflow.utils.condition.entities import Condition +class LoopVariableData(BaseModel): + """ + Loop Variable Data. + """ + + label: str + var_type: Literal["string", "number", "object", "array[string]", "array[number]", "array[object]"] + value_type: Literal["variable", "constant"] + value: Optional[Any | list[str]] = None + + class LoopNodeData(BaseLoopNodeData): """ Loop Node Data. @@ -14,6 +26,8 @@ class LoopNodeData(BaseLoopNodeData): loop_count: int # Maximum number of loops break_conditions: list[Condition] # Conditions to break the loop logical_operator: Literal["and", "or"] + loop_variables: Optional[list[LoopVariableData]] = Field(default_factory=list) + outputs: Optional[Mapping[str, Any]] = None class LoopStartNodeData(BaseNodeData): @@ -24,6 +38,14 @@ class LoopStartNodeData(BaseNodeData): pass +class LoopEndNodeData(BaseNodeData): + """ + Loop End Node Data. + """ + + pass + + class LoopState(BaseLoopState): """ Loop State. diff --git a/api/core/workflow/nodes/loop/loop_end_node.py b/api/core/workflow/nodes/loop/loop_end_node.py new file mode 100644 index 0000000000..5d4ce0ccbe --- /dev/null +++ b/api/core/workflow/nodes/loop/loop_end_node.py @@ -0,0 +1,20 @@ +from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.nodes.base import BaseNode +from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.loop.entities import LoopEndNodeData +from models.workflow import WorkflowNodeExecutionStatus + + +class LoopEndNode(BaseNode[LoopEndNodeData]): + """ + Loop End Node. + """ + + _node_data_cls = LoopEndNodeData + _node_type = NodeType.LOOP_END + + def _run(self) -> NodeRunResult: + """ + Run the node. + """ + return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED) diff --git a/api/core/workflow/nodes/loop/loop_node.py b/api/core/workflow/nodes/loop/loop_node.py index 65acf1211f..eae33c0a92 100644 --- a/api/core/workflow/nodes/loop/loop_node.py +++ b/api/core/workflow/nodes/loop/loop_node.py @@ -1,10 +1,20 @@ +import json import logging from collections.abc import Generator, Mapping, Sequence from datetime import UTC, datetime -from typing import Any, cast +from typing import TYPE_CHECKING, Any, Literal, cast from configs import dify_config -from core.variables import IntegerSegment +from core.variables import ( + ArrayNumberSegment, + ArrayObjectSegment, + ArrayStringSegment, + IntegerSegment, + ObjectSegment, + Segment, + SegmentType, + StringSegment, +) from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult from core.workflow.graph_engine.entities.event import ( BaseGraphEvent, @@ -29,6 +39,10 @@ from core.workflow.nodes.loop.entities import LoopNodeData from core.workflow.utils.condition.processor import ConditionProcessor from models.workflow import WorkflowNodeExecutionStatus +if TYPE_CHECKING: + from core.workflow.entities.variable_pool import VariablePool + from core.workflow.graph_engine.graph_engine import GraphEngine + logger = logging.getLogger(__name__) @@ -61,6 +75,28 @@ class LoopNode(BaseNode[LoopNodeData]): variable_pool = self.graph_runtime_state.variable_pool variable_pool.add([self.node_id, "index"], 0) + # Initialize loop variables + loop_variable_selectors = {} + if self.node_data.loop_variables: + for loop_variable in self.node_data.loop_variables: + value_processor = { + "constant": lambda var=loop_variable: self._get_segment_for_constant(var.var_type, var.value), + "variable": lambda var=loop_variable: variable_pool.get(var.value), + } + + if loop_variable.value_type not in value_processor: + raise ValueError( + f"Invalid value type '{loop_variable.value_type}' for loop variable {loop_variable.label}" + ) + + processed_segment = value_processor[loop_variable.value_type]() + if not processed_segment: + raise ValueError(f"Invalid value for loop variable {loop_variable.label}") + variable_selector = [self.node_id, loop_variable.label] + variable_pool.add(variable_selector, processed_segment.value) + loop_variable_selectors[loop_variable.label] = variable_selector + inputs[loop_variable.label] = processed_segment.value + from core.workflow.graph_engine.graph_engine import GraphEngine graph_engine = GraphEngine( @@ -95,135 +131,51 @@ class LoopNode(BaseNode[LoopNodeData]): predecessor_node_id=self.previous_node_id, ) - yield LoopRunNextEvent( - loop_id=self.id, - loop_node_id=self.node_id, - loop_node_type=self.node_type, - loop_node_data=self.node_data, - index=0, - pre_loop_output=None, - ) - + # yield LoopRunNextEvent( + # loop_id=self.id, + # loop_node_id=self.node_id, + # loop_node_type=self.node_type, + # loop_node_data=self.node_data, + # index=0, + # pre_loop_output=None, + # ) + loop_duration_map = {} + single_loop_variable_map = {} # single loop variable output try: check_break_result = False for i in range(loop_count): - # Run workflow - rst = graph_engine.run() - current_index_variable = variable_pool.get([self.node_id, "index"]) - if not isinstance(current_index_variable, IntegerSegment): - raise ValueError(f"loop {self.node_id} current index not found") - current_index = current_index_variable.value + loop_start_time = datetime.now(UTC).replace(tzinfo=None) + # run single loop + loop_result = yield from self._run_single_loop( + graph_engine=graph_engine, + loop_graph=loop_graph, + variable_pool=variable_pool, + loop_variable_selectors=loop_variable_selectors, + break_conditions=break_conditions, + logical_operator=logical_operator, + condition_processor=condition_processor, + current_index=i, + start_at=start_at, + inputs=inputs, + ) + loop_end_time = datetime.now(UTC).replace(tzinfo=None) - check_break_result = False - - for event in rst: - if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_loop_id: - event.in_loop_id = self.node_id - - if ( - isinstance(event, BaseNodeEvent) - and event.node_type == NodeType.LOOP_START - and not isinstance(event, NodeRunStreamChunkEvent) - ): - continue - - if isinstance(event, NodeRunSucceededEvent): - yield self._handle_event_metadata(event=event, iter_run_index=current_index) - - # Check if all variables in break conditions exist - exists_variable = False - for condition in break_conditions: - if not self.graph_runtime_state.variable_pool.get(condition.variable_selector): - exists_variable = False - break - else: - exists_variable = True - if exists_variable: - input_conditions, group_result, check_break_result = condition_processor.process_conditions( - variable_pool=self.graph_runtime_state.variable_pool, - conditions=break_conditions, - operator=logical_operator, - ) - if check_break_result: - break - - elif isinstance(event, BaseGraphEvent): - if isinstance(event, GraphRunFailedEvent): - # Loop run failed - yield LoopRunFailedEvent( - loop_id=self.id, - loop_node_id=self.node_id, - loop_node_type=self.node_type, - loop_node_data=self.node_data, - start_at=start_at, - inputs=inputs, - steps=i, - metadata={ - NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens, - "completed_reason": "error", - }, - error=event.error, - ) - yield RunCompletedEvent( - run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=event.error, - metadata={ - NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens - }, - ) - ) - return - elif isinstance(event, NodeRunFailedEvent): - # Loop run failed - yield event - yield LoopRunFailedEvent( - loop_id=self.id, - loop_node_id=self.node_id, - loop_node_type=self.node_type, - loop_node_data=self.node_data, - start_at=start_at, - inputs=inputs, - steps=i, - metadata={ - NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens, - "completed_reason": "error", - }, - error=event.error, - ) - yield RunCompletedEvent( - run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=event.error, - metadata={ - NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens - }, - ) - ) - return + single_loop_variable = {} + for key, selector in loop_variable_selectors.items(): + item = variable_pool.get(selector) + if item: + single_loop_variable[key] = item.value else: - yield self._handle_event_metadata(event=cast(InNodeEvent, event), iter_run_index=current_index) + single_loop_variable[key] = None - # Remove all nodes outputs from variable pool - for node_id in loop_graph.node_ids: - variable_pool.remove([node_id]) + loop_duration_map[str(i)] = (loop_end_time - loop_start_time).total_seconds() + single_loop_variable_map[str(i)] = single_loop_variable + + check_break_result = loop_result.get("check_break_result", False) if check_break_result: break - # Move to next loop - next_index = current_index + 1 - variable_pool.add([self.node_id, "index"], next_index) - - yield LoopRunNextEvent( - loop_id=self.id, - loop_node_id=self.node_id, - loop_node_type=self.node_type, - loop_node_data=self.node_data, - index=next_index, - pre_loop_output=None, - ) - # Loop completed successfully yield LoopRunSucceededEvent( loop_id=self.id, @@ -232,17 +184,26 @@ class LoopNode(BaseNode[LoopNodeData]): loop_node_data=self.node_data, start_at=start_at, inputs=inputs, + outputs=self.node_data.outputs, steps=loop_count, metadata={ NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens, "completed_reason": "loop_break" if check_break_result else "loop_completed", + NodeRunMetadataKey.LOOP_DURATION_MAP: loop_duration_map, + NodeRunMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map, }, ) yield RunCompletedEvent( run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, - metadata={NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens}, + metadata={ + NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens, + NodeRunMetadataKey.LOOP_DURATION_MAP: loop_duration_map, + NodeRunMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map, + }, + outputs=self.node_data.outputs, + inputs=inputs, ) ) @@ -260,6 +221,8 @@ class LoopNode(BaseNode[LoopNodeData]): metadata={ "total_tokens": graph_engine.graph_runtime_state.total_tokens, "completed_reason": "error", + NodeRunMetadataKey.LOOP_DURATION_MAP: loop_duration_map, + NodeRunMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map, }, error=str(e), ) @@ -268,7 +231,11 @@ class LoopNode(BaseNode[LoopNodeData]): run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, error=str(e), - metadata={NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens}, + metadata={ + NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens, + NodeRunMetadataKey.LOOP_DURATION_MAP: loop_duration_map, + NodeRunMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map, + }, ) ) @@ -276,6 +243,159 @@ class LoopNode(BaseNode[LoopNodeData]): # Clean up variable_pool.remove([self.node_id, "index"]) + def _run_single_loop( + self, + *, + graph_engine: "GraphEngine", + loop_graph: Graph, + variable_pool: "VariablePool", + loop_variable_selectors: dict, + break_conditions: list, + logical_operator: Literal["and", "or"], + condition_processor: ConditionProcessor, + current_index: int, + start_at: datetime, + inputs: dict, + ) -> Generator[NodeEvent | InNodeEvent, None, dict]: + """Run a single loop iteration. + Returns: + dict: {'check_break_result': bool} + """ + # Run workflow + rst = graph_engine.run() + current_index_variable = variable_pool.get([self.node_id, "index"]) + if not isinstance(current_index_variable, IntegerSegment): + raise ValueError(f"loop {self.node_id} current index not found") + current_index = current_index_variable.value + + check_break_result = False + + for event in rst: + if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_loop_id: + event.in_loop_id = self.node_id + + if ( + isinstance(event, BaseNodeEvent) + and event.node_type == NodeType.LOOP_START + and not isinstance(event, NodeRunStreamChunkEvent) + ): + continue + + if ( + isinstance(event, NodeRunSucceededEvent) + and event.node_type == NodeType.LOOP_END + and not isinstance(event, NodeRunStreamChunkEvent) + ): + check_break_result = True + yield self._handle_event_metadata(event=event, iter_run_index=current_index) + break + + if isinstance(event, NodeRunSucceededEvent): + yield self._handle_event_metadata(event=event, iter_run_index=current_index) + + # Check if all variables in break conditions exist + exists_variable = False + for condition in break_conditions: + if not self.graph_runtime_state.variable_pool.get(condition.variable_selector): + exists_variable = False + break + else: + exists_variable = True + if exists_variable: + input_conditions, group_result, check_break_result = condition_processor.process_conditions( + variable_pool=self.graph_runtime_state.variable_pool, + conditions=break_conditions, + operator=logical_operator, + ) + if check_break_result: + break + + elif isinstance(event, BaseGraphEvent): + if isinstance(event, GraphRunFailedEvent): + # Loop run failed + yield LoopRunFailedEvent( + loop_id=self.id, + loop_node_id=self.node_id, + loop_node_type=self.node_type, + loop_node_data=self.node_data, + start_at=start_at, + inputs=inputs, + steps=current_index, + metadata={ + NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens, + "completed_reason": "error", + }, + error=event.error, + ) + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=event.error, + metadata={NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens}, + ) + ) + return {"check_break_result": True} + elif isinstance(event, NodeRunFailedEvent): + # Loop run failed + yield event + yield LoopRunFailedEvent( + loop_id=self.id, + loop_node_id=self.node_id, + loop_node_type=self.node_type, + loop_node_data=self.node_data, + start_at=start_at, + inputs=inputs, + steps=current_index, + metadata={ + NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens, + "completed_reason": "error", + }, + error=event.error, + ) + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=event.error, + metadata={NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens}, + ) + ) + return {"check_break_result": True} + else: + yield self._handle_event_metadata(event=cast(InNodeEvent, event), iter_run_index=current_index) + + # Remove all nodes outputs from variable pool + for node_id in loop_graph.node_ids: + variable_pool.remove([node_id]) + + _outputs = {} + for loop_variable_key, loop_variable_selector in loop_variable_selectors.items(): + _loop_variable_segment = variable_pool.get(loop_variable_selector) + if _loop_variable_segment: + _outputs[loop_variable_key] = _loop_variable_segment.value + else: + _outputs[loop_variable_key] = None + + _outputs["loop_round"] = current_index + 1 + self.node_data.outputs = _outputs + + if check_break_result: + return {"check_break_result": True} + + # Move to next loop + next_index = current_index + 1 + variable_pool.add([self.node_id, "index"], next_index) + + yield LoopRunNextEvent( + loop_id=self.id, + loop_node_id=self.node_id, + loop_node_type=self.node_type, + loop_node_data=self.node_data, + index=next_index, + pre_loop_output=self.node_data.outputs, + ) + + return {"check_break_result": False} + def _handle_event_metadata( self, *, @@ -360,3 +480,25 @@ class LoopNode(BaseNode[LoopNodeData]): } return variable_mapping + + @staticmethod + def _get_segment_for_constant(var_type: str, value: Any) -> Segment: + """Get the appropriate segment type for a constant value.""" + segment_mapping: dict[str, tuple[type[Segment], SegmentType]] = { + "string": (StringSegment, SegmentType.STRING), + "number": (IntegerSegment, SegmentType.NUMBER), + "object": (ObjectSegment, SegmentType.OBJECT), + "array[string]": (ArrayStringSegment, SegmentType.ARRAY_STRING), + "array[number]": (ArrayNumberSegment, SegmentType.ARRAY_NUMBER), + "array[object]": (ArrayObjectSegment, SegmentType.ARRAY_OBJECT), + } + if var_type in ["array[string]", "array[number]", "array[object]"]: + if value: + value = json.loads(value) + else: + value = [] + segment_info = segment_mapping.get(var_type) + if not segment_info: + raise ValueError(f"Invalid variable type: {var_type}") + segment_class, value_type = segment_info + return segment_class(value=value, value_type=value_type) diff --git a/api/core/workflow/nodes/node_mapping.py b/api/core/workflow/nodes/node_mapping.py index 63cd289760..1f1be59542 100644 --- a/api/core/workflow/nodes/node_mapping.py +++ b/api/core/workflow/nodes/node_mapping.py @@ -13,7 +13,7 @@ from core.workflow.nodes.iteration import IterationNode, IterationStartNode from core.workflow.nodes.knowledge_retrieval import KnowledgeRetrievalNode from core.workflow.nodes.list_operator import ListOperatorNode from core.workflow.nodes.llm import LLMNode -from core.workflow.nodes.loop import LoopNode, LoopStartNode +from core.workflow.nodes.loop import LoopEndNode, LoopNode, LoopStartNode from core.workflow.nodes.parameter_extractor import ParameterExtractorNode from core.workflow.nodes.question_classifier import QuestionClassifierNode from core.workflow.nodes.start import StartNode @@ -94,6 +94,10 @@ NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[BaseNode]]] = { LATEST_VERSION: LoopStartNode, "1": LoopStartNode, }, + NodeType.LOOP_END: { + LATEST_VERSION: LoopEndNode, + "1": LoopEndNode, + }, NodeType.PARAMETER_EXTRACTOR: { LATEST_VERSION: ParameterExtractorNode, "1": ParameterExtractorNode, diff --git a/api/core/workflow/nodes/variable_assigner/v2/node.py b/api/core/workflow/nodes/variable_assigner/v2/node.py index afa5656f46..0305eb7f41 100644 --- a/api/core/workflow/nodes/variable_assigner/v2/node.py +++ b/api/core/workflow/nodes/variable_assigner/v2/node.py @@ -2,6 +2,7 @@ import json from collections.abc import Sequence from typing import Any, cast +from core.app.entities.app_invoke_entities import InvokeFrom from core.variables import SegmentType, Variable from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID from core.workflow.entities.node_entities import NodeRunResult @@ -123,13 +124,14 @@ class VariableAssignerNode(BaseNode[VariableAssignerNodeData]): if variable.selector[0] == CONVERSATION_VARIABLE_NODE_ID: conversation_id = self.graph_runtime_state.variable_pool.get(["sys", "conversation_id"]) if not conversation_id: - raise ConversationIDNotFoundError + if self.invoke_from != InvokeFrom.DEBUGGER: + raise ConversationIDNotFoundError else: conversation_id = conversation_id.value - common_helpers.update_conversation_variable( - conversation_id=cast(str, conversation_id), - variable=variable, - ) + common_helpers.update_conversation_variable( + conversation_id=cast(str, conversation_id), + variable=variable, + ) return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED,