Refa: HTTP API list datasets / test cases / docs (#7720)

### What problem does this PR solve?

This PR introduces Pydantic-based validation for the list datasets HTTP
API, improving code clarity and robustness. Key changes include:

Pydantic Validation
Error Handling
Test Updates
Documentation Updates

### Type of change

- [x] Documentation Update
- [x] Refactoring
This commit is contained in:
liu an 2025-05-20 09:58:26 +08:00 committed by GitHub
parent 6ed81d6774
commit fed1221302
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 683 additions and 412 deletions

View File

@ -32,12 +32,22 @@ from api.utils.api_utils import (
deep_merge,
get_error_argument_result,
get_error_data_result,
get_error_operating_result,
get_error_permission_result,
get_parser_config,
get_result,
remap_dictionary_keys,
token_required,
verify_embedding_availability,
)
from api.utils.validation_utils import CreateDatasetReq, DeleteDatasetReq, UpdateDatasetReq, validate_and_parse_json_request
from api.utils.validation_utils import (
CreateDatasetReq,
DeleteDatasetReq,
ListDatasetReq,
UpdateDatasetReq,
validate_and_parse_json_request,
validate_and_parse_request_args,
)
@manager.route("/datasets", methods=["POST"]) # noqa: F821
@ -113,7 +123,7 @@ def create(tenant_id):
try:
if KnowledgebaseService.get_or_none(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value):
return get_error_data_result(message=f"Dataset name '{req['name']}' already exists")
return get_error_operating_result(message=f"Dataset name '{req['name']}' already exists")
except OperationalError as e:
logging.exception(e)
return get_error_data_result(message="Database operation failed")
@ -126,7 +136,7 @@ def create(tenant_id):
try:
ok, t = TenantService.get_by_id(tenant_id)
if not ok:
return get_error_data_result(message="Tenant not found")
return get_error_permission_result(message="Tenant not found")
except OperationalError as e:
logging.exception(e)
return get_error_data_result(message="Database operation failed")
@ -153,16 +163,7 @@ def create(tenant_id):
logging.exception(e)
return get_error_data_result(message="Database operation failed")
response_data = {}
key_mapping = {
"chunk_num": "chunk_count",
"doc_num": "document_count",
"parser_id": "chunk_method",
"embd_id": "embedding_model",
}
for key, value in k.to_dict().items():
new_key = key_mapping.get(key, key)
response_data[new_key] = value
response_data = remap_dictionary_keys(k.to_dict())
return get_result(data=response_data)
@ -232,7 +233,7 @@ def delete(tenant_id):
logging.exception(e)
return get_error_data_result(message="Database operation failed")
if len(error_kb_ids) > 0:
return get_error_data_result(message=f"""User '{tenant_id}' lacks permission for datasets: '{", ".join(error_kb_ids)}'""")
return get_error_permission_result(message=f"""User '{tenant_id}' lacks permission for datasets: '{", ".join(error_kb_ids)}'""")
errors = []
success_count = 0
@ -347,7 +348,7 @@ def update(tenant_id, dataset_id):
try:
kb = KnowledgebaseService.get_or_none(id=dataset_id, tenant_id=tenant_id)
if kb is None:
return get_error_data_result(message=f"User '{tenant_id}' lacks permission for dataset '{dataset_id}'")
return get_error_permission_result(message=f"User '{tenant_id}' lacks permission for dataset '{dataset_id}'")
except OperationalError as e:
logging.exception(e)
return get_error_data_result(message="Database operation failed")
@ -418,7 +419,7 @@ def list_datasets(tenant_id):
name: page_size
type: integer
required: false
default: 1024
default: 30
description: Number of items per page.
- in: query
name: orderby
@ -445,47 +446,46 @@ def list_datasets(tenant_id):
items:
type: object
"""
id = request.args.get("id")
name = request.args.get("name")
if id:
kbs = KnowledgebaseService.get_kb_by_id(id, tenant_id)
args, err = validate_and_parse_request_args(request, ListDatasetReq)
if err is not None:
return get_error_argument_result(err)
kb_id = request.args.get("id")
name = args.get("name")
if kb_id:
try:
kbs = KnowledgebaseService.get_kb_by_id(kb_id, tenant_id)
except OperationalError as e:
logging.exception(e)
return get_error_data_result(message="Database operation failed")
if not kbs:
return get_error_data_result(f"You don't own the dataset {id}")
return get_error_permission_result(message=f"User '{tenant_id}' lacks permission for dataset '{kb_id}'")
if name:
kbs = KnowledgebaseService.get_kb_by_name(name, tenant_id)
try:
kbs = KnowledgebaseService.get_kb_by_name(name, tenant_id)
except OperationalError as e:
logging.exception(e)
return get_error_data_result(message="Database operation failed")
if not kbs:
return get_error_data_result(f"You don't own the dataset {name}")
page_number = int(request.args.get("page", 1))
items_per_page = int(request.args.get("page_size", 30))
orderby = request.args.get("orderby", "create_time")
if request.args.get("desc", "false").lower() not in ["true", "false"]:
return get_error_data_result("desc should be true or false")
if request.args.get("desc", "true").lower() == "false":
desc = False
else:
desc = True
tenants = TenantService.get_joined_tenants_by_user_id(tenant_id)
kbs = KnowledgebaseService.get_list(
[m["tenant_id"] for m in tenants],
tenant_id,
page_number,
items_per_page,
orderby,
desc,
id,
name,
)
renamed_list = []
return get_error_permission_result(message=f"User '{tenant_id}' lacks permission for dataset '{name}'")
try:
tenants = TenantService.get_joined_tenants_by_user_id(tenant_id)
kbs = KnowledgebaseService.get_list(
[m["tenant_id"] for m in tenants],
tenant_id,
args["page"],
args["page_size"],
args["orderby"],
args["desc"],
kb_id,
name,
)
except OperationalError as e:
logging.exception(e)
return get_error_data_result(message="Database operation failed")
response_data_list = []
for kb in kbs:
key_mapping = {
"chunk_num": "chunk_count",
"doc_num": "document_count",
"parser_id": "chunk_method",
"embd_id": "embedding_model",
}
renamed_data = {}
for key, value in kb.items():
new_key = key_mapping.get(key, key)
renamed_data[new_key] = value
renamed_list.append(renamed_data)
return get_result(data=renamed_list)
response_data_list.append(remap_dictionary_keys(kb))
return get_result(data=response_data_list)

View File

@ -329,6 +329,14 @@ def get_error_argument_result(message="Invalid arguments"):
return get_result(code=settings.RetCode.ARGUMENT_ERROR, message=message)
def get_error_permission_result(message="Permission error"):
return get_result(code=settings.RetCode.PERMISSION_ERROR, message=message)
def get_error_operating_result(message="Operating error"):
return get_result(code=settings.RetCode.OPERATING_ERROR, message=message)
def generate_confirmation_token(tenant_id):
serializer = URLSafeTimedSerializer(tenant_id)
return "ragflow-" + serializer.dumps(get_uuid(), salt=tenant_id)[2:34]
@ -514,3 +522,38 @@ def deep_merge(default: dict, custom: dict) -> dict:
base_dict[key] = val
return merged
def remap_dictionary_keys(source_data: dict, key_aliases: dict = None) -> dict:
"""
Transform dictionary keys using a configurable mapping schema.
Args:
source_data: Original dictionary to process
key_aliases: Custom key transformation rules (Optional)
When provided, overrides default key mapping
Format: {<original_key>: <new_key>, ...}
Returns:
dict: New dictionary with transformed keys preserving original values
Example:
>>> input_data = {"old_key": "value", "another_field": 42}
>>> remap_dictionary_keys(input_data, {"old_key": "new_key"})
{'new_key': 'value', 'another_field': 42}
"""
DEFAULT_KEY_MAP = {
"chunk_num": "chunk_count",
"doc_num": "document_count",
"parser_id": "chunk_method",
"embd_id": "embedding_model",
}
transformed_data = {}
mapping = key_aliases or DEFAULT_KEY_MAP
for original_key, value in source_data.items():
mapped_key = mapping.get(original_key, original_key)
transformed_data[mapped_key] = value
return transformed_data

View File

@ -13,13 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import uuid
from collections import Counter
from enum import auto
from typing import Annotated, Any
from uuid import UUID
from flask import Request
from pydantic import UUID1, BaseModel, Field, StringConstraints, ValidationError, field_serializer, field_validator
from pydantic import BaseModel, Field, StringConstraints, ValidationError, field_validator
from pydantic_core import PydanticCustomError
from strenum import StrEnum
from werkzeug.exceptions import BadRequest, UnsupportedMediaType
@ -102,6 +102,71 @@ def validate_and_parse_json_request(request: Request, validator: type[BaseModel]
return parsed_payload, None
def validate_and_parse_request_args(request: Request, validator: type[BaseModel], *, extras: dict[str, Any] | None = None) -> tuple[dict[str, Any] | None, str | None]:
"""
Validates and parses request arguments against a Pydantic model.
This function performs a complete request validation workflow:
1. Extracts query parameters from the request
2. Merges with optional extra values (if provided)
3. Validates against the specified Pydantic model
4. Cleans the output by removing extra values
5. Returns either parsed data or an error message
Args:
request (Request): Web framework request object containing query parameters
validator (type[BaseModel]): Pydantic model class for validation
extras (dict[str, Any] | None): Optional additional values to include in validation
but exclude from final output. Defaults to None.
Returns:
tuple[dict[str, Any] | None, str | None]:
- First element: Validated/parsed arguments as dict if successful, None otherwise
- Second element: Formatted error message if validation failed, None otherwise
Behavior:
- Query parameters are merged with extras before validation
- Extras are automatically removed from the final output
- All validation errors are formatted into a human-readable string
Raises:
TypeError: If validator is not a Pydantic BaseModel subclass
Examples:
Successful validation:
>>> validate_and_parse_request_args(request, MyValidator)
({'param1': 'value'}, None)
Failed validation:
>>> validate_and_parse_request_args(request, MyValidator)
(None, "param1: Field required")
With extras:
>>> validate_and_parse_request_args(request, MyValidator, extras={'internal_id': 123})
({'param1': 'value'}, None) # internal_id removed from output
Notes:
- Uses request.args.to_dict() for Flask-compatible parameter extraction
- Maintains immutability of original request arguments
- Preserves type conversion from Pydantic validation
"""
args = request.args.to_dict(flat=True)
try:
if extras is not None:
args.update(extras)
validated_args = validator(**args)
except ValidationError as e:
return None, format_validation_error_message(e)
parsed_args = validated_args.model_dump()
if extras is not None:
for key in list(parsed_args.keys()):
if key in extras:
del parsed_args[key]
return parsed_args, None
def format_validation_error_message(e: ValidationError) -> str:
"""
Formats validation errors into a standardized string format.
@ -143,6 +208,105 @@ def format_validation_error_message(e: ValidationError) -> str:
return "\n".join(error_messages)
def normalize_str(v: Any) -> Any:
"""
Normalizes string values to a standard format while preserving non-string inputs.
Performs the following transformations when input is a string:
1. Trims leading/trailing whitespace (str.strip())
2. Converts to lowercase (str.lower())
Non-string inputs are returned unchanged, making this function safe for mixed-type
processing pipelines.
Args:
v (Any): Input value to normalize. Accepts any Python object.
Returns:
Any: Normalized string if input was string-type, original value otherwise.
Behavior Examples:
String Input: " Admin " "admin"
Empty String: " " "" (empty string)
Non-String:
- 123 123
- None None
- ["User"] ["User"]
Typical Use Cases:
- Standardizing user input
- Preparing data for case-insensitive comparison
- Cleaning API parameters
- Normalizing configuration values
Edge Cases:
- Unicode whitespace is handled by str.strip()
- Locale-independent lowercasing (str.lower())
- Preserves falsy values (0, False, etc.)
Example:
>>> normalize_str(" ReadOnly ")
'readonly'
>>> normalize_str(42)
42
"""
if isinstance(v, str):
stripped = v.strip()
normalized = stripped.lower()
return normalized
return v
def validate_uuid1_hex(v: Any) -> str:
"""
Validates and converts input to a UUID version 1 hexadecimal string.
This function performs strict validation and normalization:
1. Accepts either UUID objects or UUID-formatted strings
2. Verifies the UUID is version 1 (time-based)
3. Returns the 32-character hexadecimal representation
Args:
v (Any): Input value to validate. Can be:
- UUID object (must be version 1)
- String in UUID format (e.g. "550e8400-e29b-41d4-a716-446655440000")
Returns:
str: 32-character lowercase hexadecimal string without hyphens
Example: "550e8400e29b41d4a716446655440000"
Raises:
PydanticCustomError: With code "invalid_UUID1_format" when:
- Input is not a UUID object or valid UUID string
- UUID version is not 1
- String doesn't match UUID format
Examples:
Valid cases:
>>> validate_uuid1_hex("550e8400-e29b-41d4-a716-446655440000")
'550e8400e29b41d4a716446655440000'
>>> validate_uuid1_hex(UUID('550e8400-e29b-41d4-a716-446655440000'))
'550e8400e29b41d4a716446655440000'
Invalid cases:
>>> validate_uuid1_hex("not-a-uuid") # raises PydanticCustomError
>>> validate_uuid1_hex(12345) # raises PydanticCustomError
>>> validate_uuid1_hex(UUID(int=0)) # v4, raises PydanticCustomError
Notes:
- Uses Python's built-in UUID parser for format validation
- Version check prevents accidental use of other UUID versions
- Hyphens in input strings are automatically removed in output
"""
try:
uuid_obj = UUID(v) if isinstance(v, str) else v
if uuid_obj.version != 1:
raise PydanticCustomError("invalid_UUID1_format", "Must be a UUID1 format")
return uuid_obj.hex
except (AttributeError, ValueError, TypeError):
raise PydanticCustomError("invalid_UUID1_format", "Invalid UUID1 format")
class PermissionEnum(StrEnum):
me = auto()
team = auto()
@ -217,8 +381,8 @@ class CreateDatasetReq(Base):
avatar: str | None = Field(default=None, max_length=65535)
description: str | None = Field(default=None, max_length=65535)
embedding_model: Annotated[str, StringConstraints(strip_whitespace=True, max_length=255), Field(default="", serialization_alias="embd_id")]
permission: Annotated[PermissionEnum, StringConstraints(strip_whitespace=True, min_length=1, max_length=16), Field(default=PermissionEnum.me)]
chunk_method: Annotated[ChunkMethodnEnum, StringConstraints(strip_whitespace=True, min_length=1, max_length=32), Field(default=ChunkMethodnEnum.naive, serialization_alias="parser_id")]
permission: PermissionEnum = Field(default=PermissionEnum.me, min_length=1, max_length=16)
chunk_method: ChunkMethodnEnum = Field(default=ChunkMethodnEnum.naive, min_length=1, max_length=32, serialization_alias="parser_id")
pagerank: int = Field(default=0, ge=0, le=100)
parser_config: ParserConfig | None = Field(default=None)
@ -315,22 +479,8 @@ class CreateDatasetReq(Base):
@field_validator("permission", mode="before")
@classmethod
def permission_auto_lowercase(cls, v: Any) -> Any:
"""
Normalize permission input to lowercase for consistent PermissionEnum matching.
Args:
v (Any): Raw input value for the permission field
Returns:
Lowercase string if input is string type, otherwise returns original value
Behavior:
- Converts string inputs to lowercase (e.g., "ME" "me")
- Non-string values pass through unchanged
- Works in validation pre-processing stage (before enum conversion)
"""
return v.lower() if isinstance(v, str) else v
def normalize_permission(cls, v: Any) -> Any:
return normalize_str(v)
@field_validator("parser_config", mode="before")
@classmethod
@ -387,93 +537,117 @@ class CreateDatasetReq(Base):
class UpdateDatasetReq(CreateDatasetReq):
dataset_id: UUID1 = Field(...)
dataset_id: str = Field(...)
name: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1, max_length=DATASET_NAME_LIMIT), Field(default="")]
@field_serializer("dataset_id")
def serialize_uuid_to_hex(self, v: uuid.UUID) -> str:
"""
Serializes a UUID version 1 object to its hexadecimal string representation.
This field serializer specifically handles UUID version 1 objects, converting them
to their canonical 32-character hexadecimal format without hyphens. The conversion
is designed for consistent serialization in API responses and database storage.
Args:
v (uuid.UUID1): The UUID version 1 object to serialize. Must be a valid
UUID1 instance generated by Python's uuid module.
Returns:
str: 32-character lowercase hexadecimal string representation
Example: "550e8400e29b41d4a716446655440000"
Raises:
AttributeError: If input is not a proper UUID object (missing hex attribute)
TypeError: If input is not a UUID1 instance (when type checking is enabled)
Notes:
- Version 1 UUIDs contain timestamp and MAC address information
- The .hex property automatically converts to lowercase hexadecimal
- For cross-version compatibility, consider typing as uuid.UUID instead
"""
return v.hex
@field_validator("dataset_id", mode="before")
@classmethod
def validate_dataset_id(cls, v: Any) -> str:
return validate_uuid1_hex(v)
class DeleteReq(Base):
ids: list[UUID1] | None = Field(...)
ids: list[str] | None = Field(...)
@field_validator("ids", mode="after")
def check_duplicate_ids(cls, v: list[UUID1] | None) -> list[str] | None:
@classmethod
def validate_ids(cls, v_list: list[str] | None) -> list[str] | None:
"""
Validates and converts a list of UUID1 objects to hexadecimal strings while checking for duplicates.
Validates and normalizes a list of UUID strings with None handling.
This validator implements a three-stage processing pipeline:
1. Null Handling - returns None for empty/null input
2. UUID Conversion - transforms UUID objects to hex strings
3. Duplicate Validation - ensures all IDs are unique
Behavior Specifications:
- Input: None Returns None (indicates no operation)
- Input: [] Returns [] (empty list for explicit no-op)
- Input: [UUID1,...] Returns validated hex strings
- Duplicates: Raises formatted PydanticCustomError
This post-processing validator performs:
1. None input handling (pass-through)
2. UUID version 1 validation for each list item
3. Duplicate value detection
4. Returns normalized UUID hex strings or None
Args:
v (list[UUID1] | None):
- None: Indicates no datasets should be processed
- Empty list: Explicit empty operation
- Populated list: Dataset UUIDs to validate/convert
v_list (list[str] | None): Input list that has passed initial validation.
Either a list of UUID strings or None.
Returns:
list[str] | None:
- None when input is None
- List of 32-character hex strings (lowercase, no hyphens)
Example: ["550e8400e29b41d4a716446655440000"]
- None if input was None
- List of normalized UUID hex strings otherwise:
* 32-character lowercase
* Valid UUID version 1
* Unique within list
Raises:
PydanticCustomError: When duplicates detected, containing:
- Error type: "duplicate_uuids"
- Template message: "Duplicate ids: '{duplicate_ids}'"
- Context: {"duplicate_ids": "id1, id2, ..."}
PydanticCustomError: With structured error details when:
- "invalid_UUID1_format": Any string fails UUIDv1 validation
- "duplicate_uuids": If duplicate IDs are detected
Example:
>>> validate([UUID("..."), UUID("...")])
["2cdf0456e9a711ee8000000000000000", ...]
Validation Rules:
- None input returns None
- Empty list returns empty list
- All non-None items must be valid UUIDv1
- No duplicates permitted
- Original order preserved
>>> validate([UUID("..."), UUID("...")]) # Duplicates
PydanticCustomError: Duplicate ids: '2cdf0456e9a711ee8000000000000000'
Examples:
Valid cases:
>>> validate_ids(None)
None
>>> validate_ids([])
[]
>>> validate_ids(["550e8400-e29b-41d4-a716-446655440000"])
["550e8400e29b41d4a716446655440000"]
Invalid cases:
>>> validate_ids(["invalid"])
# raises PydanticCustomError(invalid_UUID1_format)
>>> validate_ids(["550e...", "550e..."])
# raises PydanticCustomError(duplicate_uuids)
Security Notes:
- Validates UUID version to prevent version spoofing
- Duplicate check prevents data injection
- None handling maintains pipeline integrity
"""
if not v:
return v
if v_list is None:
return None
uuid_hex_list = [ids.hex for ids in v]
duplicates = [item for item, count in Counter(uuid_hex_list).items() if count > 1]
ids_list = []
for v in v_list:
try:
ids_list.append(validate_uuid1_hex(v))
except PydanticCustomError as e:
raise e
duplicates = [item for item, count in Counter(ids_list).items() if count > 1]
if duplicates:
duplicates_str = ", ".join(duplicates)
raise PydanticCustomError("duplicate_uuids", "Duplicate ids: '{duplicate_ids}'", {"duplicate_ids": duplicates_str})
return uuid_hex_list
return ids_list
class DeleteDatasetReq(DeleteReq): ...
class OrderByEnum(StrEnum):
create_time = auto()
update_time = auto()
class BaseListReq(Base):
id: str | None = None
name: str | None = None
page: int = Field(default=1, ge=1)
page_size: int = Field(default=30, ge=1)
orderby: OrderByEnum = Field(default=OrderByEnum.create_time)
desc: bool = Field(default=True)
@field_validator("id", mode="before")
@classmethod
def validate_id(cls, v: Any) -> str:
return validate_uuid1_hex(v)
@field_validator("orderby", mode="before")
@classmethod
def normalize_orderby(cls, v: Any) -> Any:
return normalize_str(v)
class ListDatasetReq(BaseListReq): ...

