feat(api): add a version class method to BaseNode and subclasses

This ensures that we can get the version of node while executing.

Add `node_version` to `BaseNodeEvent` to ensure that all node
related events includes node version information.
This commit is contained in:
QuantumGhost 2025-05-22 17:28:57 +08:00
parent 392ad17497
commit 25b4a96aed
26 changed files with 158 additions and 8 deletions

View File

@ -65,6 +65,8 @@ class BaseNodeEvent(GraphEngineEvent):
"""iteration id if node is in iteration""" """iteration id if node is in iteration"""
in_loop_id: Optional[str] = None in_loop_id: Optional[str] = None
"""loop id if node is in loop""" """loop id if node is in loop"""
# The version of the node, or "1" if not specified.
node_version: str = "1"
class NodeRunStartedEvent(BaseNodeEvent): class NodeRunStartedEvent(BaseNodeEvent):

View File

@ -313,6 +313,7 @@ class GraphEngine:
parallel_start_node_id=parallel_start_node_id, parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id, parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id, parent_parallel_start_node_id=parent_parallel_start_node_id,
node_version=node_instance.version(),
) )
raise e raise e
@ -630,6 +631,7 @@ class GraphEngine:
parent_parallel_id=parent_parallel_id, parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id, parent_parallel_start_node_id=parent_parallel_start_node_id,
agent_strategy=agent_strategy, agent_strategy=agent_strategy,
node_version=node_instance.version(),
) )
db.session.close() db.session.close()
@ -688,6 +690,7 @@ class GraphEngine:
error=run_result.error or "Unknown error", error=run_result.error or "Unknown error",
retry_index=retries, retry_index=retries,
start_at=retry_start_at, start_at=retry_start_at,
node_version=node_instance.version(),
) )
time.sleep(retry_interval) time.sleep(retry_interval)
break break
@ -723,6 +726,7 @@ class GraphEngine:
parallel_start_node_id=parallel_start_node_id, parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id, parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id, parent_parallel_start_node_id=parent_parallel_start_node_id,
node_version=node_instance.version(),
) )
should_continue_retry = False should_continue_retry = False
else: else:
@ -737,6 +741,7 @@ class GraphEngine:
parallel_start_node_id=parallel_start_node_id, parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id, parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id, parent_parallel_start_node_id=parent_parallel_start_node_id,
node_version=node_instance.version(),
) )
should_continue_retry = False should_continue_retry = False
elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED: elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
@ -791,6 +796,7 @@ class GraphEngine:
parallel_start_node_id=parallel_start_node_id, parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id, parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id, parent_parallel_start_node_id=parent_parallel_start_node_id,
node_version=node_instance.version(),
) )
should_continue_retry = False should_continue_retry = False
@ -808,6 +814,7 @@ class GraphEngine:
parallel_start_node_id=parallel_start_node_id, parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id, parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id, parent_parallel_start_node_id=parent_parallel_start_node_id,
node_version=node_instance.version(),
) )
elif isinstance(item, RunRetrieverResourceEvent): elif isinstance(item, RunRetrieverResourceEvent):
yield NodeRunRetrieverResourceEvent( yield NodeRunRetrieverResourceEvent(
@ -822,6 +829,7 @@ class GraphEngine:
parallel_start_node_id=parallel_start_node_id, parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id, parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id, parent_parallel_start_node_id=parent_parallel_start_node_id,
node_version=node_instance.version(),
) )
except GenerateTaskStoppedError: except GenerateTaskStoppedError:
# trigger node run failed event # trigger node run failed event
@ -838,6 +846,7 @@ class GraphEngine:
parallel_start_node_id=parallel_start_node_id, parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id, parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id, parent_parallel_start_node_id=parent_parallel_start_node_id,
node_version=node_instance.version(),
) )
return return
except Exception as e: except Exception as e:

View File

