From 4f64a5d36deaa10578c8e32a57e4d1ada60f2322 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Tue, 20 Aug 2024 15:40:19 +0800 Subject: [PATCH] refactor(api/core/workflow/nodes/variable_assigner): Split into multi files. (#7434) --- .../nodes/variable_assigner/__init__.py | 115 ++---------------- .../workflow/nodes/variable_assigner/exc.py | 2 + .../workflow/nodes/variable_assigner/node.py | 92 ++++++++++++++ .../nodes/variable_assigner/node_data.py | 19 +++ .../workflow/nodes/test_variable_assigner.py | 4 +- 5 files changed, 122 insertions(+), 110 deletions(-) create mode 100644 api/core/workflow/nodes/variable_assigner/exc.py create mode 100644 api/core/workflow/nodes/variable_assigner/node.py create mode 100644 api/core/workflow/nodes/variable_assigner/node_data.py diff --git a/api/core/workflow/nodes/variable_assigner/__init__.py b/api/core/workflow/nodes/variable_assigner/__init__.py index 552cc367f2..d791d51523 100644 --- a/api/core/workflow/nodes/variable_assigner/__init__.py +++ b/api/core/workflow/nodes/variable_assigner/__init__.py @@ -1,109 +1,8 @@ -from collections.abc import Sequence -from enum import Enum -from typing import Optional, cast +from .node import VariableAssignerNode +from .node_data import VariableAssignerData, WriteMode -from sqlalchemy import select -from sqlalchemy.orm import Session - -from core.app.segments import SegmentType, Variable, factory -from core.workflow.entities.base_node_data_entities import BaseNodeData -from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.entities.variable_pool import VariablePool -from core.workflow.nodes.base_node import BaseNode -from extensions.ext_database import db -from models import ConversationVariable, WorkflowNodeExecutionStatus - - -class VariableAssignerNodeError(Exception): - pass - - -class WriteMode(str, Enum): - OVER_WRITE = 'over-write' - APPEND = 'append' - CLEAR = 'clear' - - -class VariableAssignerData(BaseNodeData): - title: str = 'Variable Assigner' - desc: Optional[str] = 'Assign a value to a variable' - assigned_variable_selector: Sequence[str] - write_mode: WriteMode - input_variable_selector: Sequence[str] - - -class VariableAssignerNode(BaseNode): - _node_data_cls: type[BaseNodeData] = VariableAssignerData - _node_type: NodeType = NodeType.CONVERSATION_VARIABLE_ASSIGNER - - def _run(self, variable_pool: VariablePool) -> NodeRunResult: - data = cast(VariableAssignerData, self.node_data) - - # Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject - original_variable = variable_pool.get(data.assigned_variable_selector) - if not isinstance(original_variable, Variable): - raise VariableAssignerNodeError('assigned variable not found') - - match data.write_mode: - case WriteMode.OVER_WRITE: - income_value = variable_pool.get(data.input_variable_selector) - if not income_value: - raise VariableAssignerNodeError('input value not found') - updated_variable = original_variable.model_copy(update={'value': income_value.value}) - - case WriteMode.APPEND: - income_value = variable_pool.get(data.input_variable_selector) - if not income_value: - raise VariableAssignerNodeError('input value not found') - updated_value = original_variable.value + [income_value.value] - updated_variable = original_variable.model_copy(update={'value': updated_value}) - - case WriteMode.CLEAR: - income_value = get_zero_value(original_variable.value_type) - updated_variable = original_variable.model_copy(update={'value': income_value.to_object()}) - - case _: - raise VariableAssignerNodeError(f'unsupported write mode: {data.write_mode}') - - # Over write the variable. - variable_pool.add(data.assigned_variable_selector, updated_variable) - - # Update conversation variable. - # TODO: Find a better way to use the database. - conversation_id = variable_pool.get(['sys', 'conversation_id']) - if not conversation_id: - raise VariableAssignerNodeError('conversation_id not found') - update_conversation_variable(conversation_id=conversation_id.text, variable=updated_variable) - - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs={ - 'value': income_value.to_object(), - }, - ) - - -def update_conversation_variable(conversation_id: str, variable: Variable): - stmt = select(ConversationVariable).where( - ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id - ) - with Session(db.engine) as session: - row = session.scalar(stmt) - if not row: - raise VariableAssignerNodeError('conversation variable not found in the database') - row.data = variable.model_dump_json() - session.commit() - - -def get_zero_value(t: SegmentType): - match t: - case SegmentType.ARRAY_OBJECT | SegmentType.ARRAY_STRING | SegmentType.ARRAY_NUMBER: - return factory.build_segment([]) - case SegmentType.OBJECT: - return factory.build_segment({}) - case SegmentType.STRING: - return factory.build_segment('') - case SegmentType.NUMBER: - return factory.build_segment(0) - case _: - raise VariableAssignerNodeError(f'unsupported variable type: {t}') +__all__ = [ + 'VariableAssignerNode', + 'VariableAssignerData', + 'WriteMode', +] diff --git a/api/core/workflow/nodes/variable_assigner/exc.py b/api/core/workflow/nodes/variable_assigner/exc.py new file mode 100644 index 0000000000..914be22256 --- /dev/null +++ b/api/core/workflow/nodes/variable_assigner/exc.py @@ -0,0 +1,2 @@ +class VariableAssignerNodeError(Exception): + pass diff --git a/api/core/workflow/nodes/variable_assigner/node.py b/api/core/workflow/nodes/variable_assigner/node.py new file mode 100644 index 0000000000..8c2adcabb9 --- /dev/null +++ b/api/core/workflow/nodes/variable_assigner/node.py @@ -0,0 +1,92 @@ +from typing import cast + +from sqlalchemy import select +from sqlalchemy.orm import Session + +from core.app.segments import SegmentType, Variable, factory +from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.entities.node_entities import NodeRunResult, NodeType +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.nodes.base_node import BaseNode +from extensions.ext_database import db +from models import ConversationVariable, WorkflowNodeExecutionStatus + +from .exc import VariableAssignerNodeError +from .node_data import VariableAssignerData, WriteMode + + +class VariableAssignerNode(BaseNode): + _node_data_cls: type[BaseNodeData] = VariableAssignerData + _node_type: NodeType = NodeType.CONVERSATION_VARIABLE_ASSIGNER + + def _run(self, variable_pool: VariablePool) -> NodeRunResult: + data = cast(VariableAssignerData, self.node_data) + + # Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject + original_variable = variable_pool.get(data.assigned_variable_selector) + if not isinstance(original_variable, Variable): + raise VariableAssignerNodeError('assigned variable not found') + + match data.write_mode: + case WriteMode.OVER_WRITE: + income_value = variable_pool.get(data.input_variable_selector) + if not income_value: + raise VariableAssignerNodeError('input value not found') + updated_variable = original_variable.model_copy(update={'value': income_value.value}) + + case WriteMode.APPEND: + income_value = variable_pool.get(data.input_variable_selector) + if not income_value: + raise VariableAssignerNodeError('input value not found') + updated_value = original_variable.value + [income_value.value] + updated_variable = original_variable.model_copy(update={'value': updated_value}) + + case WriteMode.CLEAR: + income_value = get_zero_value(original_variable.value_type) + updated_variable = original_variable.model_copy(update={'value': income_value.to_object()}) + + case _: + raise VariableAssignerNodeError(f'unsupported write mode: {data.write_mode}') + + # Over write the variable. + variable_pool.add(data.assigned_variable_selector, updated_variable) + + # TODO: Move database operation to the pipeline. + # Update conversation variable. + conversation_id = variable_pool.get(['sys', 'conversation_id']) + if not conversation_id: + raise VariableAssignerNodeError('conversation_id not found') + update_conversation_variable(conversation_id=conversation_id.text, variable=updated_variable) + + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs={ + 'value': income_value.to_object(), + }, + ) + + +def update_conversation_variable(conversation_id: str, variable: Variable): + stmt = select(ConversationVariable).where( + ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id + ) + with Session(db.engine) as session: + row = session.scalar(stmt) + if not row: + raise VariableAssignerNodeError('conversation variable not found in the database') + row.data = variable.model_dump_json() + session.commit() + + +def get_zero_value(t: SegmentType): + match t: + case SegmentType.ARRAY_OBJECT | SegmentType.ARRAY_STRING | SegmentType.ARRAY_NUMBER: + return factory.build_segment([]) + case SegmentType.OBJECT: + return factory.build_segment({}) + case SegmentType.STRING: + return factory.build_segment('') + case SegmentType.NUMBER: + return factory.build_segment(0) + case _: + raise VariableAssignerNodeError(f'unsupported variable type: {t}') diff --git a/api/core/workflow/nodes/variable_assigner/node_data.py b/api/core/workflow/nodes/variable_assigner/node_data.py new file mode 100644 index 0000000000..b3652b6802 --- /dev/null +++ b/api/core/workflow/nodes/variable_assigner/node_data.py @@ -0,0 +1,19 @@ +from collections.abc import Sequence +from enum import Enum +from typing import Optional + +from core.workflow.entities.base_node_data_entities import BaseNodeData + + +class WriteMode(str, Enum): + OVER_WRITE = 'over-write' + APPEND = 'append' + CLEAR = 'clear' + + +class VariableAssignerData(BaseNodeData): + title: str = 'Variable Assigner' + desc: Optional[str] = 'Assign a value to a variable' + assigned_variable_selector: Sequence[str] + write_mode: WriteMode + input_variable_selector: Sequence[str] diff --git a/api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py b/api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py index 0b37d06fc0..78b3cf1415 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py @@ -52,7 +52,7 @@ def test_overwrite_string_variable(): input_variable, ) - with mock.patch('core.workflow.nodes.variable_assigner.update_conversation_variable') as mock_run: + with mock.patch('core.workflow.nodes.variable_assigner.node.update_conversation_variable') as mock_run: node.run(variable_pool) mock_run.assert_called_once() @@ -103,7 +103,7 @@ def test_append_variable_to_array(): input_variable, ) - with mock.patch('core.workflow.nodes.variable_assigner.update_conversation_variable') as mock_run: + with mock.patch('core.workflow.nodes.variable_assigner.node.update_conversation_variable') as mock_run: node.run(variable_pool) mock_run.assert_called_once()