feat(api): add a version class method to BaseNode and subclasses

This ensures that we can get the version of node while executing.

Add `node_version` to `BaseNodeEvent` to ensure that all node
related events includes node version information.
This commit is contained in:
QuantumGhost 2025-05-22 17:28:57 +08:00
parent 392ad17497
commit 25b4a96aed
26 changed files with 158 additions and 8 deletions

View File

@ -65,6 +65,8 @@ class BaseNodeEvent(GraphEngineEvent):
"""iteration id if node is in iteration"""
in_loop_id: Optional[str] = None
"""loop id if node is in loop"""
# The version of the node, or "1" if not specified.
node_version: str = "1"
class NodeRunStartedEvent(BaseNodeEvent):

View File

@ -313,6 +313,7 @@ class GraphEngine:
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
node_version=node_instance.version(),
)
raise e
@ -630,6 +631,7 @@ class GraphEngine:
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
agent_strategy=agent_strategy,
node_version=node_instance.version(),
)
db.session.close()
@ -688,6 +690,7 @@ class GraphEngine:
error=run_result.error or "Unknown error",
retry_index=retries,
start_at=retry_start_at,
node_version=node_instance.version(),
)
time.sleep(retry_interval)
break
@ -723,6 +726,7 @@ class GraphEngine:
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
node_version=node_instance.version(),
)
should_continue_retry = False
else:
@ -737,6 +741,7 @@ class GraphEngine:
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
node_version=node_instance.version(),
)
should_continue_retry = False
elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
@ -791,6 +796,7 @@ class GraphEngine:
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
node_version=node_instance.version(),
)
should_continue_retry = False
@ -808,6 +814,7 @@ class GraphEngine:
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
node_version=node_instance.version(),
)
elif isinstance(item, RunRetrieverResourceEvent):
yield NodeRunRetrieverResourceEvent(
@ -822,6 +829,7 @@ class GraphEngine:
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
node_version=node_instance.version(),
)
except GenerateTaskStoppedError:
# trigger node run failed event
@ -838,6 +846,7 @@ class GraphEngine:
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
node_version=node_instance.version(),
)
return
except Exception as e:

View File

@ -18,7 +18,11 @@ from models.workflow import WorkflowNodeExecutionStatus
class AnswerNode(BaseNode[AnswerNodeData]):
_node_data_cls = AnswerNodeData
_node_type: NodeType = NodeType.ANSWER
_node_type = NodeType.ANSWER
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> NodeRunResult:
"""

View File

@ -109,6 +109,7 @@ class AnswerStreamProcessor(StreamProcessor):
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
from_variable_selector=[answer_node_id, "answer"],
node_version=event.node_version,
)
else:
route_chunk = cast(VarGenerateRouteChunk, route_chunk)
@ -134,6 +135,7 @@ class AnswerStreamProcessor(StreamProcessor):
route_node_state=event.route_node_state,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
node_version=event.node_version,
)
self.route_position[answer_node_id] += 1

View File

@ -1,7 +1,7 @@
import logging
from abc import abstractmethod
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, cast
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Optional, TypeVar, Union, cast
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.nodes.enums import CONTINUE_ON_ERROR_NODE_TYPE, RETRY_ON_ERROR_NODE_TYPE, NodeType
@ -23,7 +23,7 @@ GenericNodeData = TypeVar("GenericNodeData", bound=BaseNodeData)
class BaseNode(Generic[GenericNodeData]):
_node_data_cls: type[GenericNodeData]
_node_type: NodeType
_node_type: ClassVar[NodeType]
def __init__(
self,
@ -101,9 +101,10 @@ class BaseNode(Generic[GenericNodeData]):
raise ValueError("Node ID is required when extracting variable selector to variable mapping.")
node_data = cls._node_data_cls(**config.get("data", {}))
return cls._extract_variable_selector_to_variable_mapping(
data = cls._extract_variable_selector_to_variable_mapping(
graph_config=graph_config, node_id=node_id, node_data=cast(GenericNodeData, node_data)
)
return data
@classmethod
def _extract_variable_selector_to_variable_mapping(
@ -139,6 +140,16 @@ class BaseNode(Generic[GenericNodeData]):
"""
return self._node_type
@classmethod
@abstractmethod
def version(cls) -> str:
"""`node_version` returns the version of current node type."""
# NOTE(QuantumGhost): This should be in sync with `NODE_TYPE_CLASSES_MAPPING`.
#
# If you have introduced a new node type, please add it to `NODE_TYPE_CLASSES_MAPPING`
# in `api/core/workflow/nodes/__init__.py`.
pass
@property
def should_continue_on_error(self) -> bool:
"""judge if should continue on error

View File

@ -40,6 +40,10 @@ class CodeNode(BaseNode[CodeNodeData]):
return code_provider.get_default_config()
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> NodeRunResult:
# Get code language
code_language = self.node_data.code_language

View File

@ -44,6 +44,10 @@ class DocumentExtractorNode(BaseNode[DocumentExtractorNodeData]):
_node_data_cls = DocumentExtractorNodeData
_node_type = NodeType.DOCUMENT_EXTRACTOR
@classmethod
def version(cls) -> str:
return "1"
def _run(self):
variable_selector = self.node_data.variable_selector
variable = self.graph_runtime_state.variable_pool.get(variable_selector)

View File

@ -9,6 +9,10 @@ class EndNode(BaseNode[EndNodeData]):
_node_data_cls = EndNodeData
_node_type = NodeType.END
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> NodeRunResult:
"""
Run node

