mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-15 10:45:54 +08:00
feat(api): add a version
class method to BaseNode and subclasses
This ensures that we can get the version of node while executing. Add `node_version` to `BaseNodeEvent` to ensure that all node related events includes node version information.
This commit is contained in:
parent
392ad17497
commit
25b4a96aed
@ -65,6 +65,8 @@ class BaseNodeEvent(GraphEngineEvent):
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: Optional[str] = None
|
||||
"""loop id if node is in loop"""
|
||||
# The version of the node, or "1" if not specified.
|
||||
node_version: str = "1"
|
||||
|
||||
|
||||
class NodeRunStartedEvent(BaseNodeEvent):
|
||||
|
@ -313,6 +313,7 @@ class GraphEngine:
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
node_version=node_instance.version(),
|
||||
)
|
||||
raise e
|
||||
|
||||
@ -630,6 +631,7 @@ class GraphEngine:
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
agent_strategy=agent_strategy,
|
||||
node_version=node_instance.version(),
|
||||
)
|
||||
|
||||
db.session.close()
|
||||
@ -688,6 +690,7 @@ class GraphEngine:
|
||||
error=run_result.error or "Unknown error",
|
||||
retry_index=retries,
|
||||
start_at=retry_start_at,
|
||||
node_version=node_instance.version(),
|
||||
)
|
||||
time.sleep(retry_interval)
|
||||
break
|
||||
@ -723,6 +726,7 @@ class GraphEngine:
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
node_version=node_instance.version(),
|
||||
)
|
||||
should_continue_retry = False
|
||||
else:
|
||||
@ -737,6 +741,7 @@ class GraphEngine:
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
node_version=node_instance.version(),
|
||||
)
|
||||
should_continue_retry = False
|
||||
elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
|
||||
@ -791,6 +796,7 @@ class GraphEngine:
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
node_version=node_instance.version(),
|
||||
)
|
||||
should_continue_retry = False
|
||||
|
||||
@ -808,6 +814,7 @@ class GraphEngine:
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
node_version=node_instance.version(),
|
||||
)
|
||||
elif isinstance(item, RunRetrieverResourceEvent):
|
||||
yield NodeRunRetrieverResourceEvent(
|
||||
@ -822,6 +829,7 @@ class GraphEngine:
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
node_version=node_instance.version(),
|
||||
)
|
||||
except GenerateTaskStoppedError:
|
||||
# trigger node run failed event
|
||||
@ -838,6 +846,7 @@ class GraphEngine:
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
node_version=node_instance.version(),
|
||||
)
|
||||
return
|
||||
except Exception as e:
|
||||
|
@ -18,7 +18,11 @@ from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
class AnswerNode(BaseNode[AnswerNodeData]):
|
||||
_node_data_cls = AnswerNodeData
|
||||
_node_type: NodeType = NodeType.ANSWER
|
||||
_node_type = NodeType.ANSWER
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
"""
|
||||
|
@ -109,6 +109,7 @@ class AnswerStreamProcessor(StreamProcessor):
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
from_variable_selector=[answer_node_id, "answer"],
|
||||
node_version=event.node_version,
|
||||
)
|
||||
else:
|
||||
route_chunk = cast(VarGenerateRouteChunk, route_chunk)
|
||||
@ -134,6 +135,7 @@ class AnswerStreamProcessor(StreamProcessor):
|
||||
route_node_state=event.route_node_state,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
node_version=event.node_version,
|
||||
)
|
||||
|
||||
self.route_position[answer_node_id] += 1
|
||||
|
@ -1,7 +1,7 @@
|
||||
import logging
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, cast
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Optional, TypeVar, Union, cast
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.nodes.enums import CONTINUE_ON_ERROR_NODE_TYPE, RETRY_ON_ERROR_NODE_TYPE, NodeType
|
||||
@ -23,7 +23,7 @@ GenericNodeData = TypeVar("GenericNodeData", bound=BaseNodeData)
|
||||
|
||||
class BaseNode(Generic[GenericNodeData]):
|
||||
_node_data_cls: type[GenericNodeData]
|
||||
_node_type: NodeType
|
||||
_node_type: ClassVar[NodeType]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -101,9 +101,10 @@ class BaseNode(Generic[GenericNodeData]):
|
||||
raise ValueError("Node ID is required when extracting variable selector to variable mapping.")
|
||||
|
||||
node_data = cls._node_data_cls(**config.get("data", {}))
|
||||
return cls._extract_variable_selector_to_variable_mapping(
|
||||
data = cls._extract_variable_selector_to_variable_mapping(
|
||||
graph_config=graph_config, node_id=node_id, node_data=cast(GenericNodeData, node_data)
|
||||
)
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
@ -139,6 +140,16 @@ class BaseNode(Generic[GenericNodeData]):
|
||||
"""
|
||||
return self._node_type
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def version(cls) -> str:
|
||||
"""`node_version` returns the version of current node type."""
|
||||
# NOTE(QuantumGhost): This should be in sync with `NODE_TYPE_CLASSES_MAPPING`.
|
||||
#
|
||||
# If you have introduced a new node type, please add it to `NODE_TYPE_CLASSES_MAPPING`
|
||||
# in `api/core/workflow/nodes/__init__.py`.
|
||||
pass
|
||||
|
||||
@property
|
||||
def should_continue_on_error(self) -> bool:
|
||||
"""judge if should continue on error
|
||||
|
@ -40,6 +40,10 @@ class CodeNode(BaseNode[CodeNodeData]):
|
||||
|
||||
return code_provider.get_default_config()
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
# Get code language
|
||||
code_language = self.node_data.code_language
|
||||
|
@ -44,6 +44,10 @@ class DocumentExtractorNode(BaseNode[DocumentExtractorNodeData]):
|
||||
_node_data_cls = DocumentExtractorNodeData
|
||||
_node_type = NodeType.DOCUMENT_EXTRACTOR
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self):
|
||||
variable_selector = self.node_data.variable_selector
|
||||
variable = self.graph_runtime_state.variable_pool.get(variable_selector)
|
||||
|
@ -9,6 +9,10 @@ class EndNode(BaseNode[EndNodeData]):
|
||||
_node_data_cls = EndNodeData
|
||||
_node_type = NodeType.END
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
"""
|
||||
Run node
|
||||
|
@ -139,6 +139,7 @@ class EndStreamProcessor(StreamProcessor):
|
||||
route_node_state=event.route_node_state,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
node_version=event.node_version,
|
||||
)
|
||||
|
||||
self.route_position[end_node_id] += 1
|
||||
|
@ -60,6 +60,10 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
|
||||
},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
process_data = {}
|
||||
try:
|
||||
|
@ -16,6 +16,10 @@ class IfElseNode(BaseNode[IfElseNodeData]):
|
||||
_node_data_cls = IfElseNodeData
|
||||
_node_type = NodeType.IF_ELSE
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
"""
|
||||
Run node
|
||||
|
@ -72,6 +72,10 @@ class IterationNode(BaseNode[IterationNodeData]):
|
||||
},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
|
||||
"""
|
||||
Run the node.
|
||||
|
@ -13,6 +13,10 @@ class IterationStartNode(BaseNode[IterationStartNodeData]):
|
||||
_node_data_cls = IterationStartNodeData
|
||||
_node_type = NodeType.ITERATION_START
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
"""
|
||||
Run the node.
|
||||
|
@ -16,6 +16,10 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
|
||||
_node_data_cls = ListOperatorNodeData
|
||||
_node_type = NodeType.LIST_OPERATOR
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self):
|
||||
inputs: dict[str, list] = {}
|
||||
process_data: dict[str, list] = {}
|
||||
|
@ -148,6 +148,10 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
)
|
||||
self._llm_file_saver = llm_file_saver
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
|
||||
def process_structured_output(text: str) -> Optional[dict[str, Any]]:
|
||||
"""Process structured output if enabled"""
|
||||
|
@ -13,6 +13,10 @@ class LoopEndNode(BaseNode[LoopEndNodeData]):
|
||||
_node_data_cls = LoopEndNodeData
|
||||
_node_type = NodeType.LOOP_END
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
"""
|
||||
Run the node.
|
||||
|
@ -54,6 +54,10 @@ class LoopNode(BaseNode[LoopNodeData]):
|
||||
_node_data_cls = LoopNodeData
|
||||
_node_type = NodeType.LOOP
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
|
||||
"""Run the node."""
|
||||
# Get inputs
|
||||
|
@ -13,6 +13,10 @@ class LoopStartNode(BaseNode[LoopStartNodeData]):
|
||||
_node_data_cls = LoopStartNodeData
|
||||
_node_type = NodeType.LOOP_START
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
"""
|
||||
Run the node.
|
||||
|
@ -25,6 +25,11 @@ from core.workflow.nodes.variable_assigner.v2 import VariableAssignerNode as Var
|
||||
|
||||
LATEST_VERSION = "latest"
|
||||
|
||||
# NOTE(QuantumGhost): This should be in sync with subclasses of BaseNode.
|
||||
# Specifically, if you have introduced new node types, you should add them here.
|
||||
#
|
||||
# TODO(QuantumGhost): This could be automated with either metaclass or `__init_subclass__`
|
||||
# hook. Try to avoid duplication of node information.
|
||||
NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[BaseNode]]] = {
|
||||
NodeType.START: {
|
||||
LATEST_VERSION: StartNode,
|
||||
|
@ -1,3 +1,4 @@
|
||||
from core.file.constants import add_dummy_output
|
||||
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
@ -10,6 +11,10 @@ class StartNode(BaseNode[StartNodeData]):
|
||||
_node_data_cls = StartNodeData
|
||||
_node_type = NodeType.START
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs)
|
||||
system_inputs = self.graph_runtime_state.variable_pool.system_variables
|
||||
@ -18,5 +23,9 @@ class StartNode(BaseNode[StartNodeData]):
|
||||
# Set system variables as node outputs.
|
||||
for var in system_inputs:
|
||||
node_inputs[SYSTEM_VARIABLE_NODE_ID + "." + var] = system_inputs[var]
|
||||
outputs = dict(node_inputs)
|
||||
# Need special handling for `Start` node, as all other output variables
|
||||
# are treated as systemd variables.
|
||||
add_dummy_output(outputs)
|
||||
|
||||
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=node_inputs, outputs=node_inputs)
|
||||
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=node_inputs, outputs=outputs)
|
||||
|
@ -28,6 +28,10 @@ class TemplateTransformNode(BaseNode[TemplateTransformNodeData]):
|
||||
"config": {"variables": [{"variable": "arg1", "value_selector": []}], "template": "{{ arg1 }}"},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
# Get variables
|
||||
variables = {}
|
||||
|
@ -44,6 +44,10 @@ class ToolNode(BaseNode[ToolNodeData]):
|
||||
_node_data_cls = ToolNodeData
|
||||
_node_type = NodeType.TOOL
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> Generator:
|
||||
"""
|
||||
Run the tool node
|
||||
|
@ -9,6 +9,10 @@ class VariableAggregatorNode(BaseNode[VariableAssignerNodeData]):
|
||||
_node_data_cls = VariableAssignerNodeData
|
||||
_node_type = NodeType.VARIABLE_AGGREGATOR
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
# Get variables
|
||||
outputs = {}
|
||||
|
@ -1,7 +1,11 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, TypedDict
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.variables import Variable
|
||||
from core.variables import Segment, SegmentType, Variable
|
||||
from core.variables.consts import MIN_SELECTORS_LENGTH
|
||||
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
|
||||
from extensions.ext_database import db
|
||||
from models import ConversationVariable
|
||||
@ -17,3 +21,22 @@ def update_conversation_variable(conversation_id: str, variable: Variable):
|
||||
raise VariableOperatorNodeError("conversation variable not found in the database")
|
||||
row.data = variable.model_dump_json()
|
||||
session.commit()
|
||||
|
||||
|
||||
class VariableOutput(TypedDict):
|
||||
name: str
|
||||
selector: Sequence[str]
|
||||
new_value: Any
|
||||
type: SegmentType
|
||||
|
||||
|
||||
def variable_to_output_mapping(selector: Sequence[str], seg: Segment) -> VariableOutput:
|
||||
if len(selector) < MIN_SELECTORS_LENGTH:
|
||||
raise Exception("selector too short")
|
||||
node_id, var_name = selector[:2]
|
||||
return {
|
||||
"name": var_name,
|
||||
"selector": selector[:2],
|
||||
"new_value": seg.value,
|
||||
"type": seg.value_type,
|
||||
}
|
||||
|
@ -14,9 +14,14 @@ class VariableAssignerNode(BaseNode[VariableAssignerData]):
|
||||
_node_data_cls = VariableAssignerData
|
||||
_node_type = NodeType.VARIABLE_ASSIGNER
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
assigned_variable_selector = self.node_data.assigned_variable_selector
|
||||
# Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject
|
||||
original_variable = self.graph_runtime_state.variable_pool.get(self.node_data.assigned_variable_selector)
|
||||
original_variable = self.graph_runtime_state.variable_pool.get(assigned_variable_selector)
|
||||
if not isinstance(original_variable, Variable):
|
||||
raise VariableOperatorNodeError("assigned variable not found")
|
||||
|
||||
@ -44,7 +49,7 @@ class VariableAssignerNode(BaseNode[VariableAssignerData]):
|
||||
raise VariableOperatorNodeError(f"unsupported write mode: {self.node_data.write_mode}")
|
||||
|
||||
# Over write the variable.
|
||||
self.graph_runtime_state.variable_pool.add(self.node_data.assigned_variable_selector, updated_variable)
|
||||
self.graph_runtime_state.variable_pool.add(assigned_variable_selector, updated_variable)
|
||||
|
||||
# TODO: Move database operation to the pipeline.
|
||||
# Update conversation variable.
|
||||
@ -58,6 +63,14 @@ class VariableAssignerNode(BaseNode[VariableAssignerData]):
|
||||
inputs={
|
||||
"value": income_value.to_object(),
|
||||
},
|
||||
outputs={
|
||||
# NOTE(QuantumGhost): although only one variable is updated in `v1.VariableAssignerNode`,
|
||||
# we still set `output_variables` as a list to ensure the schema of output is
|
||||
# compatible with `v2.VariableAssignerNode`.
|
||||
"updated_variables": [
|
||||
common_helpers.variable_to_output_mapping(assigned_variable_selector, updated_variable)
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
|
@ -29,6 +29,10 @@ class VariableAssignerNode(BaseNode[VariableAssignerNodeData]):
|
||||
_node_data_cls = VariableAssignerNodeData
|
||||
_node_type = NodeType.VARIABLE_ASSIGNER
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "2"
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
inputs = self.node_data.model_dump()
|
||||
process_data: dict[str, Any] = {}
|
||||
@ -137,6 +141,13 @@ class VariableAssignerNode(BaseNode[VariableAssignerNodeData]):
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
outputs={
|
||||
"updated_variables": [
|
||||
common_helpers.variable_to_output_mapping(selector, seg)
|
||||
for selector in updated_variable_selectors
|
||||
if (seg := self.graph_runtime_state.variable_pool.get(selector)) is not None
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
def _handle_item(
|
||||
|
Loading…
x
Reference in New Issue
Block a user