From 4f5f27cf2ba6322b4a9fbd7a98fa791f3083c8bf Mon Sep 17 00:00:00 2001 From: -LAN- Date: Tue, 20 Aug 2024 17:52:06 +0800 Subject: [PATCH] refactor(api/core/workflow/enums.py): Rename SystemVariable to SystemVariableKey. (#7445) --- .../app/apps/advanced_chat/app_generator.py | 32 +++++++------- .../advanced_chat/generate_task_pipeline.py | 12 +++--- api/core/app/apps/workflow/app_runner.py | 6 +-- .../apps/workflow/generate_task_pipeline.py | 8 ++-- .../workflow_cycle_state_manager.py | 4 +- api/core/workflow/entities/variable_pool.py | 18 ++++---- api/core/workflow/enums.py | 24 +++-------- api/core/workflow/nodes/llm/llm_node.py | 12 +++--- api/core/workflow/nodes/start/start_node.py | 14 +++---- api/core/workflow/nodes/tool/tool_node.py | 4 +- .../workflow/nodes/test_llm.py | 18 ++++---- .../nodes/test_parameter_extractor.py | 42 +++++++++---------- .../core/app/segments/test_segment.py | 6 +-- .../core/workflow/nodes/test_answer.py | 6 +-- .../core/workflow/nodes/test_if_else.py | 10 ++--- .../workflow/nodes/test_variable_assigner.py | 8 ++-- 16 files changed, 106 insertions(+), 118 deletions(-) diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index 351eb05d8a..5a1e5973cd 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -29,7 +29,7 @@ from core.file.message_file_parser import MessageFileParser from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.ops.ops_trace_manager import TraceQueueManager from core.workflow.entities.variable_pool import VariablePool -from core.workflow.enums import SystemVariable +from core.workflow.enums import SystemVariableKey from extensions.ext_database import db from models.account import Account from models.model import App, Conversation, EndUser, Message @@ -46,7 +46,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): args: dict, invoke_from: InvokeFrom, stream: bool = True, - ) -> Union[dict, Generator[dict, None, None]]: + ): """ Generate App response. @@ -73,8 +73,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): # get conversation conversation = None - if args.get('conversation_id'): - conversation = self._get_conversation_by_user(app_model, args.get('conversation_id'), user) + conversation_id = args.get('conversation_id') + if conversation_id: + conversation = self._get_conversation_by_user(app_model=app_model, conversation_id=conversation_id, user=user) # parse files files = args['files'] if args.get('files') else [] @@ -133,8 +134,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): node_id: str, user: Account, args: dict, - stream: bool = True) \ - -> Union[dict, Generator[dict, None, None]]: + stream: bool = True): """ Generate App response. @@ -157,8 +157,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): # get conversation conversation = None - if args.get('conversation_id'): - conversation = self._get_conversation_by_user(app_model, args.get('conversation_id'), user) + conversation_id = args.get('conversation_id') + if conversation_id: + conversation = self._get_conversation_by_user(app_model=app_model, conversation_id=conversation_id, user=user) # convert to app config app_config = AdvancedChatAppConfigManager.get_app_config( @@ -200,8 +201,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): invoke_from: InvokeFrom, application_generate_entity: AdvancedChatAppGenerateEntity, conversation: Conversation | None = None, - stream: bool = True) \ - -> Union[dict, Generator[dict, None, None]]: + stream: bool = True): is_first_conversation = False if not conversation: is_first_conversation = True @@ -270,11 +270,11 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): # Create a variable pool. system_inputs = { - SystemVariable.QUERY: query, - SystemVariable.FILES: files, - SystemVariable.CONVERSATION_ID: conversation_id, - SystemVariable.USER_ID: user_id, - SystemVariable.DIALOGUE_COUNT: conversation_dialogue_count, + SystemVariableKey.QUERY: query, + SystemVariableKey.FILES: files, + SystemVariableKey.CONVERSATION_ID: conversation_id, + SystemVariableKey.USER_ID: user_id, + SystemVariableKey.DIALOGUE_COUNT: conversation_dialogue_count, } variable_pool = VariablePool( system_variables=system_inputs, @@ -362,7 +362,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): logger.exception("Validation Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) except (ValueError, InvokeError) as e: - if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true': + if os.environ.get("DEBUG", "false").lower() == 'true': logger.exception("Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) except Exception as e: diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index ac51a4e840..2b3596ded2 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -49,7 +49,7 @@ from core.model_runtime.entities.llm_entities import LLMUsage from core.model_runtime.utils.encoders import jsonable_encoder from core.ops.ops_trace_manager import TraceQueueManager from core.workflow.entities.node_entities import NodeType -from core.workflow.enums import SystemVariable +from core.workflow.enums import SystemVariableKey from core.workflow.nodes.answer.answer_node import AnswerNode from core.workflow.nodes.answer.entities import TextGenerateRouteChunk, VarGenerateRouteChunk from events.message_event import message_was_created @@ -74,7 +74,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc _workflow: Workflow _user: Union[Account, EndUser] # Deprecated - _workflow_system_variables: dict[SystemVariable, Any] + _workflow_system_variables: dict[SystemVariableKey, Any] _iteration_nested_relations: dict[str, list[str]] def __init__( @@ -108,10 +108,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc self._message = message # Deprecated self._workflow_system_variables = { - SystemVariable.QUERY: message.query, - SystemVariable.FILES: application_generate_entity.files, - SystemVariable.CONVERSATION_ID: conversation.id, - SystemVariable.USER_ID: user_id, + SystemVariableKey.QUERY: message.query, + SystemVariableKey.FILES: application_generate_entity.files, + SystemVariableKey.CONVERSATION_ID: conversation.id, + SystemVariableKey.USER_ID: user_id, } self._task_state = AdvancedChatTaskState( diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py index 994919391e..e388d0184b 100644 --- a/api/core/app/apps/workflow/app_runner.py +++ b/api/core/app/apps/workflow/app_runner.py @@ -12,7 +12,7 @@ from core.app.entities.app_invoke_entities import ( ) from core.workflow.callbacks.base_workflow_callback import WorkflowCallback from core.workflow.entities.variable_pool import VariablePool -from core.workflow.enums import SystemVariable +from core.workflow.enums import SystemVariableKey from core.workflow.nodes.base_node import UserFrom from core.workflow.workflow_engine_manager import WorkflowEngineManager from extensions.ext_database import db @@ -67,8 +67,8 @@ class WorkflowAppRunner: # Create a variable pool. system_inputs = { - SystemVariable.FILES: files, - SystemVariable.USER_ID: user_id, + SystemVariableKey.FILES: files, + SystemVariableKey.USER_ID: user_id, } variable_pool = VariablePool( system_variables=system_inputs, diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index 5022eb0438..de8542d7b9 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -43,7 +43,7 @@ from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTas from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage from core.ops.ops_trace_manager import TraceQueueManager from core.workflow.entities.node_entities import NodeType -from core.workflow.enums import SystemVariable +from core.workflow.enums import SystemVariableKey from core.workflow.nodes.end.end_node import EndNode from extensions.ext_database import db from models.account import Account @@ -67,7 +67,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa _user: Union[Account, EndUser] _task_state: WorkflowTaskState _application_generate_entity: WorkflowAppGenerateEntity - _workflow_system_variables: dict[SystemVariable, Any] + _workflow_system_variables: dict[SystemVariableKey, Any] _iteration_nested_relations: dict[str, list[str]] def __init__(self, application_generate_entity: WorkflowAppGenerateEntity, @@ -92,8 +92,8 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa self._workflow = workflow self._workflow_system_variables = { - SystemVariable.FILES: application_generate_entity.files, - SystemVariable.USER_ID: user_id + SystemVariableKey.FILES: application_generate_entity.files, + SystemVariableKey.USER_ID: user_id } self._task_state = WorkflowTaskState( diff --git a/api/core/app/task_pipeline/workflow_cycle_state_manager.py b/api/core/app/task_pipeline/workflow_cycle_state_manager.py index 8baa8ba09e..bd98c82720 100644 --- a/api/core/app/task_pipeline/workflow_cycle_state_manager.py +++ b/api/core/app/task_pipeline/workflow_cycle_state_manager.py @@ -2,7 +2,7 @@ from typing import Any, Union from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity from core.app.entities.task_entities import AdvancedChatTaskState, WorkflowTaskState -from core.workflow.enums import SystemVariable +from core.workflow.enums import SystemVariableKey from models.account import Account from models.model import EndUser from models.workflow import Workflow @@ -13,4 +13,4 @@ class WorkflowCycleStateManager: _workflow: Workflow _user: Union[Account, EndUser] _task_state: Union[AdvancedChatTaskState, WorkflowTaskState] - _workflow_system_variables: dict[SystemVariable, Any] + _workflow_system_variables: dict[SystemVariableKey, Any] diff --git a/api/core/workflow/entities/variable_pool.py b/api/core/workflow/entities/variable_pool.py index 9fe3356faa..8120b2ac78 100644 --- a/api/core/workflow/entities/variable_pool.py +++ b/api/core/workflow/entities/variable_pool.py @@ -6,20 +6,20 @@ from typing_extensions import deprecated from core.app.segments import Segment, Variable, factory from core.file.file_obj import FileVar -from core.workflow.enums import SystemVariable +from core.workflow.enums import SystemVariableKey VariableValue = Union[str, int, float, dict, list, FileVar] -SYSTEM_VARIABLE_NODE_ID = 'sys' -ENVIRONMENT_VARIABLE_NODE_ID = 'env' -CONVERSATION_VARIABLE_NODE_ID = 'conversation' +SYSTEM_VARIABLE_NODE_ID = "sys" +ENVIRONMENT_VARIABLE_NODE_ID = "env" +CONVERSATION_VARIABLE_NODE_ID = "conversation" class VariablePool: def __init__( self, - system_variables: Mapping[SystemVariable, Any], + system_variables: Mapping[SystemVariableKey, Any], user_inputs: Mapping[str, Any], environment_variables: Sequence[Variable], conversation_variables: Sequence[Variable] | None = None, @@ -68,7 +68,7 @@ class VariablePool: None """ if len(selector) < 2: - raise ValueError('Invalid selector') + raise ValueError("Invalid selector") if value is None: return @@ -95,13 +95,13 @@ class VariablePool: ValueError: If the selector is invalid. """ if len(selector) < 2: - raise ValueError('Invalid selector') + raise ValueError("Invalid selector") hash_key = hash(tuple(selector[1:])) value = self._variable_dictionary[selector[0]].get(hash_key) return value - @deprecated('This method is deprecated, use `get` instead.') + @deprecated("This method is deprecated, use `get` instead.") def get_any(self, selector: Sequence[str], /) -> Any | None: """ Retrieves the value from the variable pool based on the given selector. @@ -116,7 +116,7 @@ class VariablePool: ValueError: If the selector is invalid. """ if len(selector) < 2: - raise ValueError('Invalid selector') + raise ValueError("Invalid selector") hash_key = hash(tuple(selector[1:])) value = self._variable_dictionary[selector[0]].get(hash_key) return value.to_object() if value else None diff --git a/api/core/workflow/enums.py b/api/core/workflow/enums.py index 4757cf32f8..da65f6b1fb 100644 --- a/api/core/workflow/enums.py +++ b/api/core/workflow/enums.py @@ -1,25 +1,13 @@ from enum import Enum -class SystemVariable(str, Enum): +class SystemVariableKey(str, Enum): """ System Variables. """ - QUERY = 'query' - FILES = 'files' - CONVERSATION_ID = 'conversation_id' - USER_ID = 'user_id' - DIALOGUE_COUNT = 'dialogue_count' - @classmethod - def value_of(cls, value: str): - """ - Get value of given system variable. - - :param value: system variable value - :return: system variable - """ - for system_variable in cls: - if system_variable.value == value: - return system_variable - raise ValueError(f'invalid system variable value {value}') + QUERY = "query" + FILES = "files" + CONVERSATION_ID = "conversation_id" + USER_ID = "user_id" + DIALOGUE_COUNT = "dialogue_count" diff --git a/api/core/workflow/nodes/llm/llm_node.py b/api/core/workflow/nodes/llm/llm_node.py index c20e0d4506..c3e4949421 100644 --- a/api/core/workflow/nodes/llm/llm_node.py +++ b/api/core/workflow/nodes/llm/llm_node.py @@ -24,7 +24,7 @@ from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptT from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool -from core.workflow.enums import SystemVariable +from core.workflow.enums import SystemVariableKey from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.llm.entities import ( LLMNodeChatModelMessage, @@ -94,7 +94,7 @@ class LLMNode(BaseNode): # fetch prompt messages prompt_messages, stop = self._fetch_prompt_messages( node_data=node_data, - query=variable_pool.get_any(['sys', SystemVariable.QUERY.value]) + query=variable_pool.get_any(['sys', SystemVariableKey.QUERY.value]) if node_data.memory else None, query_prompt_template=node_data.memory.query_prompt_template if node_data.memory else None, inputs=inputs, @@ -335,7 +335,7 @@ class LLMNode(BaseNode): if not node_data.vision.enabled: return [] - files = variable_pool.get_any(['sys', SystemVariable.FILES.value]) + files = variable_pool.get_any(['sys', SystemVariableKey.FILES.value]) if not files: return [] @@ -500,7 +500,7 @@ class LLMNode(BaseNode): return None # get conversation id - conversation_id = variable_pool.get_any(['sys', SystemVariable.CONVERSATION_ID.value]) + conversation_id = variable_pool.get_any(['sys', SystemVariableKey.CONVERSATION_ID.value]) if conversation_id is None: return None @@ -672,10 +672,10 @@ class LLMNode(BaseNode): variable_mapping['#context#'] = node_data.context.variable_selector if node_data.vision.enabled: - variable_mapping['#files#'] = ['sys', SystemVariable.FILES.value] + variable_mapping['#files#'] = ['sys', SystemVariableKey.FILES.value] if node_data.memory: - variable_mapping['#sys.query#'] = ['sys', SystemVariable.QUERY.value] + variable_mapping['#sys.query#'] = ['sys', SystemVariableKey.QUERY.value] if node_data.prompt_config: enable_jinja = False diff --git a/api/core/workflow/nodes/start/start_node.py b/api/core/workflow/nodes/start/start_node.py index 661b403d32..54e66bd671 100644 --- a/api/core/workflow/nodes/start/start_node.py +++ b/api/core/workflow/nodes/start/start_node.py @@ -1,7 +1,7 @@ from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.entities.variable_pool import VariablePool +from core.workflow.entities.variable_pool import SYSTEM_VARIABLE_NODE_ID, VariablePool from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.start.entities import StartNodeData from models.workflow import WorkflowNodeExecutionStatus @@ -17,16 +17,16 @@ class StartNode(BaseNode): :param variable_pool: variable pool :return: """ - # Get cleaned inputs - cleaned_inputs = dict(variable_pool.user_inputs) + node_inputs = dict(variable_pool.user_inputs) + system_inputs = variable_pool.system_variables - for var in variable_pool.system_variables: - cleaned_inputs['sys.' + var.value] = variable_pool.system_variables[var] + for var in system_inputs: + node_inputs[SYSTEM_VARIABLE_NODE_ID + '.' + var] = system_inputs[var] return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=cleaned_inputs, - outputs=cleaned_inputs + inputs=node_inputs, + outputs=node_inputs ) @classmethod diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index 554e3b6074..9b52cd2f6b 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -11,7 +11,7 @@ from core.tools.tool_manager import ToolManager from core.tools.utils.message_transformer import ToolFileMessageTransformer from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool -from core.workflow.enums import SystemVariable +from core.workflow.enums import SystemVariableKey from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.tool.entities import ToolNodeData from core.workflow.utils.variable_template_parser import VariableTemplateParser @@ -141,7 +141,7 @@ class ToolNode(BaseNode): return result def _fetch_files(self, variable_pool: VariablePool) -> list[FileVar]: - variable = variable_pool.get(['sys', SystemVariable.FILES.value]) + variable = variable_pool.get(['sys', SystemVariableKey.FILES.value]) assert isinstance(variable, ArrayAnyVariable) return list(variable.value) if variable else [] diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py index 4686ce0675..1b27af5af7 100644 --- a/api/tests/integration_tests/workflow/nodes/test_llm.py +++ b/api/tests/integration_tests/workflow/nodes/test_llm.py @@ -11,7 +11,7 @@ from core.model_manager import ModelInstance from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.model_providers import ModelProviderFactory from core.workflow.entities.variable_pool import VariablePool -from core.workflow.enums import SystemVariable +from core.workflow.enums import SystemVariableKey from core.workflow.nodes.base_node import UserFrom from core.workflow.nodes.llm.llm_node import LLMNode from extensions.ext_database import db @@ -66,10 +66,10 @@ def test_execute_llm(setup_openai_mock): # construct variable pool pool = VariablePool(system_variables={ - SystemVariable.QUERY: 'what\'s the weather today?', - SystemVariable.FILES: [], - SystemVariable.CONVERSATION_ID: 'abababa', - SystemVariable.USER_ID: 'aaa' + SystemVariableKey.QUERY: 'what\'s the weather today?', + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: 'abababa', + SystemVariableKey.USER_ID: 'aaa' }, user_inputs={}, environment_variables=[]) pool.add(['abc', 'output'], 'sunny') @@ -181,10 +181,10 @@ def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_openai_mock): # construct variable pool pool = VariablePool(system_variables={ - SystemVariable.QUERY: 'what\'s the weather today?', - SystemVariable.FILES: [], - SystemVariable.CONVERSATION_ID: 'abababa', - SystemVariable.USER_ID: 'aaa' + SystemVariableKey.QUERY: 'what\'s the weather today?', + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: 'abababa', + SystemVariableKey.USER_ID: 'aaa' }, user_inputs={}, environment_variables=[]) pool.add(['abc', 'output'], 'sunny') diff --git a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py index adf5ffe3ca..e32fa59df3 100644 --- a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py +++ b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py @@ -13,7 +13,7 @@ from core.model_manager import ModelInstance from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from core.workflow.entities.variable_pool import VariablePool -from core.workflow.enums import SystemVariable +from core.workflow.enums import SystemVariableKey from core.workflow.nodes.base_node import UserFrom from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode from extensions.ext_database import db @@ -119,10 +119,10 @@ def test_function_calling_parameter_extractor(setup_openai_mock): # construct variable pool pool = VariablePool(system_variables={ - SystemVariable.QUERY: 'what\'s the weather in SF', - SystemVariable.FILES: [], - SystemVariable.CONVERSATION_ID: 'abababa', - SystemVariable.USER_ID: 'aaa' + SystemVariableKey.QUERY: 'what\'s the weather in SF', + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: 'abababa', + SystemVariableKey.USER_ID: 'aaa' }, user_inputs={}, environment_variables=[]) result = node.run(pool) @@ -177,10 +177,10 @@ def test_instructions(setup_openai_mock): # construct variable pool pool = VariablePool(system_variables={ - SystemVariable.QUERY: 'what\'s the weather in SF', - SystemVariable.FILES: [], - SystemVariable.CONVERSATION_ID: 'abababa', - SystemVariable.USER_ID: 'aaa' + SystemVariableKey.QUERY: 'what\'s the weather in SF', + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: 'abababa', + SystemVariableKey.USER_ID: 'aaa' }, user_inputs={}, environment_variables=[]) result = node.run(pool) @@ -243,10 +243,10 @@ def test_chat_parameter_extractor(setup_anthropic_mock): # construct variable pool pool = VariablePool(system_variables={ - SystemVariable.QUERY: 'what\'s the weather in SF', - SystemVariable.FILES: [], - SystemVariable.CONVERSATION_ID: 'abababa', - SystemVariable.USER_ID: 'aaa' + SystemVariableKey.QUERY: 'what\'s the weather in SF', + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: 'abababa', + SystemVariableKey.USER_ID: 'aaa' }, user_inputs={}, environment_variables=[]) result = node.run(pool) @@ -307,10 +307,10 @@ def test_completion_parameter_extractor(setup_openai_mock): # construct variable pool pool = VariablePool(system_variables={ - SystemVariable.QUERY: 'what\'s the weather in SF', - SystemVariable.FILES: [], - SystemVariable.CONVERSATION_ID: 'abababa', - SystemVariable.USER_ID: 'aaa' + SystemVariableKey.QUERY: 'what\'s the weather in SF', + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: 'abababa', + SystemVariableKey.USER_ID: 'aaa' }, user_inputs={}, environment_variables=[]) result = node.run(pool) @@ -420,10 +420,10 @@ def test_chat_parameter_extractor_with_memory(setup_anthropic_mock): # construct variable pool pool = VariablePool(system_variables={ - SystemVariable.QUERY: 'what\'s the weather in SF', - SystemVariable.FILES: [], - SystemVariable.CONVERSATION_ID: 'abababa', - SystemVariable.USER_ID: 'aaa' + SystemVariableKey.QUERY: 'what\'s the weather in SF', + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: 'abababa', + SystemVariableKey.USER_ID: 'aaa' }, user_inputs={}, environment_variables=[]) result = node.run(pool) diff --git a/api/tests/unit_tests/core/app/segments/test_segment.py b/api/tests/unit_tests/core/app/segments/test_segment.py index 7e3e69ffbf..50d991316d 100644 --- a/api/tests/unit_tests/core/app/segments/test_segment.py +++ b/api/tests/unit_tests/core/app/segments/test_segment.py @@ -1,13 +1,13 @@ from core.app.segments import SecretVariable, StringSegment, parser from core.helper import encrypter from core.workflow.entities.variable_pool import VariablePool -from core.workflow.enums import SystemVariable +from core.workflow.enums import SystemVariableKey def test_segment_group_to_text(): variable_pool = VariablePool( system_variables={ - SystemVariable('user_id'): 'fake-user-id', + SystemVariableKey('user_id'): 'fake-user-id', }, user_inputs={}, environment_variables=[ @@ -42,7 +42,7 @@ def test_convert_constant_to_segment_group(): def test_convert_variable_to_segment_group(): variable_pool = VariablePool( system_variables={ - SystemVariable('user_id'): 'fake-user-id', + SystemVariableKey('user_id'): 'fake-user-id', }, user_inputs={}, environment_variables=[], diff --git a/api/tests/unit_tests/core/workflow/nodes/test_answer.py b/api/tests/unit_tests/core/workflow/nodes/test_answer.py index 4617b6a42f..44b7c85256 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_answer.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_answer.py @@ -2,7 +2,7 @@ from unittest.mock import MagicMock from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.entities.variable_pool import VariablePool -from core.workflow.enums import SystemVariable +from core.workflow.enums import SystemVariableKey from core.workflow.nodes.answer.answer_node import AnswerNode from core.workflow.nodes.base_node import UserFrom from extensions.ext_database import db @@ -29,8 +29,8 @@ def test_execute_answer(): # construct variable pool pool = VariablePool(system_variables={ - SystemVariable.FILES: [], - SystemVariable.USER_ID: 'aaa' + SystemVariableKey.FILES: [], + SystemVariableKey.USER_ID: 'aaa' }, user_inputs={}, environment_variables=[]) pool.add(['start', 'weather'], 'sunny') pool.add(['llm', 'text'], 'You are a helpful AI.') diff --git a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py index d21b7785c4..87ebcb34e6 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py @@ -2,7 +2,7 @@ from unittest.mock import MagicMock from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.entities.variable_pool import VariablePool -from core.workflow.enums import SystemVariable +from core.workflow.enums import SystemVariableKey from core.workflow.nodes.base_node import UserFrom from core.workflow.nodes.if_else.if_else_node import IfElseNode from extensions.ext_database import db @@ -119,8 +119,8 @@ def test_execute_if_else_result_true(): # construct variable pool pool = VariablePool(system_variables={ - SystemVariable.FILES: [], - SystemVariable.USER_ID: 'aaa' + SystemVariableKey.FILES: [], + SystemVariableKey.USER_ID: 'aaa' }, user_inputs={}, environment_variables=[]) pool.add(['start', 'array_contains'], ['ab', 'def']) pool.add(['start', 'array_not_contains'], ['ac', 'def']) @@ -182,8 +182,8 @@ def test_execute_if_else_result_false(): # construct variable pool pool = VariablePool(system_variables={ - SystemVariable.FILES: [], - SystemVariable.USER_ID: 'aaa' + SystemVariableKey.FILES: [], + SystemVariableKey.USER_ID: 'aaa' }, user_inputs={}, environment_variables=[]) pool.add(['start', 'array_contains'], ['1ab', 'def']) pool.add(['start', 'array_not_contains'], ['ab', 'def']) diff --git a/api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py b/api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py index 78b3cf1415..5df8c1b763 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py @@ -4,7 +4,7 @@ from uuid import uuid4 from core.app.entities.app_invoke_entities import InvokeFrom from core.app.segments import ArrayStringVariable, StringVariable from core.workflow.entities.variable_pool import VariablePool -from core.workflow.enums import SystemVariable +from core.workflow.enums import SystemVariableKey from core.workflow.nodes.base_node import UserFrom from core.workflow.nodes.variable_assigner import VariableAssignerNode, WriteMode @@ -42,7 +42,7 @@ def test_overwrite_string_variable(): ) variable_pool = VariablePool( - system_variables={SystemVariable.CONVERSATION_ID: 'conversation_id'}, + system_variables={SystemVariableKey.CONVERSATION_ID: 'conversation_id'}, user_inputs={}, environment_variables=[], conversation_variables=[conversation_variable], @@ -93,7 +93,7 @@ def test_append_variable_to_array(): ) variable_pool = VariablePool( - system_variables={SystemVariable.CONVERSATION_ID: 'conversation_id'}, + system_variables={SystemVariableKey.CONVERSATION_ID: 'conversation_id'}, user_inputs={}, environment_variables=[], conversation_variables=[conversation_variable], @@ -137,7 +137,7 @@ def test_clear_array(): ) variable_pool = VariablePool( - system_variables={SystemVariable.CONVERSATION_ID: 'conversation_id'}, + system_variables={SystemVariableKey.CONVERSATION_ID: 'conversation_id'}, user_inputs={}, environment_variables=[], conversation_variables=[conversation_variable],