View File

@ -139,6 +139,7 @@ class EndStreamProcessor(StreamProcessor):
route_node_state=event.route_node_state,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
node_version=event.node_version,
)
self.route_position[end_node_id] += 1

View File

@ -60,6 +60,10 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
},
}
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> NodeRunResult:
process_data = {}
try:

View File

@ -16,6 +16,10 @@ class IfElseNode(BaseNode[IfElseNodeData]):
_node_data_cls = IfElseNodeData
_node_type = NodeType.IF_ELSE
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> NodeRunResult:
"""
Run node

View File

@ -72,6 +72,10 @@ class IterationNode(BaseNode[IterationNodeData]):
},
}
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
"""
Run the node.

View File

@ -13,6 +13,10 @@ class IterationStartNode(BaseNode[IterationStartNodeData]):
_node_data_cls = IterationStartNodeData
_node_type = NodeType.ITERATION_START
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> NodeRunResult:
"""
Run the node.

View File

@ -16,6 +16,10 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
_node_data_cls = ListOperatorNodeData
_node_type = NodeType.LIST_OPERATOR
@classmethod
def version(cls) -> str:
return "1"
def _run(self):
inputs: dict[str, list] = {}
process_data: dict[str, list] = {}

View File

@ -148,6 +148,10 @@ class LLMNode(BaseNode[LLMNodeData]):
)
self._llm_file_saver = llm_file_saver
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
def process_structured_output(text: str) -> Optional[dict[str, Any]]:
"""Process structured output if enabled"""

View File

@ -13,6 +13,10 @@ class LoopEndNode(BaseNode[LoopEndNodeData]):
_node_data_cls = LoopEndNodeData
_node_type = NodeType.LOOP_END
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> NodeRunResult:
"""
Run the node.

View File

@ -54,6 +54,10 @@ class LoopNode(BaseNode[LoopNodeData]):
_node_data_cls = LoopNodeData
_node_type = NodeType.LOOP
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
"""Run the node."""
# Get inputs

View File

@ -13,6 +13,10 @@ class LoopStartNode(BaseNode[LoopStartNodeData]):
_node_data_cls = LoopStartNodeData
_node_type = NodeType.LOOP_START
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> NodeRunResult:
"""
Run the node.

View File

@ -25,6 +25,11 @@ from core.workflow.nodes.variable_assigner.v2 import VariableAssignerNode as Var
LATEST_VERSION = "latest"
# NOTE(QuantumGhost): This should be in sync with subclasses of BaseNode.
# Specifically, if you have introduced new node types, you should add them here.
#
# TODO(QuantumGhost): This could be automated with either metaclass or `__init_subclass__`
# hook. Try to avoid duplication of node information.
NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[BaseNode]]] = {
NodeType.START: {
LATEST_VERSION: StartNode,

View File

@ -1,3 +1,4 @@
from core.file.constants import add_dummy_output
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.nodes.base import BaseNode
@ -10,6 +11,10 @@ class StartNode(BaseNode[StartNodeData]):
_node_data_cls = StartNodeData
_node_type = NodeType.START
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> NodeRunResult:
node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs)
system_inputs = self.graph_runtime_state.variable_pool.system_variables
@ -18,5 +23,9 @@ class StartNode(BaseNode[StartNodeData]):
# Set system variables as node outputs.
for var in system_inputs:
node_inputs[SYSTEM_VARIABLE_NODE_ID + "." + var] = system_inputs[var]
outputs = dict(node_inputs)
# Need special handling for `Start` node, as all other output variables
# are treated as systemd variables.
add_dummy_output(outputs)
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=node_inputs, outputs=node_inputs)
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=node_inputs, outputs=outputs)

View File

@ -28,6 +28,10 @@ class TemplateTransformNode(BaseNode[TemplateTransformNodeData]):
"config": {"variables": [{"variable": "arg1", "value_selector": []}], "template": "{{ arg1 }}"},
}
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> NodeRunResult:
# Get variables
variables = {}

View File

@ -44,6 +44,10 @@ class ToolNode(BaseNode[ToolNodeData]):
_node_data_cls = ToolNodeData
_node_type = NodeType.TOOL
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> Generator:
"""
Run the tool node