@ -18,7 +18,11 @@ from models.workflow import WorkflowNodeExecutionStatus
class AnswerNode(BaseNode[AnswerNodeData]): class AnswerNode(BaseNode[AnswerNodeData]):
_node_data_cls = AnswerNodeData _node_data_cls = AnswerNodeData
_node_type: NodeType = NodeType.ANSWER _node_type = NodeType.ANSWER
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> NodeRunResult: def _run(self) -> NodeRunResult:
""" """

View File

@ -109,6 +109,7 @@ class AnswerStreamProcessor(StreamProcessor):
parallel_id=event.parallel_id, parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id, parallel_start_node_id=event.parallel_start_node_id,
from_variable_selector=[answer_node_id, "answer"], from_variable_selector=[answer_node_id, "answer"],
node_version=event.node_version,
) )
else: else:
route_chunk = cast(VarGenerateRouteChunk, route_chunk) route_chunk = cast(VarGenerateRouteChunk, route_chunk)
@ -134,6 +135,7 @@ class AnswerStreamProcessor(StreamProcessor):
route_node_state=event.route_node_state, route_node_state=event.route_node_state,
parallel_id=event.parallel_id, parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id, parallel_start_node_id=event.parallel_start_node_id,
node_version=event.node_version,
) )
self.route_position[answer_node_id] += 1 self.route_position[answer_node_id] += 1

View File

@ -1,7 +1,7 @@
import logging import logging
from abc import abstractmethod from abc import abstractmethod
from collections.abc import Generator, Mapping, Sequence from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, cast from typing import TYPE_CHECKING, Any, ClassVar, Generic, Optional, TypeVar, Union, cast
from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.nodes.enums import CONTINUE_ON_ERROR_NODE_TYPE, RETRY_ON_ERROR_NODE_TYPE, NodeType from core.workflow.nodes.enums import CONTINUE_ON_ERROR_NODE_TYPE, RETRY_ON_ERROR_NODE_TYPE, NodeType
@ -23,7 +23,7 @@ GenericNodeData = TypeVar("GenericNodeData", bound=BaseNodeData)
class BaseNode(Generic[GenericNodeData]): class BaseNode(Generic[GenericNodeData]):
_node_data_cls: type[GenericNodeData] _node_data_cls: type[GenericNodeData]
_node_type: NodeType _node_type: ClassVar[NodeType]
def __init__( def __init__(
self, self,
@ -101,9 +101,10 @@ class BaseNode(Generic[GenericNodeData]):
raise ValueError("Node ID is required when extracting variable selector to variable mapping.") raise ValueError("Node ID is required when extracting variable selector to variable mapping.")
node_data = cls._node_data_cls(**config.get("data", {})) node_data = cls._node_data_cls(**config.get("data", {}))
return cls._extract_variable_selector_to_variable_mapping( data = cls._extract_variable_selector_to_variable_mapping(
graph_config=graph_config, node_id=node_id, node_data=cast(GenericNodeData, node_data) graph_config=graph_config, node_id=node_id, node_data=cast(GenericNodeData, node_data)
) )
return data
@classmethod @classmethod
def _extract_variable_selector_to_variable_mapping( def _extract_variable_selector_to_variable_mapping(
@ -139,6 +140,16 @@ class BaseNode(Generic[GenericNodeData]):
""" """
return self._node_type return self._node_type
@classmethod
@abstractmethod
def version(cls) -> str:
"""`node_version` returns the version of current node type."""
# NOTE(QuantumGhost): This should be in sync with `NODE_TYPE_CLASSES_MAPPING`.
#
# If you have introduced a new node type, please add it to `NODE_TYPE_CLASSES_MAPPING`
# in `api/core/workflow/nodes/__init__.py`.
pass
@property @property
def should_continue_on_error(self) -> bool: def should_continue_on_error(self) -> bool:
"""judge if should continue on error """judge if should continue on error

View File

@ -40,6 +40,10 @@ class CodeNode(BaseNode[CodeNodeData]):
return code_provider.get_default_config() return code_provider.get_default_config()
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> NodeRunResult: def _run(self) -> NodeRunResult:
# Get code language # Get code language
code_language = self.node_data.code_language code_language = self.node_data.code_language

View File

@ -44,6 +44,10 @@ class DocumentExtractorNode(BaseNode[DocumentExtractorNodeData]):
_node_data_cls = DocumentExtractorNodeData _node_data_cls = DocumentExtractorNodeData
_node_type = NodeType.DOCUMENT_EXTRACTOR _node_type = NodeType.DOCUMENT_EXTRACTOR
@classmethod
def version(cls) -> str:
return "1"
def _run(self): def _run(self):
variable_selector = self.node_data.variable_selector variable_selector = self.node_data.variable_selector
variable = self.graph_runtime_state.variable_pool.get(variable_selector) variable = self.graph_runtime_state.variable_pool.get(variable_selector)

View File

@ -9,6 +9,10 @@ class EndNode(BaseNode[EndNodeData]):
_node_data_cls = EndNodeData _node_data_cls = EndNodeData
_node_type = NodeType.END _node_type = NodeType.END
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> NodeRunResult: def _run(self) -> NodeRunResult:
""" """
Run node Run node

View File

@ -139,6 +139,7 @@ class EndStreamProcessor(StreamProcessor):
route_node_state=event.route_node_state, route_node_state=event.route_node_state,
parallel_id=event.parallel_id, parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id, parallel_start_node_id=event.parallel_start_node_id,
node_version=event.node_version,
) )
self.route_position[end_node_id] += 1 self.route_position[end_node_id] += 1

View File

@ -60,6 +60,10 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
}, },
} }
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> NodeRunResult: def _run(self) -> NodeRunResult:
process_data = {} process_data = {}
try: try:

View File

@ -16,6 +16,10 @@ class IfElseNode(BaseNode[IfElseNodeData]):
_node_data_cls = IfElseNodeData _node_data_cls = IfElseNodeData
_node_type = NodeType.IF_ELSE _node_type = NodeType.IF_ELSE
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> NodeRunResult: def _run(self) -> NodeRunResult:
""" """
Run node Run node

View File

@ -72,6 +72,10 @@ class IterationNode(BaseNode[IterationNodeData]):
}, },
} }
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]: def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
""" """
Run the node. Run the node.

View File

@ -13,6 +13,10 @@ class IterationStartNode(BaseNode[IterationStartNodeData]):
_node_data_cls = IterationStartNodeData _node_data_cls = IterationStartNodeData
_node_type = NodeType.ITERATION_START _node_type = NodeType.ITERATION_START
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> NodeRunResult: def _run(self) -> NodeRunResult:
""" """
Run the node. Run the node.

View File

@ -16,6 +16,10 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
_node_data_cls = ListOperatorNodeData _node_data_cls = ListOperatorNodeData
_node_type = NodeType.LIST_OPERATOR _node_type = NodeType.LIST_OPERATOR
@classmethod
def version(cls) -> str:
return "1"
def _run(self): def _run(self):
inputs: dict[str, list] = {} inputs: dict[str, list] = {}
process_data: dict[str, list] = {} process_data: dict[str, list] = {}

View File

@ -148,6 +148,10 @@ class LLMNode(BaseNode[LLMNodeData]):
) )
self._llm_file_saver = llm_file_saver self._llm_file_saver = llm_file_saver
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]: def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
def process_structured_output(text: str) -> Optional[dict[str, Any]]: def process_structured_output(text: str) -> Optional[dict[str, Any]]:
"""Process structured output if enabled""" """Process structured output if enabled"""

View File

@ -13,6 +13,10 @@ class LoopEndNode(BaseNode[LoopEndNodeData]):
_node_data_cls = LoopEndNodeData _node_data_cls = LoopEndNodeData
_node_type = NodeType.LOOP_END _node_type = NodeType.LOOP_END
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> NodeRunResult: def _run(self) -> NodeRunResult:
""" """
Run the node. Run the node.

View File

@ -54,6 +54,10 @@ class LoopNode(BaseNode[LoopNodeData]):
_node_data_cls = LoopNodeData _node_data_cls = LoopNodeData
_node_type = NodeType.LOOP _node_type = NodeType.LOOP
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]: def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
"""Run the node.""" """Run the node."""
# Get inputs # Get inputs

