Refa: HTTP API list datasets / test cases / docs (#7720)

### What problem does this PR solve?

This PR introduces Pydantic-based validation for the list datasets HTTP
API, improving code clarity and robustness. Key changes include:

Pydantic Validation
Error Handling
Test Updates
Documentation Updates

### Type of change

- [x] Documentation Update
- [x] Refactoring
This commit is contained in:
liu an 2025-05-20 09:58:26 +08:00 committed by GitHub
parent 6ed81d6774
commit fed1221302
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 683 additions and 412 deletions

View File

@ -32,12 +32,22 @@ from api.utils.api_utils import (
deep_merge, deep_merge,
get_error_argument_result, get_error_argument_result,
get_error_data_result, get_error_data_result,
get_error_operating_result,
get_error_permission_result,
get_parser_config, get_parser_config,
get_result, get_result,
remap_dictionary_keys,
token_required, token_required,
verify_embedding_availability, verify_embedding_availability,
) )
from api.utils.validation_utils import CreateDatasetReq, DeleteDatasetReq, UpdateDatasetReq, validate_and_parse_json_request from api.utils.validation_utils import (
CreateDatasetReq,
DeleteDatasetReq,
ListDatasetReq,
UpdateDatasetReq,
validate_and_parse_json_request,
validate_and_parse_request_args,
)
@manager.route("/datasets", methods=["POST"]) # noqa: F821 @manager.route("/datasets", methods=["POST"]) # noqa: F821
@ -113,7 +123,7 @@ def create(tenant_id):
try: try:
if KnowledgebaseService.get_or_none(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value): if KnowledgebaseService.get_or_none(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value):
return get_error_data_result(message=f"Dataset name '{req['name']}' already exists") return get_error_operating_result(message=f"Dataset name '{req['name']}' already exists")
except OperationalError as e: except OperationalError as e:
logging.exception(e) logging.exception(e)
return get_error_data_result(message="Database operation failed") return get_error_data_result(message="Database operation failed")
@ -126,7 +136,7 @@ def create(tenant_id):
try: try:
ok, t = TenantService.get_by_id(tenant_id) ok, t = TenantService.get_by_id(tenant_id)
if not ok: if not ok:
return get_error_data_result(message="Tenant not found") return get_error_permission_result(message="Tenant not found")
except OperationalError as e: except OperationalError as e:
logging.exception(e) logging.exception(e)
return get_error_data_result(message="Database operation failed") return get_error_data_result(message="Database operation failed")
@ -153,16 +163,7 @@ def create(tenant_id):
logging.exception(e) logging.exception(e)
return get_error_data_result(message="Database operation failed") return get_error_data_result(message="Database operation failed")
response_data = {} response_data = remap_dictionary_keys(k.to_dict())
key_mapping = {
"chunk_num": "chunk_count",
"doc_num": "document_count",
"parser_id": "chunk_method",
"embd_id": "embedding_model",
}
for key, value in k.to_dict().items():
new_key = key_mapping.get(key, key)
response_data[new_key] = value
return get_result(data=response_data) return get_result(data=response_data)
@ -232,7 +233,7 @@ def delete(tenant_id):
logging.exception(e) logging.exception(e)
return get_error_data_result(message="Database operation failed") return get_error_data_result(message="Database operation failed")
if len(error_kb_ids) > 0: if len(error_kb_ids) > 0:
return get_error_data_result(message=f"""User '{tenant_id}' lacks permission for datasets: '{", ".join(error_kb_ids)}'""") return get_error_permission_result(message=f"""User '{tenant_id}' lacks permission for datasets: '{", ".join(error_kb_ids)}'""")
errors = [] errors = []
success_count = 0 success_count = 0
@ -347,7 +348,7 @@ def update(tenant_id, dataset_id):
try: try:
kb = KnowledgebaseService.get_or_none(id=dataset_id, tenant_id=tenant_id) kb = KnowledgebaseService.get_or_none(id=dataset_id, tenant_id=tenant_id)
if kb is None: if kb is None:
return get_error_data_result(message=f"User '{tenant_id}' lacks permission for dataset '{dataset_id}'") return get_error_permission_result(message=f"User '{tenant_id}' lacks permission for dataset '{dataset_id}'")
except OperationalError as e: except OperationalError as e:
logging.exception(e) logging.exception(e)
return get_error_data_result(message="Database operation failed") return get_error_data_result(message="Database operation failed")
@ -418,7 +419,7 @@ def list_datasets(tenant_id):
name: page_size name: page_size
type: integer type: integer
required: false required: false
default: 1024 default: 30
description: Number of items per page. description: Number of items per page.
- in: query - in: query
name: orderby name: orderby
@ -445,47 +446,46 @@ def list_datasets(tenant_id):
items: items:
type: object type: object
""" """
id = request.args.get("id") args, err = validate_and_parse_request_args(request, ListDatasetReq)
name = request.args.get("name") if err is not None:
if id: return get_error_argument_result(err)
kbs = KnowledgebaseService.get_kb_by_id(id, tenant_id)
kb_id = request.args.get("id")
name = args.get("name")
if kb_id:
try:
kbs = KnowledgebaseService.get_kb_by_id(kb_id, tenant_id)
except OperationalError as e:
logging.exception(e)
return get_error_data_result(message="Database operation failed")
if not kbs: if not kbs:
return get_error_data_result(f"You don't own the dataset {id}") return get_error_permission_result(message=f"User '{tenant_id}' lacks permission for dataset '{kb_id}'")
if name: if name:
kbs = KnowledgebaseService.get_kb_by_name(name, tenant_id) try:
kbs = KnowledgebaseService.get_kb_by_name(name, tenant_id)
except OperationalError as e:
logging.exception(e)
return get_error_data_result(message="Database operation failed")
if not kbs: if not kbs:
return get_error_data_result(f"You don't own the dataset {name}") return get_error_permission_result(message=f"User '{tenant_id}' lacks permission for dataset '{name}'")
page_number = int(request.args.get("page", 1))
items_per_page = int(request.args.get("page_size", 30)) try:
orderby = request.args.get("orderby", "create_time") tenants = TenantService.get_joined_tenants_by_user_id(tenant_id)
if request.args.get("desc", "false").lower() not in ["true", "false"]: kbs = KnowledgebaseService.get_list(
return get_error_data_result("desc should be true or false") [m["tenant_id"] for m in tenants],
if request.args.get("desc", "true").lower() == "false": tenant_id,
desc = False args["page"],
else: args["page_size"],
desc = True args["orderby"],
tenants = TenantService.get_joined_tenants_by_user_id(tenant_id) args["desc"],
kbs = KnowledgebaseService.get_list( kb_id,
[m["tenant_id"] for m in tenants], name,
tenant_id, )
page_number, except OperationalError as e:
items_per_page, logging.exception(e)
orderby, return get_error_data_result(message="Database operation failed")
desc,
id, response_data_list = []
name,
)
renamed_list = []
for kb in kbs: for kb in kbs:
key_mapping = { response_data_list.append(remap_dictionary_keys(kb))
"chunk_num": "chunk_count", return get_result(data=response_data_list)
"doc_num": "document_count",
"parser_id": "chunk_method",
"embd_id": "embedding_model",
}
renamed_data = {}
for key, value in kb.items():
new_key = key_mapping.get(key, key)
renamed_data[new_key] = value
renamed_list.append(renamed_data)
return get_result(data=renamed_list)