View File

@ -9,6 +9,10 @@ class VariableAggregatorNode(BaseNode[VariableAssignerNodeData]):
_node_data_cls = VariableAssignerNodeData
_node_type = NodeType.VARIABLE_AGGREGATOR
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> NodeRunResult:
# Get variables
outputs = {}

View File

@ -1,7 +1,11 @@
from collections.abc import Sequence
from typing import Any, TypedDict
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.variables import Variable
from core.variables import Segment, SegmentType, Variable
from core.variables.consts import MIN_SELECTORS_LENGTH
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
from extensions.ext_database import db
from models import ConversationVariable
@ -17,3 +21,22 @@ def update_conversation_variable(conversation_id: str, variable: Variable):
raise VariableOperatorNodeError("conversation variable not found in the database")
row.data = variable.model_dump_json()
session.commit()
class VariableOutput(TypedDict):
name: str
selector: Sequence[str]
new_value: Any
type: SegmentType
def variable_to_output_mapping(selector: Sequence[str], seg: Segment) -> VariableOutput:
if len(selector) < MIN_SELECTORS_LENGTH:
raise Exception("selector too short")
node_id, var_name = selector[:2]
return {
"name": var_name,
"selector": selector[:2],
"new_value": seg.value,
"type": seg.value_type,
}

View File

@ -14,9 +14,14 @@ class VariableAssignerNode(BaseNode[VariableAssignerData]):
_node_data_cls = VariableAssignerData
_node_type = NodeType.VARIABLE_ASSIGNER
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> NodeRunResult:
assigned_variable_selector = self.node_data.assigned_variable_selector
# Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject
original_variable = self.graph_runtime_state.variable_pool.get(self.node_data.assigned_variable_selector)
original_variable = self.graph_runtime_state.variable_pool.get(assigned_variable_selector)
if not isinstance(original_variable, Variable):
raise VariableOperatorNodeError("assigned variable not found")
@ -44,7 +49,7 @@ class VariableAssignerNode(BaseNode[VariableAssignerData]):
raise VariableOperatorNodeError(f"unsupported write mode: {self.node_data.write_mode}")
# Over write the variable.
self.graph_runtime_state.variable_pool.add(self.node_data.assigned_variable_selector, updated_variable)
self.graph_runtime_state.variable_pool.add(assigned_variable_selector, updated_variable)
# TODO: Move database operation to the pipeline.
# Update conversation variable.
@ -58,6 +63,14 @@ class VariableAssignerNode(BaseNode[VariableAssignerData]):
inputs={
"value": income_value.to_object(),
},
outputs={
# NOTE(QuantumGhost): although only one variable is updated in `v1.VariableAssignerNode`,
# we still set `output_variables` as a list to ensure the schema of output is
# compatible with `v2.VariableAssignerNode`.
"updated_variables": [
common_helpers.variable_to_output_mapping(assigned_variable_selector, updated_variable)
]
},
)

View File

@ -29,6 +29,10 @@ class VariableAssignerNode(BaseNode[VariableAssignerNodeData]):
_node_data_cls = VariableAssignerNodeData
_node_type = NodeType.VARIABLE_ASSIGNER
@classmethod
def version(cls) -> str:
return "2"
def _run(self) -> NodeRunResult:
inputs = self.node_data.model_dump()
process_data: dict[str, Any] = {}
@ -137,6 +141,13 @@ class VariableAssignerNode(BaseNode[VariableAssignerNodeData]):
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=inputs,
process_data=process_data,
outputs={
"updated_variables": [
common_helpers.variable_to_output_mapping(selector, seg)
for selector in updated_variable_selectors
if (seg := self.graph_runtime_state.variable_pool.get(selector)) is not None
],
},
)
def _handle_item(