mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-15 15:36:00 +08:00
fix: type num of variable converted to str (#3758)
This commit is contained in:
parent
f257f2c396
commit
2ea8c73cd8
@ -23,20 +23,28 @@ class BaseAppGenerator:
|
|||||||
value = user_inputs[variable]
|
value = user_inputs[variable]
|
||||||
|
|
||||||
if value:
|
if value:
|
||||||
if not isinstance(value, str):
|
if variable_config.type != VariableEntity.Type.NUMBER and not isinstance(value, str):
|
||||||
raise ValueError(f"{variable} in input form must be a string")
|
raise ValueError(f"{variable} in input form must be a string")
|
||||||
|
elif variable_config.type == VariableEntity.Type.NUMBER and isinstance(value, str):
|
||||||
|
if '.' in value:
|
||||||
|
value = float(value)
|
||||||
|
else:
|
||||||
|
value = int(value)
|
||||||
|
|
||||||
if variable_config.type == VariableEntity.Type.SELECT:
|
if variable_config.type == VariableEntity.Type.SELECT:
|
||||||
options = variable_config.options if variable_config.options is not None else []
|
options = variable_config.options if variable_config.options is not None else []
|
||||||
if value not in options:
|
if value not in options:
|
||||||
raise ValueError(f"{variable} in input form must be one of the following: {options}")
|
raise ValueError(f"{variable} in input form must be one of the following: {options}")
|
||||||
else:
|
elif variable_config.type in [VariableEntity.Type.TEXT_INPUT, VariableEntity.Type.PARAGRAPH]:
|
||||||
if variable_config.max_length is not None:
|
if variable_config.max_length is not None:
|
||||||
max_length = variable_config.max_length
|
max_length = variable_config.max_length
|
||||||
if len(value) > max_length:
|
if len(value) > max_length:
|
||||||
raise ValueError(f'{variable} in input form must be less than {max_length} characters')
|
raise ValueError(f'{variable} in input form must be less than {max_length} characters')
|
||||||
|
|
||||||
filtered_inputs[variable] = value.replace('\x00', '') if value else None
|
if value and isinstance(value, str):
|
||||||
|
filtered_inputs[variable] = value.replace('\x00', '')
|
||||||
|
else:
|
||||||
|
filtered_inputs[variable] = value if value else None
|
||||||
|
|
||||||
return filtered_inputs
|
return filtered_inputs
|
||||||
|
|
||||||
|
@ -72,7 +72,7 @@ class AppGenerateEntity(BaseModel):
|
|||||||
# app config
|
# app config
|
||||||
app_config: AppConfig
|
app_config: AppConfig
|
||||||
|
|
||||||
inputs: dict[str, str]
|
inputs: dict[str, Any]
|
||||||
files: list[FileVar] = []
|
files: list[FileVar] = []
|
||||||
user_id: str
|
user_id: str
|
||||||
|
|
||||||
|
@ -32,6 +32,8 @@ class AdvancedPromptTransform(PromptTransform):
|
|||||||
memory_config: Optional[MemoryConfig],
|
memory_config: Optional[MemoryConfig],
|
||||||
memory: Optional[TokenBufferMemory],
|
memory: Optional[TokenBufferMemory],
|
||||||
model_config: ModelConfigWithCredentialsEntity) -> list[PromptMessage]:
|
model_config: ModelConfigWithCredentialsEntity) -> list[PromptMessage]:
|
||||||
|
inputs = {key: str(value) for key, value in inputs.items()}
|
||||||
|
|
||||||
prompt_messages = []
|
prompt_messages = []
|
||||||
|
|
||||||
model_mode = ModelMode.value_of(model_config.mode)
|
model_mode = ModelMode.value_of(model_config.mode)
|
||||||
|
@ -55,6 +55,8 @@ class SimplePromptTransform(PromptTransform):
|
|||||||
memory: Optional[TokenBufferMemory],
|
memory: Optional[TokenBufferMemory],
|
||||||
model_config: ModelConfigWithCredentialsEntity) -> \
|
model_config: ModelConfigWithCredentialsEntity) -> \
|
||||||
tuple[list[PromptMessage], Optional[list[str]]]:
|
tuple[list[PromptMessage], Optional[list[str]]]:
|
||||||
|
inputs = {key: str(value) for key, value in inputs.items()}
|
||||||
|
|
||||||
model_mode = ModelMode.value_of(model_config.mode)
|
model_mode = ModelMode.value_of(model_config.mode)
|
||||||
if model_mode == ModelMode.CHAT:
|
if model_mode == ModelMode.CHAT:
|
||||||
prompt_messages, stops = self._get_chat_model_prompt_messages(
|
prompt_messages, stops = self._get_chat_model_prompt_messages(
|
||||||
|
@ -1,6 +1,4 @@
|
|||||||
from typing import cast
|
|
||||||
|
|
||||||
from core.app.app_config.entities import VariableEntity
|
|
||||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType, SystemVariable
|
from core.workflow.entities.node_entities import NodeRunResult, NodeType, SystemVariable
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
@ -19,12 +17,8 @@ class StartNode(BaseNode):
|
|||||||
:param variable_pool: variable pool
|
:param variable_pool: variable pool
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
node_data = self.node_data
|
|
||||||
node_data = cast(self._node_data_cls, node_data)
|
|
||||||
variables = node_data.variables
|
|
||||||
|
|
||||||
# Get cleaned inputs
|
# Get cleaned inputs
|
||||||
cleaned_inputs = self._get_cleaned_inputs(variables, variable_pool.user_inputs)
|
cleaned_inputs = variable_pool.user_inputs
|
||||||
|
|
||||||
for var in variable_pool.system_variables:
|
for var in variable_pool.system_variables:
|
||||||
if var == SystemVariable.CONVERSATION:
|
if var == SystemVariable.CONVERSATION:
|
||||||
@ -38,42 +32,6 @@ class StartNode(BaseNode):
|
|||||||
outputs=cleaned_inputs
|
outputs=cleaned_inputs
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_cleaned_inputs(self, variables: list[VariableEntity], user_inputs: dict):
|
|
||||||
if user_inputs is None:
|
|
||||||
user_inputs = {}
|
|
||||||
|
|
||||||
filtered_inputs = {}
|
|
||||||
|
|
||||||
for variable_config in variables:
|
|
||||||
variable = variable_config.variable
|
|
||||||
|
|
||||||
if variable not in user_inputs or not user_inputs[variable]:
|
|
||||||
if variable_config.required:
|
|
||||||
raise ValueError(f"Input form variable {variable} is required")
|
|
||||||
else:
|
|
||||||
filtered_inputs[variable] = variable_config.default if variable_config.default is not None else ""
|
|
||||||
continue
|
|
||||||
|
|
||||||
value = user_inputs[variable]
|
|
||||||
|
|
||||||
if value:
|
|
||||||
if not isinstance(value, str):
|
|
||||||
raise ValueError(f"{variable} in input form must be a string")
|
|
||||||
|
|
||||||
if variable_config.type == VariableEntity.Type.SELECT:
|
|
||||||
options = variable_config.options if variable_config.options is not None else []
|
|
||||||
if value not in options:
|
|
||||||
raise ValueError(f"{variable} in input form must be one of the following: {options}")
|
|
||||||
else:
|
|
||||||
if variable_config.max_length is not None:
|
|
||||||
max_length = variable_config.max_length
|
|
||||||
if len(value) > max_length:
|
|
||||||
raise ValueError(f'{variable} in input form must be less than {max_length} characters')
|
|
||||||
|
|
||||||
filtered_inputs[variable] = value.replace('\x00', '') if value else None
|
|
||||||
|
|
||||||
return filtered_inputs
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
|
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
|
||||||
"""
|
"""
|
||||||
|
Loading…
x
Reference in New Issue
Block a user