mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-15 14:15:52 +08:00
feat(api): add a version
class method to BaseNode and subclasses
This ensures that we can get the version of node while executing. Add `node_version` to `BaseNodeEvent` to ensure that all node related events includes node version information.
This commit is contained in:
parent
392ad17497
commit
25b4a96aed
@ -65,6 +65,8 @@ class BaseNodeEvent(GraphEngineEvent):
|
|||||||
"""iteration id if node is in iteration"""
|
"""iteration id if node is in iteration"""
|
||||||
in_loop_id: Optional[str] = None
|
in_loop_id: Optional[str] = None
|
||||||
"""loop id if node is in loop"""
|
"""loop id if node is in loop"""
|
||||||
|
# The version of the node, or "1" if not specified.
|
||||||
|
node_version: str = "1"
|
||||||
|
|
||||||
|
|
||||||
class NodeRunStartedEvent(BaseNodeEvent):
|
class NodeRunStartedEvent(BaseNodeEvent):
|
||||||
|
@ -313,6 +313,7 @@ class GraphEngine:
|
|||||||
parallel_start_node_id=parallel_start_node_id,
|
parallel_start_node_id=parallel_start_node_id,
|
||||||
parent_parallel_id=parent_parallel_id,
|
parent_parallel_id=parent_parallel_id,
|
||||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||||
|
node_version=node_instance.version(),
|
||||||
)
|
)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
@ -630,6 +631,7 @@ class GraphEngine:
|
|||||||
parent_parallel_id=parent_parallel_id,
|
parent_parallel_id=parent_parallel_id,
|
||||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||||
agent_strategy=agent_strategy,
|
agent_strategy=agent_strategy,
|
||||||
|
node_version=node_instance.version(),
|
||||||
)
|
)
|
||||||
|
|
||||||
db.session.close()
|
db.session.close()
|
||||||
@ -688,6 +690,7 @@ class GraphEngine:
|
|||||||
error=run_result.error or "Unknown error",
|
error=run_result.error or "Unknown error",
|
||||||
retry_index=retries,
|
retry_index=retries,
|
||||||
start_at=retry_start_at,
|
start_at=retry_start_at,
|
||||||
|
node_version=node_instance.version(),
|
||||||
)
|
)
|
||||||
time.sleep(retry_interval)
|
time.sleep(retry_interval)
|
||||||
break
|
break
|
||||||
@ -723,6 +726,7 @@ class GraphEngine:
|
|||||||
parallel_start_node_id=parallel_start_node_id,
|
parallel_start_node_id=parallel_start_node_id,
|
||||||
parent_parallel_id=parent_parallel_id,
|
parent_parallel_id=parent_parallel_id,
|
||||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||||
|
node_version=node_instance.version(),
|
||||||
)
|
)
|
||||||
should_continue_retry = False
|
should_continue_retry = False
|
||||||
else:
|
else:
|
||||||
@ -737,6 +741,7 @@ class GraphEngine:
|
|||||||
parallel_start_node_id=parallel_start_node_id,
|
parallel_start_node_id=parallel_start_node_id,
|
||||||
parent_parallel_id=parent_parallel_id,
|
parent_parallel_id=parent_parallel_id,
|
||||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||||
|
node_version=node_instance.version(),
|
||||||
)
|
)
|
||||||
should_continue_retry = False
|
should_continue_retry = False
|
||||||
elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
|
elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
|
||||||
@ -791,6 +796,7 @@ class GraphEngine:
|
|||||||
parallel_start_node_id=parallel_start_node_id,
|
parallel_start_node_id=parallel_start_node_id,
|
||||||
parent_parallel_id=parent_parallel_id,
|
parent_parallel_id=parent_parallel_id,
|
||||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||||
|
node_version=node_instance.version(),
|
||||||
)
|
)
|
||||||
should_continue_retry = False
|
should_continue_retry = False
|
||||||
|
|
||||||
@ -808,6 +814,7 @@ class GraphEngine:
|
|||||||
parallel_start_node_id=parallel_start_node_id,
|
parallel_start_node_id=parallel_start_node_id,
|
||||||
parent_parallel_id=parent_parallel_id,
|
parent_parallel_id=parent_parallel_id,
|
||||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||||
|
node_version=node_instance.version(),
|
||||||
)
|
)
|
||||||
elif isinstance(item, RunRetrieverResourceEvent):
|
elif isinstance(item, RunRetrieverResourceEvent):
|
||||||
yield NodeRunRetrieverResourceEvent(
|
yield NodeRunRetrieverResourceEvent(
|
||||||
@ -822,6 +829,7 @@ class GraphEngine:
|
|||||||
parallel_start_node_id=parallel_start_node_id,
|
parallel_start_node_id=parallel_start_node_id,
|
||||||
parent_parallel_id=parent_parallel_id,
|
parent_parallel_id=parent_parallel_id,
|
||||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||||
|
node_version=node_instance.version(),
|
||||||
)
|
)
|
||||||
except GenerateTaskStoppedError:
|
except GenerateTaskStoppedError:
|
||||||
# trigger node run failed event
|
# trigger node run failed event
|
||||||
@ -838,6 +846,7 @@ class GraphEngine:
|
|||||||
parallel_start_node_id=parallel_start_node_id,
|
parallel_start_node_id=parallel_start_node_id,
|
||||||
parent_parallel_id=parent_parallel_id,
|
parent_parallel_id=parent_parallel_id,
|
||||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||||
|
node_version=node_instance.version(),
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -18,7 +18,11 @@ from models.workflow import WorkflowNodeExecutionStatus
|
|||||||
|
|
||||||
class AnswerNode(BaseNode[AnswerNodeData]):
|
class AnswerNode(BaseNode[AnswerNodeData]):
|
||||||
_node_data_cls = AnswerNodeData
|
_node_data_cls = AnswerNodeData
|
||||||
_node_type: NodeType = NodeType.ANSWER
|
_node_type = NodeType.ANSWER
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def version(cls) -> str:
|
||||||
|
return "1"
|
||||||
|
|
||||||
def _run(self) -> NodeRunResult:
|
def _run(self) -> NodeRunResult:
|
||||||
"""
|
"""
|
||||||
|
@ -109,6 +109,7 @@ class AnswerStreamProcessor(StreamProcessor):
|
|||||||
parallel_id=event.parallel_id,
|
parallel_id=event.parallel_id,
|
||||||
parallel_start_node_id=event.parallel_start_node_id,
|
parallel_start_node_id=event.parallel_start_node_id,
|
||||||
from_variable_selector=[answer_node_id, "answer"],
|
from_variable_selector=[answer_node_id, "answer"],
|
||||||
|
node_version=event.node_version,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
route_chunk = cast(VarGenerateRouteChunk, route_chunk)
|
route_chunk = cast(VarGenerateRouteChunk, route_chunk)
|
||||||
@ -134,6 +135,7 @@ class AnswerStreamProcessor(StreamProcessor):
|
|||||||
route_node_state=event.route_node_state,
|
route_node_state=event.route_node_state,
|
||||||
parallel_id=event.parallel_id,
|
parallel_id=event.parallel_id,
|
||||||
parallel_start_node_id=event.parallel_start_node_id,
|
parallel_start_node_id=event.parallel_start_node_id,
|
||||||
|
node_version=event.node_version,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.route_position[answer_node_id] += 1
|
self.route_position[answer_node_id] += 1
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from collections.abc import Generator, Mapping, Sequence
|
from collections.abc import Generator, Mapping, Sequence
|
||||||
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, cast
|
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Optional, TypeVar, Union, cast
|
||||||
|
|
||||||
from core.workflow.entities.node_entities import NodeRunResult
|
from core.workflow.entities.node_entities import NodeRunResult
|
||||||
from core.workflow.nodes.enums import CONTINUE_ON_ERROR_NODE_TYPE, RETRY_ON_ERROR_NODE_TYPE, NodeType
|
from core.workflow.nodes.enums import CONTINUE_ON_ERROR_NODE_TYPE, RETRY_ON_ERROR_NODE_TYPE, NodeType
|
||||||
@ -23,7 +23,7 @@ GenericNodeData = TypeVar("GenericNodeData", bound=BaseNodeData)
|
|||||||
|
|
||||||
class BaseNode(Generic[GenericNodeData]):
|
class BaseNode(Generic[GenericNodeData]):
|
||||||
_node_data_cls: type[GenericNodeData]
|
_node_data_cls: type[GenericNodeData]
|
||||||
_node_type: NodeType
|
_node_type: ClassVar[NodeType]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -101,9 +101,10 @@ class BaseNode(Generic[GenericNodeData]):
|
|||||||
raise ValueError("Node ID is required when extracting variable selector to variable mapping.")
|
raise ValueError("Node ID is required when extracting variable selector to variable mapping.")
|
||||||
|
|
||||||
node_data = cls._node_data_cls(**config.get("data", {}))
|
node_data = cls._node_data_cls(**config.get("data", {}))
|
||||||
return cls._extract_variable_selector_to_variable_mapping(
|
data = cls._extract_variable_selector_to_variable_mapping(
|
||||||
graph_config=graph_config, node_id=node_id, node_data=cast(GenericNodeData, node_data)
|
graph_config=graph_config, node_id=node_id, node_data=cast(GenericNodeData, node_data)
|
||||||
)
|
)
|
||||||
|
return data
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _extract_variable_selector_to_variable_mapping(
|
def _extract_variable_selector_to_variable_mapping(
|
||||||
@ -139,6 +140,16 @@ class BaseNode(Generic[GenericNodeData]):
|
|||||||
"""
|
"""
|
||||||
return self._node_type
|
return self._node_type
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@abstractmethod
|
||||||
|
def version(cls) -> str:
|
||||||
|
"""`node_version` returns the version of current node type."""
|
||||||
|
# NOTE(QuantumGhost): This should be in sync with `NODE_TYPE_CLASSES_MAPPING`.
|
||||||
|
#
|
||||||
|
# If you have introduced a new node type, please add it to `NODE_TYPE_CLASSES_MAPPING`
|
||||||
|
# in `api/core/workflow/nodes/__init__.py`.
|
||||||
|
pass
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def should_continue_on_error(self) -> bool:
|
def should_continue_on_error(self) -> bool:
|
||||||
"""judge if should continue on error
|
"""judge if should continue on error
|
||||||
|
@ -40,6 +40,10 @@ class CodeNode(BaseNode[CodeNodeData]):
|
|||||||
|
|
||||||
return code_provider.get_default_config()
|
return code_provider.get_default_config()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def version(cls) -> str:
|
||||||
|
return "1"
|
||||||
|
|
||||||
def _run(self) -> NodeRunResult:
|
def _run(self) -> NodeRunResult:
|
||||||
# Get code language
|
# Get code language
|
||||||
code_language = self.node_data.code_language
|
code_language = self.node_data.code_language
|
||||||
|
@ -44,6 +44,10 @@ class DocumentExtractorNode(BaseNode[DocumentExtractorNodeData]):
|
|||||||
_node_data_cls = DocumentExtractorNodeData
|
_node_data_cls = DocumentExtractorNodeData
|
||||||
_node_type = NodeType.DOCUMENT_EXTRACTOR
|
_node_type = NodeType.DOCUMENT_EXTRACTOR
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def version(cls) -> str:
|
||||||
|
return "1"
|
||||||
|
|
||||||
def _run(self):
|
def _run(self):
|
||||||
variable_selector = self.node_data.variable_selector
|
variable_selector = self.node_data.variable_selector
|
||||||
variable = self.graph_runtime_state.variable_pool.get(variable_selector)
|
variable = self.graph_runtime_state.variable_pool.get(variable_selector)
|
||||||
|
@ -9,6 +9,10 @@ class EndNode(BaseNode[EndNodeData]):
|
|||||||
_node_data_cls = EndNodeData
|
_node_data_cls = EndNodeData
|
||||||
_node_type = NodeType.END
|
_node_type = NodeType.END
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def version(cls) -> str:
|
||||||
|
return "1"
|
||||||
|
|
||||||
def _run(self) -> NodeRunResult:
|
def _run(self) -> NodeRunResult:
|
||||||
"""
|
"""
|
||||||
Run node
|
Run node
|
||||||
|
@ -139,6 +139,7 @@ class EndStreamProcessor(StreamProcessor):
|
|||||||
route_node_state=event.route_node_state,
|
route_node_state=event.route_node_state,
|
||||||
parallel_id=event.parallel_id,
|
parallel_id=event.parallel_id,
|
||||||
parallel_start_node_id=event.parallel_start_node_id,
|
parallel_start_node_id=event.parallel_start_node_id,
|
||||||
|
node_version=event.node_version,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.route_position[end_node_id] += 1
|
self.route_position[end_node_id] += 1
|
||||||
|
@ -60,6 +60,10 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def version(cls) -> str:
|
||||||
|
return "1"
|
||||||
|
|
||||||
def _run(self) -> NodeRunResult:
|
def _run(self) -> NodeRunResult:
|
||||||
process_data = {}
|
process_data = {}
|
||||||
try:
|
try:
|
||||||
|
@ -16,6 +16,10 @@ class IfElseNode(BaseNode[IfElseNodeData]):
|
|||||||
_node_data_cls = IfElseNodeData
|
_node_data_cls = IfElseNodeData
|
||||||
_node_type = NodeType.IF_ELSE
|
_node_type = NodeType.IF_ELSE
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def version(cls) -> str:
|
||||||
|
return "1"
|
||||||
|
|
||||||
def _run(self) -> NodeRunResult:
|
def _run(self) -> NodeRunResult:
|
||||||
"""
|
"""
|
||||||
Run node
|
Run node
|
||||||
|
@ -72,6 +72,10 @@ class IterationNode(BaseNode[IterationNodeData]):
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def version(cls) -> str:
|
||||||
|
return "1"
|
||||||
|
|
||||||
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
|
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
|
||||||
"""
|
"""
|
||||||
Run the node.
|
Run the node.
|
||||||
|
@ -13,6 +13,10 @@ class IterationStartNode(BaseNode[IterationStartNodeData]):
|
|||||||
_node_data_cls = IterationStartNodeData
|
_node_data_cls = IterationStartNodeData
|
||||||
_node_type = NodeType.ITERATION_START
|
_node_type = NodeType.ITERATION_START
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def version(cls) -> str:
|
||||||
|
return "1"
|
||||||
|
|
||||||
def _run(self) -> NodeRunResult:
|
def _run(self) -> NodeRunResult:
|
||||||
"""
|
"""
|
||||||
Run the node.
|
Run the node.
|
||||||
|
@ -16,6 +16,10 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
|
|||||||
_node_data_cls = ListOperatorNodeData
|
_node_data_cls = ListOperatorNodeData
|
||||||
_node_type = NodeType.LIST_OPERATOR
|
_node_type = NodeType.LIST_OPERATOR
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def version(cls) -> str:
|
||||||
|
return "1"
|
||||||
|
|
||||||
def _run(self):
|
def _run(self):
|
||||||
inputs: dict[str, list] = {}
|
inputs: dict[str, list] = {}
|
||||||
process_data: dict[str, list] = {}
|
process_data: dict[str, list] = {}
|
||||||
|
@ -148,6 +148,10 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||||||
)
|
)
|
||||||
self._llm_file_saver = llm_file_saver
|
self._llm_file_saver = llm_file_saver
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def version(cls) -> str:
|
||||||
|
return "1"
|
||||||
|
|
||||||
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
|
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
|
||||||
def process_structured_output(text: str) -> Optional[dict[str, Any]]:
|
def process_structured_output(text: str) -> Optional[dict[str, Any]]:
|
||||||
"""Process structured output if enabled"""
|
"""Process structured output if enabled"""
|
||||||
|
@ -13,6 +13,10 @@ class LoopEndNode(BaseNode[LoopEndNodeData]):
|
|||||||
_node_data_cls = LoopEndNodeData
|
_node_data_cls = LoopEndNodeData
|
||||||
_node_type = NodeType.LOOP_END
|
_node_type = NodeType.LOOP_END
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def version(cls) -> str:
|
||||||
|
return "1"
|
||||||
|
|
||||||
def _run(self) -> NodeRunResult:
|
def _run(self) -> NodeRunResult:
|
||||||
"""
|
"""
|
||||||
Run the node.
|
Run the node.
|
||||||
|
@ -54,6 +54,10 @@ class LoopNode(BaseNode[LoopNodeData]):
|
|||||||
_node_data_cls = LoopNodeData
|
_node_data_cls = LoopNodeData
|
||||||
_node_type = NodeType.LOOP
|
_node_type = NodeType.LOOP
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def version(cls) -> str:
|
||||||
|
return "1"
|
||||||
|
|
||||||
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
|
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
|
||||||
"""Run the node."""
|
"""Run the node."""
|
||||||
# Get inputs
|
# Get inputs
|
||||||
|
@ -13,6 +13,10 @@ class LoopStartNode(BaseNode[LoopStartNodeData]):
|
|||||||
_node_data_cls = LoopStartNodeData
|
_node_data_cls = LoopStartNodeData
|
||||||
_node_type = NodeType.LOOP_START
|
_node_type = NodeType.LOOP_START
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def version(cls) -> str:
|
||||||
|
return "1"
|
||||||
|
|
||||||
def _run(self) -> NodeRunResult:
|
def _run(self) -> NodeRunResult:
|
||||||
"""
|
"""
|
||||||
Run the node.
|
Run the node.
|
||||||
|
@ -25,6 +25,11 @@ from core.workflow.nodes.variable_assigner.v2 import VariableAssignerNode as Var
|
|||||||
|
|
||||||
LATEST_VERSION = "latest"
|
LATEST_VERSION = "latest"
|
||||||
|
|
||||||
|
# NOTE(QuantumGhost): This should be in sync with subclasses of BaseNode.
|
||||||
|
# Specifically, if you have introduced new node types, you should add them here.
|
||||||
|
#
|
||||||
|
# TODO(QuantumGhost): This could be automated with either metaclass or `__init_subclass__`
|
||||||
|
# hook. Try to avoid duplication of node information.
|
||||||
NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[BaseNode]]] = {
|
NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[BaseNode]]] = {
|
||||||
NodeType.START: {
|
NodeType.START: {
|
||||||
LATEST_VERSION: StartNode,
|
LATEST_VERSION: StartNode,
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
from core.file.constants import add_dummy_output
|
||||||
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
||||||
from core.workflow.entities.node_entities import NodeRunResult
|
from core.workflow.entities.node_entities import NodeRunResult
|
||||||
from core.workflow.nodes.base import BaseNode
|
from core.workflow.nodes.base import BaseNode
|
||||||
@ -10,6 +11,10 @@ class StartNode(BaseNode[StartNodeData]):
|
|||||||
_node_data_cls = StartNodeData
|
_node_data_cls = StartNodeData
|
||||||
_node_type = NodeType.START
|
_node_type = NodeType.START
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def version(cls) -> str:
|
||||||
|
return "1"
|
||||||
|
|
||||||
def _run(self) -> NodeRunResult:
|
def _run(self) -> NodeRunResult:
|
||||||
node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs)
|
node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs)
|
||||||
system_inputs = self.graph_runtime_state.variable_pool.system_variables
|
system_inputs = self.graph_runtime_state.variable_pool.system_variables
|
||||||
@ -18,5 +23,9 @@ class StartNode(BaseNode[StartNodeData]):
|
|||||||
# Set system variables as node outputs.
|
# Set system variables as node outputs.
|
||||||
for var in system_inputs:
|
for var in system_inputs:
|
||||||
node_inputs[SYSTEM_VARIABLE_NODE_ID + "." + var] = system_inputs[var]
|
node_inputs[SYSTEM_VARIABLE_NODE_ID + "." + var] = system_inputs[var]
|
||||||
|
outputs = dict(node_inputs)
|
||||||
|
# Need special handling for `Start` node, as all other output variables
|
||||||
|
# are treated as systemd variables.
|
||||||
|
add_dummy_output(outputs)
|
||||||
|
|
||||||
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=node_inputs, outputs=node_inputs)
|
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=node_inputs, outputs=outputs)
|
||||||
|
@ -28,6 +28,10 @@ class TemplateTransformNode(BaseNode[TemplateTransformNodeData]):
|
|||||||
"config": {"variables": [{"variable": "arg1", "value_selector": []}], "template": "{{ arg1 }}"},
|
"config": {"variables": [{"variable": "arg1", "value_selector": []}], "template": "{{ arg1 }}"},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def version(cls) -> str:
|
||||||
|
return "1"
|
||||||
|
|
||||||
def _run(self) -> NodeRunResult:
|
def _run(self) -> NodeRunResult:
|
||||||
# Get variables
|
# Get variables
|
||||||
variables = {}
|
variables = {}
|
||||||
|
@ -44,6 +44,10 @@ class ToolNode(BaseNode[ToolNodeData]):
|
|||||||
_node_data_cls = ToolNodeData
|
_node_data_cls = ToolNodeData
|
||||||
_node_type = NodeType.TOOL
|
_node_type = NodeType.TOOL
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def version(cls) -> str:
|
||||||
|
return "1"
|
||||||
|
|
||||||
def _run(self) -> Generator:
|
def _run(self) -> Generator:
|
||||||
"""
|
"""
|
||||||
Run the tool node
|
Run the tool node
|
||||||
|
@ -9,6 +9,10 @@ class VariableAggregatorNode(BaseNode[VariableAssignerNodeData]):
|
|||||||
_node_data_cls = VariableAssignerNodeData
|
_node_data_cls = VariableAssignerNodeData
|
||||||
_node_type = NodeType.VARIABLE_AGGREGATOR
|
_node_type = NodeType.VARIABLE_AGGREGATOR
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def version(cls) -> str:
|
||||||
|
return "1"
|
||||||
|
|
||||||
def _run(self) -> NodeRunResult:
|
def _run(self) -> NodeRunResult:
|
||||||
# Get variables
|
# Get variables
|
||||||
outputs = {}
|
outputs = {}
|
||||||
|
@ -1,7 +1,11 @@
|
|||||||
|
from collections.abc import Sequence
|
||||||
|
from typing import Any, TypedDict
|
||||||
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from core.variables import Variable
|
from core.variables import Segment, SegmentType, Variable
|
||||||
|
from core.variables.consts import MIN_SELECTORS_LENGTH
|
||||||
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
|
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models import ConversationVariable
|
from models import ConversationVariable
|
||||||
@ -17,3 +21,22 @@ def update_conversation_variable(conversation_id: str, variable: Variable):
|
|||||||
raise VariableOperatorNodeError("conversation variable not found in the database")
|
raise VariableOperatorNodeError("conversation variable not found in the database")
|
||||||
row.data = variable.model_dump_json()
|
row.data = variable.model_dump_json()
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
|
|
||||||
|
class VariableOutput(TypedDict):
|
||||||
|
name: str
|
||||||
|
selector: Sequence[str]
|
||||||
|
new_value: Any
|
||||||
|
type: SegmentType
|
||||||
|
|
||||||
|
|
||||||
|
def variable_to_output_mapping(selector: Sequence[str], seg: Segment) -> VariableOutput:
|
||||||
|
if len(selector) < MIN_SELECTORS_LENGTH:
|
||||||
|
raise Exception("selector too short")
|
||||||
|
node_id, var_name = selector[:2]
|
||||||
|
return {
|
||||||
|
"name": var_name,
|
||||||
|
"selector": selector[:2],
|
||||||
|
"new_value": seg.value,
|
||||||
|
"type": seg.value_type,
|
||||||
|
}
|
||||||
|
@ -14,9 +14,14 @@ class VariableAssignerNode(BaseNode[VariableAssignerData]):
|
|||||||
_node_data_cls = VariableAssignerData
|
_node_data_cls = VariableAssignerData
|
||||||
_node_type = NodeType.VARIABLE_ASSIGNER
|
_node_type = NodeType.VARIABLE_ASSIGNER
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def version(cls) -> str:
|
||||||
|
return "1"
|
||||||
|
|
||||||
def _run(self) -> NodeRunResult:
|
def _run(self) -> NodeRunResult:
|
||||||
|
assigned_variable_selector = self.node_data.assigned_variable_selector
|
||||||
# Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject
|
# Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject
|
||||||
original_variable = self.graph_runtime_state.variable_pool.get(self.node_data.assigned_variable_selector)
|
original_variable = self.graph_runtime_state.variable_pool.get(assigned_variable_selector)
|
||||||
if not isinstance(original_variable, Variable):
|
if not isinstance(original_variable, Variable):
|
||||||
raise VariableOperatorNodeError("assigned variable not found")
|
raise VariableOperatorNodeError("assigned variable not found")
|
||||||
|
|
||||||
@ -44,7 +49,7 @@ class VariableAssignerNode(BaseNode[VariableAssignerData]):
|
|||||||
raise VariableOperatorNodeError(f"unsupported write mode: {self.node_data.write_mode}")
|
raise VariableOperatorNodeError(f"unsupported write mode: {self.node_data.write_mode}")
|
||||||
|
|
||||||
# Over write the variable.
|
# Over write the variable.
|
||||||
self.graph_runtime_state.variable_pool.add(self.node_data.assigned_variable_selector, updated_variable)
|
self.graph_runtime_state.variable_pool.add(assigned_variable_selector, updated_variable)
|
||||||
|
|
||||||
# TODO: Move database operation to the pipeline.
|
# TODO: Move database operation to the pipeline.
|
||||||
# Update conversation variable.
|
# Update conversation variable.
|
||||||
@ -58,6 +63,14 @@ class VariableAssignerNode(BaseNode[VariableAssignerData]):
|
|||||||
inputs={
|
inputs={
|
||||||
"value": income_value.to_object(),
|
"value": income_value.to_object(),
|
||||||
},
|
},
|
||||||
|
outputs={
|
||||||
|
# NOTE(QuantumGhost): although only one variable is updated in `v1.VariableAssignerNode`,
|
||||||
|
# we still set `output_variables` as a list to ensure the schema of output is
|
||||||
|
# compatible with `v2.VariableAssignerNode`.
|
||||||
|
"updated_variables": [
|
||||||
|
common_helpers.variable_to_output_mapping(assigned_variable_selector, updated_variable)
|
||||||
|
]
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -29,6 +29,10 @@ class VariableAssignerNode(BaseNode[VariableAssignerNodeData]):
|
|||||||
_node_data_cls = VariableAssignerNodeData
|
_node_data_cls = VariableAssignerNodeData
|
||||||
_node_type = NodeType.VARIABLE_ASSIGNER
|
_node_type = NodeType.VARIABLE_ASSIGNER
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def version(cls) -> str:
|
||||||
|
return "2"
|
||||||
|
|
||||||
def _run(self) -> NodeRunResult:
|
def _run(self) -> NodeRunResult:
|
||||||
inputs = self.node_data.model_dump()
|
inputs = self.node_data.model_dump()
|
||||||
process_data: dict[str, Any] = {}
|
process_data: dict[str, Any] = {}
|
||||||
@ -137,6 +141,13 @@ class VariableAssignerNode(BaseNode[VariableAssignerNodeData]):
|
|||||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
process_data=process_data,
|
process_data=process_data,
|
||||||
|
outputs={
|
||||||
|
"updated_variables": [
|
||||||
|
common_helpers.variable_to_output_mapping(selector, seg)
|
||||||
|
for selector in updated_variable_selectors
|
||||||
|
if (seg := self.graph_runtime_state.variable_pool.get(selector)) is not None
|
||||||
|
],
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def _handle_item(
|
def _handle_item(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user