mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-08-13 10:49:05 +08:00
Refa: HTTP API list datasets / test cases / docs (#7720)
### What problem does this PR solve? This PR introduces Pydantic-based validation for the list datasets HTTP API, improving code clarity and robustness. Key changes include: Pydantic Validation Error Handling Test Updates Documentation Updates ### Type of change - [x] Documentation Update - [x] Refactoring
This commit is contained in:
parent
6ed81d6774
commit
fed1221302
@ -32,12 +32,22 @@ from api.utils.api_utils import (
|
||||
deep_merge,
|
||||
get_error_argument_result,
|
||||
get_error_data_result,
|
||||
get_error_operating_result,
|
||||
get_error_permission_result,
|
||||
get_parser_config,
|
||||
get_result,
|
||||
remap_dictionary_keys,
|
||||
token_required,
|
||||
verify_embedding_availability,
|
||||
)
|
||||
from api.utils.validation_utils import CreateDatasetReq, DeleteDatasetReq, UpdateDatasetReq, validate_and_parse_json_request
|
||||
from api.utils.validation_utils import (
|
||||
CreateDatasetReq,
|
||||
DeleteDatasetReq,
|
||||
ListDatasetReq,
|
||||
UpdateDatasetReq,
|
||||
validate_and_parse_json_request,
|
||||
validate_and_parse_request_args,
|
||||
)
|
||||
|
||||
|
||||
@manager.route("/datasets", methods=["POST"]) # noqa: F821
|
||||
@ -113,7 +123,7 @@ def create(tenant_id):
|
||||
|
||||
try:
|
||||
if KnowledgebaseService.get_or_none(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value):
|
||||
return get_error_data_result(message=f"Dataset name '{req['name']}' already exists")
|
||||
return get_error_operating_result(message=f"Dataset name '{req['name']}' already exists")
|
||||
except OperationalError as e:
|
||||
logging.exception(e)
|
||||
return get_error_data_result(message="Database operation failed")
|
||||
@ -126,7 +136,7 @@ def create(tenant_id):
|
||||
try:
|
||||
ok, t = TenantService.get_by_id(tenant_id)
|
||||
if not ok:
|
||||
return get_error_data_result(message="Tenant not found")
|
||||
return get_error_permission_result(message="Tenant not found")
|
||||
except OperationalError as e:
|
||||
logging.exception(e)
|
||||
return get_error_data_result(message="Database operation failed")
|
||||
@ -153,16 +163,7 @@ def create(tenant_id):
|
||||
logging.exception(e)
|
||||
return get_error_data_result(message="Database operation failed")
|
||||
|
||||
response_data = {}
|
||||
key_mapping = {
|
||||
"chunk_num": "chunk_count",
|
||||
"doc_num": "document_count",
|
||||
"parser_id": "chunk_method",
|
||||
"embd_id": "embedding_model",
|
||||
}
|
||||
for key, value in k.to_dict().items():
|
||||
new_key = key_mapping.get(key, key)
|
||||
response_data[new_key] = value
|
||||
response_data = remap_dictionary_keys(k.to_dict())
|
||||
return get_result(data=response_data)
|
||||
|
||||
|
||||
@ -232,7 +233,7 @@ def delete(tenant_id):
|
||||
logging.exception(e)
|
||||
return get_error_data_result(message="Database operation failed")
|
||||
if len(error_kb_ids) > 0:
|
||||
return get_error_data_result(message=f"""User '{tenant_id}' lacks permission for datasets: '{", ".join(error_kb_ids)}'""")
|
||||
return get_error_permission_result(message=f"""User '{tenant_id}' lacks permission for datasets: '{", ".join(error_kb_ids)}'""")
|
||||
|
||||
errors = []
|
||||
success_count = 0
|
||||
@ -347,7 +348,7 @@ def update(tenant_id, dataset_id):
|
||||
try:
|
||||
kb = KnowledgebaseService.get_or_none(id=dataset_id, tenant_id=tenant_id)
|
||||
if kb is None:
|
||||
return get_error_data_result(message=f"User '{tenant_id}' lacks permission for dataset '{dataset_id}'")
|
||||
return get_error_permission_result(message=f"User '{tenant_id}' lacks permission for dataset '{dataset_id}'")
|
||||
except OperationalError as e:
|
||||
logging.exception(e)
|
||||
return get_error_data_result(message="Database operation failed")
|
||||
@ -418,7 +419,7 @@ def list_datasets(tenant_id):
|
||||
name: page_size
|
||||
type: integer
|
||||
required: false
|
||||
default: 1024
|
||||
default: 30
|
||||
description: Number of items per page.
|
||||
- in: query
|
||||
name: orderby
|
||||
@ -445,47 +446,46 @@ def list_datasets(tenant_id):
|
||||
items:
|
||||
type: object
|
||||
"""
|
||||
id = request.args.get("id")
|
||||
name = request.args.get("name")
|
||||
if id:
|
||||
kbs = KnowledgebaseService.get_kb_by_id(id, tenant_id)
|
||||
args, err = validate_and_parse_request_args(request, ListDatasetReq)
|
||||
if err is not None:
|
||||
return get_error_argument_result(err)
|
||||
|
||||
kb_id = request.args.get("id")
|
||||
name = args.get("name")
|
||||
if kb_id:
|
||||
try:
|
||||
kbs = KnowledgebaseService.get_kb_by_id(kb_id, tenant_id)
|
||||
except OperationalError as e:
|
||||
logging.exception(e)
|
||||
return get_error_data_result(message="Database operation failed")
|
||||
if not kbs:
|
||||
return get_error_data_result(f"You don't own the dataset {id}")
|
||||
return get_error_permission_result(message=f"User '{tenant_id}' lacks permission for dataset '{kb_id}'")
|
||||
if name:
|
||||
kbs = KnowledgebaseService.get_kb_by_name(name, tenant_id)
|
||||
try:
|
||||
kbs = KnowledgebaseService.get_kb_by_name(name, tenant_id)
|
||||
except OperationalError as e:
|
||||
logging.exception(e)
|
||||
return get_error_data_result(message="Database operation failed")
|
||||
if not kbs:
|
||||
return get_error_data_result(f"You don't own the dataset {name}")
|
||||
page_number = int(request.args.get("page", 1))
|
||||
items_per_page = int(request.args.get("page_size", 30))
|
||||
orderby = request.args.get("orderby", "create_time")
|
||||
if request.args.get("desc", "false").lower() not in ["true", "false"]:
|
||||
return get_error_data_result("desc should be true or false")
|
||||
if request.args.get("desc", "true").lower() == "false":
|
||||
desc = False
|
||||
else:
|
||||
desc = True
|
||||
tenants = TenantService.get_joined_tenants_by_user_id(tenant_id)
|
||||
kbs = KnowledgebaseService.get_list(
|
||||
[m["tenant_id"] for m in tenants],
|
||||
tenant_id,
|
||||
page_number,
|
||||
items_per_page,
|
||||
orderby,
|
||||
desc,
|
||||
id,
|
||||
name,
|
||||
)
|
||||
renamed_list = []
|
||||
return get_error_permission_result(message=f"User '{tenant_id}' lacks permission for dataset '{name}'")
|
||||
|
||||
try:
|
||||
tenants = TenantService.get_joined_tenants_by_user_id(tenant_id)
|
||||
kbs = KnowledgebaseService.get_list(
|
||||
[m["tenant_id"] for m in tenants],
|
||||
tenant_id,
|
||||
args["page"],
|
||||
args["page_size"],
|
||||
args["orderby"],
|
||||
args["desc"],
|
||||
kb_id,
|
||||
name,
|
||||
)
|
||||
except OperationalError as e:
|
||||
logging.exception(e)
|
||||
return get_error_data_result(message="Database operation failed")
|
||||
|
||||
response_data_list = []
|
||||
for kb in kbs:
|
||||
key_mapping = {
|
||||
"chunk_num": "chunk_count",
|
||||
"doc_num": "document_count",
|
||||
"parser_id": "chunk_method",
|
||||
"embd_id": "embedding_model",
|
||||
}
|
||||
renamed_data = {}
|
||||
for key, value in kb.items():
|
||||
new_key = key_mapping.get(key, key)
|
||||
renamed_data[new_key] = value
|
||||
renamed_list.append(renamed_data)
|
||||
return get_result(data=renamed_list)
|
||||
response_data_list.append(remap_dictionary_keys(kb))
|
||||
return get_result(data=response_data_list)
|
||||
|
@ -329,6 +329,14 @@ def get_error_argument_result(message="Invalid arguments"):
|
||||
return get_result(code=settings.RetCode.ARGUMENT_ERROR, message=message)
|
||||
|
||||
|
||||
def get_error_permission_result(message="Permission error"):
|
||||
return get_result(code=settings.RetCode.PERMISSION_ERROR, message=message)
|
||||
|
||||
|
||||
def get_error_operating_result(message="Operating error"):
|
||||
return get_result(code=settings.RetCode.OPERATING_ERROR, message=message)
|
||||
|
||||
|
||||
def generate_confirmation_token(tenant_id):
|
||||
serializer = URLSafeTimedSerializer(tenant_id)
|
||||
return "ragflow-" + serializer.dumps(get_uuid(), salt=tenant_id)[2:34]
|
||||
@ -514,3 +522,38 @@ def deep_merge(default: dict, custom: dict) -> dict:
|
||||
base_dict[key] = val
|
||||
|
||||
return merged
|
||||
|
||||
|
||||
def remap_dictionary_keys(source_data: dict, key_aliases: dict = None) -> dict:
|
||||
"""
|
||||
Transform dictionary keys using a configurable mapping schema.
|
||||
|
||||
Args:
|
||||
source_data: Original dictionary to process
|
||||
key_aliases: Custom key transformation rules (Optional)
|
||||
When provided, overrides default key mapping
|
||||
Format: {<original_key>: <new_key>, ...}
|
||||
|
||||
Returns:
|
||||
dict: New dictionary with transformed keys preserving original values
|
||||
|
||||
Example:
|
||||
>>> input_data = {"old_key": "value", "another_field": 42}
|
||||
>>> remap_dictionary_keys(input_data, {"old_key": "new_key"})
|
||||
{'new_key': 'value', 'another_field': 42}
|
||||
"""
|
||||
DEFAULT_KEY_MAP = {
|
||||
"chunk_num": "chunk_count",
|
||||
"doc_num": "document_count",
|
||||
"parser_id": "chunk_method",
|
||||
"embd_id": "embedding_model",
|
||||
}
|
||||
|
||||
transformed_data = {}
|
||||
mapping = key_aliases or DEFAULT_KEY_MAP
|
||||
|
||||
for original_key, value in source_data.items():
|
||||
mapped_key = mapping.get(original_key, original_key)
|
||||
transformed_data[mapped_key] = value
|
||||
|
||||
return transformed_data
|
||||
|
@ -13,13 +13,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import uuid
|
||||
from collections import Counter
|
||||
from enum import auto
|
||||
from typing import Annotated, Any
|
||||
from uuid import UUID
|
||||
|
||||
from flask import Request
|
||||
from pydantic import UUID1, BaseModel, Field, StringConstraints, ValidationError, field_serializer, field_validator
|
||||
from pydantic import BaseModel, Field, StringConstraints, ValidationError, field_validator
|
||||
from pydantic_core import PydanticCustomError
|
||||
from strenum import StrEnum
|
||||
from werkzeug.exceptions import BadRequest, UnsupportedMediaType
|
||||
@ -102,6 +102,71 @@ def validate_and_parse_json_request(request: Request, validator: type[BaseModel]
|
||||
return parsed_payload, None
|
||||
|
||||
|
||||
def validate_and_parse_request_args(request: Request, validator: type[BaseModel], *, extras: dict[str, Any] | None = None) -> tuple[dict[str, Any] | None, str | None]:
|
||||
"""
|
||||
Validates and parses request arguments against a Pydantic model.
|
||||
|
||||
This function performs a complete request validation workflow:
|
||||
1. Extracts query parameters from the request
|
||||
2. Merges with optional extra values (if provided)
|
||||
3. Validates against the specified Pydantic model
|
||||
4. Cleans the output by removing extra values
|
||||
5. Returns either parsed data or an error message
|
||||
|
||||
Args:
|
||||
request (Request): Web framework request object containing query parameters
|
||||
validator (type[BaseModel]): Pydantic model class for validation
|
||||
extras (dict[str, Any] | None): Optional additional values to include in validation
|
||||
but exclude from final output. Defaults to None.
|
||||
|
||||
Returns:
|
||||
tuple[dict[str, Any] | None, str | None]:
|
||||
- First element: Validated/parsed arguments as dict if successful, None otherwise
|
||||
- Second element: Formatted error message if validation failed, None otherwise
|
||||
|
||||
Behavior:
|
||||
- Query parameters are merged with extras before validation
|
||||
- Extras are automatically removed from the final output
|
||||
- All validation errors are formatted into a human-readable string
|
||||
|
||||
Raises:
|
||||
TypeError: If validator is not a Pydantic BaseModel subclass
|
||||
|
||||
Examples:
|
||||
Successful validation:
|
||||
>>> validate_and_parse_request_args(request, MyValidator)
|
||||
({'param1': 'value'}, None)
|
||||
|
||||
Failed validation:
|
||||
>>> validate_and_parse_request_args(request, MyValidator)
|
||||
(None, "param1: Field required")
|
||||
|
||||
With extras:
|
||||
>>> validate_and_parse_request_args(request, MyValidator, extras={'internal_id': 123})
|
||||
({'param1': 'value'}, None) # internal_id removed from output
|
||||
|
||||
Notes:
|
||||
- Uses request.args.to_dict() for Flask-compatible parameter extraction
|
||||
- Maintains immutability of original request arguments
|
||||
- Preserves type conversion from Pydantic validation
|
||||
"""
|
||||
args = request.args.to_dict(flat=True)
|
||||
try:
|
||||
if extras is not None:
|
||||
args.update(extras)
|
||||
validated_args = validator(**args)
|
||||
except ValidationError as e:
|
||||
return None, format_validation_error_message(e)
|
||||
|
||||
parsed_args = validated_args.model_dump()
|
||||
if extras is not None:
|
||||
for key in list(parsed_args.keys()):
|
||||
if key in extras:
|
||||
del parsed_args[key]
|
||||
|
||||
return parsed_args, None
|
||||
|
||||
|
||||
def format_validation_error_message(e: ValidationError) -> str:
|
||||
"""
|
||||
Formats validation errors into a standardized string format.
|
||||
@ -143,6 +208,105 @@ def format_validation_error_message(e: ValidationError) -> str:
|
||||
return "\n".join(error_messages)
|
||||
|
||||
|
||||
def normalize_str(v: Any) -> Any:
|
||||
"""
|
||||
Normalizes string values to a standard format while preserving non-string inputs.
|
||||
|
||||
Performs the following transformations when input is a string:
|
||||
1. Trims leading/trailing whitespace (str.strip())
|
||||
2. Converts to lowercase (str.lower())
|
||||
|
||||
Non-string inputs are returned unchanged, making this function safe for mixed-type
|
||||
processing pipelines.
|
||||
|
||||
Args:
|
||||
v (Any): Input value to normalize. Accepts any Python object.
|
||||
|
||||
Returns:
|
||||
Any: Normalized string if input was string-type, original value otherwise.
|
||||
|
||||
Behavior Examples:
|
||||
String Input: " Admin " → "admin"
|
||||
Empty String: " " → "" (empty string)
|
||||
Non-String:
|
||||
- 123 → 123
|
||||
- None → None
|
||||
- ["User"] → ["User"]
|
||||
|
||||
Typical Use Cases:
|
||||
- Standardizing user input
|
||||
- Preparing data for case-insensitive comparison
|
||||
- Cleaning API parameters
|
||||
- Normalizing configuration values
|
||||
|
||||
Edge Cases:
|
||||
- Unicode whitespace is handled by str.strip()
|
||||
- Locale-independent lowercasing (str.lower())
|
||||
- Preserves falsy values (0, False, etc.)
|
||||
|
||||
Example:
|
||||
>>> normalize_str(" ReadOnly ")
|
||||
'readonly'
|
||||
>>> normalize_str(42)
|
||||
42
|
||||
"""
|
||||
if isinstance(v, str):
|
||||
stripped = v.strip()
|
||||
normalized = stripped.lower()
|
||||
return normalized
|
||||
return v
|
||||
|
||||
|
||||
def validate_uuid1_hex(v: Any) -> str:
|
||||
"""
|
||||
Validates and converts input to a UUID version 1 hexadecimal string.
|
||||
|
||||
This function performs strict validation and normalization:
|
||||
1. Accepts either UUID objects or UUID-formatted strings
|
||||
2. Verifies the UUID is version 1 (time-based)
|
||||
3. Returns the 32-character hexadecimal representation
|
||||
|
||||
Args:
|
||||
v (Any): Input value to validate. Can be:
|
||||
- UUID object (must be version 1)
|
||||
- String in UUID format (e.g. "550e8400-e29b-41d4-a716-446655440000")
|
||||
|
||||
Returns:
|
||||
str: 32-character lowercase hexadecimal string without hyphens
|
||||
Example: "550e8400e29b41d4a716446655440000"
|
||||
|
||||
Raises:
|
||||
PydanticCustomError: With code "invalid_UUID1_format" when:
|
||||
- Input is not a UUID object or valid UUID string
|
||||
- UUID version is not 1
|
||||
- String doesn't match UUID format
|
||||
|
||||
Examples:
|
||||
Valid cases:
|
||||
>>> validate_uuid1_hex("550e8400-e29b-41d4-a716-446655440000")
|
||||
'550e8400e29b41d4a716446655440000'
|
||||
>>> validate_uuid1_hex(UUID('550e8400-e29b-41d4-a716-446655440000'))
|
||||
'550e8400e29b41d4a716446655440000'
|
||||
|
||||
Invalid cases:
|
||||
>>> validate_uuid1_hex("not-a-uuid") # raises PydanticCustomError
|
||||
>>> validate_uuid1_hex(12345) # raises PydanticCustomError
|
||||
>>> validate_uuid1_hex(UUID(int=0)) # v4, raises PydanticCustomError
|
||||
|
||||
Notes:
|
||||
- Uses Python's built-in UUID parser for format validation
|
||||
- Version check prevents accidental use of other UUID versions
|
||||
- Hyphens in input strings are automatically removed in output
|
||||
"""
|
||||
try:
|
||||
uuid_obj = UUID(v) if isinstance(v, str) else v
|
||||
if uuid_obj.version != 1:
|
||||
raise PydanticCustomError("invalid_UUID1_format", "Must be a UUID1 format")
|
||||
return uuid_obj.hex
|
||||
except (AttributeError, ValueError, TypeError):
|
||||
raise PydanticCustomError("invalid_UUID1_format", "Invalid UUID1 format")
|
||||
|
||||
|
||||
class PermissionEnum(StrEnum):
|
||||
me = auto()
|
||||
team = auto()
|
||||
@ -217,8 +381,8 @@ class CreateDatasetReq(Base):
|
||||
avatar: str | None = Field(default=None, max_length=65535)
|
||||
description: str | None = Field(default=None, max_length=65535)
|
||||
embedding_model: Annotated[str, StringConstraints(strip_whitespace=True, max_length=255), Field(default="", serialization_alias="embd_id")]
|
||||
permission: Annotated[PermissionEnum, StringConstraints(strip_whitespace=True, min_length=1, max_length=16), Field(default=PermissionEnum.me)]
|
||||
chunk_method: Annotated[ChunkMethodnEnum, StringConstraints(strip_whitespace=True, min_length=1, max_length=32), Field(default=ChunkMethodnEnum.naive, serialization_alias="parser_id")]
|
||||
permission: PermissionEnum = Field(default=PermissionEnum.me, min_length=1, max_length=16)
|
||||
chunk_method: ChunkMethodnEnum = Field(default=ChunkMethodnEnum.naive, min_length=1, max_length=32, serialization_alias="parser_id")
|
||||
pagerank: int = Field(default=0, ge=0, le=100)
|
||||
parser_config: ParserConfig | None = Field(default=None)
|
||||
|
||||
@ -315,22 +479,8 @@ class CreateDatasetReq(Base):
|
||||
|
||||
@field_validator("permission", mode="before")
|
||||
@classmethod
|
||||
def permission_auto_lowercase(cls, v: Any) -> Any:
|
||||
"""
|
||||
Normalize permission input to lowercase for consistent PermissionEnum matching.
|
||||
|
||||
Args:
|
||||
v (Any): Raw input value for the permission field
|
||||
|
||||
Returns:
|
||||
Lowercase string if input is string type, otherwise returns original value
|
||||
|
||||
Behavior:
|
||||
- Converts string inputs to lowercase (e.g., "ME" → "me")
|
||||
- Non-string values pass through unchanged
|
||||
- Works in validation pre-processing stage (before enum conversion)
|
||||
"""
|
||||
return v.lower() if isinstance(v, str) else v
|
||||
def normalize_permission(cls, v: Any) -> Any:
|
||||
return normalize_str(v)
|
||||
|
||||
@field_validator("parser_config", mode="before")
|
||||
@classmethod
|
||||
@ -387,93 +537,117 @@ class CreateDatasetReq(Base):
|
||||
|
||||
|
||||
class UpdateDatasetReq(CreateDatasetReq):
|
||||
dataset_id: UUID1 = Field(...)
|
||||
dataset_id: str = Field(...)
|
||||
name: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1, max_length=DATASET_NAME_LIMIT), Field(default="")]
|
||||
|
||||
@field_serializer("dataset_id")
|
||||
def serialize_uuid_to_hex(self, v: uuid.UUID) -> str:
|
||||
"""
|
||||
Serializes a UUID version 1 object to its hexadecimal string representation.
|
||||
|
||||
This field serializer specifically handles UUID version 1 objects, converting them
|
||||
to their canonical 32-character hexadecimal format without hyphens. The conversion
|
||||
is designed for consistent serialization in API responses and database storage.
|
||||
|
||||
Args:
|
||||
v (uuid.UUID1): The UUID version 1 object to serialize. Must be a valid
|
||||
UUID1 instance generated by Python's uuid module.
|
||||
|
||||
Returns:
|
||||
str: 32-character lowercase hexadecimal string representation
|
||||
Example: "550e8400e29b41d4a716446655440000"
|
||||
|
||||
Raises:
|
||||
AttributeError: If input is not a proper UUID object (missing hex attribute)
|
||||
TypeError: If input is not a UUID1 instance (when type checking is enabled)
|
||||
|
||||
Notes:
|
||||
- Version 1 UUIDs contain timestamp and MAC address information
|
||||
- The .hex property automatically converts to lowercase hexadecimal
|
||||
- For cross-version compatibility, consider typing as uuid.UUID instead
|
||||
"""
|
||||
return v.hex
|
||||
@field_validator("dataset_id", mode="before")
|
||||
@classmethod
|
||||
def validate_dataset_id(cls, v: Any) -> str:
|
||||
return validate_uuid1_hex(v)
|
||||
|
||||
|
||||
class DeleteReq(Base):
|
||||
ids: list[UUID1] | None = Field(...)
|
||||
ids: list[str] | None = Field(...)
|
||||
|
||||
@field_validator("ids", mode="after")
|
||||
def check_duplicate_ids(cls, v: list[UUID1] | None) -> list[str] | None:
|
||||
@classmethod
|
||||
def validate_ids(cls, v_list: list[str] | None) -> list[str] | None:
|
||||
"""
|
||||
Validates and converts a list of UUID1 objects to hexadecimal strings while checking for duplicates.
|
||||
Validates and normalizes a list of UUID strings with None handling.
|
||||
|
||||
This validator implements a three-stage processing pipeline:
|
||||
1. Null Handling - returns None for empty/null input
|
||||
2. UUID Conversion - transforms UUID objects to hex strings
|
||||
3. Duplicate Validation - ensures all IDs are unique
|
||||
|
||||
Behavior Specifications:
|
||||
- Input: None → Returns None (indicates no operation)
|
||||
- Input: [] → Returns [] (empty list for explicit no-op)
|
||||
- Input: [UUID1,...] → Returns validated hex strings
|
||||
- Duplicates: Raises formatted PydanticCustomError
|
||||
This post-processing validator performs:
|
||||
1. None input handling (pass-through)
|
||||
2. UUID version 1 validation for each list item
|
||||
3. Duplicate value detection
|
||||
4. Returns normalized UUID hex strings or None
|
||||
|
||||
Args:
|
||||
v (list[UUID1] | None):
|
||||
- None: Indicates no datasets should be processed
|
||||
- Empty list: Explicit empty operation
|
||||
- Populated list: Dataset UUIDs to validate/convert
|
||||
v_list (list[str] | None): Input list that has passed initial validation.
|
||||
Either a list of UUID strings or None.
|
||||
|
||||
Returns:
|
||||
list[str] | None:
|
||||
- None when input is None
|
||||
- List of 32-character hex strings (lowercase, no hyphens)
|
||||
Example: ["550e8400e29b41d4a716446655440000"]
|
||||
- None if input was None
|
||||
- List of normalized UUID hex strings otherwise:
|
||||
* 32-character lowercase
|
||||
* Valid UUID version 1
|
||||
* Unique within list
|
||||
|
||||
Raises:
|
||||
PydanticCustomError: When duplicates detected, containing:
|
||||
- Error type: "duplicate_uuids"
|
||||
- Template message: "Duplicate ids: '{duplicate_ids}'"
|
||||
- Context: {"duplicate_ids": "id1, id2, ..."}
|
||||
PydanticCustomError: With structured error details when:
|
||||
- "invalid_UUID1_format": Any string fails UUIDv1 validation
|
||||
- "duplicate_uuids": If duplicate IDs are detected
|
||||
|
||||
Example:
|
||||
>>> validate([UUID("..."), UUID("...")])
|
||||
["2cdf0456e9a711ee8000000000000000", ...]
|
||||
Validation Rules:
|
||||
- None input returns None
|
||||
- Empty list returns empty list
|
||||
- All non-None items must be valid UUIDv1
|
||||
- No duplicates permitted
|
||||
- Original order preserved
|
||||
|
||||
>>> validate([UUID("..."), UUID("...")]) # Duplicates
|
||||
PydanticCustomError: Duplicate ids: '2cdf0456e9a711ee8000000000000000'
|
||||
Examples:
|
||||
Valid cases:
|
||||
>>> validate_ids(None)
|
||||
None
|
||||
>>> validate_ids([])
|
||||
[]
|
||||
>>> validate_ids(["550e8400-e29b-41d4-a716-446655440000"])
|
||||
["550e8400e29b41d4a716446655440000"]
|
||||
|
||||
Invalid cases:
|
||||
>>> validate_ids(["invalid"])
|
||||
# raises PydanticCustomError(invalid_UUID1_format)
|
||||
>>> validate_ids(["550e...", "550e..."])
|
||||
# raises PydanticCustomError(duplicate_uuids)
|
||||
|
||||
Security Notes:
|
||||
- Validates UUID version to prevent version spoofing
|
||||
- Duplicate check prevents data injection
|
||||
- None handling maintains pipeline integrity
|
||||
"""
|
||||
if not v:
|
||||
return v
|
||||
if v_list is None:
|
||||
return None
|
||||
|
||||
uuid_hex_list = [ids.hex for ids in v]
|
||||
duplicates = [item for item, count in Counter(uuid_hex_list).items() if count > 1]
|
||||
ids_list = []
|
||||
for v in v_list:
|
||||
try:
|
||||
ids_list.append(validate_uuid1_hex(v))
|
||||
except PydanticCustomError as e:
|
||||
raise e
|
||||
|
||||
duplicates = [item for item, count in Counter(ids_list).items() if count > 1]
|
||||
if duplicates:
|
||||
duplicates_str = ", ".join(duplicates)
|
||||
raise PydanticCustomError("duplicate_uuids", "Duplicate ids: '{duplicate_ids}'", {"duplicate_ids": duplicates_str})
|
||||
|
||||
return uuid_hex_list
|
||||
return ids_list
|
||||
|
||||
|
||||
class DeleteDatasetReq(DeleteReq): ...
|
||||
|
||||
|
||||
class OrderByEnum(StrEnum):
|
||||
create_time = auto()
|
||||
update_time = auto()
|
||||
|
||||
|
||||
class BaseListReq(Base):
|
||||
id: str | None = None
|
||||
name: str | None = None
|
||||
page: int = Field(default=1, ge=1)
|
||||
page_size: int = Field(default=30, ge=1)
|
||||
orderby: OrderByEnum = Field(default=OrderByEnum.create_time)
|
||||
desc: bool = Field(default=True)
|
||||
|
||||
@field_validator("id", mode="before")
|
||||
@classmethod
|
||||
def validate_id(cls, v: Any) -> str:
|
||||
return validate_uuid1_hex(v)
|
||||
|
||||
@field_validator("orderby", mode="before")
|
||||
@classmethod
|
||||
def normalize_orderby(cls, v: Any) -> Any:
|
||||
return normalize_str(v)
|
||||
|
||||
|
||||
class ListDatasetReq(BaseListReq): ...
|
||||
|
@ -122,7 +122,7 @@ class TestDatasetCreate:
|
||||
assert res["code"] == 0, res
|
||||
|
||||
res = create_dataset(get_http_api_auth, payload)
|
||||
assert res["code"] == 102, res
|
||||
assert res["code"] == 103, res
|
||||
assert res["message"] == f"Dataset name '{name}' already exists", res
|
||||
|
||||
@pytest.mark.p3
|
||||
@ -134,7 +134,7 @@ class TestDatasetCreate:
|
||||
|
||||
payload = {"name": name.lower()}
|
||||
res = create_dataset(get_http_api_auth, payload)
|
||||
assert res["code"] == 102, res
|
||||
assert res["code"] == 103, res
|
||||
assert res["message"] == f"Dataset name '{name.lower()}' already exists", res
|
||||
|
||||
@pytest.mark.p2
|
||||
@ -296,14 +296,15 @@ class TestDatasetCreate:
|
||||
("team", "team"),
|
||||
("me_upercase", "ME"),
|
||||
("team_upercase", "TEAM"),
|
||||
("whitespace", " ME "),
|
||||
],
|
||||
ids=["me", "team", "me_upercase", "team_upercase"],
|
||||
ids=["me", "team", "me_upercase", "team_upercase", "whitespace"],
|
||||
)
|
||||
def test_permission(self, get_http_api_auth, name, permission):
|
||||
payload = {"name": name, "permission": permission}
|
||||
res = create_dataset(get_http_api_auth, payload)
|
||||
assert res["code"] == 0, res
|
||||
assert res["data"]["permission"] == permission.lower(), res
|
||||
assert res["data"]["permission"] == permission.lower().strip(), res
|
||||
|
||||
@pytest.mark.p2
|
||||
@pytest.mark.parametrize(
|
||||
|
@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import uuid
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import pytest
|
||||
@ -40,8 +41,8 @@ class TestAuthorization:
|
||||
)
|
||||
def test_auth_invalid(self, auth, expected_code, expected_message):
|
||||
res = delete_datasets(auth)
|
||||
assert res["code"] == expected_code
|
||||
assert res["message"] == expected_message
|
||||
assert res["code"] == expected_code, res
|
||||
assert res["message"] == expected_message, res
|
||||
|
||||
|
||||
class TestRquest:
|
||||
@ -140,17 +141,25 @@ class TestDatasetsDelete:
|
||||
payload = {"ids": ["not_uuid"]}
|
||||
res = delete_datasets(get_http_api_auth, payload)
|
||||
assert res["code"] == 101, res
|
||||
assert "Input should be a valid UUID" in res["message"], res
|
||||
assert "Invalid UUID1 format" in res["message"], res
|
||||
|
||||
res = list_datasets(get_http_api_auth)
|
||||
assert len(res["data"]) == 1, res
|
||||
|
||||
@pytest.mark.p3
|
||||
@pytest.mark.usefixtures("add_dataset_func")
|
||||
def test_id_not_uuid1(self, get_http_api_auth):
|
||||
payload = {"ids": [uuid.uuid4().hex]}
|
||||
res = delete_datasets(get_http_api_auth, payload)
|
||||
assert res["code"] == 101, res
|
||||
assert "Invalid UUID1 format" in res["message"], res
|
||||
|
||||
@pytest.mark.p2
|
||||
@pytest.mark.usefixtures("add_dataset_func")
|
||||
def test_id_wrong_uuid(self, get_http_api_auth):
|
||||
payload = {"ids": ["d94a8dc02c9711f0930f7fbc369eab6d"]}
|
||||
res = delete_datasets(get_http_api_auth, payload)
|
||||
assert res["code"] == 102, res
|
||||
assert res["code"] == 108, res
|
||||
assert "lacks permission for dataset" in res["message"], res
|
||||
|
||||
res = list_datasets(get_http_api_auth)
|
||||
@ -170,7 +179,7 @@ class TestDatasetsDelete:
|
||||
if callable(func):
|
||||
payload = func(dataset_ids)
|
||||
res = delete_datasets(get_http_api_auth, payload)
|
||||
assert res["code"] == 102, res
|
||||
assert res["code"] == 108, res
|
||||
assert "lacks permission for dataset" in res["message"], res
|
||||
|
||||
res = list_datasets(get_http_api_auth)
|
||||
@ -195,7 +204,7 @@ class TestDatasetsDelete:
|
||||
assert res["code"] == 0, res
|
||||
|
||||
res = delete_datasets(get_http_api_auth, payload)
|
||||
assert res["code"] == 102, res
|
||||
assert res["code"] == 108, res
|
||||
assert "lacks permission for dataset" in res["message"], res
|
||||
|
||||
@pytest.mark.p2
|
||||
|
@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import uuid
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import pytest
|
||||
@ -21,8 +22,8 @@ from libs.auth import RAGFlowHttpApiAuth
|
||||
from libs.utils import is_sorted
|
||||
|
||||
|
||||
@pytest.mark.p1
|
||||
class TestAuthorization:
|
||||
@pytest.mark.p1
|
||||
@pytest.mark.parametrize(
|
||||
"auth, expected_code, expected_message",
|
||||
[
|
||||
@ -34,269 +35,305 @@ class TestAuthorization:
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_invalid_auth(self, auth, expected_code, expected_message):
|
||||
def test_auth_invalid(self, auth, expected_code, expected_message):
|
||||
res = list_datasets(auth)
|
||||
assert res["code"] == expected_code
|
||||
assert res["message"] == expected_message
|
||||
assert res["code"] == expected_code, res
|
||||
assert res["message"] == expected_message, res
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("add_datasets")
|
||||
class TestDatasetsList:
|
||||
@pytest.mark.p1
|
||||
def test_default(self, get_http_api_auth):
|
||||
res = list_datasets(get_http_api_auth, params={})
|
||||
|
||||
assert res["code"] == 0
|
||||
assert len(res["data"]) == 5
|
||||
|
||||
@pytest.mark.p1
|
||||
@pytest.mark.parametrize(
|
||||
"params, expected_code, expected_page_size, expected_message",
|
||||
[
|
||||
({"page": None, "page_size": 2}, 0, 2, ""),
|
||||
({"page": 0, "page_size": 2}, 0, 2, ""),
|
||||
({"page": 2, "page_size": 2}, 0, 2, ""),
|
||||
({"page": 3, "page_size": 2}, 0, 1, ""),
|
||||
({"page": "3", "page_size": 2}, 0, 1, ""),
|
||||
pytest.param(
|
||||
{"page": -1, "page_size": 2},
|
||||
100,
|
||||
0,
|
||||
"1064",
|
||||
marks=pytest.mark.skip(reason="issues/5851"),
|
||||
),
|
||||
pytest.param(
|
||||
{"page": "a", "page_size": 2},
|
||||
100,
|
||||
0,
|
||||
"""ValueError("invalid literal for int() with base 10: \'a\'")""",
|
||||
marks=pytest.mark.skip(reason="issues/5851"),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_page(self, get_http_api_auth, params, expected_code, expected_page_size, expected_message):
|
||||
res = list_datasets(get_http_api_auth, params=params)
|
||||
assert res["code"] == expected_code
|
||||
if expected_code == 0:
|
||||
assert len(res["data"]) == expected_page_size
|
||||
else:
|
||||
assert res["message"] == expected_message
|
||||
|
||||
@pytest.mark.p1
|
||||
@pytest.mark.parametrize(
|
||||
"params, expected_code, expected_page_size, expected_message",
|
||||
[
|
||||
({"page_size": None}, 0, 5, ""),
|
||||
({"page_size": 0}, 0, 0, ""),
|
||||
({"page_size": 1}, 0, 1, ""),
|
||||
({"page_size": 6}, 0, 5, ""),
|
||||
({"page_size": "1"}, 0, 1, ""),
|
||||
pytest.param(
|
||||
{"page_size": -1},
|
||||
100,
|
||||
0,
|
||||
"1064",
|
||||
marks=pytest.mark.skip(reason="issues/5851"),
|
||||
),
|
||||
pytest.param(
|
||||
{"page_size": "a"},
|
||||
100,
|
||||
0,
|
||||
"""ValueError("invalid literal for int() with base 10: \'a\'")""",
|
||||
marks=pytest.mark.skip(reason="issues/5851"),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_page_size(
|
||||
self,
|
||||
get_http_api_auth,
|
||||
params,
|
||||
expected_code,
|
||||
expected_page_size,
|
||||
expected_message,
|
||||
):
|
||||
res = list_datasets(get_http_api_auth, params=params)
|
||||
assert res["code"] == expected_code
|
||||
if expected_code == 0:
|
||||
assert len(res["data"]) == expected_page_size
|
||||
else:
|
||||
assert res["message"] == expected_message
|
||||
|
||||
@pytest.mark.p3
|
||||
@pytest.mark.parametrize(
|
||||
"params, expected_code, assertions, expected_message",
|
||||
[
|
||||
({"orderby": None}, 0, lambda r: (is_sorted(r["data"], "create_time", True)), ""),
|
||||
({"orderby": "create_time"}, 0, lambda r: (is_sorted(r["data"], "create_time", True)), ""),
|
||||
({"orderby": "update_time"}, 0, lambda r: (is_sorted(r["data"], "update_time", True)), ""),
|
||||
pytest.param(
|
||||
{"orderby": "name", "desc": "False"},
|
||||
0,
|
||||
lambda r: (is_sorted(r["data"]["docs"], "name", False)),
|
||||
"",
|
||||
marks=pytest.mark.skip(reason="issues/5851"),
|
||||
),
|
||||
pytest.param(
|
||||
{"orderby": "unknown"},
|
||||
102,
|
||||
0,
|
||||
"orderby should be create_time or update_time",
|
||||
marks=pytest.mark.skip(reason="issues/5851"),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_orderby(
|
||||
self,
|
||||
get_http_api_auth,
|
||||
params,
|
||||
expected_code,
|
||||
assertions,
|
||||
expected_message,
|
||||
):
|
||||
res = list_datasets(get_http_api_auth, params=params)
|
||||
assert res["code"] == expected_code
|
||||
if expected_code == 0:
|
||||
if callable(assertions):
|
||||
assert assertions(res)
|
||||
else:
|
||||
assert res["message"] == expected_message
|
||||
|
||||
@pytest.mark.p3
|
||||
@pytest.mark.parametrize(
|
||||
"params, expected_code, assertions, expected_message",
|
||||
[
|
||||
({"desc": None}, 0, lambda r: (is_sorted(r["data"], "create_time", True)), ""),
|
||||
({"desc": "true"}, 0, lambda r: (is_sorted(r["data"], "create_time", True)), ""),
|
||||
({"desc": "True"}, 0, lambda r: (is_sorted(r["data"], "create_time", True)), ""),
|
||||
({"desc": True}, 0, lambda r: (is_sorted(r["data"], "create_time", True)), ""),
|
||||
({"desc": "false"}, 0, lambda r: (is_sorted(r["data"], "create_time", False)), ""),
|
||||
({"desc": "False"}, 0, lambda r: (is_sorted(r["data"], "create_time", False)), ""),
|
||||
({"desc": False}, 0, lambda r: (is_sorted(r["data"], "create_time", False)), ""),
|
||||
({"desc": "False", "orderby": "update_time"}, 0, lambda r: (is_sorted(r["data"], "update_time", False)), ""),
|
||||
pytest.param(
|
||||
{"desc": "unknown"},
|
||||
102,
|
||||
0,
|
||||
"desc should be true or false",
|
||||
marks=pytest.mark.skip(reason="issues/5851"),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_desc(
|
||||
self,
|
||||
get_http_api_auth,
|
||||
params,
|
||||
expected_code,
|
||||
assertions,
|
||||
expected_message,
|
||||
):
|
||||
res = list_datasets(get_http_api_auth, params=params)
|
||||
assert res["code"] == expected_code
|
||||
if expected_code == 0:
|
||||
if callable(assertions):
|
||||
assert assertions(res)
|
||||
else:
|
||||
assert res["message"] == expected_message
|
||||
|
||||
@pytest.mark.p1
|
||||
@pytest.mark.parametrize(
|
||||
"params, expected_code, expected_num, expected_message",
|
||||
[
|
||||
({"name": None}, 0, 5, ""),
|
||||
({"name": ""}, 0, 5, ""),
|
||||
({"name": "dataset_1"}, 0, 1, ""),
|
||||
({"name": "unknown"}, 102, 0, "You don't own the dataset unknown"),
|
||||
],
|
||||
)
|
||||
def test_name(self, get_http_api_auth, params, expected_code, expected_num, expected_message):
|
||||
res = list_datasets(get_http_api_auth, params=params)
|
||||
assert res["code"] == expected_code
|
||||
if expected_code == 0:
|
||||
if params["name"] in [None, ""]:
|
||||
assert len(res["data"]) == expected_num
|
||||
else:
|
||||
assert res["data"][0]["name"] == params["name"]
|
||||
else:
|
||||
assert res["message"] == expected_message
|
||||
|
||||
@pytest.mark.p1
|
||||
@pytest.mark.parametrize(
|
||||
"dataset_id, expected_code, expected_num, expected_message",
|
||||
[
|
||||
(None, 0, 5, ""),
|
||||
("", 0, 5, ""),
|
||||
(lambda r: r[0], 0, 1, ""),
|
||||
("unknown", 102, 0, "You don't own the dataset unknown"),
|
||||
],
|
||||
)
|
||||
def test_id(
|
||||
self,
|
||||
get_http_api_auth,
|
||||
add_datasets,
|
||||
dataset_id,
|
||||
expected_code,
|
||||
expected_num,
|
||||
expected_message,
|
||||
):
|
||||
dataset_ids = add_datasets
|
||||
if callable(dataset_id):
|
||||
params = {"id": dataset_id(dataset_ids)}
|
||||
else:
|
||||
params = {"id": dataset_id}
|
||||
|
||||
res = list_datasets(get_http_api_auth, params=params)
|
||||
assert res["code"] == expected_code
|
||||
if expected_code == 0:
|
||||
if params["id"] in [None, ""]:
|
||||
assert len(res["data"]) == expected_num
|
||||
else:
|
||||
assert res["data"][0]["id"] == params["id"]
|
||||
else:
|
||||
assert res["message"] == expected_message
|
||||
|
||||
@pytest.mark.p3
|
||||
@pytest.mark.parametrize(
|
||||
"dataset_id, name, expected_code, expected_num, expected_message",
|
||||
[
|
||||
(lambda r: r[0], "dataset_0", 0, 1, ""),
|
||||
(lambda r: r[0], "dataset_1", 0, 0, ""),
|
||||
(lambda r: r[0], "unknown", 102, 0, "You don't own the dataset unknown"),
|
||||
("id", "dataset_0", 102, 0, "You don't own the dataset id"),
|
||||
],
|
||||
)
|
||||
def test_name_and_id(
|
||||
self,
|
||||
get_http_api_auth,
|
||||
add_datasets,
|
||||
dataset_id,
|
||||
name,
|
||||
expected_code,
|
||||
expected_num,
|
||||
expected_message,
|
||||
):
|
||||
dataset_ids = add_datasets
|
||||
if callable(dataset_id):
|
||||
params = {"id": dataset_id(dataset_ids), "name": name}
|
||||
else:
|
||||
params = {"id": dataset_id, "name": name}
|
||||
|
||||
res = list_datasets(get_http_api_auth, params=params)
|
||||
assert res["code"] == expected_code
|
||||
if expected_code == 0:
|
||||
assert len(res["data"]) == expected_num
|
||||
else:
|
||||
assert res["message"] == expected_message
|
||||
|
||||
class TestCapability:
|
||||
@pytest.mark.p3
|
||||
def test_concurrent_list(self, get_http_api_auth):
|
||||
with ThreadPoolExecutor(max_workers=5) as executor:
|
||||
futures = [executor.submit(list_datasets, get_http_api_auth) for i in range(100)]
|
||||
responses = [f.result() for f in futures]
|
||||
assert all(r["code"] == 0 for r in responses)
|
||||
assert all(r["code"] == 0 for r in responses), responses
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("add_datasets")
|
||||
class TestDatasetsList:
|
||||
@pytest.mark.p1
|
||||
def test_params_unset(self, get_http_api_auth):
|
||||
res = list_datasets(get_http_api_auth, None)
|
||||
assert res["code"] == 0, res
|
||||
assert len(res["data"]) == 5, res
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_params_empty(self, get_http_api_auth):
|
||||
res = list_datasets(get_http_api_auth, {})
|
||||
assert res["code"] == 0, res
|
||||
assert len(res["data"]) == 5, res
|
||||
|
||||
@pytest.mark.p1
|
||||
@pytest.mark.parametrize(
|
||||
"params, expected_page_size",
|
||||
[
|
||||
({"page": 2, "page_size": 2}, 2),
|
||||
({"page": 3, "page_size": 2}, 1),
|
||||
({"page": 4, "page_size": 2}, 0),
|
||||
({"page": "2", "page_size": 2}, 2),
|
||||
({"page": 1, "page_size": 10}, 5),
|
||||
],
|
||||
ids=["normal_middle_page", "normal_last_partial_page", "beyond_max_page", "string_page_number", "full_data_single_page"],
|
||||
)
|
||||
def test_page(self, get_http_api_auth, params, expected_page_size):
|
||||
res = list_datasets(get_http_api_auth, params)
|
||||
assert res["code"] == 0, res
|
||||
assert len(res["data"]) == expected_page_size, res
|
||||
|
||||
@pytest.mark.p2
|
||||
@pytest.mark.parametrize(
|
||||
"params, expected_code, expected_message",
|
||||
[
|
||||
({"page": 0}, 101, "Input should be greater than or equal to 1"),
|
||||
({"page": "a"}, 101, "Input should be a valid integer, unable to parse string as an integer"),
|
||||
],
|
||||
ids=["page_0", "page_a"],
|
||||
)
|
||||
def test_page_invalid(self, get_http_api_auth, params, expected_code, expected_message):
|
||||
res = list_datasets(get_http_api_auth, params=params)
|
||||
assert res["code"] == expected_code, res
|
||||
assert expected_message in res["message"], res
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_page_none(self, get_http_api_auth):
|
||||
params = {"page": None}
|
||||
res = list_datasets(get_http_api_auth, params)
|
||||
assert res["code"] == 0, res
|
||||
assert len(res["data"]) == 5, res
|
||||
|
||||
@pytest.mark.p1
|
||||
@pytest.mark.parametrize(
|
||||
"params, expected_page_size",
|
||||
[
|
||||
({"page_size": 1}, 1),
|
||||
({"page_size": 3}, 3),
|
||||
({"page_size": 5}, 5),
|
||||
({"page_size": 6}, 5),
|
||||
({"page_size": "1"}, 1),
|
||||
],
|
||||
ids=["min_valid_page_size", "medium_page_size", "page_size_equals_total", "page_size_exceeds_total", "string_type_page_size"],
|
||||
)
|
||||
def test_page_size(self, get_http_api_auth, params, expected_page_size):
|
||||
res = list_datasets(get_http_api_auth, params)
|
||||
assert res["code"] == 0, res
|
||||
assert len(res["data"]) == expected_page_size, res
|
||||
|
||||
@pytest.mark.p2
|
||||
@pytest.mark.parametrize(
|
||||
"params, expected_code, expected_message",
|
||||
[
|
||||
({"page_size": 0}, 101, "Input should be greater than or equal to 1"),
|
||||
({"page_size": "a"}, 101, "Input should be a valid integer, unable to parse string as an integer"),
|
||||
],
|
||||
)
|
||||
def test_page_size_invalid(self, get_http_api_auth, params, expected_code, expected_message):
|
||||
res = list_datasets(get_http_api_auth, params)
|
||||
assert res["code"] == expected_code, res
|
||||
assert expected_message in res["message"], res
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_page_size_none(self, get_http_api_auth):
|
||||
params = {"page_size": None}
|
||||
res = list_datasets(get_http_api_auth, params)
|
||||
assert res["code"] == 0, res
|
||||
assert len(res["data"]) == 5, res
|
||||
|
||||
@pytest.mark.p2
|
||||
@pytest.mark.parametrize(
|
||||
"params, assertions",
|
||||
[
|
||||
({"orderby": "create_time"}, lambda r: (is_sorted(r["data"], "create_time", True))),
|
||||
({"orderby": "update_time"}, lambda r: (is_sorted(r["data"], "update_time", True))),
|
||||
({"orderby": "CREATE_TIME"}, lambda r: (is_sorted(r["data"], "create_time", True))),
|
||||
({"orderby": "UPDATE_TIME"}, lambda r: (is_sorted(r["data"], "update_time", True))),
|
||||
({"orderby": " create_time "}, lambda r: (is_sorted(r["data"], "update_time", True))),
|
||||
],
|
||||
ids=["orderby_create_time", "orderby_update_time", "orderby_create_time_upper", "orderby_update_time_upper", "whitespace"],
|
||||
)
|
||||
def test_orderby(self, get_http_api_auth, params, assertions):
|
||||
res = list_datasets(get_http_api_auth, params)
|
||||
assert res["code"] == 0, res
|
||||
if callable(assertions):
|
||||
assert assertions(res), res
|
||||
|
||||
@pytest.mark.p3
|
||||
def test_invalid_params(self, get_http_api_auth):
|
||||
params = {"a": "b"}
|
||||
res = list_datasets(get_http_api_auth, params=params)
|
||||
@pytest.mark.parametrize(
|
||||
"params",
|
||||
[
|
||||
{"orderby": ""},
|
||||
{"orderby": "unknown"},
|
||||
],
|
||||
ids=["empty", "unknown"],
|
||||
)
|
||||
def test_orderby_invalid(self, get_http_api_auth, params):
|
||||
res = list_datasets(get_http_api_auth, params)
|
||||
assert res["code"] == 101, res
|
||||
assert "Input should be 'create_time' or 'update_time'" in res["message"], res
|
||||
|
||||
@pytest.mark.p3
|
||||
def test_orderby_none(self, get_http_api_auth):
|
||||
params = {"order_by": None}
|
||||
res = list_datasets(get_http_api_auth, params)
|
||||
assert res["code"] == 0, res
|
||||
assert is_sorted(res["data"], "create_time", True), res
|
||||
|
||||
@pytest.mark.p2
|
||||
@pytest.mark.parametrize(
|
||||
"params, assertions",
|
||||
[
|
||||
({"desc": True}, lambda r: (is_sorted(r["data"], "create_time", True))),
|
||||
({"desc": False}, lambda r: (is_sorted(r["data"], "create_time", False))),
|
||||
({"desc": "true"}, lambda r: (is_sorted(r["data"], "create_time", True))),
|
||||
({"desc": "false"}, lambda r: (is_sorted(r["data"], "create_time", False))),
|
||||
({"desc": 1}, lambda r: (is_sorted(r["data"], "create_time", True))),
|
||||
({"desc": 0}, lambda r: (is_sorted(r["data"], "create_time", False))),
|
||||
({"desc": "yes"}, lambda r: (is_sorted(r["data"], "create_time", True))),
|
||||
({"desc": "no"}, lambda r: (is_sorted(r["data"], "create_time", False))),
|
||||
({"desc": "y"}, lambda r: (is_sorted(r["data"], "create_time", True))),
|
||||
({"desc": "n"}, lambda r: (is_sorted(r["data"], "create_time", False))),
|
||||
],
|
||||
ids=["desc=True", "desc=False", "desc=true", "desc=false", "desc=1", "desc=0", "desc=yes", "desc=no", "desc=y", "desc=n"],
|
||||
)
|
||||
def test_desc(self, get_http_api_auth, params, assertions):
|
||||
res = list_datasets(get_http_api_auth, params)
|
||||
assert res["code"] == 0, res
|
||||
if callable(assertions):
|
||||
assert assertions(res), res
|
||||
|
||||
@pytest.mark.p3
|
||||
@pytest.mark.parametrize(
|
||||
"params",
|
||||
[
|
||||
{"desc": 3.14},
|
||||
{"desc": "unknown"},
|
||||
],
|
||||
ids=["empty", "unknown"],
|
||||
)
|
||||
def test_desc_invalid(self, get_http_api_auth, params):
|
||||
res = list_datasets(get_http_api_auth, params)
|
||||
assert res["code"] == 101, res
|
||||
assert "Input should be a valid boolean, unable to interpret input" in res["message"], res
|
||||
|
||||
@pytest.mark.p3
|
||||
def test_desc_none(self, get_http_api_auth):
|
||||
params = {"desc": None}
|
||||
res = list_datasets(get_http_api_auth, params)
|
||||
assert res["code"] == 0, res
|
||||
assert is_sorted(res["data"], "create_time", True), res
|
||||
|
||||
@pytest.mark.p1
|
||||
def test_name(self, get_http_api_auth):
|
||||
params = {"name": "dataset_1"}
|
||||
res = list_datasets(get_http_api_auth, params)
|
||||
assert res["code"] == 0, res
|
||||
assert len(res["data"]) == 1, res
|
||||
assert res["data"][0]["name"] == "dataset_1", res
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_name_wrong(self, get_http_api_auth):
|
||||
params = {"name": "wrong name"}
|
||||
res = list_datasets(get_http_api_auth, params)
|
||||
assert res["code"] == 108, res
|
||||
assert "lacks permission for dataset" in res["message"], res
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_name_empty(self, get_http_api_auth):
|
||||
params = {"name": ""}
|
||||
res = list_datasets(get_http_api_auth, params)
|
||||
assert res["code"] == 0, res
|
||||
assert len(res["data"]) == 5, res
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_name_none(self, get_http_api_auth):
|
||||
params = {"name": None}
|
||||
res = list_datasets(get_http_api_auth, params)
|
||||
assert res["code"] == 0, res
|
||||
assert len(res["data"]) == 5, res
|
||||
|
||||
@pytest.mark.p1
|
||||
def test_id(self, get_http_api_auth, add_datasets):
|
||||
dataset_ids = add_datasets
|
||||
params = {"id": dataset_ids[0]}
|
||||
res = list_datasets(get_http_api_auth, params)
|
||||
assert res["code"] == 0
|
||||
assert len(res["data"]) == 5
|
||||
assert len(res["data"]) == 1
|
||||
assert res["data"][0]["id"] == dataset_ids[0]
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_id_not_uuid(self, get_http_api_auth):
|
||||
params = {"id": "not_uuid"}
|
||||
res = list_datasets(get_http_api_auth, params)
|
||||
assert res["code"] == 101, res
|
||||
assert "Invalid UUID1 format" in res["message"], res
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_id_not_uuid1(self, get_http_api_auth):
|
||||
params = {"id": uuid.uuid4().hex}
|
||||
res = list_datasets(get_http_api_auth, params)
|
||||
assert res["code"] == 101, res
|
||||
assert "Invalid UUID1 format" in res["message"], res
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_id_wrong_uuid(self, get_http_api_auth):
|
||||
params = {"id": "d94a8dc02c9711f0930f7fbc369eab6d"}
|
||||
res = list_datasets(get_http_api_auth, params)
|
||||
assert res["code"] == 108, res
|
||||
assert "lacks permission for dataset" in res["message"], res
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_id_empty(self, get_http_api_auth):
|
||||
params = {"id": ""}
|
||||
res = list_datasets(get_http_api_auth, params)
|
||||
assert res["code"] == 101, res
|
||||
assert "Invalid UUID1 format" in res["message"], res
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_id_none(self, get_http_api_auth):
|
||||
params = {"id": None}
|
||||
res = list_datasets(get_http_api_auth, params)
|
||||
assert res["code"] == 0, res
|
||||
assert len(res["data"]) == 5, res
|
||||
|
||||
@pytest.mark.p2
|
||||
@pytest.mark.parametrize(
|
||||
"func, name, expected_num",
|
||||
[
|
||||
(lambda r: r[0], "dataset_0", 1),
|
||||
(lambda r: r[0], "dataset_1", 0),
|
||||
],
|
||||
ids=["name_and_id_match", "name_and_id_mismatch"],
|
||||
)
|
||||
def test_name_and_id(self, get_http_api_auth, add_datasets, func, name, expected_num):
|
||||
dataset_ids = add_datasets
|
||||
if callable(func):
|
||||
params = {"id": func(dataset_ids), "name": name}
|
||||
res = list_datasets(get_http_api_auth, params)
|
||||
assert res["code"] == 0, res
|
||||
assert len(res["data"]) == expected_num, res
|
||||
|
||||
@pytest.mark.p3
|
||||
@pytest.mark.parametrize(
|
||||
"dataset_id, name",
|
||||
[
|
||||
(lambda r: r[0], "wrong_name"),
|
||||
(uuid.uuid1().hex, "dataset_0"),
|
||||
],
|
||||
ids=["name", "id"],
|
||||
)
|
||||
def test_name_and_id_wrong(self, get_http_api_auth, add_datasets, dataset_id, name):
|
||||
dataset_ids = add_datasets
|
||||
if callable(dataset_id):
|
||||
params = {"id": dataset_id(dataset_ids), "name": name}
|
||||
else:
|
||||
params = {"id": dataset_id, "name": name}
|
||||
res = list_datasets(get_http_api_auth, params)
|
||||
assert res["code"] == 108, res
|
||||
assert "lacks permission for dataset" in res["message"], res
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_field_unsupported(self, get_http_api_auth):
|
||||
params = {"unknown_field": "unknown_field"}
|
||||
res = list_datasets(get_http_api_auth, params)
|
||||
assert res["code"] == 101, res
|
||||
assert "Extra inputs are not permitted" in res["message"], res
|
||||
|
@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import uuid
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import pytest
|
||||
@ -98,16 +99,23 @@ class TestCapability:
|
||||
class TestDatasetUpdate:
|
||||
@pytest.mark.p3
|
||||
def test_dataset_id_not_uuid(self, get_http_api_auth):
|
||||
payload = {"name": "not_uuid"}
|
||||
payload = {"name": "not uuid"}
|
||||
res = update_dataset(get_http_api_auth, "not_uuid", payload)
|
||||
assert res["code"] == 101, res
|
||||
assert "Input should be a valid UUID" in res["message"], res
|
||||
assert "Invalid UUID1 format" in res["message"], res
|
||||
|
||||
@pytest.mark.p3
|
||||
def test_dataset_id_not_uuid1(self, get_http_api_auth):
|
||||
payload = {"name": "not uuid1"}
|
||||
res = update_dataset(get_http_api_auth, uuid.uuid4().hex, payload)
|
||||
assert res["code"] == 101, res
|
||||
assert "Invalid UUID1 format" in res["message"], res
|
||||
|
||||
@pytest.mark.p3
|
||||
def test_dataset_id_wrong_uuid(self, get_http_api_auth):
|
||||
payload = {"name": "wrong_uuid"}
|
||||
payload = {"name": "wrong uuid"}
|
||||
res = update_dataset(get_http_api_auth, "d94a8dc02c9711f0930f7fbc369eab6d", payload)
|
||||
assert res["code"] == 102, res
|
||||
assert res["code"] == 108, res
|
||||
assert "lacks permission for dataset" in res["message"], res
|
||||
|
||||
@pytest.mark.p1
|
||||
@ -322,8 +330,9 @@ class TestDatasetUpdate:
|
||||
"team",
|
||||
"ME",
|
||||
"TEAM",
|
||||
" ME ",
|
||||
],
|
||||
ids=["me", "team", "me_upercase", "team_upercase"],
|
||||
ids=["me", "team", "me_upercase", "team_upercase", "whitespace"],
|
||||
)
|
||||
def test_permission(self, get_http_api_auth, add_dataset_func, permission):
|
||||
dataset_id = add_dataset_func
|
||||
@ -333,7 +342,7 @@ class TestDatasetUpdate:
|
||||
|
||||
res = list_datasets(get_http_api_auth)
|
||||
assert res["code"] == 0, res
|
||||
assert res["data"][0]["permission"] == permission.lower(), res
|
||||
assert res["data"][0]["permission"] == permission.lower().strip(), res
|
||||
|
||||
@pytest.mark.p2
|
||||
@pytest.mark.parametrize(
|
||||
@ -734,7 +743,6 @@ class TestDatasetUpdate:
|
||||
assert res["code"] == 0, res
|
||||
|
||||
res = list_datasets(get_http_api_auth)
|
||||
print(res)
|
||||
assert res["code"] == 0, res
|
||||
assert res["data"][0]["parser_config"] == {"raptor": {"use_raptor": False}}, res
|
||||
|
||||
@ -757,7 +765,6 @@ class TestDatasetUpdate:
|
||||
assert res["code"] == 0, res
|
||||
|
||||
res = list_datasets(get_http_api_auth, {"id": dataset_id})
|
||||
print(res)
|
||||
assert res["code"] == 0, res
|
||||
assert res["data"][0]["parser_config"] == {"raptor": {"use_raptor": False}}, res
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user