From 2ea8c73cd8a21d5cd13045c335083e366cfcc976 Mon Sep 17 00:00:00 2001 From: takatost Date: Wed, 24 Apr 2024 15:07:56 +0800 Subject: [PATCH] fix: type num of variable converted to str (#3758) --- api/core/app/apps/base_app_generator.py | 14 +++++-- api/core/app/entities/app_invoke_entities.py | 2 +- api/core/prompt/advanced_prompt_transform.py | 2 + api/core/prompt/simple_prompt_transform.py | 2 + api/core/workflow/nodes/start/start_node.py | 44 +------------------- 5 files changed, 17 insertions(+), 47 deletions(-) diff --git a/api/core/app/apps/base_app_generator.py b/api/core/app/apps/base_app_generator.py index 750c6dae10..9d88c834e6 100644 --- a/api/core/app/apps/base_app_generator.py +++ b/api/core/app/apps/base_app_generator.py @@ -23,20 +23,28 @@ class BaseAppGenerator: value = user_inputs[variable] if value: - if not isinstance(value, str): + if variable_config.type != VariableEntity.Type.NUMBER and not isinstance(value, str): raise ValueError(f"{variable} in input form must be a string") + elif variable_config.type == VariableEntity.Type.NUMBER and isinstance(value, str): + if '.' in value: + value = float(value) + else: + value = int(value) if variable_config.type == VariableEntity.Type.SELECT: options = variable_config.options if variable_config.options is not None else [] if value not in options: raise ValueError(f"{variable} in input form must be one of the following: {options}") - else: + elif variable_config.type in [VariableEntity.Type.TEXT_INPUT, VariableEntity.Type.PARAGRAPH]: if variable_config.max_length is not None: max_length = variable_config.max_length if len(value) > max_length: raise ValueError(f'{variable} in input form must be less than {max_length} characters') - filtered_inputs[variable] = value.replace('\x00', '') if value else None + if value and isinstance(value, str): + filtered_inputs[variable] = value.replace('\x00', '') + else: + filtered_inputs[variable] = value if value else None return filtered_inputs diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index c05a8a77d0..09c62c802c 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -72,7 +72,7 @@ class AppGenerateEntity(BaseModel): # app config app_config: AppConfig - inputs: dict[str, str] + inputs: dict[str, Any] files: list[FileVar] = [] user_id: str diff --git a/api/core/prompt/advanced_prompt_transform.py b/api/core/prompt/advanced_prompt_transform.py index 674ba29b6e..d3480d2c47 100644 --- a/api/core/prompt/advanced_prompt_transform.py +++ b/api/core/prompt/advanced_prompt_transform.py @@ -32,6 +32,8 @@ class AdvancedPromptTransform(PromptTransform): memory_config: Optional[MemoryConfig], memory: Optional[TokenBufferMemory], model_config: ModelConfigWithCredentialsEntity) -> list[PromptMessage]: + inputs = {key: str(value) for key, value in inputs.items()} + prompt_messages = [] model_mode = ModelMode.value_of(model_config.mode) diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index cd74438337..9b0c96b8bf 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -55,6 +55,8 @@ class SimplePromptTransform(PromptTransform): memory: Optional[TokenBufferMemory], model_config: ModelConfigWithCredentialsEntity) -> \ tuple[list[PromptMessage], Optional[list[str]]]: + inputs = {key: str(value) for key, value in inputs.items()} + model_mode = ModelMode.value_of(model_config.mode) if model_mode == ModelMode.CHAT: prompt_messages, stops = self._get_chat_model_prompt_messages( diff --git a/api/core/workflow/nodes/start/start_node.py b/api/core/workflow/nodes/start/start_node.py index b897767e9d..e32e850a23 100644 --- a/api/core/workflow/nodes/start/start_node.py +++ b/api/core/workflow/nodes/start/start_node.py @@ -1,6 +1,4 @@ -from typing import cast -from core.app.app_config.entities import VariableEntity from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType, SystemVariable from core.workflow.entities.variable_pool import VariablePool @@ -19,12 +17,8 @@ class StartNode(BaseNode): :param variable_pool: variable pool :return: """ - node_data = self.node_data - node_data = cast(self._node_data_cls, node_data) - variables = node_data.variables - # Get cleaned inputs - cleaned_inputs = self._get_cleaned_inputs(variables, variable_pool.user_inputs) + cleaned_inputs = variable_pool.user_inputs for var in variable_pool.system_variables: if var == SystemVariable.CONVERSATION: @@ -38,42 +32,6 @@ class StartNode(BaseNode): outputs=cleaned_inputs ) - def _get_cleaned_inputs(self, variables: list[VariableEntity], user_inputs: dict): - if user_inputs is None: - user_inputs = {} - - filtered_inputs = {} - - for variable_config in variables: - variable = variable_config.variable - - if variable not in user_inputs or not user_inputs[variable]: - if variable_config.required: - raise ValueError(f"Input form variable {variable} is required") - else: - filtered_inputs[variable] = variable_config.default if variable_config.default is not None else "" - continue - - value = user_inputs[variable] - - if value: - if not isinstance(value, str): - raise ValueError(f"{variable} in input form must be a string") - - if variable_config.type == VariableEntity.Type.SELECT: - options = variable_config.options if variable_config.options is not None else [] - if value not in options: - raise ValueError(f"{variable} in input form must be one of the following: {options}") - else: - if variable_config.max_length is not None: - max_length = variable_config.max_length - if len(value) > max_length: - raise ValueError(f'{variable} in input form must be less than {max_length} characters') - - filtered_inputs[variable] = value.replace('\x00', '') if value else None - - return filtered_inputs - @classmethod def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: """