ragflow/api/utils/validation_utils.py
liu an fed1221302
Refa: HTTP API list datasets / test cases / docs (#7720)
### What problem does this PR solve?

This PR introduces Pydantic-based validation for the list datasets HTTP
API, improving code clarity and robustness. Key changes include:

Pydantic Validation
Error Handling
Test Updates
Documentation Updates

### Type of change

- [x] Documentation Update
- [x] Refactoring
2025-05-20 09:58:26 +08:00

654 lines
23 KiB
Python

#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from collections import Counter
from enum import auto
from typing import Annotated, Any
from uuid import UUID
from flask import Request
from pydantic import BaseModel, Field, StringConstraints, ValidationError, field_validator
from pydantic_core import PydanticCustomError
from strenum import StrEnum
from werkzeug.exceptions import BadRequest, UnsupportedMediaType
from api.constants import DATASET_NAME_LIMIT
def validate_and_parse_json_request(request: Request, validator: type[BaseModel], *, extras: dict[str, Any] | None = None, exclude_unset: bool = False) -> tuple[dict[str, Any] | None, str | None]:
"""
Validates and parses JSON requests through a multi-stage validation pipeline.
Implements a four-stage validation process:
1. Content-Type verification (must be application/json)
2. JSON syntax validation
3. Payload structure type checking
4. Pydantic model validation with error formatting
Args:
request (Request): Flask request object containing HTTP payload
validator (type[BaseModel]): Pydantic model class for data validation
extras (dict[str, Any] | None): Additional fields to merge into payload
before validation. These fields will be removed from the final output
exclude_unset (bool): Whether to exclude fields that have not been explicitly set
Returns:
tuple[Dict[str, Any] | None, str | None]:
- First element:
- Validated dictionary on success
- None on validation failure
- Second element:
- None on success
- Diagnostic error message on failure
Raises:
UnsupportedMediaType: When Content-Type header is not application/json
BadRequest: For structural JSON syntax errors
ValidationError: When payload violates Pydantic schema rules
Examples:
>>> validate_and_parse_json_request(valid_request, DatasetSchema)
({"name": "Dataset1", "format": "csv"}, None)
>>> validate_and_parse_json_request(xml_request, DatasetSchema)
(None, "Unsupported content type: Expected application/json, got text/xml")
>>> validate_and_parse_json_request(bad_json_request, DatasetSchema)
(None, "Malformed JSON syntax: Missing commas/brackets or invalid encoding")
Notes:
1. Validation Priority:
- Content-Type verification precedes JSON parsing
- Structural validation occurs before schema validation
2. Extra fields added via `extras` parameter are automatically removed
from the final output after validation
"""
try:
payload = request.get_json() or {}
except UnsupportedMediaType:
return None, f"Unsupported content type: Expected application/json, got {request.content_type}"
except BadRequest:
return None, "Malformed JSON syntax: Missing commas/brackets or invalid encoding"
if not isinstance(payload, dict):
return None, f"Invalid request payload: expected object, got {type(payload).__name__}"
try:
if extras is not None:
payload.update(extras)
validated_request = validator(**payload)
except ValidationError as e:
return None, format_validation_error_message(e)
parsed_payload = validated_request.model_dump(by_alias=True, exclude_unset=exclude_unset)
if extras is not None:
for key in list(parsed_payload.keys()):
if key in extras:
del parsed_payload[key]
return parsed_payload, None
def validate_and_parse_request_args(request: Request, validator: type[BaseModel], *, extras: dict[str, Any] | None = None) -> tuple[dict[str, Any] | None, str | None]:
"""
Validates and parses request arguments against a Pydantic model.
This function performs a complete request validation workflow:
1. Extracts query parameters from the request
2. Merges with optional extra values (if provided)
3. Validates against the specified Pydantic model
4. Cleans the output by removing extra values
5. Returns either parsed data or an error message
Args:
request (Request): Web framework request object containing query parameters
validator (type[BaseModel]): Pydantic model class for validation
extras (dict[str, Any] | None): Optional additional values to include in validation
but exclude from final output. Defaults to None.
Returns:
tuple[dict[str, Any] | None, str | None]:
- First element: Validated/parsed arguments as dict if successful, None otherwise
- Second element: Formatted error message if validation failed, None otherwise
Behavior:
- Query parameters are merged with extras before validation
- Extras are automatically removed from the final output
- All validation errors are formatted into a human-readable string
Raises:
TypeError: If validator is not a Pydantic BaseModel subclass
Examples:
Successful validation:
>>> validate_and_parse_request_args(request, MyValidator)
({'param1': 'value'}, None)
Failed validation:
>>> validate_and_parse_request_args(request, MyValidator)
(None, "param1: Field required")
With extras:
>>> validate_and_parse_request_args(request, MyValidator, extras={'internal_id': 123})
({'param1': 'value'}, None) # internal_id removed from output
Notes:
- Uses request.args.to_dict() for Flask-compatible parameter extraction
- Maintains immutability of original request arguments
- Preserves type conversion from Pydantic validation
"""
args = request.args.to_dict(flat=True)
try:
if extras is not None:
args.update(extras)
validated_args = validator(**args)
except ValidationError as e:
return None, format_validation_error_message(e)
parsed_args = validated_args.model_dump()
if extras is not None:
for key in list(parsed_args.keys()):
if key in extras:
del parsed_args[key]
return parsed_args, None
def format_validation_error_message(e: ValidationError) -> str:
"""
Formats validation errors into a standardized string format.
Processes pydantic ValidationError objects to create human-readable error messages
containing field locations, error descriptions, and input values.
Args:
e (ValidationError): The validation error instance containing error details
Returns:
str: Formatted error messages joined by newlines. Each line contains:
- Field path (dot-separated)
- Error message
- Truncated input value (max 128 chars)
Example:
>>> try:
... UserModel(name=123, email="invalid")
... except ValidationError as e:
... print(format_validation_error_message(e))
Field: <name> - Message: <Input should be a valid string> - Value: <123>
Field: <email> - Message: <value is not a valid email address> - Value: <invalid>
"""
error_messages = []
for error in e.errors():
field = ".".join(map(str, error["loc"]))
msg = error["msg"]
input_val = error["input"]
input_str = str(input_val)
if len(input_str) > 128:
input_str = input_str[:125] + "..."
error_msg = f"Field: <{field}> - Message: <{msg}> - Value: <{input_str}>"
error_messages.append(error_msg)
return "\n".join(error_messages)
def normalize_str(v: Any) -> Any:
"""
Normalizes string values to a standard format while preserving non-string inputs.
Performs the following transformations when input is a string:
1. Trims leading/trailing whitespace (str.strip())
2. Converts to lowercase (str.lower())
Non-string inputs are returned unchanged, making this function safe for mixed-type
processing pipelines.
Args:
v (Any): Input value to normalize. Accepts any Python object.
Returns:
Any: Normalized string if input was string-type, original value otherwise.
Behavior Examples:
String Input: " Admin ""admin"
Empty String: " """ (empty string)
Non-String:
- 123 → 123
- None → None
- ["User"] → ["User"]
Typical Use Cases:
- Standardizing user input
- Preparing data for case-insensitive comparison
- Cleaning API parameters
- Normalizing configuration values
Edge Cases:
- Unicode whitespace is handled by str.strip()
- Locale-independent lowercasing (str.lower())
- Preserves falsy values (0, False, etc.)
Example:
>>> normalize_str(" ReadOnly ")
'readonly'
>>> normalize_str(42)
42
"""
if isinstance(v, str):
stripped = v.strip()
normalized = stripped.lower()
return normalized
return v
def validate_uuid1_hex(v: Any) -> str:
"""
Validates and converts input to a UUID version 1 hexadecimal string.
This function performs strict validation and normalization:
1. Accepts either UUID objects or UUID-formatted strings
2. Verifies the UUID is version 1 (time-based)
3. Returns the 32-character hexadecimal representation
Args:
v (Any): Input value to validate. Can be:
- UUID object (must be version 1)
- String in UUID format (e.g. "550e8400-e29b-41d4-a716-446655440000")
Returns:
str: 32-character lowercase hexadecimal string without hyphens
Example: "550e8400e29b41d4a716446655440000"
Raises:
PydanticCustomError: With code "invalid_UUID1_format" when:
- Input is not a UUID object or valid UUID string
- UUID version is not 1
- String doesn't match UUID format
Examples:
Valid cases:
>>> validate_uuid1_hex("550e8400-e29b-41d4-a716-446655440000")
'550e8400e29b41d4a716446655440000'
>>> validate_uuid1_hex(UUID('550e8400-e29b-41d4-a716-446655440000'))
'550e8400e29b41d4a716446655440000'
Invalid cases:
>>> validate_uuid1_hex("not-a-uuid") # raises PydanticCustomError
>>> validate_uuid1_hex(12345) # raises PydanticCustomError
>>> validate_uuid1_hex(UUID(int=0)) # v4, raises PydanticCustomError
Notes:
- Uses Python's built-in UUID parser for format validation
- Version check prevents accidental use of other UUID versions
- Hyphens in input strings are automatically removed in output
"""
try:
uuid_obj = UUID(v) if isinstance(v, str) else v
if uuid_obj.version != 1:
raise PydanticCustomError("invalid_UUID1_format", "Must be a UUID1 format")
return uuid_obj.hex
except (AttributeError, ValueError, TypeError):
raise PydanticCustomError("invalid_UUID1_format", "Invalid UUID1 format")
class PermissionEnum(StrEnum):
me = auto()
team = auto()
class ChunkMethodnEnum(StrEnum):
naive = auto()
book = auto()
email = auto()
laws = auto()
manual = auto()
one = auto()
paper = auto()
picture = auto()
presentation = auto()
qa = auto()
table = auto()
tag = auto()
class GraphragMethodEnum(StrEnum):
light = auto()
general = auto()
class Base(BaseModel):
class Config:
extra = "forbid"
class RaptorConfig(Base):
use_raptor: bool = Field(default=False)
prompt: Annotated[
str,
StringConstraints(strip_whitespace=True, min_length=1),
Field(
default="Please summarize the following paragraphs. Be careful with the numbers, do not make things up. Paragraphs as following:\n {cluster_content}\nThe above is the content you need to summarize."
),
]
max_token: int = Field(default=256, ge=1, le=2048)
threshold: float = Field(default=0.1, ge=0.0, le=1.0)
max_cluster: int = Field(default=64, ge=1, le=1024)
random_seed: int = Field(default=0, ge=0)
class GraphragConfig(Base):
use_graphrag: bool = Field(default=False)
entity_types: list[str] = Field(default_factory=lambda: ["organization", "person", "geo", "event", "category"])
method: GraphragMethodEnum = Field(default=GraphragMethodEnum.light)
community: bool = Field(default=False)
resolution: bool = Field(default=False)
class ParserConfig(Base):
auto_keywords: int = Field(default=0, ge=0, le=32)
auto_questions: int = Field(default=0, ge=0, le=10)
chunk_token_num: int = Field(default=128, ge=1, le=2048)
delimiter: str = Field(default=r"\n", min_length=1)
graphrag: GraphragConfig | None = None
html4excel: bool = False
layout_recognize: str = "DeepDOC"
raptor: RaptorConfig | None = None
tag_kb_ids: list[str] = Field(default_factory=list)
topn_tags: int = Field(default=1, ge=1, le=10)
filename_embd_weight: float | None = Field(default=None, ge=0.0, le=1.0)
task_page_size: int | None = Field(default=None, ge=1)
pages: list[list[int]] | None = None
class CreateDatasetReq(Base):
name: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1, max_length=DATASET_NAME_LIMIT), Field(...)]
avatar: str | None = Field(default=None, max_length=65535)
description: str | None = Field(default=None, max_length=65535)
embedding_model: Annotated[str, StringConstraints(strip_whitespace=True, max_length=255), Field(default="", serialization_alias="embd_id")]
permission: PermissionEnum = Field(default=PermissionEnum.me, min_length=1, max_length=16)
chunk_method: ChunkMethodnEnum = Field(default=ChunkMethodnEnum.naive, min_length=1, max_length=32, serialization_alias="parser_id")
pagerank: int = Field(default=0, ge=0, le=100)
parser_config: ParserConfig | None = Field(default=None)
@field_validator("avatar")
@classmethod
def validate_avatar_base64(cls, v: str | None) -> str | None:
"""
Validates Base64-encoded avatar string format and MIME type compliance.
Implements a three-stage validation workflow:
1. MIME prefix existence check
2. MIME type format validation
3. Supported type verification
Args:
v (str): Raw avatar field value
Returns:
str: Validated Base64 string
Raises:
PydanticCustomError: For structural errors in these cases:
- Missing MIME prefix header
- Invalid MIME prefix format
- Unsupported image MIME type
Example:
```python
# Valid case
CreateDatasetReq(avatar="...")
# Invalid cases
CreateDatasetReq(avatar="image/jpeg;base64,...") # Missing 'data:' prefix
CreateDatasetReq(avatar="data:video/mp4;base64,...") # Unsupported MIME type
```
"""
if v is None:
return v
if "," in v:
prefix, _ = v.split(",", 1)
if not prefix.startswith("data:"):
raise PydanticCustomError("format_invalid", "Invalid MIME prefix format. Must start with 'data:'")
mime_type = prefix[5:].split(";")[0]
supported_mime_types = ["image/jpeg", "image/png"]
if mime_type not in supported_mime_types:
raise PydanticCustomError("format_invalid", "Unsupported MIME type. Allowed: {supported_mime_types}", {"supported_mime_types": supported_mime_types})
return v
else:
raise PydanticCustomError("format_invalid", "Missing MIME prefix. Expected format: data:<mime>;base64,<data>")
@field_validator("embedding_model", mode="after")
@classmethod
def validate_embedding_model(cls, v: str) -> str:
"""
Validates embedding model identifier format compliance.
Validation pipeline:
1. Structural format verification
2. Component non-empty check
3. Value normalization
Args:
v (str): Raw model identifier
Returns:
str: Validated <model_name>@<provider> format
Raises:
PydanticCustomError: For these violations:
- Missing @ separator
- Empty model_name/provider
- Invalid component structure
Examples:
Valid: "text-embedding-3-large@openai"
Invalid: "invalid_model" (no @)
Invalid: "@openai" (empty model_name)
Invalid: "text-embedding-3-large@" (empty provider)
"""
if "@" not in v:
raise PydanticCustomError("format_invalid", "Embedding model identifier must follow <model_name>@<provider> format")
components = v.split("@", 1)
if len(components) != 2 or not all(components):
raise PydanticCustomError("format_invalid", "Both model_name and provider must be non-empty strings")
model_name, provider = components
if not model_name.strip() or not provider.strip():
raise PydanticCustomError("format_invalid", "Model name and provider cannot be whitespace-only strings")
return v
@field_validator("permission", mode="before")
@classmethod
def normalize_permission(cls, v: Any) -> Any:
return normalize_str(v)
@field_validator("parser_config", mode="before")
@classmethod
def normalize_empty_parser_config(cls, v: Any) -> Any:
"""
Normalizes empty parser configuration by converting empty dictionaries to None.
This validator ensures consistent handling of empty parser configurations across
the application by converting empty dicts to None values.
Args:
v (Any): Raw input value for the parser config field
Returns:
Any: Returns None if input is an empty dict, otherwise returns the original value
Example:
>>> normalize_empty_parser_config({})
None
>>> normalize_empty_parser_config({"key": "value"})
{"key": "value"}
"""
if v == {}:
return None
return v
@field_validator("parser_config", mode="after")
@classmethod
def validate_parser_config_json_length(cls, v: ParserConfig | None) -> ParserConfig | None:
"""
Validates serialized JSON length constraints for parser configuration.
Implements a two-stage validation workflow:
1. Null check - bypass validation for empty configurations
2. Model serialization - convert Pydantic model to JSON string
3. Size verification - enforce maximum allowed payload size
Args:
v (ParserConfig | None): Raw parser configuration object
Returns:
ParserConfig | None: Validated configuration object
Raises:
PydanticCustomError: When serialized JSON exceeds 65,535 characters
"""
if v is None:
return None
if (json_str := v.model_dump_json()) and len(json_str) > 65535:
raise PydanticCustomError("string_too_long", "Parser config exceeds size limit (max 65,535 characters). Current size: {actual}", {"actual": len(json_str)})
return v
class UpdateDatasetReq(CreateDatasetReq):
dataset_id: str = Field(...)
name: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1, max_length=DATASET_NAME_LIMIT), Field(default="")]
@field_validator("dataset_id", mode="before")
@classmethod
def validate_dataset_id(cls, v: Any) -> str:
return validate_uuid1_hex(v)
class DeleteReq(Base):
ids: list[str] | None = Field(...)
@field_validator("ids", mode="after")
@classmethod
def validate_ids(cls, v_list: list[str] | None) -> list[str] | None:
"""
Validates and normalizes a list of UUID strings with None handling.
This post-processing validator performs:
1. None input handling (pass-through)
2. UUID version 1 validation for each list item
3. Duplicate value detection
4. Returns normalized UUID hex strings or None
Args:
v_list (list[str] | None): Input list that has passed initial validation.
Either a list of UUID strings or None.
Returns:
list[str] | None:
- None if input was None
- List of normalized UUID hex strings otherwise:
* 32-character lowercase
* Valid UUID version 1
* Unique within list
Raises:
PydanticCustomError: With structured error details when:
- "invalid_UUID1_format": Any string fails UUIDv1 validation
- "duplicate_uuids": If duplicate IDs are detected
Validation Rules:
- None input returns None
- Empty list returns empty list
- All non-None items must be valid UUIDv1
- No duplicates permitted
- Original order preserved
Examples:
Valid cases:
>>> validate_ids(None)
None
>>> validate_ids([])
[]
>>> validate_ids(["550e8400-e29b-41d4-a716-446655440000"])
["550e8400e29b41d4a716446655440000"]
Invalid cases:
>>> validate_ids(["invalid"])
# raises PydanticCustomError(invalid_UUID1_format)
>>> validate_ids(["550e...", "550e..."])
# raises PydanticCustomError(duplicate_uuids)
Security Notes:
- Validates UUID version to prevent version spoofing
- Duplicate check prevents data injection
- None handling maintains pipeline integrity
"""
if v_list is None:
return None
ids_list = []
for v in v_list:
try:
ids_list.append(validate_uuid1_hex(v))
except PydanticCustomError as e:
raise e
duplicates = [item for item, count in Counter(ids_list).items() if count > 1]
if duplicates:
duplicates_str = ", ".join(duplicates)
raise PydanticCustomError("duplicate_uuids", "Duplicate ids: '{duplicate_ids}'", {"duplicate_ids": duplicates_str})
return ids_list
class DeleteDatasetReq(DeleteReq): ...
class OrderByEnum(StrEnum):
create_time = auto()
update_time = auto()
class BaseListReq(Base):
id: str | None = None
name: str | None = None
page: int = Field(default=1, ge=1)
page_size: int = Field(default=30, ge=1)
orderby: OrderByEnum = Field(default=OrderByEnum.create_time)
desc: bool = Field(default=True)
@field_validator("id", mode="before")
@classmethod
def validate_id(cls, v: Any) -> str:
return validate_uuid1_hex(v)
@field_validator("orderby", mode="before")
@classmethod
def normalize_orderby(cls, v: Any) -> Any:
return normalize_str(v)
class ListDatasetReq(BaseListReq): ...