mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-12 20:29:02 +08:00
refactor(api/core/workflow/nodes/variable_assigner): Split into multi files. (#7434)
This commit is contained in:
parent
0d4753785f
commit
4f64a5d36d
@ -1,109 +1,8 @@
|
|||||||
from collections.abc import Sequence
|
from .node import VariableAssignerNode
|
||||||
from enum import Enum
|
from .node_data import VariableAssignerData, WriteMode
|
||||||
from typing import Optional, cast
|
|
||||||
|
|
||||||
from sqlalchemy import select
|
__all__ = [
|
||||||
from sqlalchemy.orm import Session
|
'VariableAssignerNode',
|
||||||
|
'VariableAssignerData',
|
||||||
from core.app.segments import SegmentType, Variable, factory
|
'WriteMode',
|
||||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
]
|
||||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
|
||||||
from core.workflow.nodes.base_node import BaseNode
|
|
||||||
from extensions.ext_database import db
|
|
||||||
from models import ConversationVariable, WorkflowNodeExecutionStatus
|
|
||||||
|
|
||||||
|
|
||||||
class VariableAssignerNodeError(Exception):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class WriteMode(str, Enum):
|
|
||||||
OVER_WRITE = 'over-write'
|
|
||||||
APPEND = 'append'
|
|
||||||
CLEAR = 'clear'
|
|
||||||
|
|
||||||
|
|
||||||
class VariableAssignerData(BaseNodeData):
|
|
||||||
title: str = 'Variable Assigner'
|
|
||||||
desc: Optional[str] = 'Assign a value to a variable'
|
|
||||||
assigned_variable_selector: Sequence[str]
|
|
||||||
write_mode: WriteMode
|
|
||||||
input_variable_selector: Sequence[str]
|
|
||||||
|
|
||||||
|
|
||||||
class VariableAssignerNode(BaseNode):
|
|
||||||
_node_data_cls: type[BaseNodeData] = VariableAssignerData
|
|
||||||
_node_type: NodeType = NodeType.CONVERSATION_VARIABLE_ASSIGNER
|
|
||||||
|
|
||||||
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
|
|
||||||
data = cast(VariableAssignerData, self.node_data)
|
|
||||||
|
|
||||||
# Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject
|
|
||||||
original_variable = variable_pool.get(data.assigned_variable_selector)
|
|
||||||
if not isinstance(original_variable, Variable):
|
|
||||||
raise VariableAssignerNodeError('assigned variable not found')
|
|
||||||
|
|
||||||
match data.write_mode:
|
|
||||||
case WriteMode.OVER_WRITE:
|
|
||||||
income_value = variable_pool.get(data.input_variable_selector)
|
|
||||||
if not income_value:
|
|
||||||
raise VariableAssignerNodeError('input value not found')
|
|
||||||
updated_variable = original_variable.model_copy(update={'value': income_value.value})
|
|
||||||
|
|
||||||
case WriteMode.APPEND:
|
|
||||||
income_value = variable_pool.get(data.input_variable_selector)
|
|
||||||
if not income_value:
|
|
||||||
raise VariableAssignerNodeError('input value not found')
|
|
||||||
updated_value = original_variable.value + [income_value.value]
|
|
||||||
updated_variable = original_variable.model_copy(update={'value': updated_value})
|
|
||||||
|
|
||||||
case WriteMode.CLEAR:
|
|
||||||
income_value = get_zero_value(original_variable.value_type)
|
|
||||||
updated_variable = original_variable.model_copy(update={'value': income_value.to_object()})
|
|
||||||
|
|
||||||
case _:
|
|
||||||
raise VariableAssignerNodeError(f'unsupported write mode: {data.write_mode}')
|
|
||||||
|
|
||||||
# Over write the variable.
|
|
||||||
variable_pool.add(data.assigned_variable_selector, updated_variable)
|
|
||||||
|
|
||||||
# Update conversation variable.
|
|
||||||
# TODO: Find a better way to use the database.
|
|
||||||
conversation_id = variable_pool.get(['sys', 'conversation_id'])
|
|
||||||
if not conversation_id:
|
|
||||||
raise VariableAssignerNodeError('conversation_id not found')
|
|
||||||
update_conversation_variable(conversation_id=conversation_id.text, variable=updated_variable)
|
|
||||||
|
|
||||||
return NodeRunResult(
|
|
||||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
|
||||||
inputs={
|
|
||||||
'value': income_value.to_object(),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def update_conversation_variable(conversation_id: str, variable: Variable):
|
|
||||||
stmt = select(ConversationVariable).where(
|
|
||||||
ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id
|
|
||||||
)
|
|
||||||
with Session(db.engine) as session:
|
|
||||||
row = session.scalar(stmt)
|
|
||||||
if not row:
|
|
||||||
raise VariableAssignerNodeError('conversation variable not found in the database')
|
|
||||||
row.data = variable.model_dump_json()
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
|
|
||||||
def get_zero_value(t: SegmentType):
|
|
||||||
match t:
|
|
||||||
case SegmentType.ARRAY_OBJECT | SegmentType.ARRAY_STRING | SegmentType.ARRAY_NUMBER:
|
|
||||||
return factory.build_segment([])
|
|
||||||
case SegmentType.OBJECT:
|
|
||||||
return factory.build_segment({})
|
|
||||||
case SegmentType.STRING:
|
|
||||||
return factory.build_segment('')
|
|
||||||
case SegmentType.NUMBER:
|
|
||||||
return factory.build_segment(0)
|
|
||||||
case _:
|
|
||||||
raise VariableAssignerNodeError(f'unsupported variable type: {t}')
|
|
||||||
|
2
api/core/workflow/nodes/variable_assigner/exc.py
Normal file
2
api/core/workflow/nodes/variable_assigner/exc.py
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
class VariableAssignerNodeError(Exception):
|
||||||
|
pass
|
92
api/core/workflow/nodes/variable_assigner/node.py
Normal file
92
api/core/workflow/nodes/variable_assigner/node.py
Normal file
@ -0,0 +1,92 @@
|
|||||||
|
from typing import cast
|
||||||
|
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from core.app.segments import SegmentType, Variable, factory
|
||||||
|
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||||
|
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||||
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
|
from core.workflow.nodes.base_node import BaseNode
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from models import ConversationVariable, WorkflowNodeExecutionStatus
|
||||||
|
|
||||||
|
from .exc import VariableAssignerNodeError
|
||||||
|
from .node_data import VariableAssignerData, WriteMode
|
||||||
|
|
||||||
|
|
||||||
|
class VariableAssignerNode(BaseNode):
|
||||||
|
_node_data_cls: type[BaseNodeData] = VariableAssignerData
|
||||||
|
_node_type: NodeType = NodeType.CONVERSATION_VARIABLE_ASSIGNER
|
||||||
|
|
||||||
|
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
|
||||||
|
data = cast(VariableAssignerData, self.node_data)
|
||||||
|
|
||||||
|
# Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject
|
||||||
|
original_variable = variable_pool.get(data.assigned_variable_selector)
|
||||||
|
if not isinstance(original_variable, Variable):
|
||||||
|
raise VariableAssignerNodeError('assigned variable not found')
|
||||||
|
|
||||||
|
match data.write_mode:
|
||||||
|
case WriteMode.OVER_WRITE:
|
||||||
|
income_value = variable_pool.get(data.input_variable_selector)
|
||||||
|
if not income_value:
|
||||||
|
raise VariableAssignerNodeError('input value not found')
|
||||||
|
updated_variable = original_variable.model_copy(update={'value': income_value.value})
|
||||||
|
|
||||||
|
case WriteMode.APPEND:
|
||||||
|
income_value = variable_pool.get(data.input_variable_selector)
|
||||||
|
if not income_value:
|
||||||
|
raise VariableAssignerNodeError('input value not found')
|
||||||
|
updated_value = original_variable.value + [income_value.value]
|
||||||
|
updated_variable = original_variable.model_copy(update={'value': updated_value})
|
||||||
|
|
||||||
|
case WriteMode.CLEAR:
|
||||||
|
income_value = get_zero_value(original_variable.value_type)
|
||||||
|
updated_variable = original_variable.model_copy(update={'value': income_value.to_object()})
|
||||||
|
|
||||||
|
case _:
|
||||||
|
raise VariableAssignerNodeError(f'unsupported write mode: {data.write_mode}')
|
||||||
|
|
||||||
|
# Over write the variable.
|
||||||
|
variable_pool.add(data.assigned_variable_selector, updated_variable)
|
||||||
|
|
||||||
|
# TODO: Move database operation to the pipeline.
|
||||||
|
# Update conversation variable.
|
||||||
|
conversation_id = variable_pool.get(['sys', 'conversation_id'])
|
||||||
|
if not conversation_id:
|
||||||
|
raise VariableAssignerNodeError('conversation_id not found')
|
||||||
|
update_conversation_variable(conversation_id=conversation_id.text, variable=updated_variable)
|
||||||
|
|
||||||
|
return NodeRunResult(
|
||||||
|
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||||
|
inputs={
|
||||||
|
'value': income_value.to_object(),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def update_conversation_variable(conversation_id: str, variable: Variable):
|
||||||
|
stmt = select(ConversationVariable).where(
|
||||||
|
ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id
|
||||||
|
)
|
||||||
|
with Session(db.engine) as session:
|
||||||
|
row = session.scalar(stmt)
|
||||||
|
if not row:
|
||||||
|
raise VariableAssignerNodeError('conversation variable not found in the database')
|
||||||
|
row.data = variable.model_dump_json()
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
|
||||||
|
def get_zero_value(t: SegmentType):
|
||||||
|
match t:
|
||||||
|
case SegmentType.ARRAY_OBJECT | SegmentType.ARRAY_STRING | SegmentType.ARRAY_NUMBER:
|
||||||
|
return factory.build_segment([])
|
||||||
|
case SegmentType.OBJECT:
|
||||||
|
return factory.build_segment({})
|
||||||
|
case SegmentType.STRING:
|
||||||
|
return factory.build_segment('')
|
||||||
|
case SegmentType.NUMBER:
|
||||||
|
return factory.build_segment(0)
|
||||||
|
case _:
|
||||||
|
raise VariableAssignerNodeError(f'unsupported variable type: {t}')
|
19
api/core/workflow/nodes/variable_assigner/node_data.py
Normal file
19
api/core/workflow/nodes/variable_assigner/node_data.py
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
from collections.abc import Sequence
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||||
|
|
||||||
|
|
||||||
|
class WriteMode(str, Enum):
|
||||||
|
OVER_WRITE = 'over-write'
|
||||||
|
APPEND = 'append'
|
||||||
|
CLEAR = 'clear'
|
||||||
|
|
||||||
|
|
||||||
|
class VariableAssignerData(BaseNodeData):
|
||||||
|
title: str = 'Variable Assigner'
|
||||||
|
desc: Optional[str] = 'Assign a value to a variable'
|
||||||
|
assigned_variable_selector: Sequence[str]
|
||||||
|
write_mode: WriteMode
|
||||||
|
input_variable_selector: Sequence[str]
|
@ -52,7 +52,7 @@ def test_overwrite_string_variable():
|
|||||||
input_variable,
|
input_variable,
|
||||||
)
|
)
|
||||||
|
|
||||||
with mock.patch('core.workflow.nodes.variable_assigner.update_conversation_variable') as mock_run:
|
with mock.patch('core.workflow.nodes.variable_assigner.node.update_conversation_variable') as mock_run:
|
||||||
node.run(variable_pool)
|
node.run(variable_pool)
|
||||||
mock_run.assert_called_once()
|
mock_run.assert_called_once()
|
||||||
|
|
||||||
@ -103,7 +103,7 @@ def test_append_variable_to_array():
|
|||||||
input_variable,
|
input_variable,
|
||||||
)
|
)
|
||||||
|
|
||||||
with mock.patch('core.workflow.nodes.variable_assigner.update_conversation_variable') as mock_run:
|
with mock.patch('core.workflow.nodes.variable_assigner.node.update_conversation_variable') as mock_run:
|
||||||
node.run(variable_pool)
|
node.run(variable_pool)
|
||||||
mock_run.assert_called_once()
|
mock_run.assert_called_once()
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user