View File

@ -122,7 +122,7 @@ class TestDatasetCreate:
assert res["code"] == 0, res
res = create_dataset(get_http_api_auth, payload)
assert res["code"] == 102, res
assert res["code"] == 103, res
assert res["message"] == f"Dataset name '{name}' already exists", res
@pytest.mark.p3
@ -134,7 +134,7 @@ class TestDatasetCreate:
payload = {"name": name.lower()}
res = create_dataset(get_http_api_auth, payload)
assert res["code"] == 102, res
assert res["code"] == 103, res
assert res["message"] == f"Dataset name '{name.lower()}' already exists", res
@pytest.mark.p2
@ -296,14 +296,15 @@ class TestDatasetCreate:
("team", "team"),
("me_upercase", "ME"),
("team_upercase", "TEAM"),
("whitespace", " ME "),
],
ids=["me", "team", "me_upercase", "team_upercase"],
ids=["me", "team", "me_upercase", "team_upercase", "whitespace"],
)
def test_permission(self, get_http_api_auth, name, permission):
payload = {"name": name, "permission": permission}
res = create_dataset(get_http_api_auth, payload)
assert res["code"] == 0, res
assert res["data"]["permission"] == permission.lower(), res
assert res["data"]["permission"] == permission.lower().strip(), res
@pytest.mark.p2
@pytest.mark.parametrize(

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import uuid
from concurrent.futures import ThreadPoolExecutor
import pytest
@ -40,8 +41,8 @@ class TestAuthorization:
)
def test_auth_invalid(self, auth, expected_code, expected_message):
res = delete_datasets(auth)
assert res["code"] == expected_code
assert res["message"] == expected_message
assert res["code"] == expected_code, res
assert res["message"] == expected_message, res
class TestRquest:
@ -140,17 +141,25 @@ class TestDatasetsDelete:
payload = {"ids": ["not_uuid"]}
res = delete_datasets(get_http_api_auth, payload)
assert res["code"] == 101, res
assert "Input should be a valid UUID" in res["message"], res
assert "Invalid UUID1 format" in res["message"], res
res = list_datasets(get_http_api_auth)
assert len(res["data"]) == 1, res
@pytest.mark.p3
@pytest.mark.usefixtures("add_dataset_func")
def test_id_not_uuid1(self, get_http_api_auth):
payload = {"ids": [uuid.uuid4().hex]}
res = delete_datasets(get_http_api_auth, payload)
assert res["code"] == 101, res
assert "Invalid UUID1 format" in res["message"], res
@pytest.mark.p2
@pytest.mark.usefixtures("add_dataset_func")
def test_id_wrong_uuid(self, get_http_api_auth):
payload = {"ids": ["d94a8dc02c9711f0930f7fbc369eab6d"]}
res = delete_datasets(get_http_api_auth, payload)
assert res["code"] == 102, res
assert res["code"] == 108, res
assert "lacks permission for dataset" in res["message"], res
res = list_datasets(get_http_api_auth)
@ -170,7 +179,7 @@ class TestDatasetsDelete:
if callable(func):
payload = func(dataset_ids)
res = delete_datasets(get_http_api_auth, payload)
assert res["code"] == 102, res
assert res["code"] == 108, res
assert "lacks permission for dataset" in res["message"], res
res = list_datasets(get_http_api_auth)
@ -195,7 +204,7 @@ class TestDatasetsDelete:
assert res["code"] == 0, res
res = delete_datasets(get_http_api_auth, payload)
assert res["code"] == 102, res
assert res["code"] == 108, res
assert "lacks permission for dataset" in res["message"], res
@pytest.mark.p2

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import uuid
from concurrent.futures import ThreadPoolExecutor
import pytest
@ -21,8 +22,8 @@ from libs.auth import RAGFlowHttpApiAuth
from libs.utils import is_sorted
@pytest.mark.p1
class TestAuthorization:
@pytest.mark.p1
@pytest.mark.parametrize(
"auth, expected_code, expected_message",
[
@ -34,269 +35,305 @@ class TestAuthorization:
),
],
)
def test_invalid_auth(self, auth, expected_code, expected_message):
def test_auth_invalid(self, auth, expected_code, expected_message):
res = list_datasets(auth)
assert res["code"] == expected_code
assert res["message"] == expected_message
assert res["code"] == expected_code, res
assert res["message"] == expected_message, res
@pytest.mark.usefixtures("add_datasets")
class TestDatasetsList:
@pytest.mark.p1
def test_default(self, get_http_api_auth):
res = list_datasets(get_http_api_auth, params={})
assert res["code"] == 0
assert len(res["data"]) == 5
@pytest.mark.p1
@pytest.mark.parametrize(
"params, expected_code, expected_page_size, expected_message",
[
({"page": None, "page_size": 2}, 0, 2, ""),
({"page": 0, "page_size": 2}, 0, 2, ""),
({"page": 2, "page_size": 2}, 0, 2, ""),
({"page": 3, "page_size": 2}, 0, 1, ""),
({"page": "3", "page_size": 2}, 0, 1, ""),
pytest.param(
{"page": -1, "page_size": 2},
100,
0,
"1064",
marks=pytest.mark.skip(reason="issues/5851"),
),
pytest.param(
{"page": "a", "page_size": 2},
100,
0,
"""ValueError("invalid literal for int() with base 10: \'a\'")""",
marks=pytest.mark.skip(reason="issues/5851"),
),
],
)
def test_page(self, get_http_api_auth, params, expected_code, expected_page_size, expected_message):
res = list_datasets(get_http_api_auth, params=params)
assert res["code"] == expected_code
if expected_code == 0:
assert len(res["data"]) == expected_page_size
else:
assert res["message"] == expected_message
@pytest.mark.p1
@pytest.mark.parametrize(
"params, expected_code, expected_page_size, expected_message",
[
({"page_size": None}, 0, 5, ""),
({"page_size": 0}, 0, 0, ""),
({"page_size": 1}, 0, 1, ""),
({"page_size": 6}, 0, 5, ""),
({"page_size": "1"}, 0, 1, ""),
pytest.param(
{"page_size": -1},
100,
0,
"1064",
marks=pytest.mark.skip(reason="issues/5851"),
),
pytest.param(
{"page_size": "a"},
100,
0,
"""ValueError("invalid literal for int() with base 10: \'a\'")""",
marks=pytest.mark.skip(reason="issues/5851"),
),
],
)
def test_page_size(
self,
get_http_api_auth,
params,
expected_code,
expected_page_size,
expected_message,
):
res = list_datasets(get_http_api_auth, params=params)
assert res["code"] == expected_code
if expected_code == 0:
assert len(res["data"]) == expected_page_size
else:
assert res["message"] == expected_message
@pytest.mark.p3
@pytest.mark.parametrize(
"params, expected_code, assertions, expected_message",
[
({"orderby": None}, 0, lambda r: (is_sorted(r["data"], "create_time", True)), ""),
({"orderby": "create_time"}, 0, lambda r: (is_sorted(r["data"], "create_time", True)), ""),
({"orderby": "update_time"}, 0, lambda r: (is_sorted(r["data"], "update_time", True)), ""),
pytest.param(
{"orderby": "name", "desc": "False"},
0,
lambda r: (is_sorted(r["data"]["docs"], "name", False)),
"",
marks=pytest.mark.skip(reason="issues/5851"),
),
pytest.param(
{"orderby": "unknown"},
102,
0,
"orderby should be create_time or update_time",
marks=pytest.mark.skip(reason="issues/5851"),
),
],
)
def test_orderby(
self,
get_http_api_auth,
params,
expected_code,
assertions,
expected_message,
):
res = list_datasets(get_http_api_auth, params=params)
assert res["code"] == expected_code
if expected_code == 0:
if callable(assertions):
assert assertions(res)
else:
assert res["message"] == expected_message
@pytest.mark.p3
@pytest.mark.parametrize(
"params, expected_code, assertions, expected_message",
[
({"desc": None}, 0, lambda r: (is_sorted(r["data"], "create_time", True)), ""),
({"desc": "true"}, 0, lambda r: (is_sorted(r["data"], "create_time", True)), ""),
({"desc": "True"}, 0, lambda r: (is_sorted(r["data"], "create_time", True)), ""),
({"desc": True}, 0, lambda r: (is_sorted(r["data"], "create_time", True)), ""),
({"desc": "false"}, 0, lambda r: (is_sorted(r["data"], "create_time", False)), ""),
({"desc": "False"}, 0, lambda r: (is_sorted(r["data"], "create_time", False)), ""),
({"desc": False}, 0, lambda r: (is_sorted(r["data"], "create_time", False)), ""),
({"desc": "False", "orderby": "update_time"}, 0, lambda r: (is_sorted(r["data"], "update_time", False)), ""),
pytest.param(
{"desc": "unknown"},
102,
0,
"desc should be true or false",
marks=pytest.mark.skip(reason="issues/5851"),
),
],
)
def test_desc(
self,
get_http_api_auth,
params,
expected_code,
assertions,
expected_message,
):
res = list_datasets(get_http_api_auth, params=params)
assert res["code"] == expected_code
if expected_code == 0:
if callable(assertions):
assert assertions(res)
else:
assert res["message"] == expected_message
@pytest.mark.p1
@pytest.mark.parametrize(
"params, expected_code, expected_num, expected_message",
[
({"name": None}, 0, 5, ""),
({"name": ""}, 0, 5, ""),
({"name": "dataset_1"}, 0, 1, ""),
({"name": "unknown"}, 102, 0, "You don't own the dataset unknown"),
],
)
def test_name(self, get_http_api_auth, params, expected_code, expected_num, expected_message):
res = list_datasets(get_http_api_auth, params=params)
assert res["code"] == expected_code
if expected_code == 0:
if params["name"] in [None, ""]:
assert len(res["data"]) == expected_num
else:
assert res["data"][0]["name"] == params["name"]
else:
assert res["message"] == expected_message
@pytest.mark.p1
@pytest.mark.parametrize(
"dataset_id, expected_code, expected_num, expected_message",
[
(None, 0, 5, ""),
("", 0, 5, ""),
(lambda r: r[0], 0, 1, ""),
("unknown", 102, 0, "You don't own the dataset unknown"),
],
)
def test_id(
self,
get_http_api_auth,
add_datasets,
dataset_id,
expected_code,
expected_num,
expected_message,
):
dataset_ids = add_datasets
if callable(dataset_id):
params = {"id": dataset_id(dataset_ids)}
else:
params = {"id": dataset_id}
res = list_datasets(get_http_api_auth, params=params)
assert res["code"] == expected_code
if expected_code == 0:
if params["id"] in [None, ""]:
assert len(res["data"]) == expected_num
else:
assert res["data"][0]["id"] == params["id"]
else:
assert res["message"] == expected_message
@pytest.mark.p3
@pytest.mark.parametrize(
"dataset_id, name, expected_code, expected_num, expected_message",
[
(lambda r: r[0], "dataset_0", 0, 1, ""),
(lambda r: r[0], "dataset_1", 0, 0, ""),
(lambda r: r[0], "unknown", 102, 0, "You don't own the dataset unknown"),
("id", "dataset_0", 102, 0, "You don't own the dataset id"),
],
)
def test_name_and_id(
self,
get_http_api_auth,
add_datasets,
dataset_id,
name,
expected_code,
expected_num,
expected_message,
):
dataset_ids = add_datasets
if callable(dataset_id):
params = {"id": dataset_id(dataset_ids), "name": name}
else:
params = {"id": dataset_id, "name": name}
res = list_datasets(get_http_api_auth, params=params)
assert res["code"] == expected_code
if expected_code == 0:
assert len(res["data"]) == expected_num
else:
assert res["message"] == expected_message
class TestCapability:
@pytest.mark.p3
def test_concurrent_list(self, get_http_api_auth):
with ThreadPoolExecutor(max_workers=5) as executor:
futures = [executor.submit(list_datasets, get_http_api_auth) for i in range(100)]
responses = [f.result() for f in futures]
assert all(r["code"] == 0 for r in responses)
assert all(r["code"] == 0 for r in responses), responses
@pytest.mark.usefixtures("add_datasets")
class TestDatasetsList:
@pytest.mark.p1
def test_params_unset(self, get_http_api_auth):
res = list_datasets(get_http_api_auth, None)
assert res["code"] == 0, res
assert len(res["data"]) == 5, res
@pytest.mark.p2
def test_params_empty(self, get_http_api_auth):
res = list_datasets(get_http_api_auth, {})
assert res["code"] == 0, res
assert len(res["data"]) == 5, res
@pytest.mark.p1
@pytest.mark.parametrize(
"params, expected_page_size",
[
({"page": 2, "page_size": 2}, 2),
({"page": 3, "page_size": 2}, 1),
({"page": 4, "page_size": 2}, 0),
({"page": "2", "page_size": 2}, 2),
({"page": 1, "page_size": 10}, 5),
],
ids=["normal_middle_page", "normal_last_partial_page", "beyond_max_page", "string_page_number", "full_data_single_page"],
)
def test_page(self, get_http_api_auth, params, expected_page_size):
res = list_datasets(get_http_api_auth, params)
assert res["code"] == 0, res
assert len(res["data"]) == expected_page_size, res
@pytest.mark.p2
@pytest.mark.parametrize(
"params, expected_code, expected_message",
[
({"page": 0}, 101, "Input should be greater than or equal to 1"),
({"page": "a"}, 101, "Input should be a valid integer, unable to parse string as an integer"),
],
ids=["page_0", "page_a"],
)
def test_page_invalid(self, get_http_api_auth, params, expected_code, expected_message):
res = list_datasets(get_http_api_auth, params=params)
assert res["code"] == expected_code, res
assert expected_message in res["message"], res
@pytest.mark.p2
def test_page_none(self, get_http_api_auth):
params = {"page": None}
res = list_datasets(get_http_api_auth, params)
assert res["code"] == 0, res
assert len(res["data"]) == 5, res
@pytest.mark.p1
@pytest.mark.parametrize(
"params, expected_page_size",
[
({"page_size": 1}, 1),
({"page_size": 3}, 3),
({"page_size": 5}, 5),
({"page_size": 6}, 5),
({"page_size": "1"}, 1),
],
ids=["min_valid_page_size", "medium_page_size", "page_size_equals_total", "page_size_exceeds_total", "string_type_page_size"],
)
def test_page_size(self, get_http_api_auth, params, expected_page_size):
res = list_datasets(get_http_api_auth, params)
assert res["code"] == 0, res
assert len(res["data"]) == expected_page_size, res
@pytest.mark.p2
@pytest.mark.parametrize(
"params, expected_code, expected_message",
[
({"page_size": 0}, 101, "Input should be greater than or equal to 1"),
({"page_size": "a"}, 101, "Input should be a valid integer, unable to parse string as an integer"),
],
)
def test_page_size_invalid(self, get_http_api_auth, params, expected_code, expected_message):
res = list_datasets(get_http_api_auth, params)
assert res["code"] == expected_code, res
assert expected_message in res["message"], res
@pytest.mark.p2
def test_page_size_none(self, get_http_api_auth):
params = {"page_size": None}
res = list_datasets(get_http_api_auth, params)
assert res["code"] == 0, res
assert len(res["data"]) == 5, res
@pytest.mark.p2
@pytest.mark.parametrize(
"params, assertions",
[
({"orderby": "create_time"}, lambda r: (is_sorted(r["data"], "create_time", True))),
({"orderby": "update_time"}, lambda r: (is_sorted(r["data"], "update_time", True))),
({"orderby": "CREATE_TIME"}, lambda r: (is_sorted(r["data"], "create_time", True))),
({"orderby": "UPDATE_TIME"}, lambda r: (is_sorted(r["data"], "update_time", True))),
({"orderby": " create_time "}, lambda r: (is_sorted(r["data"], "update_time", True))),
],
ids=["orderby_create_time", "orderby_update_time", "orderby_create_time_upper", "orderby_update_time_upper", "whitespace"],
)
def test_orderby(self, get_http_api_auth, params, assertions):
res = list_datasets(get_http_api_auth, params)
assert res["code"] == 0, res
if callable(assertions):
assert assertions(res), res
@pytest.mark.p3
def test_invalid_params(self, get_http_api_auth):
params = {"a": "b"}
res = list_datasets(get_http_api_auth, params=params)
@pytest.mark.parametrize(
"params",
[
{"orderby": ""},
{"orderby": "unknown"},
],
ids=["empty", "unknown"],
)
def test_orderby_invalid(self, get_http_api_auth, params):
res = list_datasets(get_http_api_auth, params)
assert res["code"] == 101, res
assert "Input should be 'create_time' or 'update_time'" in res["message"], res
@pytest.mark.p3
def test_orderby_none(self, get_http_api_auth):
params = {"order_by": None}
res = list_datasets(get_http_api_auth, params)
assert res["code"] == 0, res
assert is_sorted(res["data"], "create_time", True), res
@pytest.mark.p2
@pytest.mark.parametrize(
"params, assertions",
[
({"desc": True}, lambda r: (is_sorted(r["data"], "create_time", True))),
({"desc": False}, lambda r: (is_sorted(r["data"], "create_time", False))),
({"desc": "true"}, lambda r: (is_sorted(r["data"], "create_time", True))),
({"desc": "false"}, lambda r: (is_sorted(r["data"], "create_time", False))),
({"desc": 1}, lambda r: (is_sorted(r["data"], "create_time", True))),
({"desc": 0}, lambda r: (is_sorted(r["data"], "create_time", False))),
({"desc": "yes"}, lambda r: (is_sorted(r["data"], "create_time", True))),
({"desc": "no"}, lambda r: (is_sorted(r["data"], "create_time", False))),
({"desc": "y"}, lambda r: (is_sorted(r["data"], "create_time", True))),
({"desc": "n"}, lambda r: (is_sorted(r["data"], "create_time", False))),
],
ids=["desc=True", "desc=False", "desc=true", "desc=false", "desc=1", "desc=0", "desc=yes", "desc=no", "desc=y", "desc=n"],
)
def test_desc(self, get_http_api_auth, params, assertions):
res = list_datasets(get_http_api_auth, params)
assert res["code"] == 0, res
if callable(assertions):
assert assertions(res), res
@pytest.mark.p3
@pytest.mark.parametrize(
"params",
[
{"desc": 3.14},
{"desc": "unknown"},
],
ids=["empty", "unknown"],
)
def test_desc_invalid(self, get_http_api_auth, params):
res = list_datasets(get_http_api_auth, params)
assert res["code"] == 101, res
assert "Input should be a valid boolean, unable to interpret input" in res["message"], res
@pytest.mark.p3
def test_desc_none(self, get_http_api_auth):
params = {"desc": None}
res = list_datasets(get_http_api_auth, params)
assert res["code"] == 0, res
assert is_sorted(res["data"], "create_time", True), res
@pytest.mark.p1
def test_name(self, get_http_api_auth):
params = {"name": "dataset_1"}
res = list_datasets(get_http_api_auth, params)
assert res["code"] == 0, res
assert len(res["data"]) == 1, res
assert res["data"][0]["name"] == "dataset_1", res
@pytest.mark.p2
def test_name_wrong(self, get_http_api_auth):
params = {"name": "wrong name"}
res = list_datasets(get_http_api_auth, params)
assert res["code"] == 108, res
assert "lacks permission for dataset" in res["message"], res
@pytest.mark.p2
def test_name_empty(self, get_http_api_auth):
params = {"name": ""}
res = list_datasets(get_http_api_auth, params)
assert res["code"] == 0, res
assert len(res["data"]) == 5, res
@pytest.mark.p2
def test_name_none(self, get_http_api_auth):
params = {"name": None}
res = list_datasets(get_http_api_auth, params)
assert res["code"] == 0, res
assert len(res["data"]) == 5, res
@pytest.mark.p1
def test_id(self, get_http_api_auth, add_datasets):
dataset_ids = add_datasets
params = {"id": dataset_ids[0]}
res = list_datasets(get_http_api_auth, params)
assert res["code"] == 0
assert len(res["data"]) == 5
assert len(res["data"]) == 1
assert res["data"][0]["id"] == dataset_ids[0]
@pytest.mark.p2
def test_id_not_uuid(self, get_http_api_auth):
params = {"id": "not_uuid"}
res = list_datasets(get_http_api_auth, params)
assert res["code"] == 101, res
assert "Invalid UUID1 format" in res["message"], res
@pytest.mark.p2
def test_id_not_uuid1(self, get_http_api_auth):
params = {"id": uuid.uuid4().hex}
res = list_datasets(get_http_api_auth, params)
assert res["code"] == 101, res
assert "Invalid UUID1 format" in res["message"], res
@pytest.mark.p2
def test_id_wrong_uuid(self, get_http_api_auth):
params = {"id": "d94a8dc02c9711f0930f7fbc369eab6d"}
res = list_datasets(get_http_api_auth, params)
assert res["code"] == 108, res
assert "lacks permission for dataset" in res["message"], res
@pytest.mark.p2
def test_id_empty(self, get_http_api_auth):
params = {"id": ""}
res = list_datasets(get_http_api_auth, params)
assert res["code"] == 101, res
assert "Invalid UUID1 format" in res["message"], res
@pytest.mark.p2
def test_id_none(self, get_http_api_auth):
params = {"id": None}
res = list_datasets(get_http_api_auth, params)
assert res["code"] == 0, res
assert len(res["data"]) == 5, res
@pytest.mark.p2
@pytest.mark.parametrize(
"func, name, expected_num",
[
(lambda r: r[0], "dataset_0", 1),
(lambda r: r[0], "dataset_1", 0),
],
ids=["name_and_id_match", "name_and_id_mismatch"],
)
def test_name_and_id(self, get_http_api_auth, add_datasets, func, name, expected_num):
dataset_ids = add_datasets
if callable(func):
params = {"id": func(dataset_ids), "name": name}
res = list_datasets(get_http_api_auth, params)
assert res["code"] == 0, res
assert len(res["data"]) == expected_num, res
@pytest.mark.p3
@pytest.mark.parametrize(
"dataset_id, name",
[
(lambda r: r[0], "wrong_name"),
(uuid.uuid1().hex, "dataset_0"),
],
ids=["name", "id"],
)
def test_name_and_id_wrong(self, get_http_api_auth, add_datasets, dataset_id, name):
dataset_ids = add_datasets
if callable(dataset_id):
params = {"id": dataset_id(dataset_ids), "name": name}
else:
params = {"id": dataset_id, "name": name}
res = list_datasets(get_http_api_auth, params)
assert res["code"] == 108, res
assert "lacks permission for dataset" in res["message"], res
@pytest.mark.p2
def test_field_unsupported(self, get_http_api_auth):
params = {"unknown_field": "unknown_field"}
res = list_datasets(get_http_api_auth, params)
assert res["code"] == 101, res
assert "Extra inputs are not permitted" in res["message"], res

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import uuid
from concurrent.futures import ThreadPoolExecutor
import pytest
@ -98,16 +99,23 @@ class TestCapability:
class TestDatasetUpdate:
@pytest.mark.p3
def test_dataset_id_not_uuid(self, get_http_api_auth):
payload = {"name": "not_uuid"}
payload = {"name": "not uuid"}
res = update_dataset(get_http_api_auth, "not_uuid", payload)
assert res["code"] == 101, res
assert "Input should be a valid UUID" in res["message"], res
assert "Invalid UUID1 format" in res["message"], res
@pytest.mark.p3
def test_dataset_id_not_uuid1(self, get_http_api_auth):
payload = {"name": "not uuid1"}
res = update_dataset(get_http_api_auth, uuid.uuid4().hex, payload)
assert res["code"] == 101, res
assert "Invalid UUID1 format" in res["message"], res
@pytest.mark.p3
def test_dataset_id_wrong_uuid(self, get_http_api_auth):
payload = {"name": "wrong_uuid"}
payload = {"name": "wrong uuid"}
res = update_dataset(get_http_api_auth, "d94a8dc02c9711f0930f7fbc369eab6d", payload)
assert res["code"] == 102, res
assert res["code"] == 108, res
assert "lacks permission for dataset" in res["message"], res
@pytest.mark.p1
@ -322,8 +330,9 @@ class TestDatasetUpdate:
"team",
"ME",
"TEAM",
" ME ",
],
ids=["me", "team", "me_upercase", "team_upercase"],
ids=["me", "team", "me_upercase", "team_upercase", "whitespace"],
)
def test_permission(self, get_http_api_auth, add_dataset_func, permission):
dataset_id = add_dataset_func
@ -333,7 +342,7 @@ class TestDatasetUpdate:
res = list_datasets(get_http_api_auth)
assert res["code"] == 0, res
assert res["data"][0]["permission"] == permission.lower(), res
assert res["data"][0]["permission"] == permission.lower().strip(), res
@pytest.mark.p2
@pytest.mark.parametrize(
@ -734,7 +743,6 @@ class TestDatasetUpdate:
assert res["code"] == 0, res
res = list_datasets(get_http_api_auth)
print(res)
assert res["code"] == 0, res
assert res["data"][0]["parser_config"] == {"raptor": {"use_raptor": False}}, res
@ -757,7 +765,6 @@ class TestDatasetUpdate:
assert res["code"] == 0, res
res = list_datasets(get_http_api_auth, {"id": dataset_id})
print(res)
assert res["code"] == 0, res
assert res["data"][0]["parser_config"] == {"raptor": {"use_raptor": False}}, res