mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-06-04 11:14:10 +08:00
refactor(api/core/app/apps/base_app_generator.py): improve input validation and sanitization in BaseAppGenerator (#5866)
This commit is contained in:
parent
04c0a9ad45
commit
66a62e6c13
@ -114,6 +114,10 @@ class VariableEntity(BaseModel):
|
|||||||
default: Optional[str] = None
|
default: Optional[str] = None
|
||||||
hint: Optional[str] = None
|
hint: Optional[str] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return self.variable
|
||||||
|
|
||||||
|
|
||||||
class ExternalDataVariableEntity(BaseModel):
|
class ExternalDataVariableEntity(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
@ -1,52 +1,56 @@
|
|||||||
|
from collections.abc import Mapping
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
from core.app.app_config.entities import AppConfig, VariableEntity
|
from core.app.app_config.entities import AppConfig, VariableEntity
|
||||||
|
|
||||||
|
|
||||||
class BaseAppGenerator:
|
class BaseAppGenerator:
|
||||||
def _get_cleaned_inputs(self, user_inputs: dict, app_config: AppConfig):
|
def _get_cleaned_inputs(self, user_inputs: Optional[Mapping[str, Any]], app_config: AppConfig) -> Mapping[str, Any]:
|
||||||
if user_inputs is None:
|
user_inputs = user_inputs or {}
|
||||||
user_inputs = {}
|
|
||||||
|
|
||||||
filtered_inputs = {}
|
|
||||||
|
|
||||||
# Filter input variables from form configuration, handle required fields, default values, and option values
|
# Filter input variables from form configuration, handle required fields, default values, and option values
|
||||||
variables = app_config.variables
|
variables = app_config.variables
|
||||||
for variable_config in variables:
|
filtered_inputs = {var.name: self._validate_input(inputs=user_inputs, var=var) for var in variables}
|
||||||
variable = variable_config.variable
|
filtered_inputs = {k: self._sanitize_value(v) for k, v in filtered_inputs.items()}
|
||||||
|
|
||||||
if (variable not in user_inputs
|
|
||||||
or user_inputs[variable] is None
|
|
||||||
or (isinstance(user_inputs[variable], str) and user_inputs[variable] == '')):
|
|
||||||
if variable_config.required:
|
|
||||||
raise ValueError(f"{variable} is required in input form")
|
|
||||||
else:
|
|
||||||
filtered_inputs[variable] = variable_config.default if variable_config.default is not None else ""
|
|
||||||
continue
|
|
||||||
|
|
||||||
value = user_inputs[variable]
|
|
||||||
|
|
||||||
if value is not None:
|
|
||||||
if variable_config.type != VariableEntity.Type.NUMBER and not isinstance(value, str):
|
|
||||||
raise ValueError(f"{variable} in input form must be a string")
|
|
||||||
elif variable_config.type == VariableEntity.Type.NUMBER and isinstance(value, str):
|
|
||||||
if '.' in value:
|
|
||||||
value = float(value)
|
|
||||||
else:
|
|
||||||
value = int(value)
|
|
||||||
|
|
||||||
if variable_config.type == VariableEntity.Type.SELECT:
|
|
||||||
options = variable_config.options if variable_config.options is not None else []
|
|
||||||
if value not in options:
|
|
||||||
raise ValueError(f"{variable} in input form must be one of the following: {options}")
|
|
||||||
elif variable_config.type in [VariableEntity.Type.TEXT_INPUT, VariableEntity.Type.PARAGRAPH]:
|
|
||||||
if variable_config.max_length is not None:
|
|
||||||
max_length = variable_config.max_length
|
|
||||||
if len(value) > max_length:
|
|
||||||
raise ValueError(f'{variable} in input form must be less than {max_length} characters')
|
|
||||||
|
|
||||||
if value and isinstance(value, str):
|
|
||||||
filtered_inputs[variable] = value.replace('\x00', '')
|
|
||||||
else:
|
|
||||||
filtered_inputs[variable] = value if value is not None else None
|
|
||||||
|
|
||||||
return filtered_inputs
|
return filtered_inputs
|
||||||
|
|
||||||
|
def _validate_input(self, *, inputs: Mapping[str, Any], var: VariableEntity):
|
||||||
|
user_input_value = inputs.get(var.name)
|
||||||
|
if var.required and not user_input_value:
|
||||||
|
raise ValueError(f'{var.name} is required in input form')
|
||||||
|
if not var.required and not user_input_value:
|
||||||
|
# TODO: should we return None here if the default value is None?
|
||||||
|
return var.default or ''
|
||||||
|
if (
|
||||||
|
var.type
|
||||||
|
in (
|
||||||
|
VariableEntity.Type.TEXT_INPUT,
|
||||||
|
VariableEntity.Type.SELECT,
|
||||||
|
VariableEntity.Type.PARAGRAPH,
|
||||||
|
)
|
||||||
|
and user_input_value
|
||||||
|
and not isinstance(user_input_value, str)
|
||||||
|
):
|
||||||
|
raise ValueError(f"(type '{var.type}') {var.name} in input form must be a string")
|
||||||
|
if var.type == VariableEntity.Type.NUMBER and isinstance(user_input_value, str):
|
||||||
|
# may raise ValueError if user_input_value is not a valid number
|
||||||
|
try:
|
||||||
|
if '.' in user_input_value:
|
||||||
|
return float(user_input_value)
|
||||||
|
else:
|
||||||
|
return int(user_input_value)
|
||||||
|
except ValueError:
|
||||||
|
raise ValueError(f"{var.name} in input form must be a valid number")
|
||||||
|
if var.type == VariableEntity.Type.SELECT:
|
||||||
|
options = var.options or []
|
||||||
|
if user_input_value not in options:
|
||||||
|
raise ValueError(f'{var.name} in input form must be one of the following: {options}')
|
||||||
|
elif var.type in (VariableEntity.Type.TEXT_INPUT, VariableEntity.Type.PARAGRAPH):
|
||||||
|
if var.max_length and user_input_value and len(user_input_value) > var.max_length:
|
||||||
|
raise ValueError(f'{var.name} in input form must be less than {var.max_length} characters')
|
||||||
|
|
||||||
|
return user_input_value
|
||||||
|
|
||||||
|
def _sanitize_value(self, value: Any) -> Any:
|
||||||
|
if isinstance(value, str):
|
||||||
|
return value.replace('\x00', '')
|
||||||
|
return value
|
||||||
|
Loading…
x
Reference in New Issue
Block a user