View File

@ -329,6 +329,14 @@ def get_error_argument_result(message="Invalid arguments"):
return get_result(code=settings.RetCode.ARGUMENT_ERROR, message=message) return get_result(code=settings.RetCode.ARGUMENT_ERROR, message=message)
def get_error_permission_result(message="Permission error"):
return get_result(code=settings.RetCode.PERMISSION_ERROR, message=message)
def get_error_operating_result(message="Operating error"):
return get_result(code=settings.RetCode.OPERATING_ERROR, message=message)
def generate_confirmation_token(tenant_id): def generate_confirmation_token(tenant_id):
serializer = URLSafeTimedSerializer(tenant_id) serializer = URLSafeTimedSerializer(tenant_id)
return "ragflow-" + serializer.dumps(get_uuid(), salt=tenant_id)[2:34] return "ragflow-" + serializer.dumps(get_uuid(), salt=tenant_id)[2:34]
@ -514,3 +522,38 @@ def deep_merge(default: dict, custom: dict) -> dict:
base_dict[key] = val base_dict[key] = val
return merged return merged
def remap_dictionary_keys(source_data: dict, key_aliases: dict = None) -> dict:
"""
Transform dictionary keys using a configurable mapping schema.
Args:
source_data: Original dictionary to process
key_aliases: Custom key transformation rules (Optional)
When provided, overrides default key mapping
Format: {<original_key>: <new_key>, ...}
Returns:
dict: New dictionary with transformed keys preserving original values
Example:
>>> input_data = {"old_key": "value", "another_field": 42}
>>> remap_dictionary_keys(input_data, {"old_key": "new_key"})
{'new_key': 'value', 'another_field': 42}
"""
DEFAULT_KEY_MAP = {
"chunk_num": "chunk_count",
"doc_num": "document_count",
"parser_id": "chunk_method",
"embd_id": "embedding_model",
}
transformed_data = {}
mapping = key_aliases or DEFAULT_KEY_MAP
for original_key, value in source_data.items():
mapped_key = mapping.get(original_key, original_key)
transformed_data[mapped_key] = value
return transformed_data

View File