View File

@ -13,6 +13,10 @@ class LoopStartNode(BaseNode[LoopStartNodeData]):
_node_data_cls = LoopStartNodeData _node_data_cls = LoopStartNodeData
_node_type = NodeType.LOOP_START _node_type = NodeType.LOOP_START
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> NodeRunResult: def _run(self) -> NodeRunResult:
""" """
Run the node. Run the node.

View File

@ -25,6 +25,11 @@ from core.workflow.nodes.variable_assigner.v2 import VariableAssignerNode as Var
LATEST_VERSION = "latest" LATEST_VERSION = "latest"
# NOTE(QuantumGhost): This should be in sync with subclasses of BaseNode.
# Specifically, if you have introduced new node types, you should add them here.
#
# TODO(QuantumGhost): This could be automated with either metaclass or `__init_subclass__`
# hook. Try to avoid duplication of node information.
NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[BaseNode]]] = { NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[BaseNode]]] = {
NodeType.START: { NodeType.START: {
LATEST_VERSION: StartNode, LATEST_VERSION: StartNode,

View File

@ -1,3 +1,4 @@
from core.file.constants import add_dummy_output
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.nodes.base import BaseNode from core.workflow.nodes.base import BaseNode
@ -10,6 +11,10 @@ class StartNode(BaseNode[StartNodeData]):
_node_data_cls = StartNodeData _node_data_cls = StartNodeData
_node_type = NodeType.START _node_type = NodeType.START
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> NodeRunResult: def _run(self) -> NodeRunResult:
node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs) node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs)
system_inputs = self.graph_runtime_state.variable_pool.system_variables system_inputs = self.graph_runtime_state.variable_pool.system_variables
@ -18,5 +23,9 @@ class StartNode(BaseNode[StartNodeData]):
# Set system variables as node outputs. # Set system variables as node outputs.
for var in system_inputs: for var in system_inputs:
node_inputs[SYSTEM_VARIABLE_NODE_ID + "." + var] = system_inputs[var] node_inputs[SYSTEM_VARIABLE_NODE_ID + "." + var] = system_inputs[var]
outputs = dict(node_inputs)
# Need special handling for `Start` node, as all other output variables
# are treated as systemd variables.
add_dummy_output(outputs)
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=node_inputs, outputs=node_inputs) return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=node_inputs, outputs=outputs)

View File

@ -28,6 +28,10 @@ class TemplateTransformNode(BaseNode[TemplateTransformNodeData]):
"config": {"variables": [{"variable": "arg1", "value_selector": []}], "template": "{{ arg1 }}"}, "config": {"variables": [{"variable": "arg1", "value_selector": []}], "template": "{{ arg1 }}"},
} }
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> NodeRunResult: def _run(self) -> NodeRunResult:
# Get variables # Get variables
variables = {} variables = {}

View File

@ -44,6 +44,10 @@ class ToolNode(BaseNode[ToolNodeData]):
_node_data_cls = ToolNodeData _node_data_cls = ToolNodeData
_node_type = NodeType.TOOL _node_type = NodeType.TOOL
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> Generator: def _run(self) -> Generator:
""" """
Run the tool node Run the tool node

View File

@ -9,6 +9,10 @@ class VariableAggregatorNode(BaseNode[VariableAssignerNodeData]):
_node_data_cls = VariableAssignerNodeData _node_data_cls = VariableAssignerNodeData
_node_type = NodeType.VARIABLE_AGGREGATOR _node_type = NodeType.VARIABLE_AGGREGATOR
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> NodeRunResult: def _run(self) -> NodeRunResult:
# Get variables # Get variables
outputs = {} outputs = {}

View File

