diff --git a/api/apps/sdk/dataset.py b/api/apps/sdk/dataset.py index c96a8975d..f76cf2f9d 100644 --- a/api/apps/sdk/dataset.py +++ b/api/apps/sdk/dataset.py @@ -32,12 +32,22 @@ from api.utils.api_utils import ( deep_merge, get_error_argument_result, get_error_data_result, + get_error_operating_result, + get_error_permission_result, get_parser_config, get_result, + remap_dictionary_keys, token_required, verify_embedding_availability, ) -from api.utils.validation_utils import CreateDatasetReq, DeleteDatasetReq, UpdateDatasetReq, validate_and_parse_json_request +from api.utils.validation_utils import ( + CreateDatasetReq, + DeleteDatasetReq, + ListDatasetReq, + UpdateDatasetReq, + validate_and_parse_json_request, + validate_and_parse_request_args, +) @manager.route("/datasets", methods=["POST"]) # noqa: F821 @@ -113,7 +123,7 @@ def create(tenant_id): try: if KnowledgebaseService.get_or_none(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value): - return get_error_data_result(message=f"Dataset name '{req['name']}' already exists") + return get_error_operating_result(message=f"Dataset name '{req['name']}' already exists") except OperationalError as e: logging.exception(e) return get_error_data_result(message="Database operation failed") @@ -126,7 +136,7 @@ def create(tenant_id): try: ok, t = TenantService.get_by_id(tenant_id) if not ok: - return get_error_data_result(message="Tenant not found") + return get_error_permission_result(message="Tenant not found") except OperationalError as e: logging.exception(e) return get_error_data_result(message="Database operation failed") @@ -153,16 +163,7 @@ def create(tenant_id): logging.exception(e) return get_error_data_result(message="Database operation failed") - response_data = {} - key_mapping = { - "chunk_num": "chunk_count", - "doc_num": "document_count", - "parser_id": "chunk_method", - "embd_id": "embedding_model", - } - for key, value in k.to_dict().items(): - new_key = key_mapping.get(key, key) - response_data[new_key] = value + response_data = remap_dictionary_keys(k.to_dict()) return get_result(data=response_data) @@ -232,7 +233,7 @@ def delete(tenant_id): logging.exception(e) return get_error_data_result(message="Database operation failed") if len(error_kb_ids) > 0: - return get_error_data_result(message=f"""User '{tenant_id}' lacks permission for datasets: '{", ".join(error_kb_ids)}'""") + return get_error_permission_result(message=f"""User '{tenant_id}' lacks permission for datasets: '{", ".join(error_kb_ids)}'""") errors = [] success_count = 0 @@ -347,7 +348,7 @@ def update(tenant_id, dataset_id): try: kb = KnowledgebaseService.get_or_none(id=dataset_id, tenant_id=tenant_id) if kb is None: - return get_error_data_result(message=f"User '{tenant_id}' lacks permission for dataset '{dataset_id}'") + return get_error_permission_result(message=f"User '{tenant_id}' lacks permission for dataset '{dataset_id}'") except OperationalError as e: logging.exception(e) return get_error_data_result(message="Database operation failed") @@ -418,7 +419,7 @@ def list_datasets(tenant_id): name: page_size type: integer required: false - default: 1024 + default: 30 description: Number of items per page. - in: query name: orderby @@ -445,47 +446,46 @@ def list_datasets(tenant_id): items: type: object """ - id = request.args.get("id") - name = request.args.get("name") - if id: - kbs = KnowledgebaseService.get_kb_by_id(id, tenant_id) + args, err = validate_and_parse_request_args(request, ListDatasetReq) + if err is not None: + return get_error_argument_result(err) + + kb_id = request.args.get("id") + name = args.get("name") + if kb_id: + try: + kbs = KnowledgebaseService.get_kb_by_id(kb_id, tenant_id) + except OperationalError as e: + logging.exception(e) + return get_error_data_result(message="Database operation failed") if not kbs: - return get_error_data_result(f"You don't own the dataset {id}") + return get_error_permission_result(message=f"User '{tenant_id}' lacks permission for dataset '{kb_id}'") if name: - kbs = KnowledgebaseService.get_kb_by_name(name, tenant_id) + try: + kbs = KnowledgebaseService.get_kb_by_name(name, tenant_id) + except OperationalError as e: + logging.exception(e) + return get_error_data_result(message="Database operation failed") if not kbs: - return get_error_data_result(f"You don't own the dataset {name}") - page_number = int(request.args.get("page", 1)) - items_per_page = int(request.args.get("page_size", 30)) - orderby = request.args.get("orderby", "create_time") - if request.args.get("desc", "false").lower() not in ["true", "false"]: - return get_error_data_result("desc should be true or false") - if request.args.get("desc", "true").lower() == "false": - desc = False - else: - desc = True - tenants = TenantService.get_joined_tenants_by_user_id(tenant_id) - kbs = KnowledgebaseService.get_list( - [m["tenant_id"] for m in tenants], - tenant_id, - page_number, - items_per_page, - orderby, - desc, - id, - name, - ) - renamed_list = [] + return get_error_permission_result(message=f"User '{tenant_id}' lacks permission for dataset '{name}'") + + try: + tenants = TenantService.get_joined_tenants_by_user_id(tenant_id) + kbs = KnowledgebaseService.get_list( + [m["tenant_id"] for m in tenants], + tenant_id, + args["page"], + args["page_size"], + args["orderby"], + args["desc"], + kb_id, + name, + ) + except OperationalError as e: + logging.exception(e) + return get_error_data_result(message="Database operation failed") + + response_data_list = [] for kb in kbs: - key_mapping = { - "chunk_num": "chunk_count", - "doc_num": "document_count", - "parser_id": "chunk_method", - "embd_id": "embedding_model", - } - renamed_data = {} - for key, value in kb.items(): - new_key = key_mapping.get(key, key) - renamed_data[new_key] = value - renamed_list.append(renamed_data) - return get_result(data=renamed_list) + response_data_list.append(remap_dictionary_keys(kb)) + return get_result(data=response_data_list) diff --git a/api/utils/api_utils.py b/api/utils/api_utils.py index 70e1282bd..c0f2c1957 100644 --- a/api/utils/api_utils.py +++ b/api/utils/api_utils.py @@ -329,6 +329,14 @@ def get_error_argument_result(message="Invalid arguments"): return get_result(code=settings.RetCode.ARGUMENT_ERROR, message=message) +def get_error_permission_result(message="Permission error"): + return get_result(code=settings.RetCode.PERMISSION_ERROR, message=message) + + +def get_error_operating_result(message="Operating error"): + return get_result(code=settings.RetCode.OPERATING_ERROR, message=message) + + def generate_confirmation_token(tenant_id): serializer = URLSafeTimedSerializer(tenant_id) return "ragflow-" + serializer.dumps(get_uuid(), salt=tenant_id)[2:34] @@ -514,3 +522,38 @@ def deep_merge(default: dict, custom: dict) -> dict: base_dict[key] = val return merged + + +def remap_dictionary_keys(source_data: dict, key_aliases: dict = None) -> dict: + """ + Transform dictionary keys using a configurable mapping schema. + + Args: + source_data: Original dictionary to process + key_aliases: Custom key transformation rules (Optional) + When provided, overrides default key mapping + Format: {: , ...} + + Returns: + dict: New dictionary with transformed keys preserving original values + + Example: + >>> input_data = {"old_key": "value", "another_field": 42} + >>> remap_dictionary_keys(input_data, {"old_key": "new_key"}) + {'new_key': 'value', 'another_field': 42} + """ + DEFAULT_KEY_MAP = { + "chunk_num": "chunk_count", + "doc_num": "document_count", + "parser_id": "chunk_method", + "embd_id": "embedding_model", + } + + transformed_data = {} + mapping = key_aliases or DEFAULT_KEY_MAP + + for original_key, value in source_data.items(): + mapped_key = mapping.get(original_key, original_key) + transformed_data[mapped_key] = value + + return transformed_data diff --git a/api/utils/validation_utils.py b/api/utils/validation_utils.py index eb5c44b9a..206a91f12 100644 --- a/api/utils/validation_utils.py +++ b/api/utils/validation_utils.py @@ -13,13 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import uuid from collections import Counter from enum import auto from typing import Annotated, Any +from uuid import UUID from flask import Request -from pydantic import UUID1, BaseModel, Field, StringConstraints, ValidationError, field_serializer, field_validator +from pydantic import BaseModel, Field, StringConstraints, ValidationError, field_validator from pydantic_core import PydanticCustomError from strenum import StrEnum from werkzeug.exceptions import BadRequest, UnsupportedMediaType @@ -102,6 +102,71 @@ def validate_and_parse_json_request(request: Request, validator: type[BaseModel] return parsed_payload, None +def validate_and_parse_request_args(request: Request, validator: type[BaseModel], *, extras: dict[str, Any] | None = None) -> tuple[dict[str, Any] | None, str | None]: + """ + Validates and parses request arguments against a Pydantic model. + + This function performs a complete request validation workflow: + 1. Extracts query parameters from the request + 2. Merges with optional extra values (if provided) + 3. Validates against the specified Pydantic model + 4. Cleans the output by removing extra values + 5. Returns either parsed data or an error message + + Args: + request (Request): Web framework request object containing query parameters + validator (type[BaseModel]): Pydantic model class for validation + extras (dict[str, Any] | None): Optional additional values to include in validation + but exclude from final output. Defaults to None. + + Returns: + tuple[dict[str, Any] | None, str | None]: + - First element: Validated/parsed arguments as dict if successful, None otherwise + - Second element: Formatted error message if validation failed, None otherwise + + Behavior: + - Query parameters are merged with extras before validation + - Extras are automatically removed from the final output + - All validation errors are formatted into a human-readable string + + Raises: + TypeError: If validator is not a Pydantic BaseModel subclass + + Examples: + Successful validation: + >>> validate_and_parse_request_args(request, MyValidator) + ({'param1': 'value'}, None) + + Failed validation: + >>> validate_and_parse_request_args(request, MyValidator) + (None, "param1: Field required") + + With extras: + >>> validate_and_parse_request_args(request, MyValidator, extras={'internal_id': 123}) + ({'param1': 'value'}, None) # internal_id removed from output + + Notes: + - Uses request.args.to_dict() for Flask-compatible parameter extraction + - Maintains immutability of original request arguments + - Preserves type conversion from Pydantic validation + """ + args = request.args.to_dict(flat=True) + try: + if extras is not None: + args.update(extras) + validated_args = validator(**args) + except ValidationError as e: + return None, format_validation_error_message(e) + + parsed_args = validated_args.model_dump() + if extras is not None: + for key in list(parsed_args.keys()): + if key in extras: + del parsed_args[key] + + return parsed_args, None + + def format_validation_error_message(e: ValidationError) -> str: """ Formats validation errors into a standardized string format. @@ -143,6 +208,105 @@ def format_validation_error_message(e: ValidationError) -> str: return "\n".join(error_messages) +def normalize_str(v: Any) -> Any: + """ + Normalizes string values to a standard format while preserving non-string inputs. + + Performs the following transformations when input is a string: + 1. Trims leading/trailing whitespace (str.strip()) + 2. Converts to lowercase (str.lower()) + + Non-string inputs are returned unchanged, making this function safe for mixed-type + processing pipelines. + + Args: + v (Any): Input value to normalize. Accepts any Python object. + + Returns: + Any: Normalized string if input was string-type, original value otherwise. + + Behavior Examples: + String Input: " Admin " → "admin" + Empty String: " " → "" (empty string) + Non-String: + - 123 → 123 + - None → None + - ["User"] → ["User"] + + Typical Use Cases: + - Standardizing user input + - Preparing data for case-insensitive comparison + - Cleaning API parameters + - Normalizing configuration values + + Edge Cases: + - Unicode whitespace is handled by str.strip() + - Locale-independent lowercasing (str.lower()) + - Preserves falsy values (0, False, etc.) + + Example: + >>> normalize_str(" ReadOnly ") + 'readonly' + >>> normalize_str(42) + 42 + """ + if isinstance(v, str): + stripped = v.strip() + normalized = stripped.lower() + return normalized + return v + + +def validate_uuid1_hex(v: Any) -> str: + """ + Validates and converts input to a UUID version 1 hexadecimal string. + + This function performs strict validation and normalization: + 1. Accepts either UUID objects or UUID-formatted strings + 2. Verifies the UUID is version 1 (time-based) + 3. Returns the 32-character hexadecimal representation + + Args: + v (Any): Input value to validate. Can be: + - UUID object (must be version 1) + - String in UUID format (e.g. "550e8400-e29b-41d4-a716-446655440000") + + Returns: + str: 32-character lowercase hexadecimal string without hyphens + Example: "550e8400e29b41d4a716446655440000" + + Raises: + PydanticCustomError: With code "invalid_UUID1_format" when: + - Input is not a UUID object or valid UUID string + - UUID version is not 1 + - String doesn't match UUID format + + Examples: + Valid cases: + >>> validate_uuid1_hex("550e8400-e29b-41d4-a716-446655440000") + '550e8400e29b41d4a716446655440000' + >>> validate_uuid1_hex(UUID('550e8400-e29b-41d4-a716-446655440000')) + '550e8400e29b41d4a716446655440000' + + Invalid cases: + >>> validate_uuid1_hex("not-a-uuid") # raises PydanticCustomError + >>> validate_uuid1_hex(12345) # raises PydanticCustomError + >>> validate_uuid1_hex(UUID(int=0)) # v4, raises PydanticCustomError + + Notes: + - Uses Python's built-in UUID parser for format validation + - Version check prevents accidental use of other UUID versions + - Hyphens in input strings are automatically removed in output + """ + try: + uuid_obj = UUID(v) if isinstance(v, str) else v + if uuid_obj.version != 1: + raise PydanticCustomError("invalid_UUID1_format", "Must be a UUID1 format") + return uuid_obj.hex + except (AttributeError, ValueError, TypeError): + raise PydanticCustomError("invalid_UUID1_format", "Invalid UUID1 format") + + class PermissionEnum(StrEnum): me = auto() team = auto() @@ -217,8 +381,8 @@ class CreateDatasetReq(Base): avatar: str | None = Field(default=None, max_length=65535) description: str | None = Field(default=None, max_length=65535) embedding_model: Annotated[str, StringConstraints(strip_whitespace=True, max_length=255), Field(default="", serialization_alias="embd_id")] - permission: Annotated[PermissionEnum, StringConstraints(strip_whitespace=True, min_length=1, max_length=16), Field(default=PermissionEnum.me)] - chunk_method: Annotated[ChunkMethodnEnum, StringConstraints(strip_whitespace=True, min_length=1, max_length=32), Field(default=ChunkMethodnEnum.naive, serialization_alias="parser_id")] + permission: PermissionEnum = Field(default=PermissionEnum.me, min_length=1, max_length=16) + chunk_method: ChunkMethodnEnum = Field(default=ChunkMethodnEnum.naive, min_length=1, max_length=32, serialization_alias="parser_id") pagerank: int = Field(default=0, ge=0, le=100) parser_config: ParserConfig | None = Field(default=None) @@ -315,22 +479,8 @@ class CreateDatasetReq(Base): @field_validator("permission", mode="before") @classmethod - def permission_auto_lowercase(cls, v: Any) -> Any: - """ - Normalize permission input to lowercase for consistent PermissionEnum matching. - - Args: - v (Any): Raw input value for the permission field - - Returns: - Lowercase string if input is string type, otherwise returns original value - - Behavior: - - Converts string inputs to lowercase (e.g., "ME" → "me") - - Non-string values pass through unchanged - - Works in validation pre-processing stage (before enum conversion) - """ - return v.lower() if isinstance(v, str) else v + def normalize_permission(cls, v: Any) -> Any: + return normalize_str(v) @field_validator("parser_config", mode="before") @classmethod @@ -387,93 +537,117 @@ class CreateDatasetReq(Base): class UpdateDatasetReq(CreateDatasetReq): - dataset_id: UUID1 = Field(...) + dataset_id: str = Field(...) name: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1, max_length=DATASET_NAME_LIMIT), Field(default="")] - @field_serializer("dataset_id") - def serialize_uuid_to_hex(self, v: uuid.UUID) -> str: - """ - Serializes a UUID version 1 object to its hexadecimal string representation. - - This field serializer specifically handles UUID version 1 objects, converting them - to their canonical 32-character hexadecimal format without hyphens. The conversion - is designed for consistent serialization in API responses and database storage. - - Args: - v (uuid.UUID1): The UUID version 1 object to serialize. Must be a valid - UUID1 instance generated by Python's uuid module. - - Returns: - str: 32-character lowercase hexadecimal string representation - Example: "550e8400e29b41d4a716446655440000" - - Raises: - AttributeError: If input is not a proper UUID object (missing hex attribute) - TypeError: If input is not a UUID1 instance (when type checking is enabled) - - Notes: - - Version 1 UUIDs contain timestamp and MAC address information - - The .hex property automatically converts to lowercase hexadecimal - - For cross-version compatibility, consider typing as uuid.UUID instead - """ - return v.hex + @field_validator("dataset_id", mode="before") + @classmethod + def validate_dataset_id(cls, v: Any) -> str: + return validate_uuid1_hex(v) class DeleteReq(Base): - ids: list[UUID1] | None = Field(...) + ids: list[str] | None = Field(...) @field_validator("ids", mode="after") - def check_duplicate_ids(cls, v: list[UUID1] | None) -> list[str] | None: + @classmethod + def validate_ids(cls, v_list: list[str] | None) -> list[str] | None: """ - Validates and converts a list of UUID1 objects to hexadecimal strings while checking for duplicates. + Validates and normalizes a list of UUID strings with None handling. - This validator implements a three-stage processing pipeline: - 1. Null Handling - returns None for empty/null input - 2. UUID Conversion - transforms UUID objects to hex strings - 3. Duplicate Validation - ensures all IDs are unique - - Behavior Specifications: - - Input: None → Returns None (indicates no operation) - - Input: [] → Returns [] (empty list for explicit no-op) - - Input: [UUID1,...] → Returns validated hex strings - - Duplicates: Raises formatted PydanticCustomError + This post-processing validator performs: + 1. None input handling (pass-through) + 2. UUID version 1 validation for each list item + 3. Duplicate value detection + 4. Returns normalized UUID hex strings or None Args: - v (list[UUID1] | None): - - None: Indicates no datasets should be processed - - Empty list: Explicit empty operation - - Populated list: Dataset UUIDs to validate/convert + v_list (list[str] | None): Input list that has passed initial validation. + Either a list of UUID strings or None. Returns: list[str] | None: - - None when input is None - - List of 32-character hex strings (lowercase, no hyphens) - Example: ["550e8400e29b41d4a716446655440000"] + - None if input was None + - List of normalized UUID hex strings otherwise: + * 32-character lowercase + * Valid UUID version 1 + * Unique within list Raises: - PydanticCustomError: When duplicates detected, containing: - - Error type: "duplicate_uuids" - - Template message: "Duplicate ids: '{duplicate_ids}'" - - Context: {"duplicate_ids": "id1, id2, ..."} + PydanticCustomError: With structured error details when: + - "invalid_UUID1_format": Any string fails UUIDv1 validation + - "duplicate_uuids": If duplicate IDs are detected - Example: - >>> validate([UUID("..."), UUID("...")]) - ["2cdf0456e9a711ee8000000000000000", ...] + Validation Rules: + - None input returns None + - Empty list returns empty list + - All non-None items must be valid UUIDv1 + - No duplicates permitted + - Original order preserved - >>> validate([UUID("..."), UUID("...")]) # Duplicates - PydanticCustomError: Duplicate ids: '2cdf0456e9a711ee8000000000000000' + Examples: + Valid cases: + >>> validate_ids(None) + None + >>> validate_ids([]) + [] + >>> validate_ids(["550e8400-e29b-41d4-a716-446655440000"]) + ["550e8400e29b41d4a716446655440000"] + + Invalid cases: + >>> validate_ids(["invalid"]) + # raises PydanticCustomError(invalid_UUID1_format) + >>> validate_ids(["550e...", "550e..."]) + # raises PydanticCustomError(duplicate_uuids) + + Security Notes: + - Validates UUID version to prevent version spoofing + - Duplicate check prevents data injection + - None handling maintains pipeline integrity """ - if not v: - return v + if v_list is None: + return None - uuid_hex_list = [ids.hex for ids in v] - duplicates = [item for item, count in Counter(uuid_hex_list).items() if count > 1] + ids_list = [] + for v in v_list: + try: + ids_list.append(validate_uuid1_hex(v)) + except PydanticCustomError as e: + raise e + duplicates = [item for item, count in Counter(ids_list).items() if count > 1] if duplicates: duplicates_str = ", ".join(duplicates) raise PydanticCustomError("duplicate_uuids", "Duplicate ids: '{duplicate_ids}'", {"duplicate_ids": duplicates_str}) - return uuid_hex_list + return ids_list class DeleteDatasetReq(DeleteReq): ... + + +class OrderByEnum(StrEnum): + create_time = auto() + update_time = auto() + + +class BaseListReq(Base): + id: str | None = None + name: str | None = None + page: int = Field(default=1, ge=1) + page_size: int = Field(default=30, ge=1) + orderby: OrderByEnum = Field(default=OrderByEnum.create_time) + desc: bool = Field(default=True) + + @field_validator("id", mode="before") + @classmethod + def validate_id(cls, v: Any) -> str: + return validate_uuid1_hex(v) + + @field_validator("orderby", mode="before") + @classmethod + def normalize_orderby(cls, v: Any) -> Any: + return normalize_str(v) + + +class ListDatasetReq(BaseListReq): ... diff --git a/sdk/python/test/test_http_api/test_dataset_mangement/test_create_dataset.py b/sdk/python/test/test_http_api/test_dataset_mangement/test_create_dataset.py index e72fbed64..fcea6fdf5 100644 --- a/sdk/python/test/test_http_api/test_dataset_mangement/test_create_dataset.py +++ b/sdk/python/test/test_http_api/test_dataset_mangement/test_create_dataset.py @@ -122,7 +122,7 @@ class TestDatasetCreate: assert res["code"] == 0, res res = create_dataset(get_http_api_auth, payload) - assert res["code"] == 102, res + assert res["code"] == 103, res assert res["message"] == f"Dataset name '{name}' already exists", res @pytest.mark.p3 @@ -134,7 +134,7 @@ class TestDatasetCreate: payload = {"name": name.lower()} res = create_dataset(get_http_api_auth, payload) - assert res["code"] == 102, res + assert res["code"] == 103, res assert res["message"] == f"Dataset name '{name.lower()}' already exists", res @pytest.mark.p2 @@ -296,14 +296,15 @@ class TestDatasetCreate: ("team", "team"), ("me_upercase", "ME"), ("team_upercase", "TEAM"), + ("whitespace", " ME "), ], - ids=["me", "team", "me_upercase", "team_upercase"], + ids=["me", "team", "me_upercase", "team_upercase", "whitespace"], ) def test_permission(self, get_http_api_auth, name, permission): payload = {"name": name, "permission": permission} res = create_dataset(get_http_api_auth, payload) assert res["code"] == 0, res - assert res["data"]["permission"] == permission.lower(), res + assert res["data"]["permission"] == permission.lower().strip(), res @pytest.mark.p2 @pytest.mark.parametrize( diff --git a/sdk/python/test/test_http_api/test_dataset_mangement/test_delete_datasets.py b/sdk/python/test/test_http_api/test_dataset_mangement/test_delete_datasets.py index 78b4efb58..a73a1568b 100644 --- a/sdk/python/test/test_http_api/test_dataset_mangement/test_delete_datasets.py +++ b/sdk/python/test/test_http_api/test_dataset_mangement/test_delete_datasets.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import uuid from concurrent.futures import ThreadPoolExecutor import pytest @@ -40,8 +41,8 @@ class TestAuthorization: ) def test_auth_invalid(self, auth, expected_code, expected_message): res = delete_datasets(auth) - assert res["code"] == expected_code - assert res["message"] == expected_message + assert res["code"] == expected_code, res + assert res["message"] == expected_message, res class TestRquest: @@ -140,17 +141,25 @@ class TestDatasetsDelete: payload = {"ids": ["not_uuid"]} res = delete_datasets(get_http_api_auth, payload) assert res["code"] == 101, res - assert "Input should be a valid UUID" in res["message"], res + assert "Invalid UUID1 format" in res["message"], res res = list_datasets(get_http_api_auth) assert len(res["data"]) == 1, res + @pytest.mark.p3 + @pytest.mark.usefixtures("add_dataset_func") + def test_id_not_uuid1(self, get_http_api_auth): + payload = {"ids": [uuid.uuid4().hex]} + res = delete_datasets(get_http_api_auth, payload) + assert res["code"] == 101, res + assert "Invalid UUID1 format" in res["message"], res + @pytest.mark.p2 @pytest.mark.usefixtures("add_dataset_func") def test_id_wrong_uuid(self, get_http_api_auth): payload = {"ids": ["d94a8dc02c9711f0930f7fbc369eab6d"]} res = delete_datasets(get_http_api_auth, payload) - assert res["code"] == 102, res + assert res["code"] == 108, res assert "lacks permission for dataset" in res["message"], res res = list_datasets(get_http_api_auth) @@ -170,7 +179,7 @@ class TestDatasetsDelete: if callable(func): payload = func(dataset_ids) res = delete_datasets(get_http_api_auth, payload) - assert res["code"] == 102, res + assert res["code"] == 108, res assert "lacks permission for dataset" in res["message"], res res = list_datasets(get_http_api_auth) @@ -195,7 +204,7 @@ class TestDatasetsDelete: assert res["code"] == 0, res res = delete_datasets(get_http_api_auth, payload) - assert res["code"] == 102, res + assert res["code"] == 108, res assert "lacks permission for dataset" in res["message"], res @pytest.mark.p2 diff --git a/sdk/python/test/test_http_api/test_dataset_mangement/test_list_datasets.py b/sdk/python/test/test_http_api/test_dataset_mangement/test_list_datasets.py index 6eaaef57a..d81584aa5 100644 --- a/sdk/python/test/test_http_api/test_dataset_mangement/test_list_datasets.py +++ b/sdk/python/test/test_http_api/test_dataset_mangement/test_list_datasets.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import uuid from concurrent.futures import ThreadPoolExecutor import pytest @@ -21,8 +22,8 @@ from libs.auth import RAGFlowHttpApiAuth from libs.utils import is_sorted -@pytest.mark.p1 class TestAuthorization: + @pytest.mark.p1 @pytest.mark.parametrize( "auth, expected_code, expected_message", [ @@ -34,269 +35,305 @@ class TestAuthorization: ), ], ) - def test_invalid_auth(self, auth, expected_code, expected_message): + def test_auth_invalid(self, auth, expected_code, expected_message): res = list_datasets(auth) - assert res["code"] == expected_code - assert res["message"] == expected_message + assert res["code"] == expected_code, res + assert res["message"] == expected_message, res -@pytest.mark.usefixtures("add_datasets") -class TestDatasetsList: - @pytest.mark.p1 - def test_default(self, get_http_api_auth): - res = list_datasets(get_http_api_auth, params={}) - - assert res["code"] == 0 - assert len(res["data"]) == 5 - - @pytest.mark.p1 - @pytest.mark.parametrize( - "params, expected_code, expected_page_size, expected_message", - [ - ({"page": None, "page_size": 2}, 0, 2, ""), - ({"page": 0, "page_size": 2}, 0, 2, ""), - ({"page": 2, "page_size": 2}, 0, 2, ""), - ({"page": 3, "page_size": 2}, 0, 1, ""), - ({"page": "3", "page_size": 2}, 0, 1, ""), - pytest.param( - {"page": -1, "page_size": 2}, - 100, - 0, - "1064", - marks=pytest.mark.skip(reason="issues/5851"), - ), - pytest.param( - {"page": "a", "page_size": 2}, - 100, - 0, - """ValueError("invalid literal for int() with base 10: \'a\'")""", - marks=pytest.mark.skip(reason="issues/5851"), - ), - ], - ) - def test_page(self, get_http_api_auth, params, expected_code, expected_page_size, expected_message): - res = list_datasets(get_http_api_auth, params=params) - assert res["code"] == expected_code - if expected_code == 0: - assert len(res["data"]) == expected_page_size - else: - assert res["message"] == expected_message - - @pytest.mark.p1 - @pytest.mark.parametrize( - "params, expected_code, expected_page_size, expected_message", - [ - ({"page_size": None}, 0, 5, ""), - ({"page_size": 0}, 0, 0, ""), - ({"page_size": 1}, 0, 1, ""), - ({"page_size": 6}, 0, 5, ""), - ({"page_size": "1"}, 0, 1, ""), - pytest.param( - {"page_size": -1}, - 100, - 0, - "1064", - marks=pytest.mark.skip(reason="issues/5851"), - ), - pytest.param( - {"page_size": "a"}, - 100, - 0, - """ValueError("invalid literal for int() with base 10: \'a\'")""", - marks=pytest.mark.skip(reason="issues/5851"), - ), - ], - ) - def test_page_size( - self, - get_http_api_auth, - params, - expected_code, - expected_page_size, - expected_message, - ): - res = list_datasets(get_http_api_auth, params=params) - assert res["code"] == expected_code - if expected_code == 0: - assert len(res["data"]) == expected_page_size - else: - assert res["message"] == expected_message - - @pytest.mark.p3 - @pytest.mark.parametrize( - "params, expected_code, assertions, expected_message", - [ - ({"orderby": None}, 0, lambda r: (is_sorted(r["data"], "create_time", True)), ""), - ({"orderby": "create_time"}, 0, lambda r: (is_sorted(r["data"], "create_time", True)), ""), - ({"orderby": "update_time"}, 0, lambda r: (is_sorted(r["data"], "update_time", True)), ""), - pytest.param( - {"orderby": "name", "desc": "False"}, - 0, - lambda r: (is_sorted(r["data"]["docs"], "name", False)), - "", - marks=pytest.mark.skip(reason="issues/5851"), - ), - pytest.param( - {"orderby": "unknown"}, - 102, - 0, - "orderby should be create_time or update_time", - marks=pytest.mark.skip(reason="issues/5851"), - ), - ], - ) - def test_orderby( - self, - get_http_api_auth, - params, - expected_code, - assertions, - expected_message, - ): - res = list_datasets(get_http_api_auth, params=params) - assert res["code"] == expected_code - if expected_code == 0: - if callable(assertions): - assert assertions(res) - else: - assert res["message"] == expected_message - - @pytest.mark.p3 - @pytest.mark.parametrize( - "params, expected_code, assertions, expected_message", - [ - ({"desc": None}, 0, lambda r: (is_sorted(r["data"], "create_time", True)), ""), - ({"desc": "true"}, 0, lambda r: (is_sorted(r["data"], "create_time", True)), ""), - ({"desc": "True"}, 0, lambda r: (is_sorted(r["data"], "create_time", True)), ""), - ({"desc": True}, 0, lambda r: (is_sorted(r["data"], "create_time", True)), ""), - ({"desc": "false"}, 0, lambda r: (is_sorted(r["data"], "create_time", False)), ""), - ({"desc": "False"}, 0, lambda r: (is_sorted(r["data"], "create_time", False)), ""), - ({"desc": False}, 0, lambda r: (is_sorted(r["data"], "create_time", False)), ""), - ({"desc": "False", "orderby": "update_time"}, 0, lambda r: (is_sorted(r["data"], "update_time", False)), ""), - pytest.param( - {"desc": "unknown"}, - 102, - 0, - "desc should be true or false", - marks=pytest.mark.skip(reason="issues/5851"), - ), - ], - ) - def test_desc( - self, - get_http_api_auth, - params, - expected_code, - assertions, - expected_message, - ): - res = list_datasets(get_http_api_auth, params=params) - assert res["code"] == expected_code - if expected_code == 0: - if callable(assertions): - assert assertions(res) - else: - assert res["message"] == expected_message - - @pytest.mark.p1 - @pytest.mark.parametrize( - "params, expected_code, expected_num, expected_message", - [ - ({"name": None}, 0, 5, ""), - ({"name": ""}, 0, 5, ""), - ({"name": "dataset_1"}, 0, 1, ""), - ({"name": "unknown"}, 102, 0, "You don't own the dataset unknown"), - ], - ) - def test_name(self, get_http_api_auth, params, expected_code, expected_num, expected_message): - res = list_datasets(get_http_api_auth, params=params) - assert res["code"] == expected_code - if expected_code == 0: - if params["name"] in [None, ""]: - assert len(res["data"]) == expected_num - else: - assert res["data"][0]["name"] == params["name"] - else: - assert res["message"] == expected_message - - @pytest.mark.p1 - @pytest.mark.parametrize( - "dataset_id, expected_code, expected_num, expected_message", - [ - (None, 0, 5, ""), - ("", 0, 5, ""), - (lambda r: r[0], 0, 1, ""), - ("unknown", 102, 0, "You don't own the dataset unknown"), - ], - ) - def test_id( - self, - get_http_api_auth, - add_datasets, - dataset_id, - expected_code, - expected_num, - expected_message, - ): - dataset_ids = add_datasets - if callable(dataset_id): - params = {"id": dataset_id(dataset_ids)} - else: - params = {"id": dataset_id} - - res = list_datasets(get_http_api_auth, params=params) - assert res["code"] == expected_code - if expected_code == 0: - if params["id"] in [None, ""]: - assert len(res["data"]) == expected_num - else: - assert res["data"][0]["id"] == params["id"] - else: - assert res["message"] == expected_message - - @pytest.mark.p3 - @pytest.mark.parametrize( - "dataset_id, name, expected_code, expected_num, expected_message", - [ - (lambda r: r[0], "dataset_0", 0, 1, ""), - (lambda r: r[0], "dataset_1", 0, 0, ""), - (lambda r: r[0], "unknown", 102, 0, "You don't own the dataset unknown"), - ("id", "dataset_0", 102, 0, "You don't own the dataset id"), - ], - ) - def test_name_and_id( - self, - get_http_api_auth, - add_datasets, - dataset_id, - name, - expected_code, - expected_num, - expected_message, - ): - dataset_ids = add_datasets - if callable(dataset_id): - params = {"id": dataset_id(dataset_ids), "name": name} - else: - params = {"id": dataset_id, "name": name} - - res = list_datasets(get_http_api_auth, params=params) - assert res["code"] == expected_code - if expected_code == 0: - assert len(res["data"]) == expected_num - else: - assert res["message"] == expected_message - +class TestCapability: @pytest.mark.p3 def test_concurrent_list(self, get_http_api_auth): with ThreadPoolExecutor(max_workers=5) as executor: futures = [executor.submit(list_datasets, get_http_api_auth) for i in range(100)] responses = [f.result() for f in futures] - assert all(r["code"] == 0 for r in responses) + assert all(r["code"] == 0 for r in responses), responses + + +@pytest.mark.usefixtures("add_datasets") +class TestDatasetsList: + @pytest.mark.p1 + def test_params_unset(self, get_http_api_auth): + res = list_datasets(get_http_api_auth, None) + assert res["code"] == 0, res + assert len(res["data"]) == 5, res + + @pytest.mark.p2 + def test_params_empty(self, get_http_api_auth): + res = list_datasets(get_http_api_auth, {}) + assert res["code"] == 0, res + assert len(res["data"]) == 5, res + + @pytest.mark.p1 + @pytest.mark.parametrize( + "params, expected_page_size", + [ + ({"page": 2, "page_size": 2}, 2), + ({"page": 3, "page_size": 2}, 1), + ({"page": 4, "page_size": 2}, 0), + ({"page": "2", "page_size": 2}, 2), + ({"page": 1, "page_size": 10}, 5), + ], + ids=["normal_middle_page", "normal_last_partial_page", "beyond_max_page", "string_page_number", "full_data_single_page"], + ) + def test_page(self, get_http_api_auth, params, expected_page_size): + res = list_datasets(get_http_api_auth, params) + assert res["code"] == 0, res + assert len(res["data"]) == expected_page_size, res + + @pytest.mark.p2 + @pytest.mark.parametrize( + "params, expected_code, expected_message", + [ + ({"page": 0}, 101, "Input should be greater than or equal to 1"), + ({"page": "a"}, 101, "Input should be a valid integer, unable to parse string as an integer"), + ], + ids=["page_0", "page_a"], + ) + def test_page_invalid(self, get_http_api_auth, params, expected_code, expected_message): + res = list_datasets(get_http_api_auth, params=params) + assert res["code"] == expected_code, res + assert expected_message in res["message"], res + + @pytest.mark.p2 + def test_page_none(self, get_http_api_auth): + params = {"page": None} + res = list_datasets(get_http_api_auth, params) + assert res["code"] == 0, res + assert len(res["data"]) == 5, res + + @pytest.mark.p1 + @pytest.mark.parametrize( + "params, expected_page_size", + [ + ({"page_size": 1}, 1), + ({"page_size": 3}, 3), + ({"page_size": 5}, 5), + ({"page_size": 6}, 5), + ({"page_size": "1"}, 1), + ], + ids=["min_valid_page_size", "medium_page_size", "page_size_equals_total", "page_size_exceeds_total", "string_type_page_size"], + ) + def test_page_size(self, get_http_api_auth, params, expected_page_size): + res = list_datasets(get_http_api_auth, params) + assert res["code"] == 0, res + assert len(res["data"]) == expected_page_size, res + + @pytest.mark.p2 + @pytest.mark.parametrize( + "params, expected_code, expected_message", + [ + ({"page_size": 0}, 101, "Input should be greater than or equal to 1"), + ({"page_size": "a"}, 101, "Input should be a valid integer, unable to parse string as an integer"), + ], + ) + def test_page_size_invalid(self, get_http_api_auth, params, expected_code, expected_message): + res = list_datasets(get_http_api_auth, params) + assert res["code"] == expected_code, res + assert expected_message in res["message"], res + + @pytest.mark.p2 + def test_page_size_none(self, get_http_api_auth): + params = {"page_size": None} + res = list_datasets(get_http_api_auth, params) + assert res["code"] == 0, res + assert len(res["data"]) == 5, res + + @pytest.mark.p2 + @pytest.mark.parametrize( + "params, assertions", + [ + ({"orderby": "create_time"}, lambda r: (is_sorted(r["data"], "create_time", True))), + ({"orderby": "update_time"}, lambda r: (is_sorted(r["data"], "update_time", True))), + ({"orderby": "CREATE_TIME"}, lambda r: (is_sorted(r["data"], "create_time", True))), + ({"orderby": "UPDATE_TIME"}, lambda r: (is_sorted(r["data"], "update_time", True))), + ({"orderby": " create_time "}, lambda r: (is_sorted(r["data"], "update_time", True))), + ], + ids=["orderby_create_time", "orderby_update_time", "orderby_create_time_upper", "orderby_update_time_upper", "whitespace"], + ) + def test_orderby(self, get_http_api_auth, params, assertions): + res = list_datasets(get_http_api_auth, params) + assert res["code"] == 0, res + if callable(assertions): + assert assertions(res), res @pytest.mark.p3 - def test_invalid_params(self, get_http_api_auth): - params = {"a": "b"} - res = list_datasets(get_http_api_auth, params=params) + @pytest.mark.parametrize( + "params", + [ + {"orderby": ""}, + {"orderby": "unknown"}, + ], + ids=["empty", "unknown"], + ) + def test_orderby_invalid(self, get_http_api_auth, params): + res = list_datasets(get_http_api_auth, params) + assert res["code"] == 101, res + assert "Input should be 'create_time' or 'update_time'" in res["message"], res + + @pytest.mark.p3 + def test_orderby_none(self, get_http_api_auth): + params = {"order_by": None} + res = list_datasets(get_http_api_auth, params) + assert res["code"] == 0, res + assert is_sorted(res["data"], "create_time", True), res + + @pytest.mark.p2 + @pytest.mark.parametrize( + "params, assertions", + [ + ({"desc": True}, lambda r: (is_sorted(r["data"], "create_time", True))), + ({"desc": False}, lambda r: (is_sorted(r["data"], "create_time", False))), + ({"desc": "true"}, lambda r: (is_sorted(r["data"], "create_time", True))), + ({"desc": "false"}, lambda r: (is_sorted(r["data"], "create_time", False))), + ({"desc": 1}, lambda r: (is_sorted(r["data"], "create_time", True))), + ({"desc": 0}, lambda r: (is_sorted(r["data"], "create_time", False))), + ({"desc": "yes"}, lambda r: (is_sorted(r["data"], "create_time", True))), + ({"desc": "no"}, lambda r: (is_sorted(r["data"], "create_time", False))), + ({"desc": "y"}, lambda r: (is_sorted(r["data"], "create_time", True))), + ({"desc": "n"}, lambda r: (is_sorted(r["data"], "create_time", False))), + ], + ids=["desc=True", "desc=False", "desc=true", "desc=false", "desc=1", "desc=0", "desc=yes", "desc=no", "desc=y", "desc=n"], + ) + def test_desc(self, get_http_api_auth, params, assertions): + res = list_datasets(get_http_api_auth, params) + assert res["code"] == 0, res + if callable(assertions): + assert assertions(res), res + + @pytest.mark.p3 + @pytest.mark.parametrize( + "params", + [ + {"desc": 3.14}, + {"desc": "unknown"}, + ], + ids=["empty", "unknown"], + ) + def test_desc_invalid(self, get_http_api_auth, params): + res = list_datasets(get_http_api_auth, params) + assert res["code"] == 101, res + assert "Input should be a valid boolean, unable to interpret input" in res["message"], res + + @pytest.mark.p3 + def test_desc_none(self, get_http_api_auth): + params = {"desc": None} + res = list_datasets(get_http_api_auth, params) + assert res["code"] == 0, res + assert is_sorted(res["data"], "create_time", True), res + + @pytest.mark.p1 + def test_name(self, get_http_api_auth): + params = {"name": "dataset_1"} + res = list_datasets(get_http_api_auth, params) + assert res["code"] == 0, res + assert len(res["data"]) == 1, res + assert res["data"][0]["name"] == "dataset_1", res + + @pytest.mark.p2 + def test_name_wrong(self, get_http_api_auth): + params = {"name": "wrong name"} + res = list_datasets(get_http_api_auth, params) + assert res["code"] == 108, res + assert "lacks permission for dataset" in res["message"], res + + @pytest.mark.p2 + def test_name_empty(self, get_http_api_auth): + params = {"name": ""} + res = list_datasets(get_http_api_auth, params) + assert res["code"] == 0, res + assert len(res["data"]) == 5, res + + @pytest.mark.p2 + def test_name_none(self, get_http_api_auth): + params = {"name": None} + res = list_datasets(get_http_api_auth, params) + assert res["code"] == 0, res + assert len(res["data"]) == 5, res + + @pytest.mark.p1 + def test_id(self, get_http_api_auth, add_datasets): + dataset_ids = add_datasets + params = {"id": dataset_ids[0]} + res = list_datasets(get_http_api_auth, params) assert res["code"] == 0 - assert len(res["data"]) == 5 + assert len(res["data"]) == 1 + assert res["data"][0]["id"] == dataset_ids[0] + + @pytest.mark.p2 + def test_id_not_uuid(self, get_http_api_auth): + params = {"id": "not_uuid"} + res = list_datasets(get_http_api_auth, params) + assert res["code"] == 101, res + assert "Invalid UUID1 format" in res["message"], res + + @pytest.mark.p2 + def test_id_not_uuid1(self, get_http_api_auth): + params = {"id": uuid.uuid4().hex} + res = list_datasets(get_http_api_auth, params) + assert res["code"] == 101, res + assert "Invalid UUID1 format" in res["message"], res + + @pytest.mark.p2 + def test_id_wrong_uuid(self, get_http_api_auth): + params = {"id": "d94a8dc02c9711f0930f7fbc369eab6d"} + res = list_datasets(get_http_api_auth, params) + assert res["code"] == 108, res + assert "lacks permission for dataset" in res["message"], res + + @pytest.mark.p2 + def test_id_empty(self, get_http_api_auth): + params = {"id": ""} + res = list_datasets(get_http_api_auth, params) + assert res["code"] == 101, res + assert "Invalid UUID1 format" in res["message"], res + + @pytest.mark.p2 + def test_id_none(self, get_http_api_auth): + params = {"id": None} + res = list_datasets(get_http_api_auth, params) + assert res["code"] == 0, res + assert len(res["data"]) == 5, res + + @pytest.mark.p2 + @pytest.mark.parametrize( + "func, name, expected_num", + [ + (lambda r: r[0], "dataset_0", 1), + (lambda r: r[0], "dataset_1", 0), + ], + ids=["name_and_id_match", "name_and_id_mismatch"], + ) + def test_name_and_id(self, get_http_api_auth, add_datasets, func, name, expected_num): + dataset_ids = add_datasets + if callable(func): + params = {"id": func(dataset_ids), "name": name} + res = list_datasets(get_http_api_auth, params) + assert res["code"] == 0, res + assert len(res["data"]) == expected_num, res + + @pytest.mark.p3 + @pytest.mark.parametrize( + "dataset_id, name", + [ + (lambda r: r[0], "wrong_name"), + (uuid.uuid1().hex, "dataset_0"), + ], + ids=["name", "id"], + ) + def test_name_and_id_wrong(self, get_http_api_auth, add_datasets, dataset_id, name): + dataset_ids = add_datasets + if callable(dataset_id): + params = {"id": dataset_id(dataset_ids), "name": name} + else: + params = {"id": dataset_id, "name": name} + res = list_datasets(get_http_api_auth, params) + assert res["code"] == 108, res + assert "lacks permission for dataset" in res["message"], res + + @pytest.mark.p2 + def test_field_unsupported(self, get_http_api_auth): + params = {"unknown_field": "unknown_field"} + res = list_datasets(get_http_api_auth, params) + assert res["code"] == 101, res + assert "Extra inputs are not permitted" in res["message"], res diff --git a/sdk/python/test/test_http_api/test_dataset_mangement/test_update_dataset.py b/sdk/python/test/test_http_api/test_dataset_mangement/test_update_dataset.py index 40d2dc01a..39d0cafc1 100644 --- a/sdk/python/test/test_http_api/test_dataset_mangement/test_update_dataset.py +++ b/sdk/python/test/test_http_api/test_dataset_mangement/test_update_dataset.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import uuid from concurrent.futures import ThreadPoolExecutor import pytest @@ -98,16 +99,23 @@ class TestCapability: class TestDatasetUpdate: @pytest.mark.p3 def test_dataset_id_not_uuid(self, get_http_api_auth): - payload = {"name": "not_uuid"} + payload = {"name": "not uuid"} res = update_dataset(get_http_api_auth, "not_uuid", payload) assert res["code"] == 101, res - assert "Input should be a valid UUID" in res["message"], res + assert "Invalid UUID1 format" in res["message"], res + + @pytest.mark.p3 + def test_dataset_id_not_uuid1(self, get_http_api_auth): + payload = {"name": "not uuid1"} + res = update_dataset(get_http_api_auth, uuid.uuid4().hex, payload) + assert res["code"] == 101, res + assert "Invalid UUID1 format" in res["message"], res @pytest.mark.p3 def test_dataset_id_wrong_uuid(self, get_http_api_auth): - payload = {"name": "wrong_uuid"} + payload = {"name": "wrong uuid"} res = update_dataset(get_http_api_auth, "d94a8dc02c9711f0930f7fbc369eab6d", payload) - assert res["code"] == 102, res + assert res["code"] == 108, res assert "lacks permission for dataset" in res["message"], res @pytest.mark.p1 @@ -322,8 +330,9 @@ class TestDatasetUpdate: "team", "ME", "TEAM", + " ME ", ], - ids=["me", "team", "me_upercase", "team_upercase"], + ids=["me", "team", "me_upercase", "team_upercase", "whitespace"], ) def test_permission(self, get_http_api_auth, add_dataset_func, permission): dataset_id = add_dataset_func @@ -333,7 +342,7 @@ class TestDatasetUpdate: res = list_datasets(get_http_api_auth) assert res["code"] == 0, res - assert res["data"][0]["permission"] == permission.lower(), res + assert res["data"][0]["permission"] == permission.lower().strip(), res @pytest.mark.p2 @pytest.mark.parametrize( @@ -734,7 +743,6 @@ class TestDatasetUpdate: assert res["code"] == 0, res res = list_datasets(get_http_api_auth) - print(res) assert res["code"] == 0, res assert res["data"][0]["parser_config"] == {"raptor": {"use_raptor": False}}, res @@ -757,7 +765,6 @@ class TestDatasetUpdate: assert res["code"] == 0, res res = list_datasets(get_http_api_auth, {"id": dataset_id}) - print(res) assert res["code"] == 0, res assert res["data"][0]["parser_config"] == {"raptor": {"use_raptor": False}}, res