mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-13 00:39:04 +08:00
Feat/loop node (#17273)
This commit is contained in:
parent
11e95d2a61
commit
8c77f2dc03
@ -30,6 +30,7 @@ class NodeRunMetadataKey(StrEnum):
|
||||
ITERATION_DURATION_MAP = "iteration_duration_map" # single iteration duration if iteration node runs
|
||||
LOOP_DURATION_MAP = "loop_duration_map" # single loop duration if loop node runs
|
||||
ERROR_STRATEGY = "error_strategy" # node in continue on error mode return the field
|
||||
LOOP_VARIABLE_MAP = "loop_variable_map" # single loop variable output
|
||||
|
||||
|
||||
class NodeRunResult(BaseModel):
|
||||
|
@ -17,6 +17,7 @@ class NodeType(StrEnum):
|
||||
LEGACY_VARIABLE_AGGREGATOR = "variable-assigner" # TODO: Merge this into VARIABLE_AGGREGATOR in the database.
|
||||
LOOP = "loop"
|
||||
LOOP_START = "loop-start"
|
||||
LOOP_END = "loop-end"
|
||||
ITERATION = "iteration"
|
||||
ITERATION_START = "iteration-start" # Fake start node for iteration.
|
||||
PARAMETER_EXTRACTOR = "parameter-extractor"
|
||||
|
@ -1,5 +1,6 @@
|
||||
from .entities import LoopNodeData
|
||||
from .loop_end_node import LoopEndNode
|
||||
from .loop_node import LoopNode
|
||||
from .loop_start_node import LoopStartNode
|
||||
|
||||
__all__ = ["LoopNode", "LoopNodeData", "LoopStartNode"]
|
||||
__all__ = ["LoopEndNode", "LoopNode", "LoopNodeData", "LoopStartNode"]
|
||||
|
@ -1,11 +1,23 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.nodes.base import BaseLoopNodeData, BaseLoopState, BaseNodeData
|
||||
from core.workflow.utils.condition.entities import Condition
|
||||
|
||||
|
||||
class LoopVariableData(BaseModel):
|
||||
"""
|
||||
Loop Variable Data.
|
||||
"""
|
||||
|
||||
label: str
|
||||
var_type: Literal["string", "number", "object", "array[string]", "array[number]", "array[object]"]
|
||||
value_type: Literal["variable", "constant"]
|
||||
value: Optional[Any | list[str]] = None
|
||||
|
||||
|
||||
class LoopNodeData(BaseLoopNodeData):
|
||||
"""
|
||||
Loop Node Data.
|
||||
@ -14,6 +26,8 @@ class LoopNodeData(BaseLoopNodeData):
|
||||
loop_count: int # Maximum number of loops
|
||||
break_conditions: list[Condition] # Conditions to break the loop
|
||||
logical_operator: Literal["and", "or"]
|
||||
loop_variables: Optional[list[LoopVariableData]] = Field(default_factory=list)
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
|
||||
|
||||
class LoopStartNodeData(BaseNodeData):
|
||||
@ -24,6 +38,14 @@ class LoopStartNodeData(BaseNodeData):
|
||||
pass
|
||||
|
||||
|
||||
class LoopEndNodeData(BaseNodeData):
|
||||
"""
|
||||
Loop End Node Data.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class LoopState(BaseLoopState):
|
||||
"""
|
||||
Loop State.
|
||||
|
20
api/core/workflow/nodes/loop/loop_end_node.py
Normal file
20
api/core/workflow/nodes/loop/loop_end_node.py
Normal file
@ -0,0 +1,20 @@
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.loop.entities import LoopEndNodeData
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class LoopEndNode(BaseNode[LoopEndNodeData]):
|
||||
"""
|
||||
Loop End Node.
|
||||
"""
|
||||
|
||||
_node_data_cls = LoopEndNodeData
|
||||
_node_type = NodeType.LOOP_END
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
"""
|
||||
Run the node.
|
||||
"""
|
||||
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED)
|
@ -1,10 +1,20 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, cast
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
|
||||
from configs import dify_config
|
||||
from core.variables import IntegerSegment
|
||||
from core.variables import (
|
||||
ArrayNumberSegment,
|
||||
ArrayObjectSegment,
|
||||
ArrayStringSegment,
|
||||
IntegerSegment,
|
||||
ObjectSegment,
|
||||
Segment,
|
||||
SegmentType,
|
||||
StringSegment,
|
||||
)
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
BaseGraphEvent,
|
||||
@ -29,6 +39,10 @@ from core.workflow.nodes.loop.entities import LoopNodeData
|
||||
from core.workflow.utils.condition.processor import ConditionProcessor
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.graph_engine import GraphEngine
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -61,6 +75,28 @@ class LoopNode(BaseNode[LoopNodeData]):
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
variable_pool.add([self.node_id, "index"], 0)
|
||||
|
||||
# Initialize loop variables
|
||||
loop_variable_selectors = {}
|
||||
if self.node_data.loop_variables:
|
||||
for loop_variable in self.node_data.loop_variables:
|
||||
value_processor = {
|
||||
"constant": lambda var=loop_variable: self._get_segment_for_constant(var.var_type, var.value),
|
||||
"variable": lambda var=loop_variable: variable_pool.get(var.value),
|
||||
}
|
||||
|
||||
if loop_variable.value_type not in value_processor:
|
||||
raise ValueError(
|
||||
f"Invalid value type '{loop_variable.value_type}' for loop variable {loop_variable.label}"
|
||||
)
|
||||
|
||||
processed_segment = value_processor[loop_variable.value_type]()
|
||||
if not processed_segment:
|
||||
raise ValueError(f"Invalid value for loop variable {loop_variable.label}")
|
||||
variable_selector = [self.node_id, loop_variable.label]
|
||||
variable_pool.add(variable_selector, processed_segment.value)
|
||||
loop_variable_selectors[loop_variable.label] = variable_selector
|
||||
inputs[loop_variable.label] = processed_segment.value
|
||||
|
||||
from core.workflow.graph_engine.graph_engine import GraphEngine
|
||||
|
||||
graph_engine = GraphEngine(
|
||||
@ -95,135 +131,51 @@ class LoopNode(BaseNode[LoopNodeData]):
|
||||
predecessor_node_id=self.previous_node_id,
|
||||
)
|
||||
|
||||
yield LoopRunNextEvent(
|
||||
loop_id=self.id,
|
||||
loop_node_id=self.node_id,
|
||||
loop_node_type=self.node_type,
|
||||
loop_node_data=self.node_data,
|
||||
index=0,
|
||||
pre_loop_output=None,
|
||||
)
|
||||
|
||||
# yield LoopRunNextEvent(
|
||||
# loop_id=self.id,
|
||||
# loop_node_id=self.node_id,
|
||||
# loop_node_type=self.node_type,
|
||||
# loop_node_data=self.node_data,
|
||||
# index=0,
|
||||
# pre_loop_output=None,
|
||||
# )
|
||||
loop_duration_map = {}
|
||||
single_loop_variable_map = {} # single loop variable output
|
||||
try:
|
||||
check_break_result = False
|
||||
for i in range(loop_count):
|
||||
# Run workflow
|
||||
rst = graph_engine.run()
|
||||
current_index_variable = variable_pool.get([self.node_id, "index"])
|
||||
if not isinstance(current_index_variable, IntegerSegment):
|
||||
raise ValueError(f"loop {self.node_id} current index not found")
|
||||
current_index = current_index_variable.value
|
||||
loop_start_time = datetime.now(UTC).replace(tzinfo=None)
|
||||
# run single loop
|
||||
loop_result = yield from self._run_single_loop(
|
||||
graph_engine=graph_engine,
|
||||
loop_graph=loop_graph,
|
||||
variable_pool=variable_pool,
|
||||
loop_variable_selectors=loop_variable_selectors,
|
||||
break_conditions=break_conditions,
|
||||
logical_operator=logical_operator,
|
||||
condition_processor=condition_processor,
|
||||
current_index=i,
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
)
|
||||
loop_end_time = datetime.now(UTC).replace(tzinfo=None)
|
||||
|
||||
check_break_result = False
|
||||
|
||||
for event in rst:
|
||||
if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_loop_id:
|
||||
event.in_loop_id = self.node_id
|
||||
|
||||
if (
|
||||
isinstance(event, BaseNodeEvent)
|
||||
and event.node_type == NodeType.LOOP_START
|
||||
and not isinstance(event, NodeRunStreamChunkEvent)
|
||||
):
|
||||
continue
|
||||
|
||||
if isinstance(event, NodeRunSucceededEvent):
|
||||
yield self._handle_event_metadata(event=event, iter_run_index=current_index)
|
||||
|
||||
# Check if all variables in break conditions exist
|
||||
exists_variable = False
|
||||
for condition in break_conditions:
|
||||
if not self.graph_runtime_state.variable_pool.get(condition.variable_selector):
|
||||
exists_variable = False
|
||||
break
|
||||
else:
|
||||
exists_variable = True
|
||||
if exists_variable:
|
||||
input_conditions, group_result, check_break_result = condition_processor.process_conditions(
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
conditions=break_conditions,
|
||||
operator=logical_operator,
|
||||
)
|
||||
if check_break_result:
|
||||
break
|
||||
|
||||
elif isinstance(event, BaseGraphEvent):
|
||||
if isinstance(event, GraphRunFailedEvent):
|
||||
# Loop run failed
|
||||
yield LoopRunFailedEvent(
|
||||
loop_id=self.id,
|
||||
loop_node_id=self.node_id,
|
||||
loop_node_type=self.node_type,
|
||||
loop_node_data=self.node_data,
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
steps=i,
|
||||
metadata={
|
||||
NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
|
||||
"completed_reason": "error",
|
||||
},
|
||||
error=event.error,
|
||||
)
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=event.error,
|
||||
metadata={
|
||||
NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens
|
||||
},
|
||||
)
|
||||
)
|
||||
return
|
||||
elif isinstance(event, NodeRunFailedEvent):
|
||||
# Loop run failed
|
||||
yield event
|
||||
yield LoopRunFailedEvent(
|
||||
loop_id=self.id,
|
||||
loop_node_id=self.node_id,
|
||||
loop_node_type=self.node_type,
|
||||
loop_node_data=self.node_data,
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
steps=i,
|
||||
metadata={
|
||||
NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
|
||||
"completed_reason": "error",
|
||||
},
|
||||
error=event.error,
|
||||
)
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=event.error,
|
||||
metadata={
|
||||
NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens
|
||||
},
|
||||
)
|
||||
)
|
||||
return
|
||||
single_loop_variable = {}
|
||||
for key, selector in loop_variable_selectors.items():
|
||||
item = variable_pool.get(selector)
|
||||
if item:
|
||||
single_loop_variable[key] = item.value
|
||||
else:
|
||||
yield self._handle_event_metadata(event=cast(InNodeEvent, event), iter_run_index=current_index)
|
||||
single_loop_variable[key] = None
|
||||
|
||||
# Remove all nodes outputs from variable pool
|
||||
for node_id in loop_graph.node_ids:
|
||||
variable_pool.remove([node_id])
|
||||
loop_duration_map[str(i)] = (loop_end_time - loop_start_time).total_seconds()
|
||||
single_loop_variable_map[str(i)] = single_loop_variable
|
||||
|
||||
check_break_result = loop_result.get("check_break_result", False)
|
||||
|
||||
if check_break_result:
|
||||
break
|
||||
|
||||
# Move to next loop
|
||||
next_index = current_index + 1
|
||||
variable_pool.add([self.node_id, "index"], next_index)
|
||||
|
||||
yield LoopRunNextEvent(
|
||||
loop_id=self.id,
|
||||
loop_node_id=self.node_id,
|
||||
loop_node_type=self.node_type,
|
||||
loop_node_data=self.node_data,
|
||||
index=next_index,
|
||||
pre_loop_output=None,
|
||||
)
|
||||
|
||||
# Loop completed successfully
|
||||
yield LoopRunSucceededEvent(
|
||||
loop_id=self.id,
|
||||
@ -232,17 +184,26 @@ class LoopNode(BaseNode[LoopNodeData]):
|
||||
loop_node_data=self.node_data,
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
outputs=self.node_data.outputs,
|
||||
steps=loop_count,
|
||||
metadata={
|
||||
NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
|
||||
"completed_reason": "loop_break" if check_break_result else "loop_completed",
|
||||
NodeRunMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
|
||||
NodeRunMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
|
||||
},
|
||||
)
|
||||
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
metadata={NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens},
|
||||
metadata={
|
||||
NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
|
||||
NodeRunMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
|
||||
NodeRunMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
|
||||
},
|
||||
outputs=self.node_data.outputs,
|
||||
inputs=inputs,
|
||||
)
|
||||
)
|
||||
|
||||
@ -260,6 +221,8 @@ class LoopNode(BaseNode[LoopNodeData]):
|
||||
metadata={
|
||||
"total_tokens": graph_engine.graph_runtime_state.total_tokens,
|
||||
"completed_reason": "error",
|
||||
NodeRunMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
|
||||
NodeRunMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
|
||||
},
|
||||
error=str(e),
|
||||
)
|
||||
@ -268,7 +231,11 @@ class LoopNode(BaseNode[LoopNodeData]):
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=str(e),
|
||||
metadata={NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens},
|
||||
metadata={
|
||||
NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
|
||||
NodeRunMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
|
||||
NodeRunMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
@ -276,6 +243,159 @@ class LoopNode(BaseNode[LoopNodeData]):
|
||||
# Clean up
|
||||
variable_pool.remove([self.node_id, "index"])
|
||||
|
||||
def _run_single_loop(
|
||||
self,
|
||||
*,
|
||||
graph_engine: "GraphEngine",
|
||||
loop_graph: Graph,
|
||||
variable_pool: "VariablePool",
|
||||
loop_variable_selectors: dict,
|
||||
break_conditions: list,
|
||||
logical_operator: Literal["and", "or"],
|
||||
condition_processor: ConditionProcessor,
|
||||
current_index: int,
|
||||
start_at: datetime,
|
||||
inputs: dict,
|
||||
) -> Generator[NodeEvent | InNodeEvent, None, dict]:
|
||||
"""Run a single loop iteration.
|
||||
Returns:
|
||||
dict: {'check_break_result': bool}
|
||||
"""
|
||||
# Run workflow
|
||||
rst = graph_engine.run()
|
||||
current_index_variable = variable_pool.get([self.node_id, "index"])
|
||||
if not isinstance(current_index_variable, IntegerSegment):
|
||||
raise ValueError(f"loop {self.node_id} current index not found")
|
||||
current_index = current_index_variable.value
|
||||
|
||||
check_break_result = False
|
||||
|
||||
for event in rst:
|
||||
if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_loop_id:
|
||||
event.in_loop_id = self.node_id
|
||||
|
||||
if (
|
||||
isinstance(event, BaseNodeEvent)
|
||||
and event.node_type == NodeType.LOOP_START
|
||||
and not isinstance(event, NodeRunStreamChunkEvent)
|
||||
):
|
||||
continue
|
||||
|
||||
if (
|
||||
isinstance(event, NodeRunSucceededEvent)
|
||||
and event.node_type == NodeType.LOOP_END
|
||||
and not isinstance(event, NodeRunStreamChunkEvent)
|
||||
):
|
||||
check_break_result = True
|
||||
yield self._handle_event_metadata(event=event, iter_run_index=current_index)
|
||||
break
|
||||
|
||||
if isinstance(event, NodeRunSucceededEvent):
|
||||
yield self._handle_event_metadata(event=event, iter_run_index=current_index)
|
||||
|
||||
# Check if all variables in break conditions exist
|
||||
exists_variable = False
|
||||
for condition in break_conditions:
|
||||
if not self.graph_runtime_state.variable_pool.get(condition.variable_selector):
|
||||
exists_variable = False
|
||||
break
|
||||
else:
|
||||
exists_variable = True
|
||||
if exists_variable:
|
||||
input_conditions, group_result, check_break_result = condition_processor.process_conditions(
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
conditions=break_conditions,
|
||||
operator=logical_operator,
|
||||
)
|
||||
if check_break_result:
|
||||
break
|
||||
|
||||
elif isinstance(event, BaseGraphEvent):
|
||||
if isinstance(event, GraphRunFailedEvent):
|
||||
# Loop run failed
|
||||
yield LoopRunFailedEvent(
|
||||
loop_id=self.id,
|
||||
loop_node_id=self.node_id,
|
||||
loop_node_type=self.node_type,
|
||||
loop_node_data=self.node_data,
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
steps=current_index,
|
||||
metadata={
|
||||
NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
|
||||
"completed_reason": "error",
|
||||
},
|
||||
error=event.error,
|
||||
)
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=event.error,
|
||||
metadata={NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens},
|
||||
)
|
||||
)
|
||||
return {"check_break_result": True}
|
||||
elif isinstance(event, NodeRunFailedEvent):
|
||||
# Loop run failed
|
||||
yield event
|
||||
yield LoopRunFailedEvent(
|
||||
loop_id=self.id,
|
||||
loop_node_id=self.node_id,
|
||||
loop_node_type=self.node_type,
|
||||
loop_node_data=self.node_data,
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
steps=current_index,
|
||||
metadata={
|
||||
NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
|
||||
"completed_reason": "error",
|
||||
},
|
||||
error=event.error,
|
||||
)
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=event.error,
|
||||
metadata={NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens},
|
||||
)
|
||||
)
|
||||
return {"check_break_result": True}
|
||||
else:
|
||||
yield self._handle_event_metadata(event=cast(InNodeEvent, event), iter_run_index=current_index)
|
||||
|
||||
# Remove all nodes outputs from variable pool
|
||||
for node_id in loop_graph.node_ids:
|
||||
variable_pool.remove([node_id])
|
||||
|
||||
_outputs = {}
|
||||
for loop_variable_key, loop_variable_selector in loop_variable_selectors.items():
|
||||
_loop_variable_segment = variable_pool.get(loop_variable_selector)
|
||||
if _loop_variable_segment:
|
||||
_outputs[loop_variable_key] = _loop_variable_segment.value
|
||||
else:
|
||||
_outputs[loop_variable_key] = None
|
||||
|
||||
_outputs["loop_round"] = current_index + 1
|
||||
self.node_data.outputs = _outputs
|
||||
|
||||
if check_break_result:
|
||||
return {"check_break_result": True}
|
||||
|
||||
# Move to next loop
|
||||
next_index = current_index + 1
|
||||
variable_pool.add([self.node_id, "index"], next_index)
|
||||
|
||||
yield LoopRunNextEvent(
|
||||
loop_id=self.id,
|
||||
loop_node_id=self.node_id,
|
||||
loop_node_type=self.node_type,
|
||||
loop_node_data=self.node_data,
|
||||
index=next_index,
|
||||
pre_loop_output=self.node_data.outputs,
|
||||
)
|
||||
|
||||
return {"check_break_result": False}
|
||||
|
||||
def _handle_event_metadata(
|
||||
self,
|
||||
*,
|
||||
@ -360,3 +480,25 @@ class LoopNode(BaseNode[LoopNodeData]):
|
||||
}
|
||||
|
||||
return variable_mapping
|
||||
|
||||
@staticmethod
|
||||
def _get_segment_for_constant(var_type: str, value: Any) -> Segment:
|
||||
"""Get the appropriate segment type for a constant value."""
|
||||
segment_mapping: dict[str, tuple[type[Segment], SegmentType]] = {
|
||||
"string": (StringSegment, SegmentType.STRING),
|
||||
"number": (IntegerSegment, SegmentType.NUMBER),
|
||||
"object": (ObjectSegment, SegmentType.OBJECT),
|
||||
"array[string]": (ArrayStringSegment, SegmentType.ARRAY_STRING),
|
||||
"array[number]": (ArrayNumberSegment, SegmentType.ARRAY_NUMBER),
|
||||
"array[object]": (ArrayObjectSegment, SegmentType.ARRAY_OBJECT),
|
||||
}
|
||||
if var_type in ["array[string]", "array[number]", "array[object]"]:
|
||||
if value:
|
||||
value = json.loads(value)
|
||||
else:
|
||||
value = []
|
||||
segment_info = segment_mapping.get(var_type)
|
||||
if not segment_info:
|
||||
raise ValueError(f"Invalid variable type: {var_type}")
|
||||
segment_class, value_type = segment_info
|
||||
return segment_class(value=value, value_type=value_type)
|
||||
|
@ -13,7 +13,7 @@ from core.workflow.nodes.iteration import IterationNode, IterationStartNode
|
||||
from core.workflow.nodes.knowledge_retrieval import KnowledgeRetrievalNode
|
||||
from core.workflow.nodes.list_operator import ListOperatorNode
|
||||
from core.workflow.nodes.llm import LLMNode
|
||||
from core.workflow.nodes.loop import LoopNode, LoopStartNode
|
||||
from core.workflow.nodes.loop import LoopEndNode, LoopNode, LoopStartNode
|
||||
from core.workflow.nodes.parameter_extractor import ParameterExtractorNode
|
||||
from core.workflow.nodes.question_classifier import QuestionClassifierNode
|
||||
from core.workflow.nodes.start import StartNode
|
||||
@ -94,6 +94,10 @@ NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[BaseNode]]] = {
|
||||
LATEST_VERSION: LoopStartNode,
|
||||
"1": LoopStartNode,
|
||||
},
|
||||
NodeType.LOOP_END: {
|
||||
LATEST_VERSION: LoopEndNode,
|
||||
"1": LoopEndNode,
|
||||
},
|
||||
NodeType.PARAMETER_EXTRACTOR: {
|
||||
LATEST_VERSION: ParameterExtractorNode,
|
||||
"1": ParameterExtractorNode,
|
||||
|
@ -2,6 +2,7 @@ import json
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.variables import SegmentType, Variable
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
@ -123,13 +124,14 @@ class VariableAssignerNode(BaseNode[VariableAssignerNodeData]):
|
||||
if variable.selector[0] == CONVERSATION_VARIABLE_NODE_ID:
|
||||
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", "conversation_id"])
|
||||
if not conversation_id:
|
||||
raise ConversationIDNotFoundError
|
||||
if self.invoke_from != InvokeFrom.DEBUGGER:
|
||||
raise ConversationIDNotFoundError
|
||||
else:
|
||||
conversation_id = conversation_id.value
|
||||
common_helpers.update_conversation_variable(
|
||||
conversation_id=cast(str, conversation_id),
|
||||
variable=variable,
|
||||
)
|
||||
common_helpers.update_conversation_variable(
|
||||
conversation_id=cast(str, conversation_id),
|
||||
variable=variable,
|
||||
)
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
|
Loading…
x
Reference in New Issue
Block a user