@ -1,7 +1,11 @@
from collections.abc import Sequence
from typing import Any, TypedDict
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from core.variables import Variable from core.variables import Segment, SegmentType, Variable
from core.variables.consts import MIN_SELECTORS_LENGTH
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
from extensions.ext_database import db from extensions.ext_database import db
from models import ConversationVariable from models import ConversationVariable
@ -17,3 +21,22 @@ def update_conversation_variable(conversation_id: str, variable: Variable):
raise VariableOperatorNodeError("conversation variable not found in the database") raise VariableOperatorNodeError("conversation variable not found in the database")
row.data = variable.model_dump_json() row.data = variable.model_dump_json()
session.commit() session.commit()
class VariableOutput(TypedDict):
name: str
selector: Sequence[str]
new_value: Any
type: SegmentType
def variable_to_output_mapping(selector: Sequence[str], seg: Segment) -> VariableOutput:
if len(selector) < MIN_SELECTORS_LENGTH:
raise Exception("selector too short")
node_id, var_name = selector[:2]
return {
"name": var_name,
"selector": selector[:2],
"new_value": seg.value,
"type": seg.value_type,
}

View File

@ -14,9 +14,14 @@ class VariableAssignerNode(BaseNode[VariableAssignerData]):
_node_data_cls = VariableAssignerData _node_data_cls = VariableAssignerData
_node_type = NodeType.VARIABLE_ASSIGNER _node_type = NodeType.VARIABLE_ASSIGNER
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> NodeRunResult: def _run(self) -> NodeRunResult:
assigned_variable_selector = self.node_data.assigned_variable_selector
# Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject # Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject
original_variable = self.graph_runtime_state.variable_pool.get(self.node_data.assigned_variable_selector) original_variable = self.graph_runtime_state.variable_pool.get(assigned_variable_selector)
if not isinstance(original_variable, Variable): if not isinstance(original_variable, Variable):
raise VariableOperatorNodeError("assigned variable not found") raise VariableOperatorNodeError("assigned variable not found")
@ -44,7 +49,7 @@ class VariableAssignerNode(BaseNode[VariableAssignerData]):
raise VariableOperatorNodeError(f"unsupported write mode: {self.node_data.write_mode}") raise VariableOperatorNodeError(f"unsupported write mode: {self.node_data.write_mode}")
# Over write the variable. # Over write the variable.
self.graph_runtime_state.variable_pool.add(self.node_data.assigned_variable_selector, updated_variable) self.graph_runtime_state.variable_pool.add(assigned_variable_selector, updated_variable)
# TODO: Move database operation to the pipeline. # TODO: Move database operation to the pipeline.
# Update conversation variable. # Update conversation variable.
@ -58,6 +63,14 @@ class VariableAssignerNode(BaseNode[VariableAssignerData]):
inputs={ inputs={
"value": income_value.to_object(), "value": income_value.to_object(),
}, },
outputs={
# NOTE(QuantumGhost): although only one variable is updated in `v1.VariableAssignerNode`,
# we still set `output_variables` as a list to ensure the schema of output is
# compatible with `v2.VariableAssignerNode`.
"updated_variables": [
common_helpers.variable_to_output_mapping(assigned_variable_selector, updated_variable)
]
},
) )

View File

@ -29,6 +29,10 @@ class VariableAssignerNode(BaseNode[VariableAssignerNodeData]):
_node_data_cls = VariableAssignerNodeData _node_data_cls = VariableAssignerNodeData
_node_type = NodeType.VARIABLE_ASSIGNER _node_type = NodeType.VARIABLE_ASSIGNER
@classmethod
def version(cls) -> str:
return "2"
def _run(self) -> NodeRunResult: def _run(self) -> NodeRunResult:
inputs = self.node_data.model_dump() inputs = self.node_data.model_dump()
process_data: dict[str, Any] = {} process_data: dict[str, Any] = {}
@ -137,6 +141,13 @@ class VariableAssignerNode(BaseNode[VariableAssignerNodeData]):
status=WorkflowNodeExecutionStatus.SUCCEEDED, status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=inputs, inputs=inputs,
process_data=process_data, process_data=process_data,
outputs={
"updated_variables": [
common_helpers.variable_to_output_mapping(selector, seg)
for selector in updated_variable_selectors
if (seg := self.graph_runtime_state.variable_pool.get(selector)) is not None
],
},
) )
def _handle_item( def _handle_item(