@ -13,13 +13,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import uuid
from collections import Counter from collections import Counter
from enum import auto from enum import auto
from typing import Annotated, Any from typing import Annotated, Any
from uuid import UUID
from flask import Request from flask import Request
from pydantic import UUID1, BaseModel, Field, StringConstraints, ValidationError, field_serializer, field_validator from pydantic import BaseModel, Field, StringConstraints, ValidationError, field_validator
from pydantic_core import PydanticCustomError from pydantic_core import PydanticCustomError
from strenum import StrEnum from strenum import StrEnum
from werkzeug.exceptions import BadRequest, UnsupportedMediaType from werkzeug.exceptions import BadRequest, UnsupportedMediaType
@ -102,6 +102,71 @@ def validate_and_parse_json_request(request: Request, validator: type[BaseModel]
return parsed_payload, None return parsed_payload, None
def validate_and_parse_request_args(request: Request, validator: type[BaseModel], *, extras: dict[str, Any] | None = None) -> tuple[dict[str, Any] | None, str | None]:
"""
Validates and parses request arguments against a Pydantic model.
This function performs a complete request validation workflow:
1. Extracts query parameters from the request
2. Merges with optional extra values (if provided)
3. Validates against the specified Pydantic model
4. Cleans the output by removing extra values
5. Returns either parsed data or an error message
Args:
request (Request): Web framework request object containing query parameters
validator (type[BaseModel]): Pydantic model class for validation
extras (dict[str, Any] | None): Optional additional values to include in validation
but exclude from final output. Defaults to None.
Returns:
tuple[dict[str, Any] | None, str | None]:
- First element: Validated/parsed arguments as dict if successful, None otherwise
- Second element: Formatted error message if validation failed, None otherwise
Behavior:
- Query parameters are merged with extras before validation
- Extras are automatically removed from the final output
- All validation errors are formatted into a human-readable string
Raises:
TypeError: If validator is not a Pydantic BaseModel subclass
Examples:
Successful validation:
>>> validate_and_parse_request_args(request, MyValidator)
({'param1': 'value'}, None)
Failed validation:
>>> validate_and_parse_request_args(request, MyValidator)
(None, "param1: Field required")
With extras:
>>> validate_and_parse_request_args(request, MyValidator, extras={'internal_id': 123})
({'param1': 'value'}, None) # internal_id removed from output
Notes:
- Uses request.args.to_dict() for Flask-compatible parameter extraction
- Maintains immutability of original request arguments
- Preserves type conversion from Pydantic validation
"""
args = request.args.to_dict(flat=True)
try:
if extras is not None:
args.update(extras)
validated_args = validator(**args)
except ValidationError as e:
return None, format_validation_error_message(e)
parsed_args = validated_args.model_dump()
if extras is not None:
for key in list(parsed_args.keys()):
if key in extras:
del parsed_args[key]
return parsed_args, None
def format_validation_error_message(e: ValidationError) -> str: def format_validation_error_message(e: ValidationError) -> str:
""" """
Formats validation errors into a standardized string format. Formats validation errors into a standardized string format.
@ -143,6 +208,105 @@ def format_validation_error_message(e: ValidationError) -> str:
return "\n".join(error_messages) return "\n".join(error_messages)
def normalize_str(v: Any) -> Any:
"""
Normalizes string values to a standard format while preserving non-string inputs.
Performs the following transformations when input is a string:
1. Trims leading/trailing whitespace (str.strip())
2. Converts to lowercase (str.lower())
Non-string inputs are returned unchanged, making this function safe for mixed-type
processing pipelines.
Args:
v (Any): Input value to normalize. Accepts any Python object.
Returns:
Any: Normalized string if input was string-type, original value otherwise.
Behavior Examples:
String Input: " Admin " "admin"
Empty String: " " "" (empty string)
Non-String:
- 123 123
- None None
- ["User"] ["User"]
Typical Use Cases:
- Standardizing user input
- Preparing data for case-insensitive comparison
- Cleaning API parameters
- Normalizing configuration values
Edge Cases:
- Unicode whitespace is handled by str.strip()
- Locale-independent lowercasing (str.lower())
- Preserves falsy values (0, False, etc.)
Example:
>>> normalize_str(" ReadOnly ")
'readonly'
>>> normalize_str(42)
42
"""
if isinstance(v, str):
stripped = v.strip()
normalized = stripped.lower()
return normalized
return v
def validate_uuid1_hex(v: Any) -> str:
"""
Validates and converts input to a UUID version 1 hexadecimal string.
This function performs strict validation and normalization:
1. Accepts either UUID objects or UUID-formatted strings
2. Verifies the UUID is version 1 (time-based)
3. Returns the 32-character hexadecimal representation
Args:
v (Any): Input value to validate. Can be:
- UUID object (must be version 1)
- String in UUID format (e.g. "550e8400-e29b-41d4-a716-446655440000")
Returns:
str: 32-character lowercase hexadecimal string without hyphens
Example: "550e8400e29b41d4a716446655440000"
Raises:
PydanticCustomError: With code "invalid_UUID1_format" when:
- Input is not a UUID object or valid UUID string
- UUID version is not 1
- String doesn't match UUID format
Examples:
Valid cases:
>>> validate_uuid1_hex("550e8400-e29b-41d4-a716-446655440000")
'550e8400e29b41d4a716446655440000'
>>> validate_uuid1_hex(UUID('550e8400-e29b-41d4-a716-446655440000'))
'550e8400e29b41d4a716446655440000'
Invalid cases:
>>> validate_uuid1_hex("not-a-uuid") # raises PydanticCustomError
>>> validate_uuid1_hex(12345) # raises PydanticCustomError
>>> validate_uuid1_hex(UUID(int=0)) # v4, raises PydanticCustomError
Notes:
- Uses Python's built-in UUID parser for format validation
- Version check prevents accidental use of other UUID versions
- Hyphens in input strings are automatically removed in output
"""
try:
uuid_obj = UUID(v) if isinstance(v, str) else v
if uuid_obj.version != 1:
raise PydanticCustomError("invalid_UUID1_format", "Must be a UUID1 format")
return uuid_obj.hex
except (AttributeError, ValueError, TypeError):
raise PydanticCustomError("invalid_UUID1_format", "Invalid UUID1 format")
class PermissionEnum(StrEnum): class PermissionEnum(StrEnum):
me = auto() me = auto()
team = auto() team = auto()
@ -217,8 +381,8 @@ class CreateDatasetReq(Base):
avatar: str | None = Field(default=None, max_length=65535) avatar: str | None = Field(default=None, max_length=65535)
description: str | None = Field(default=None, max_length=65535) description: str | None = Field(default=None, max_length=65535)
embedding_model: Annotated[str, StringConstraints(strip_whitespace=True, max_length=255), Field(default="", serialization_alias="embd_id")] embedding_model: Annotated[str, StringConstraints(strip_whitespace=True, max_length=255), Field(default="", serialization_alias="embd_id")]
permission: Annotated[PermissionEnum, StringConstraints(strip_whitespace=True, min_length=1, max_length=16), Field(default=PermissionEnum.me)] permission: PermissionEnum = Field(default=PermissionEnum.me, min_length=1, max_length=16)
chunk_method: Annotated[ChunkMethodnEnum, StringConstraints(strip_whitespace=True, min_length=1, max_length=32), Field(default=ChunkMethodnEnum.naive, serialization_alias="parser_id")] chunk_method: ChunkMethodnEnum = Field(default=ChunkMethodnEnum.naive, min_length=1, max_length=32, serialization_alias="parser_id")
pagerank: int = Field(default=0, ge=0, le=100) pagerank: int = Field(default=0, ge=0, le=100)
parser_config: ParserConfig | None = Field(default=None) parser_config: ParserConfig | None = Field(default=None)
@ -315,22 +479,8 @@ class CreateDatasetReq(Base):
@field_validator("permission", mode="before") @field_validator("permission", mode="before")
@classmethod @classmethod
def permission_auto_lowercase(cls, v: Any) -> Any: def normalize_permission(cls, v: Any) -> Any:
""" return normalize_str(v)
Normalize permission input to lowercase for consistent PermissionEnum matching.
Args:
v (Any): Raw input value for the permission field
Returns:
Lowercase string if input is string type, otherwise returns original value
Behavior:
- Converts string inputs to lowercase (e.g., "ME" "me")
- Non-string values pass through unchanged
- Works in validation pre-processing stage (before enum conversion)
"""
return v.lower() if isinstance(v, str) else v
@field_validator("parser_config", mode="before") @field_validator("parser_config", mode="before")
@classmethod @classmethod
@ -387,93 +537,117 @@ class CreateDatasetReq(Base):
class UpdateDatasetReq(CreateDatasetReq): class UpdateDatasetReq(CreateDatasetReq):
dataset_id: UUID1 = Field(...) dataset_id: str = Field(...)
name: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1, max_length=DATASET_NAME_LIMIT), Field(default="")] name: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1, max_length=DATASET_NAME_LIMIT), Field(default="")]
@field_serializer("dataset_id") @field_validator("dataset_id", mode="before")
def serialize_uuid_to_hex(self, v: uuid.UUID) -> str: @classmethod
""" def validate_dataset_id(cls, v: Any) -> str:
Serializes a UUID version 1 object to its hexadecimal string representation. return validate_uuid1_hex(v)
This field serializer specifically handles UUID version 1 objects, converting them
to their canonical 32-character hexadecimal format without hyphens. The conversion
is designed for consistent serialization in API responses and database storage.
Args:
v (uuid.UUID1): The UUID version 1 object to serialize. Must be a valid
UUID1 instance generated by Python's uuid module.
Returns:
str: 32-character lowercase hexadecimal string representation
Example: "550e8400e29b41d4a716446655440000"
Raises:
AttributeError: If input is not a proper UUID object (missing hex attribute)
TypeError: If input is not a UUID1 instance (when type checking is enabled)
Notes:
- Version 1 UUIDs contain timestamp and MAC address information
- The .hex property automatically converts to lowercase hexadecimal
- For cross-version compatibility, consider typing as uuid.UUID instead
"""
return v.hex
class DeleteReq(Base): class DeleteReq(Base):
ids: list[UUID1] | None = Field(...) ids: list[str] | None = Field(...)
@field_validator("ids", mode="after") @field_validator("ids", mode="after")
def check_duplicate_ids(cls, v: list[UUID1] | None) -> list[str] | None: @classmethod
def validate_ids(cls, v_list: list[str] | None) -> list[str] | None:
""" """
Validates and converts a list of UUID1 objects to hexadecimal strings while checking for duplicates. Validates and normalizes a list of UUID strings with None handling.
This validator implements a three-stage processing pipeline: This post-processing validator performs:
1. Null Handling - returns None for empty/null input 1. None input handling (pass-through)
2. UUID Conversion - transforms UUID objects to hex strings 2. UUID version 1 validation for each list item
3. Duplicate Validation - ensures all IDs are unique 3. Duplicate value detection
4. Returns normalized UUID hex strings or None
Behavior Specifications:
- Input: None Returns None (indicates no operation)
- Input: [] Returns [] (empty list for explicit no-op)
- Input: [UUID1,...] Returns validated hex strings
- Duplicates: Raises formatted PydanticCustomError
Args: Args:
v (list[UUID1] | None): v_list (list[str] | None): Input list that has passed initial validation.
- None: Indicates no datasets should be processed Either a list of UUID strings or None.
- Empty list: Explicit empty operation
- Populated list: Dataset UUIDs to validate/convert
Returns: Returns:
list[str] | None: list[str] | None:
- None when input is None - None if input was None
- List of 32-character hex strings (lowercase, no hyphens) - List of normalized UUID hex strings otherwise:
Example: ["550e8400e29b41d4a716446655440000"] * 32-character lowercase
* Valid UUID version 1
* Unique within list
Raises: Raises:
PydanticCustomError: When duplicates detected, containing: PydanticCustomError: With structured error details when:
- Error type: "duplicate_uuids" - "invalid_UUID1_format": Any string fails UUIDv1 validation
- Template message: "Duplicate ids: '{duplicate_ids}'" - "duplicate_uuids": If duplicate IDs are detected
- Context: {"duplicate_ids": "id1, id2, ..."}
Example: Validation Rules:
>>> validate([UUID("..."), UUID("...")]) - None input returns None
["2cdf0456e9a711ee8000000000000000", ...] - Empty list returns empty list
- All non-None items must be valid UUIDv1
- No duplicates permitted
- Original order preserved
>>> validate([UUID("..."), UUID("...")]) # Duplicates Examples:
PydanticCustomError: Duplicate ids: '2cdf0456e9a711ee8000000000000000' Valid cases:
>>> validate_ids(None)
None
>>> validate_ids([])
[]
>>> validate_ids(["550e8400-e29b-41d4-a716-446655440000"])
["550e8400e29b41d4a716446655440000"]
Invalid cases:
>>> validate_ids(["invalid"])
# raises PydanticCustomError(invalid_UUID1_format)
>>> validate_ids(["550e...", "550e..."])
# raises PydanticCustomError(duplicate_uuids)
Security Notes:
- Validates UUID version to prevent version spoofing
- Duplicate check prevents data injection
- None handling maintains pipeline integrity
""" """
if not v: if v_list is None:
return v return None
uuid_hex_list = [ids.hex for ids in v] ids_list = []
duplicates = [item for item, count in Counter(uuid_hex_list).items() if count > 1] for v in v_list:
try:
ids_list.append(validate_uuid1_hex(v))
except PydanticCustomError as e:
raise e
duplicates = [item for item, count in Counter(ids_list).items() if count > 1]
if duplicates: if duplicates:
duplicates_str = ", ".join(duplicates) duplicates_str = ", ".join(duplicates)
raise PydanticCustomError("duplicate_uuids", "Duplicate ids: '{duplicate_ids}'", {"duplicate_ids": duplicates_str}) raise PydanticCustomError("duplicate_uuids", "Duplicate ids: '{duplicate_ids}'", {"duplicate_ids": duplicates_str})
return uuid_hex_list return ids_list
class DeleteDatasetReq(DeleteReq): ... class DeleteDatasetReq(DeleteReq): ...
class OrderByEnum(StrEnum):
create_time = auto()
update_time = auto()
class BaseListReq(Base):
id: str | None = None
name: str | None = None
page: int = Field(default=1, ge=1)
page_size: int = Field(default=30, ge=1)
orderby: OrderByEnum = Field(default=OrderByEnum.create_time)
desc: bool = Field(default=True)
@field_validator("id", mode="before")
@classmethod
def validate_id(cls, v: Any) -> str:
return validate_uuid1_hex(v)
@field_validator("orderby", mode="before")
@classmethod
def normalize_orderby(cls, v: Any) -> Any:
return normalize_str(v)
class ListDatasetReq(BaseListReq): ...

