fix: type num of variable converted to str (#3758)

This commit is contained in:
takatost 2024-04-24 15:07:56 +08:00 committed by GitHub
parent f257f2c396
commit 2ea8c73cd8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 17 additions and 47 deletions

View File

@ -23,20 +23,28 @@ class BaseAppGenerator:
value = user_inputs[variable] value = user_inputs[variable]
if value: if value:
if not isinstance(value, str): if variable_config.type != VariableEntity.Type.NUMBER and not isinstance(value, str):
raise ValueError(f"{variable} in input form must be a string") raise ValueError(f"{variable} in input form must be a string")
elif variable_config.type == VariableEntity.Type.NUMBER and isinstance(value, str):
if '.' in value:
value = float(value)
else:
value = int(value)
if variable_config.type == VariableEntity.Type.SELECT: if variable_config.type == VariableEntity.Type.SELECT:
options = variable_config.options if variable_config.options is not None else [] options = variable_config.options if variable_config.options is not None else []
if value not in options: if value not in options:
raise ValueError(f"{variable} in input form must be one of the following: {options}") raise ValueError(f"{variable} in input form must be one of the following: {options}")
else: elif variable_config.type in [VariableEntity.Type.TEXT_INPUT, VariableEntity.Type.PARAGRAPH]:
if variable_config.max_length is not None: if variable_config.max_length is not None:
max_length = variable_config.max_length max_length = variable_config.max_length
if len(value) > max_length: if len(value) > max_length:
raise ValueError(f'{variable} in input form must be less than {max_length} characters') raise ValueError(f'{variable} in input form must be less than {max_length} characters')
filtered_inputs[variable] = value.replace('\x00', '') if value else None if value and isinstance(value, str):
filtered_inputs[variable] = value.replace('\x00', '')
else:
filtered_inputs[variable] = value if value else None
return filtered_inputs return filtered_inputs

View File

@ -72,7 +72,7 @@ class AppGenerateEntity(BaseModel):
# app config # app config
app_config: AppConfig app_config: AppConfig
inputs: dict[str, str] inputs: dict[str, Any]
files: list[FileVar] = [] files: list[FileVar] = []
user_id: str user_id: str

View File

@ -32,6 +32,8 @@ class AdvancedPromptTransform(PromptTransform):
memory_config: Optional[MemoryConfig], memory_config: Optional[MemoryConfig],
memory: Optional[TokenBufferMemory], memory: Optional[TokenBufferMemory],
model_config: ModelConfigWithCredentialsEntity) -> list[PromptMessage]: model_config: ModelConfigWithCredentialsEntity) -> list[PromptMessage]:
inputs = {key: str(value) for key, value in inputs.items()}
prompt_messages = [] prompt_messages = []
model_mode = ModelMode.value_of(model_config.mode) model_mode = ModelMode.value_of(model_config.mode)

View File

@ -55,6 +55,8 @@ class SimplePromptTransform(PromptTransform):
memory: Optional[TokenBufferMemory], memory: Optional[TokenBufferMemory],
model_config: ModelConfigWithCredentialsEntity) -> \ model_config: ModelConfigWithCredentialsEntity) -> \
tuple[list[PromptMessage], Optional[list[str]]]: tuple[list[PromptMessage], Optional[list[str]]]:
inputs = {key: str(value) for key, value in inputs.items()}
model_mode = ModelMode.value_of(model_config.mode) model_mode = ModelMode.value_of(model_config.mode)
if model_mode == ModelMode.CHAT: if model_mode == ModelMode.CHAT:
prompt_messages, stops = self._get_chat_model_prompt_messages( prompt_messages, stops = self._get_chat_model_prompt_messages(

View File

@ -1,6 +1,4 @@
from typing import cast
from core.app.app_config.entities import VariableEntity
from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType, SystemVariable from core.workflow.entities.node_entities import NodeRunResult, NodeType, SystemVariable
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
@ -19,12 +17,8 @@ class StartNode(BaseNode):
:param variable_pool: variable pool :param variable_pool: variable pool
:return: :return:
""" """
node_data = self.node_data
node_data = cast(self._node_data_cls, node_data)
variables = node_data.variables
# Get cleaned inputs # Get cleaned inputs
cleaned_inputs = self._get_cleaned_inputs(variables, variable_pool.user_inputs) cleaned_inputs = variable_pool.user_inputs
for var in variable_pool.system_variables: for var in variable_pool.system_variables:
if var == SystemVariable.CONVERSATION: if var == SystemVariable.CONVERSATION:
@ -38,42 +32,6 @@ class StartNode(BaseNode):
outputs=cleaned_inputs outputs=cleaned_inputs
) )
def _get_cleaned_inputs(self, variables: list[VariableEntity], user_inputs: dict):
if user_inputs is None:
user_inputs = {}
filtered_inputs = {}
for variable_config in variables:
variable = variable_config.variable
if variable not in user_inputs or not user_inputs[variable]:
if variable_config.required:
raise ValueError(f"Input form variable {variable} is required")
else:
filtered_inputs[variable] = variable_config.default if variable_config.default is not None else ""
continue
value = user_inputs[variable]
if value:
if not isinstance(value, str):
raise ValueError(f"{variable} in input form must be a string")
if variable_config.type == VariableEntity.Type.SELECT:
options = variable_config.options if variable_config.options is not None else []
if value not in options:
raise ValueError(f"{variable} in input form must be one of the following: {options}")
else:
if variable_config.max_length is not None:
max_length = variable_config.max_length
if len(value) > max_length:
raise ValueError(f'{variable} in input form must be less than {max_length} characters')
filtered_inputs[variable] = value.replace('\x00', '') if value else None
return filtered_inputs
@classmethod @classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
""" """