View File

@ -122,7 +122,7 @@ class TestDatasetCreate:
assert res["code"] == 0, res assert res["code"] == 0, res
res = create_dataset(get_http_api_auth, payload) res = create_dataset(get_http_api_auth, payload)
assert res["code"] == 102, res assert res["code"] == 103, res
assert res["message"] == f"Dataset name '{name}' already exists", res assert res["message"] == f"Dataset name '{name}' already exists", res
@pytest.mark.p3 @pytest.mark.p3
@ -134,7 +134,7 @@ class TestDatasetCreate:
payload = {"name": name.lower()} payload = {"name": name.lower()}
res = create_dataset(get_http_api_auth, payload) res = create_dataset(get_http_api_auth, payload)
assert res["code"] == 102, res assert res["code"] == 103, res
assert res["message"] == f"Dataset name '{name.lower()}' already exists", res assert res["message"] == f"Dataset name '{name.lower()}' already exists", res
@pytest.mark.p2 @pytest.mark.p2
@ -296,14 +296,15 @@ class TestDatasetCreate:
("team", "team"), ("team", "team"),
("me_upercase", "ME"), ("me_upercase", "ME"),
("team_upercase", "TEAM"), ("team_upercase", "TEAM"),
("whitespace", " ME "),
], ],
ids=["me", "team", "me_upercase", "team_upercase"], ids=["me", "team", "me_upercase", "team_upercase", "whitespace"],
) )
def test_permission(self, get_http_api_auth, name, permission): def test_permission(self, get_http_api_auth, name, permission):
payload = {"name": name, "permission": permission} payload = {"name": name, "permission": permission}
res = create_dataset(get_http_api_auth, payload) res = create_dataset(get_http_api_auth, payload)
assert res["code"] == 0, res assert res["code"] == 0, res
assert res["data"]["permission"] == permission.lower(), res assert res["data"]["permission"] == permission.lower().strip(), res
@pytest.mark.p2 @pytest.mark.p2
@pytest.mark.parametrize( @pytest.mark.parametrize(

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import uuid
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
import pytest import pytest
@ -40,8 +41,8 @@ class TestAuthorization:
) )
def test_auth_invalid(self, auth, expected_code, expected_message): def test_auth_invalid(self, auth, expected_code, expected_message):
res = delete_datasets(auth) res = delete_datasets(auth)
assert res["code"] == expected_code assert res["code"] == expected_code, res
assert res["message"] == expected_message assert res["message"] == expected_message, res
class TestRquest: class TestRquest:
@ -140,17 +141,25 @@ class TestDatasetsDelete:
payload = {"ids": ["not_uuid"]} payload = {"ids": ["not_uuid"]}
res = delete_datasets(get_http_api_auth, payload) res = delete_datasets(get_http_api_auth, payload)
assert res["code"] == 101, res assert res["code"] == 101, res
assert "Input should be a valid UUID" in res["message"], res assert "Invalid UUID1 format" in res["message"], res
res = list_datasets(get_http_api_auth) res = list_datasets(get_http_api_auth)
assert len(res["data"]) == 1, res assert len(res["data"]) == 1, res
@pytest.mark.p3
@pytest.mark.usefixtures("add_dataset_func")
def test_id_not_uuid1(self, get_http_api_auth):
payload = {"ids": [uuid.uuid4().hex]}
res = delete_datasets(get_http_api_auth, payload)
assert res["code"] == 101, res
assert "Invalid UUID1 format" in res["message"], res
@pytest.mark.p2 @pytest.mark.p2
@pytest.mark.usefixtures("add_dataset_func") @pytest.mark.usefixtures("add_dataset_func")
def test_id_wrong_uuid(self, get_http_api_auth): def test_id_wrong_uuid(self, get_http_api_auth):
payload = {"ids": ["d94a8dc02c9711f0930f7fbc369eab6d"]} payload = {"ids": ["d94a8dc02c9711f0930f7fbc369eab6d"]}
res = delete_datasets(get_http_api_auth, payload) res = delete_datasets(get_http_api_auth, payload)
assert res["code"] == 102, res assert res["code"] == 108, res
assert "lacks permission for dataset" in res["message"], res assert "lacks permission for dataset" in res["message"], res
res = list_datasets(get_http_api_auth) res = list_datasets(get_http_api_auth)
@ -170,7 +179,7 @@ class TestDatasetsDelete:
if callable(func): if callable(func):
payload = func(dataset_ids) payload = func(dataset_ids)
res = delete_datasets(get_http_api_auth, payload) res = delete_datasets(get_http_api_auth, payload)
assert res["code"] == 102, res assert res["code"] == 108, res
assert "lacks permission for dataset" in res["message"], res assert "lacks permission for dataset" in res["message"], res
res = list_datasets(get_http_api_auth) res = list_datasets(get_http_api_auth)
@ -195,7 +204,7 @@ class TestDatasetsDelete:
assert res["code"] == 0, res assert res["code"] == 0, res
res = delete_datasets(get_http_api_auth, payload) res = delete_datasets(get_http_api_auth, payload)
assert res["code"] == 102, res assert res["code"] == 108, res
assert "lacks permission for dataset" in res["message"], res assert "lacks permission for dataset" in res["message"], res
@pytest.mark.p2 @pytest.mark.p2

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import uuid
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
import pytest import pytest
@ -21,8 +22,8 @@ from libs.auth import RAGFlowHttpApiAuth
from libs.utils import is_sorted from libs.utils import is_sorted
@pytest.mark.p1
class TestAuthorization: class TestAuthorization:
@pytest.mark.p1
@pytest.mark.parametrize( @pytest.mark.parametrize(
"auth, expected_code, expected_message", "auth, expected_code, expected_message",
[ [
@ -34,269 +35,305 @@ class TestAuthorization:
), ),
], ],
) )
def test_invalid_auth(self, auth, expected_code, expected_message): def test_auth_invalid(self, auth, expected_code, expected_message):
res = list_datasets(auth) res = list_datasets(auth)
assert res["code"] == expected_code assert res["code"] == expected_code, res
assert res["message"] == expected_message assert res["message"] == expected_message, res
@pytest.mark.usefixtures("add_datasets") class TestCapability:
class TestDatasetsList:
@pytest.mark.p1
def test_default(self, get_http_api_auth):
res = list_datasets(get_http_api_auth, params={})
assert res["code"] == 0
assert len(res["data"]) == 5
@pytest.mark.p1
@pytest.mark.parametrize(
"params, expected_code, expected_page_size, expected_message",
[
({"page": None, "page_size": 2}, 0, 2, ""),
({"page": 0, "page_size": 2}, 0, 2, ""),
({"page": 2, "page_size": 2}, 0, 2, ""),
({"page": 3, "page_size": 2}, 0, 1, ""),
({"page": "3", "page_size": 2}, 0, 1, ""),
pytest.param(
{"page": -1, "page_size": 2},
100,
0,
"1064",
marks=pytest.mark.skip(reason="issues/5851"),
),
pytest.param(
{"page": "a", "page_size": 2},
100,
0,
"""ValueError("invalid literal for int() with base 10: \'a\'")""",
marks=pytest.mark.skip(reason="issues/5851"),
),
],
)
def test_page(self, get_http_api_auth, params, expected_code, expected_page_size, expected_message):
res = list_datasets(get_http_api_auth, params=params)
assert res["code"] == expected_code
if expected_code == 0:
assert len(res["data"]) == expected_page_size
else:
assert res["message"] == expected_message
@pytest.mark.p1
@pytest.mark.parametrize(
"params, expected_code, expected_page_size, expected_message",
[
({"page_size": None}, 0, 5, ""),
({"page_size": 0}, 0, 0, ""),
({"page_size": 1}, 0, 1, ""),
({"page_size": 6}, 0, 5, ""),
({"page_size": "1"}, 0, 1, ""),
pytest.param(
{"page_size": -1},
100,
0,
"1064",
marks=pytest.mark.skip(reason="issues/5851"),
),
pytest.param(
{"page_size": "a"},
100,
0,
"""ValueError("invalid literal for int() with base 10: \'a\'")""",
marks=pytest.mark.skip(reason="issues/5851"),
),
],
)
def test_page_size(
self,
get_http_api_auth,
params,
expected_code,
expected_page_size,
expected_message,
):
res = list_datasets(get_http_api_auth, params=params)
assert res["code"] == expected_code
if expected_code == 0:
assert len(res["data"]) == expected_page_size
else:
assert res["message"] == expected_message
@pytest.mark.p3
@pytest.mark.parametrize(
"params, expected_code, assertions, expected_message",
[
({"orderby": None}, 0, lambda r: (is_sorted(r["data"], "create_time", True)), ""),
({"orderby": "create_time"}, 0, lambda r: (is_sorted(r["data"], "create_time", True)), ""),
({"orderby": "update_time"}, 0, lambda r: (is_sorted(r["data"], "update_time", True)), ""),
pytest.param(
{"orderby": "name", "desc": "False"},
0,
lambda r: (is_sorted(r["data"]["docs"], "name", False)),
"",
marks=pytest.mark.skip(reason="issues/5851"),
),
pytest.param(
{"orderby": "unknown"},
102,
0,
"orderby should be create_time or update_time",
marks=pytest.mark.skip(reason="issues/5851"),
),
],
)
def test_orderby(
self,
get_http_api_auth,
params,
expected_code,
assertions,
expected_message,
):
res = list_datasets(get_http_api_auth, params=params)
assert res["code"] == expected_code
if expected_code == 0:
if callable(assertions):
assert assertions(res)
else:
assert res["message"] == expected_message
@pytest.mark.p3
@pytest.mark.parametrize(
"params, expected_code, assertions, expected_message",
[
({"desc": None}, 0, lambda r: (is_sorted(r["data"], "create_time", True)), ""),
({"desc": "true"}, 0, lambda r: (is_sorted(r["data"], "create_time", True)), ""),
({"desc": "True"}, 0, lambda r: (is_sorted(r["data"], "create_time", True)), ""),
({"desc": True}, 0, lambda r: (is_sorted(r["data"], "create_time", True)), ""),
({"desc": "false"}, 0, lambda r: (is_sorted(r["data"], "create_time", False)), ""),
({"desc": "False"}, 0, lambda r: (is_sorted(r["data"], "create_time", False)), ""),
({"desc": False}, 0, lambda r: (is_sorted(r["data"], "create_time", False)), ""),
({"desc": "False", "orderby": "update_time"}, 0, lambda r: (is_sorted(r["data"], "update_time", False)), ""),
pytest.param(
{"desc": "unknown"},
102,
0,
"desc should be true or false",
marks=pytest.mark.skip(reason="issues/5851"),
),
],
)
def test_desc(
self,
get_http_api_auth,
params,
expected_code,
assertions,
expected_message,
):
res = list_datasets(get_http_api_auth, params=params)
assert res["code"] == expected_code
if expected_code == 0:
if callable(assertions):
assert assertions(res)
else:
assert res["message"] == expected_message
@pytest.mark.p1
@pytest.mark.parametrize(
"params, expected_code, expected_num, expected_message",
[
({"name": None}, 0, 5, ""),
({"name": ""}, 0, 5, ""),
({"name": "dataset_1"}, 0, 1, ""),
({"name": "unknown"}, 102, 0, "You don't own the dataset unknown"),
],
)
def test_name(self, get_http_api_auth, params, expected_code, expected_num, expected_message):
res = list_datasets(get_http_api_auth, params=params)
assert res["code"] == expected_code
if expected_code == 0:
if params["name"] in [None, ""]:
assert len(res["data"]) == expected_num
else:
assert res["data"][0]["name"] == params["name"]
else:
assert res["message"] == expected_message
@pytest.mark.p1
@pytest.mark.parametrize(
"dataset_id, expected_code, expected_num, expected_message",
[
(None, 0, 5, ""),
("", 0, 5, ""),
(lambda r: r[0], 0, 1, ""),
("unknown", 102, 0, "You don't own the dataset unknown"),
],
)
def test_id(
self,
get_http_api_auth,
add_datasets,
dataset_id,
expected_code,
expected_num,
expected_message,
):
dataset_ids = add_datasets
if callable(dataset_id):
params = {"id": dataset_id(dataset_ids)}
else:
params = {"id": dataset_id}
res = list_datasets(get_http_api_auth, params=params)
assert res["code"] == expected_code
if expected_code == 0:
if params["id"] in [None, ""]:
assert len(res["data"]) == expected_num
else:
assert res["data"][0]["id"] == params["id"]
else:
assert res["message"] == expected_message
@pytest.mark.p3
@pytest.mark.parametrize(
"dataset_id, name, expected_code, expected_num, expected_message",
[
(lambda r: r[0], "dataset_0", 0, 1, ""),
(lambda r: r[0], "dataset_1", 0, 0, ""),
(lambda r: r[0], "unknown", 102, 0, "You don't own the dataset unknown"),
("id", "dataset_0", 102, 0, "You don't own the dataset id"),
],
)
def test_name_and_id(
self,
get_http_api_auth,
add_datasets,
dataset_id,
name,
expected_code,
expected_num,
expected_message,
):
dataset_ids = add_datasets
if callable(dataset_id):
params = {"id": dataset_id(dataset_ids), "name": name}
else:
params = {"id": dataset_id, "name": name}
res = list_datasets(get_http_api_auth, params=params)
assert res["code"] == expected_code
if expected_code == 0:
assert len(res["data"]) == expected_num
else:
assert res["message"] == expected_message
@pytest.mark.p3 @pytest.mark.p3
def test_concurrent_list(self, get_http_api_auth): def test_concurrent_list(self, get_http_api_auth):
with ThreadPoolExecutor(max_workers=5) as executor: with ThreadPoolExecutor(max_workers=5) as executor:
futures = [executor.submit(list_datasets, get_http_api_auth) for i in range(100)] futures = [executor.submit(list_datasets, get_http_api_auth) for i in range(100)]
responses = [f.result() for f in futures] responses = [f.result() for f in futures]
assert all(r["code"] == 0 for r in responses) assert all(r["code"] == 0 for r in responses), responses
@pytest.mark.usefixtures("add_datasets")
class TestDatasetsList:
@pytest.mark.p1
def test_params_unset(self, get_http_api_auth):
res = list_datasets(get_http_api_auth, None)
assert res["code"] == 0, res
assert len(res["data"]) == 5, res
@pytest.mark.p2
def test_params_empty(self, get_http_api_auth):
res = list_datasets(get_http_api_auth, {})
assert res["code"] == 0, res
assert len(res["data"]) == 5, res
@pytest.mark.p1
@pytest.mark.parametrize(
"params, expected_page_size",
[
({"page": 2, "page_size": 2}, 2),
({"page": 3, "page_size": 2}, 1),
({"page": 4, "page_size": 2}, 0),
({"page": "2", "page_size": 2}, 2),
({"page": 1, "page_size": 10}, 5),
],
ids=["normal_middle_page", "normal_last_partial_page", "beyond_max_page", "string_page_number", "full_data_single_page"],
)
def test_page(self, get_http_api_auth, params, expected_page_size):
res = list_datasets(get_http_api_auth, params)
assert res["code"] == 0, res
assert len(res["data"]) == expected_page_size, res
@pytest.mark.p2
@pytest.mark.parametrize(
"params, expected_code, expected_message",
[
({"page": 0}, 101, "Input should be greater than or equal to 1"),
({"page": "a"}, 101, "Input should be a valid integer, unable to parse string as an integer"),
],
ids=["page_0", "page_a"],
)
def test_page_invalid(self, get_http_api_auth, params, expected_code, expected_message):
res = list_datasets(get_http_api_auth, params=params)
assert res["code"] == expected_code, res
assert expected_message in res["message"], res
@pytest.mark.p2
def test_page_none(self, get_http_api_auth):
params = {"page": None}
res = list_datasets(get_http_api_auth, params)
assert res["code"] == 0, res
assert len(res["data"]) == 5, res
@pytest.mark.p1
@pytest.mark.parametrize(
"params, expected_page_size",
[
({"page_size": 1}, 1),
({"page_size": 3}, 3),
({"page_size": 5}, 5),
({"page_size": 6}, 5),
({"page_size": "1"}, 1),
],
ids=["min_valid_page_size", "medium_page_size", "page_size_equals_total", "page_size_exceeds_total", "string_type_page_size"],
)
def test_page_size(self, get_http_api_auth, params, expected_page_size):
res = list_datasets(get_http_api_auth, params)
assert res["code"] == 0, res
assert len(res["data"]) == expected_page_size, res
@pytest.mark.p2
@pytest.mark.parametrize(
"params, expected_code, expected_message",
[
({"page_size": 0}, 101, "Input should be greater than or equal to 1"),
({"page_size": "a"}, 101, "Input should be a valid integer, unable to parse string as an integer"),
],
)
def test_page_size_invalid(self, get_http_api_auth, params, expected_code, expected_message):
res = list_datasets(get_http_api_auth, params)
assert res["code"] == expected_code, res
assert expected_message in res["message"], res
@pytest.mark.p2
def test_page_size_none(self, get_http_api_auth):
params = {"page_size": None}
res = list_datasets(get_http_api_auth, params)
assert res["code"] == 0, res
assert len(res["data"]) == 5, res
@pytest.mark.p2
@pytest.mark.parametrize(
"params, assertions",
[
({"orderby": "create_time"}, lambda r: (is_sorted(r["data"], "create_time", True))),
({"orderby": "update_time"}, lambda r: (is_sorted(r["data"], "update_time", True))),
({"orderby": "CREATE_TIME"}, lambda r: (is_sorted(r["data"], "create_time", True))),
({"orderby": "UPDATE_TIME"}, lambda r: (is_sorted(r["data"], "update_time", True))),
({"orderby": " create_time "}, lambda r: (is_sorted(r["data"], "update_time", True))),
],
ids=["orderby_create_time", "orderby_update_time", "orderby_create_time_upper", "orderby_update_time_upper", "whitespace"],
)
def test_orderby(self, get_http_api_auth, params, assertions):
res = list_datasets(get_http_api_auth, params)
assert res["code"] == 0, res
if callable(assertions):
assert assertions(res), res
@pytest.mark.p3 @pytest.mark.p3
def test_invalid_params(self, get_http_api_auth): @pytest.mark.parametrize(
params = {"a": "b"} "params",
res = list_datasets(get_http_api_auth, params=params) [
{"orderby": ""},
{"orderby": "unknown"},
],
ids=["empty", "unknown"],
)
def test_orderby_invalid(self, get_http_api_auth, params):
res = list_datasets(get_http_api_auth, params)
assert res["code"] == 101, res
assert "Input should be 'create_time' or 'update_time'" in res["message"], res
@pytest.mark.p3
def test_orderby_none(self, get_http_api_auth):
params = {"order_by": None}
res = list_datasets(get_http_api_auth, params)
assert res["code"] == 0, res
assert is_sorted(res["data"], "create_time", True), res
@pytest.mark.p2
@pytest.mark.parametrize(
"params, assertions",
[
({"desc": True}, lambda r: (is_sorted(r["data"], "create_time", True))),
({"desc": False}, lambda r: (is_sorted(r["data"], "create_time", False))),
({"desc": "true"}, lambda r: (is_sorted(r["data"], "create_time", True))),
({"desc": "false"}, lambda r: (is_sorted(r["data"], "create_time", False))),
({"desc": 1}, lambda r: (is_sorted(r["data"], "create_time", True))),
({"desc": 0}, lambda r: (is_sorted(r["data"], "create_time", False))),
({"desc": "yes"}, lambda r: (is_sorted(r["data"], "create_time", True))),
({"desc": "no"}, lambda r: (is_sorted(r["data"], "create_time", False))),
({"desc": "y"}, lambda r: (is_sorted(r["data"], "create_time", True))),
({"desc": "n"}, lambda r: (is_sorted(r["data"], "create_time", False))),
],
ids=["desc=True", "desc=False", "desc=true", "desc=false", "desc=1", "desc=0", "desc=yes", "desc=no", "desc=y", "desc=n"],
)
def test_desc(self, get_http_api_auth, params, assertions):
res = list_datasets(get_http_api_auth, params)
assert res["code"] == 0, res
if callable(assertions):
assert assertions(res), res
@pytest.mark.p3
@pytest.mark.parametrize(
"params",
[
{"desc": 3.14},
{"desc": "unknown"},
],
ids=["empty", "unknown"],
)
def test_desc_invalid(self, get_http_api_auth, params):
res = list_datasets(get_http_api_auth, params)
assert res["code"] == 101, res
assert "Input should be a valid boolean, unable to interpret input" in res["message"], res
@pytest.mark.p3
def test_desc_none(self, get_http_api_auth):
params = {"desc": None}
res = list_datasets(get_http_api_auth, params)
assert res["code"] == 0, res
assert is_sorted(res["data"], "create_time", True), res
@pytest.mark.p1
def test_name(self, get_http_api_auth):
params = {"name": "dataset_1"}
res = list_datasets(get_http_api_auth, params)
assert res["code"] == 0, res
assert len(res["data"]) == 1, res
assert res["data"][0]["name"] == "dataset_1", res
@pytest.mark.p2
def test_name_wrong(self, get_http_api_auth):
params = {"name": "wrong name"}
res = list_datasets(get_http_api_auth, params)
assert res["code"] == 108, res
assert "lacks permission for dataset" in res["message"], res
@pytest.mark.p2
def test_name_empty(self, get_http_api_auth):
params = {"name": ""}
res = list_datasets(get_http_api_auth, params)
assert res["code"] == 0, res
assert len(res["data"]) == 5, res
@pytest.mark.p2
def test_name_none(self, get_http_api_auth):
params = {"name": None}
res = list_datasets(get_http_api_auth, params)
assert res["code"] == 0, res
assert len(res["data"]) == 5, res
@pytest.mark.p1
def test_id(self, get_http_api_auth, add_datasets):
dataset_ids = add_datasets
params = {"id": dataset_ids[0]}
res = list_datasets(get_http_api_auth, params)
assert res["code"] == 0 assert res["code"] == 0
assert len(res["data"]) == 5 assert len(res["data"]) == 1
assert res["data"][0]["id"] == dataset_ids[0]
@pytest.mark.p2
def test_id_not_uuid(self, get_http_api_auth):
params = {"id": "not_uuid"}
res = list_datasets(get_http_api_auth, params)
assert res["code"] == 101, res
assert "Invalid UUID1 format" in res["message"], res
@pytest.mark.p2
def test_id_not_uuid1(self, get_http_api_auth):
params = {"id": uuid.uuid4().hex}
res = list_datasets(get_http_api_auth, params)
assert res["code"] == 101, res
assert "Invalid UUID1 format" in res["message"], res
@pytest.mark.p2
def test_id_wrong_uuid(self, get_http_api_auth):
params = {"id": "d94a8dc02c9711f0930f7fbc369eab6d"}
res = list_datasets(get_http_api_auth, params)
assert res["code"] == 108, res
assert "lacks permission for dataset" in res["message"], res
@pytest.mark.p2
def test_id_empty(self, get_http_api_auth):
params = {"id": ""}
res = list_datasets(get_http_api_auth, params)
assert res["code"] == 101, res
assert "Invalid UUID1 format" in res["message"], res
@pytest.mark.p2
def test_id_none(self, get_http_api_auth):
params = {"id": None}
res = list_datasets(get_http_api_auth, params)
assert res["code"] == 0, res
assert len(res["data"]) == 5, res
@pytest.mark.p2
@pytest.mark.parametrize(
"func, name, expected_num",
[
(lambda r: r[0], "dataset_0", 1),
(lambda r: r[0], "dataset_1", 0),
],
ids=["name_and_id_match", "name_and_id_mismatch"],
)
def test_name_and_id(self, get_http_api_auth, add_datasets, func, name, expected_num):
dataset_ids = add_datasets
if callable(func):
params = {"id": func(dataset_ids), "name": name}
res = list_datasets(get_http_api_auth, params)
assert res["code"] == 0, res
assert len(res["data"]) == expected_num, res
@pytest.mark.p3
@pytest.mark.parametrize(
"dataset_id, name",
[
(lambda r: r[0], "wrong_name"),
(uuid.uuid1().hex, "dataset_0"),
],
ids=["name", "id"],
)
def test_name_and_id_wrong(self, get_http_api_auth, add_datasets, dataset_id, name):
dataset_ids = add_datasets
if callable(dataset_id):
params = {"id": dataset_id(dataset_ids), "name": name}
else:
params = {"id": dataset_id, "name": name}
res = list_datasets(get_http_api_auth, params)
assert res["code"] == 108, res
assert "lacks permission for dataset" in res["message"], res
@pytest.mark.p2
def test_field_unsupported(self, get_http_api_auth):
params = {"unknown_field": "unknown_field"}
res = list_datasets(get_http_api_auth, params)
assert res["code"] == 101, res
assert "Extra inputs are not permitted" in res["message"], res

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import uuid
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
import pytest import pytest
@ -98,16 +99,23 @@ class TestCapability:
class TestDatasetUpdate: class TestDatasetUpdate:
@pytest.mark.p3 @pytest.mark.p3
def test_dataset_id_not_uuid(self, get_http_api_auth): def test_dataset_id_not_uuid(self, get_http_api_auth):
payload = {"name": "not_uuid"} payload = {"name": "not uuid"}
res = update_dataset(get_http_api_auth, "not_uuid", payload) res = update_dataset(get_http_api_auth, "not_uuid", payload)
assert res["code"] == 101, res assert res["code"] == 101, res
assert "Input should be a valid UUID" in res["message"], res assert "Invalid UUID1 format" in res["message"], res
@pytest.mark.p3
def test_dataset_id_not_uuid1(self, get_http_api_auth):
payload = {"name": "not uuid1"}
res = update_dataset(get_http_api_auth, uuid.uuid4().hex, payload)
assert res["code"] == 101, res
assert "Invalid UUID1 format" in res["message"], res
@pytest.mark.p3 @pytest.mark.p3
def test_dataset_id_wrong_uuid(self, get_http_api_auth): def test_dataset_id_wrong_uuid(self, get_http_api_auth):
payload = {"name": "wrong_uuid"} payload = {"name": "wrong uuid"}
res = update_dataset(get_http_api_auth, "d94a8dc02c9711f0930f7fbc369eab6d", payload) res = update_dataset(get_http_api_auth, "d94a8dc02c9711f0930f7fbc369eab6d", payload)
assert res["code"] == 102, res assert res["code"] == 108, res
assert "lacks permission for dataset" in res["message"], res assert "lacks permission for dataset" in res["message"], res
@pytest.mark.p1 @pytest.mark.p1
@ -322,8 +330,9 @@ class TestDatasetUpdate:
"team", "team",
"ME", "ME",
"TEAM", "TEAM",
" ME ",
], ],
ids=["me", "team", "me_upercase", "team_upercase"], ids=["me", "team", "me_upercase", "team_upercase", "whitespace"],
) )
def test_permission(self, get_http_api_auth, add_dataset_func, permission): def test_permission(self, get_http_api_auth, add_dataset_func, permission):
dataset_id = add_dataset_func dataset_id = add_dataset_func
@ -333,7 +342,7 @@ class TestDatasetUpdate:
res = list_datasets(get_http_api_auth) res = list_datasets(get_http_api_auth)
assert res["code"] == 0, res assert res["code"] == 0, res
assert res["data"][0]["permission"] == permission.lower(), res assert res["data"][0]["permission"] == permission.lower().strip(), res
@pytest.mark.p2 @pytest.mark.p2
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -734,7 +743,6 @@ class TestDatasetUpdate:
assert res["code"] == 0, res assert res["code"] == 0, res
res = list_datasets(get_http_api_auth) res = list_datasets(get_http_api_auth)
print(res)
assert res["code"] == 0, res assert res["code"] == 0, res
assert res["data"][0]["parser_config"] == {"raptor": {"use_raptor": False}}, res assert res["data"][0]["parser_config"] == {"raptor": {"use_raptor": False}}, res
@ -757,7 +765,6 @@ class TestDatasetUpdate:
assert res["code"] == 0, res assert res["code"] == 0, res
res = list_datasets(get_http_api_auth, {"id": dataset_id}) res = list_datasets(get_http_api_auth, {"id": dataset_id})
print(res)
assert res["code"] == 0, res assert res["code"] == 0, res
assert res["data"][0]["parser_config"] == {"raptor": {"use_raptor": False}}, res assert res["data"][0]["parser_config"] == {"raptor": {"use_raptor": False}}, res