Refa: HTTP API update dataset / test cases / docs (#7564)

### What problem does this PR solve?

This PR introduces Pydantic-based validation for the update dataset HTTP
API, improving code clarity and robustness. Key changes include:
1. Pydantic Validation
2. ​​Error Handling
3. Test Updates
4. Documentation Updates
5. fix bug: #5915

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
- [x] Documentation Update
- [x] Refactoring
This commit is contained in:
liu an 2025-05-09 19:17:08 +08:00 committed by GitHub
parent 31718581b5
commit 35e36cb945
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 1283 additions and 552 deletions

View File

@ -27,22 +27,19 @@ from api.db.services.document_service import DocumentService
from api.db.services.file2document_service import File2DocumentService from api.db.services.file2document_service import File2DocumentService
from api.db.services.file_service import FileService from api.db.services.file_service import FileService
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMService, TenantLLMService
from api.db.services.user_service import TenantService from api.db.services.user_service import TenantService
from api.utils import get_uuid from api.utils import get_uuid
from api.utils.api_utils import ( from api.utils.api_utils import (
check_duplicate_ids, check_duplicate_ids,
dataset_readonly_fields, deep_merge,
get_error_argument_result, get_error_argument_result,
get_error_data_result, get_error_data_result,
get_parser_config, get_parser_config,
get_result, get_result,
token_required, token_required,
valid,
valid_parser_config,
verify_embedding_availability, verify_embedding_availability,
) )
from api.utils.validation_utils import CreateDatasetReq, validate_and_parse_json_request from api.utils.validation_utils import CreateDatasetReq, UpdateDatasetReq, validate_and_parse_json_request
@manager.route("/datasets", methods=["POST"]) # noqa: F821 @manager.route("/datasets", methods=["POST"]) # noqa: F821
@ -117,8 +114,8 @@ def create(tenant_id):
return get_error_argument_result(err) return get_error_argument_result(err)
try: try:
if KnowledgebaseService.query(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value): if KnowledgebaseService.get_or_none(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value):
return get_error_argument_result(message=f"Dataset name '{req['name']}' already exists") return get_error_data_result(message=f"Dataset name '{req['name']}' already exists")
except OperationalError as e: except OperationalError as e:
logging.exception(e) logging.exception(e)
return get_error_data_result(message="Database operation failed") return get_error_data_result(message="Database operation failed")
@ -145,7 +142,7 @@ def create(tenant_id):
try: try:
if not KnowledgebaseService.save(**req): if not KnowledgebaseService.save(**req):
return get_error_data_result(message="Database operation failed") return get_error_data_result(message="Create dataset error.(Database error)")
except OperationalError as e: except OperationalError as e:
logging.exception(e) logging.exception(e)
return get_error_data_result(message="Database operation failed") return get_error_data_result(message="Database operation failed")
@ -205,6 +202,7 @@ def delete(tenant_id):
schema: schema:
type: object type: object
""" """
errors = [] errors = []
success_count = 0 success_count = 0
req = request.json req = request.json
@ -291,16 +289,28 @@ def update(tenant_id, dataset_id):
name: name:
type: string type: string
description: New name of the dataset. description: New name of the dataset.
avatar:
type: string
description: Updated base64 encoding of the avatar.
description:
type: string
description: Updated description of the dataset.
embedding_model:
type: string
description: Updated embedding model Name.
permission: permission:
type: string type: string
enum: ['me', 'team'] enum: ['me', 'team']
description: Updated permission. description: Updated dataset permission.
chunk_method: chunk_method:
type: string type: string
enum: ["naive", "manual", "qa", "table", "paper", "book", "laws", enum: ["naive", "book", "email", "laws", "manual", "one", "paper",
"presentation", "picture", "one", "email", "tag" "picture", "presentation", "qa", "table", "tag"
] ]
description: Updated chunking method. description: Updated chunking method.
pagerank:
type: integer
description: Updated page rank.
parser_config: parser_config:
type: object type: object
description: Updated parser configuration. description: Updated parser configuration.
@ -310,98 +320,56 @@ def update(tenant_id, dataset_id):
schema: schema:
type: object type: object
""" """
if not KnowledgebaseService.query(id=dataset_id, tenant_id=tenant_id): # Field name transformations during model dump:
return get_error_data_result(message="You don't own the dataset") # | Original | Dump Output |
req = request.json # |----------------|-------------|
for k in req.keys(): # | embedding_model| embd_id |
if dataset_readonly_fields(k): # | chunk_method | parser_id |
return get_result(code=settings.RetCode.ARGUMENT_ERROR, message=f"'{k}' is readonly.") extras = {"dataset_id": dataset_id}
e, t = TenantService.get_by_id(tenant_id) req, err = validate_and_parse_json_request(request, UpdateDatasetReq, extras=extras, exclude_unset=True)
invalid_keys = {"id", "embd_id", "chunk_num", "doc_num", "parser_id", "create_date", "create_time", "created_by", "status", "token_num", "update_date", "update_time"} if err is not None:
if any(key in req for key in invalid_keys): return get_error_argument_result(err)
return get_error_data_result(message="The input parameters are invalid.")
permission = req.get("permission") if not req:
chunk_method = req.get("chunk_method") return get_error_argument_result(message="No properties were modified")
parser_config = req.get("parser_config")
valid_parser_config(parser_config) try:
valid_permission = ["me", "team"] kb = KnowledgebaseService.get_or_none(id=dataset_id, tenant_id=tenant_id)
valid_chunk_method = ["naive", "manual", "qa", "table", "paper", "book", "laws", "presentation", "picture", "one", "email", "tag"] if kb is None:
check_validation = valid( return get_error_data_result(message=f"User '{tenant_id}' lacks permission for dataset '{dataset_id}'")
permission, except OperationalError as e:
valid_permission, logging.exception(e)
chunk_method, return get_error_data_result(message="Database operation failed")
valid_chunk_method,
) if req.get("parser_config"):
if check_validation: req["parser_config"] = deep_merge(kb.parser_config, req["parser_config"])
return check_validation
if "tenant_id" in req: if (chunk_method := req.get("parser_id")) and chunk_method != kb.parser_id and req.get("parser_config") is None:
if req["tenant_id"] != tenant_id: req["parser_config"] = get_parser_config(chunk_method, None)
return get_error_data_result(message="Can't change `tenant_id`.")
e, kb = KnowledgebaseService.get_by_id(dataset_id) if "name" in req and req["name"].lower() != kb.name.lower():
if "parser_config" in req: try:
temp_dict = kb.parser_config exists = KnowledgebaseService.get_or_none(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value)
temp_dict.update(req["parser_config"]) if exists:
req["parser_config"] = temp_dict return get_error_data_result(message=f"Dataset name '{req['name']}' already exists")
if "chunk_count" in req: except OperationalError as e:
if req["chunk_count"] != kb.chunk_num: logging.exception(e)
return get_error_data_result(message="Can't change `chunk_count`.") return get_error_data_result(message="Database operation failed")
req.pop("chunk_count")
if "document_count" in req: if "embd_id" in req:
if req["document_count"] != kb.doc_num: if kb.chunk_num != 0 and req["embd_id"] != kb.embd_id:
return get_error_data_result(message="Can't change `document_count`.") return get_error_data_result(message=f"When chunk_num ({kb.chunk_num}) > 0, embedding_model must remain {kb.embd_id}")
req.pop("document_count") ok, err = verify_embedding_availability(req["embd_id"], tenant_id)
if req.get("chunk_method"): if not ok:
if kb.chunk_num != 0 and req["chunk_method"] != kb.parser_id: return err
return get_error_data_result(message="If `chunk_count` is not 0, `chunk_method` is not changeable.")
req["parser_id"] = req.pop("chunk_method") try:
if req["parser_id"] != kb.parser_id: if not KnowledgebaseService.update_by_id(kb.id, req):
if not req.get("parser_config"): return get_error_data_result(message="Update dataset error.(Database error)")
req["parser_config"] = get_parser_config(chunk_method, parser_config) except OperationalError as e:
if "embedding_model" in req: logging.exception(e)
if kb.chunk_num != 0 and req["embedding_model"] != kb.embd_id: return get_error_data_result(message="Database operation failed")
return get_error_data_result(message="If `chunk_count` is not 0, `embedding_model` is not changeable.")
if not req.get("embedding_model"):
return get_error_data_result("`embedding_model` can't be empty")
valid_embedding_models = [
"BAAI/bge-large-zh-v1.5",
"BAAI/bge-base-en-v1.5",
"BAAI/bge-large-en-v1.5",
"BAAI/bge-small-en-v1.5",
"BAAI/bge-small-zh-v1.5",
"jinaai/jina-embeddings-v2-base-en",
"jinaai/jina-embeddings-v2-small-en",
"nomic-ai/nomic-embed-text-v1.5",
"sentence-transformers/all-MiniLM-L6-v2",
"text-embedding-v2",
"text-embedding-v3",
"maidalun1020/bce-embedding-base_v1",
]
embd_model = LLMService.query(llm_name=req["embedding_model"], model_type="embedding")
if embd_model:
if req["embedding_model"] not in valid_embedding_models and not TenantLLMService.query(
tenant_id=tenant_id,
model_type="embedding",
llm_name=req.get("embedding_model"),
):
return get_error_data_result(f"`embedding_model` {req.get('embedding_model')} doesn't exist")
if not embd_model:
embd_model = TenantLLMService.query(tenant_id=tenant_id, model_type="embedding", llm_name=req.get("embedding_model"))
if not embd_model:
return get_error_data_result(f"`embedding_model` {req.get('embedding_model')} doesn't exist")
req["embd_id"] = req.pop("embedding_model")
if "name" in req:
req["name"] = req["name"].strip()
if len(req["name"]) >= 128:
return get_error_data_result(message="Dataset name should not be longer than 128 characters.")
if req["name"].lower() != kb.name.lower() and len(KnowledgebaseService.query(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value)) > 0:
return get_error_data_result(message="Duplicated dataset name in updating dataset.")
flds = list(req.keys())
for f in flds:
if req[f] == "" and f in ["permission", "parser_id", "chunk_method"]:
del req[f]
if not KnowledgebaseService.update_by_id(kb.id, req):
return get_error_data_result(message="Update dataset error.(Database error)")
return get_result(code=settings.RetCode.SUCCESS) return get_result(code=settings.RetCode.SUCCESS)

View File

@ -19,6 +19,7 @@ import logging
import random import random
import time import time
from base64 import b64encode from base64 import b64encode
from copy import deepcopy
from functools import wraps from functools import wraps
from hmac import HMAC from hmac import HMAC
from io import BytesIO from io import BytesIO
@ -333,22 +334,6 @@ def generate_confirmation_token(tenant_id):
return "ragflow-" + serializer.dumps(get_uuid(), salt=tenant_id)[2:34] return "ragflow-" + serializer.dumps(get_uuid(), salt=tenant_id)[2:34]
def valid(permission, valid_permission, chunk_method, valid_chunk_method):
if valid_parameter(permission, valid_permission):
return valid_parameter(permission, valid_permission)
if valid_parameter(chunk_method, valid_chunk_method):
return valid_parameter(chunk_method, valid_chunk_method)
def valid_parameter(parameter, valid_values):
if parameter and parameter not in valid_values:
return get_error_data_result(f"'{parameter}' is not in {valid_values}")
def dataset_readonly_fields(field_name):
return field_name in ["chunk_count", "create_date", "create_time", "update_date", "update_time", "created_by", "document_count", "token_num", "status", "tenant_id", "id"]
def get_parser_config(chunk_method, parser_config): def get_parser_config(chunk_method, parser_config):
if parser_config: if parser_config:
return parser_config return parser_config
@ -402,43 +387,6 @@ def get_data_openai(
} }
def valid_parser_config(parser_config):
if not parser_config:
return
scopes = set(
[
"chunk_token_num",
"delimiter",
"raptor",
"graphrag",
"layout_recognize",
"task_page_size",
"pages",
"html4excel",
"auto_keywords",
"auto_questions",
"tag_kb_ids",
"topn_tags",
"filename_embd_weight",
]
)
for k in parser_config.keys():
assert k in scopes, f"Abnormal 'parser_config'. Invalid key: {k}"
assert isinstance(parser_config.get("chunk_token_num", 1), int), "chunk_token_num should be int"
assert 1 <= parser_config.get("chunk_token_num", 1) < 100000000, "chunk_token_num should be in range from 1 to 100000000"
assert isinstance(parser_config.get("task_page_size", 1), int), "task_page_size should be int"
assert 1 <= parser_config.get("task_page_size", 1) < 100000000, "task_page_size should be in range from 1 to 100000000"
assert isinstance(parser_config.get("auto_keywords", 1), int), "auto_keywords should be int"
assert 0 <= parser_config.get("auto_keywords", 0) < 32, "auto_keywords should be in range from 0 to 32"
assert isinstance(parser_config.get("auto_questions", 1), int), "auto_questions should be int"
assert 0 <= parser_config.get("auto_questions", 0) < 10, "auto_questions should be in range from 0 to 10"
assert isinstance(parser_config.get("topn_tags", 1), int), "topn_tags should be int"
assert 0 <= parser_config.get("topn_tags", 0) < 10, "topn_tags should be in range from 0 to 10"
assert isinstance(parser_config.get("html4excel", False), bool), "html4excel should be True or False"
assert isinstance(parser_config.get("delimiter", ""), str), "delimiter should be str"
def check_duplicate_ids(ids, id_type="item"): def check_duplicate_ids(ids, id_type="item"):
""" """
Check for duplicate IDs in a list and return unique IDs and error messages. Check for duplicate IDs in a list and return unique IDs and error messages.
@ -469,7 +417,8 @@ def check_duplicate_ids(ids, id_type="item"):
def verify_embedding_availability(embd_id: str, tenant_id: str) -> tuple[bool, Response | None]: def verify_embedding_availability(embd_id: str, tenant_id: str) -> tuple[bool, Response | None]:
"""Verifies availability of an embedding model for a specific tenant. """
Verifies availability of an embedding model for a specific tenant.
Implements a four-stage validation process: Implements a four-stage validation process:
1. Model identifier parsing and validation 1. Model identifier parsing and validation
@ -518,3 +467,50 @@ def verify_embedding_availability(embd_id: str, tenant_id: str) -> tuple[bool, R
return False, get_error_data_result(message="Database operation failed") return False, get_error_data_result(message="Database operation failed")
return True, None return True, None
def deep_merge(default: dict, custom: dict) -> dict:
"""
Recursively merges two dictionaries with priority given to `custom` values.
Creates a deep copy of the `default` dictionary and iteratively merges nested
dictionaries using a stack-based approach. Non-dict values in `custom` will
completely override corresponding entries in `default`.
Args:
default (dict): Base dictionary containing default values.
custom (dict): Dictionary containing overriding values.
Returns:
dict: New merged dictionary combining values from both inputs.
Example:
>>> from copy import deepcopy
>>> default = {"a": 1, "nested": {"x": 10, "y": 20}}
>>> custom = {"b": 2, "nested": {"y": 99, "z": 30}}
>>> deep_merge(default, custom)
{'a': 1, 'b': 2, 'nested': {'x': 10, 'y': 99, 'z': 30}}
>>> deep_merge({"config": {"mode": "auto"}}, {"config": "manual"})
{'config': 'manual'}
Notes:
1. Merge priority is always given to `custom` values at all nesting levels
2. Non-dict values (e.g. list, str) in `custom` will replace entire values
in `default`, even if the original value was a dictionary
3. Time complexity: O(N) where N is total key-value pairs in `custom`
4. Recommended for configuration merging and nested data updates
"""
merged = deepcopy(default)
stack = [(merged, custom)]
while stack:
base_dict, override_dict = stack.pop()
for key, val in override_dict.items():
if key in base_dict and isinstance(val, dict) and isinstance(base_dict[key], dict):
stack.append((base_dict[key], val))
else:
base_dict[key] = val
return merged

View File

@ -13,21 +13,23 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import uuid
from enum import auto from enum import auto
from typing import Annotated, Any from typing import Annotated, Any
from flask import Request from flask import Request
from pydantic import BaseModel, Field, StringConstraints, ValidationError, field_validator from pydantic import UUID1, BaseModel, Field, StringConstraints, ValidationError, field_serializer, field_validator
from strenum import StrEnum from strenum import StrEnum
from werkzeug.exceptions import BadRequest, UnsupportedMediaType from werkzeug.exceptions import BadRequest, UnsupportedMediaType
from api.constants import DATASET_NAME_LIMIT from api.constants import DATASET_NAME_LIMIT
def validate_and_parse_json_request(request: Request, validator: type[BaseModel]) -> tuple[dict[str, Any] | None, str | None]: def validate_and_parse_json_request(request: Request, validator: type[BaseModel], *, extras: dict[str, Any] | None = None, exclude_unset: bool = False) -> tuple[dict[str, Any] | None, str | None]:
"""Validates and parses JSON requests through a multi-stage validation pipeline. """
Validates and parses JSON requests through a multi-stage validation pipeline.
Implements a robust four-stage validation process: Implements a four-stage validation process:
1. Content-Type verification (must be application/json) 1. Content-Type verification (must be application/json)
2. JSON syntax validation 2. JSON syntax validation
3. Payload structure type checking 3. Payload structure type checking
@ -35,6 +37,10 @@ def validate_and_parse_json_request(request: Request, validator: type[BaseModel]
Args: Args:
request (Request): Flask request object containing HTTP payload request (Request): Flask request object containing HTTP payload
validator (type[BaseModel]): Pydantic model class for data validation
extras (dict[str, Any] | None): Additional fields to merge into payload
before validation. These fields will be removed from the final output
exclude_unset (bool): Whether to exclude fields that have not been explicitly set
Returns: Returns:
tuple[Dict[str, Any] | None, str | None]: tuple[Dict[str, Any] | None, str | None]:
@ -46,26 +52,26 @@ def validate_and_parse_json_request(request: Request, validator: type[BaseModel]
- Diagnostic error message on failure - Diagnostic error message on failure
Raises: Raises:
UnsupportedMediaType: When Content-Type application/json UnsupportedMediaType: When Content-Type header is not application/json
BadRequest: For structural JSON syntax errors BadRequest: For structural JSON syntax errors
ValidationError: When payload violates Pydantic schema rules ValidationError: When payload violates Pydantic schema rules
Examples: Examples:
Successful validation: >>> validate_and_parse_json_request(valid_request, DatasetSchema)
```python ({"name": "Dataset1", "format": "csv"}, None)
# Input: {"name": "Dataset1", "format": "csv"}
# Returns: ({"name": "Dataset1", "format": "csv"}, None)
```
Invalid Content-Type: >>> validate_and_parse_json_request(xml_request, DatasetSchema)
```python (None, "Unsupported content type: Expected application/json, got text/xml")
# Returns: (None, "Unsupported content type: Expected application/json, got text/xml")
```
Malformed JSON: >>> validate_and_parse_json_request(bad_json_request, DatasetSchema)
```python (None, "Malformed JSON syntax: Missing commas/brackets or invalid encoding")
# Returns: (None, "Malformed JSON syntax: Missing commas/brackets or invalid encoding")
``` Notes:
1. Validation Priority:
- Content-Type verification precedes JSON parsing
- Structural validation occurs before schema validation
2. Extra fields added via `extras` parameter are automatically removed
from the final output after validation
""" """
try: try:
payload = request.get_json() or {} payload = request.get_json() or {}
@ -78,17 +84,25 @@ def validate_and_parse_json_request(request: Request, validator: type[BaseModel]
return None, f"Invalid request payload: expected object, got {type(payload).__name__}" return None, f"Invalid request payload: expected object, got {type(payload).__name__}"
try: try:
if extras is not None:
payload.update(extras)
validated_request = validator(**payload) validated_request = validator(**payload)
except ValidationError as e: except ValidationError as e:
return None, format_validation_error_message(e) return None, format_validation_error_message(e)
parsed_payload = validated_request.model_dump(by_alias=True) parsed_payload = validated_request.model_dump(by_alias=True, exclude_unset=exclude_unset)
if extras is not None:
for key in list(parsed_payload.keys()):
if key in extras:
del parsed_payload[key]
return parsed_payload, None return parsed_payload, None
def format_validation_error_message(e: ValidationError) -> str: def format_validation_error_message(e: ValidationError) -> str:
"""Formats validation errors into a standardized string format. """
Formats validation errors into a standardized string format.
Processes pydantic ValidationError objects to create human-readable error messages Processes pydantic ValidationError objects to create human-readable error messages
containing field locations, error descriptions, and input values. containing field locations, error descriptions, and input values.
@ -155,7 +169,6 @@ class GraphragMethodEnum(StrEnum):
class Base(BaseModel): class Base(BaseModel):
class Config: class Config:
extra = "forbid" extra = "forbid"
json_schema_extra = {"charset": "utf8mb4", "collation": "utf8mb4_0900_ai_ci"}
class RaptorConfig(Base): class RaptorConfig(Base):
@ -201,16 +214,17 @@ class CreateDatasetReq(Base):
name: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1, max_length=DATASET_NAME_LIMIT), Field(...)] name: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1, max_length=DATASET_NAME_LIMIT), Field(...)]
avatar: str | None = Field(default=None, max_length=65535) avatar: str | None = Field(default=None, max_length=65535)
description: str | None = Field(default=None, max_length=65535) description: str | None = Field(default=None, max_length=65535)
embedding_model: Annotated[str | None, StringConstraints(strip_whitespace=True, max_length=255), Field(default=None, serialization_alias="embd_id")] embedding_model: Annotated[str, StringConstraints(strip_whitespace=True, max_length=255), Field(default="", serialization_alias="embd_id")]
permission: Annotated[PermissionEnum, StringConstraints(strip_whitespace=True, min_length=1, max_length=16), Field(default=PermissionEnum.me)] permission: Annotated[PermissionEnum, StringConstraints(strip_whitespace=True, min_length=1, max_length=16), Field(default=PermissionEnum.me)]
chunk_method: Annotated[ChunkMethodnEnum, StringConstraints(strip_whitespace=True, min_length=1, max_length=32), Field(default=ChunkMethodnEnum.naive, serialization_alias="parser_id")] chunk_method: Annotated[ChunkMethodnEnum, StringConstraints(strip_whitespace=True, min_length=1, max_length=32), Field(default=ChunkMethodnEnum.naive, serialization_alias="parser_id")]
pagerank: int = Field(default=0, ge=0, le=100) pagerank: int = Field(default=0, ge=0, le=100)
parser_config: ParserConfig | None = Field(default=None) parser_config: ParserConfig = Field(default_factory=dict)
@field_validator("avatar") @field_validator("avatar")
@classmethod @classmethod
def validate_avatar_base64(cls, v: str) -> str: def validate_avatar_base64(cls, v: str | None) -> str | None:
"""Validates Base64-encoded avatar string format and MIME type compliance. """
Validates Base64-encoded avatar string format and MIME type compliance.
Implements a three-stage validation workflow: Implements a three-stage validation workflow:
1. MIME prefix existence check 1. MIME prefix existence check
@ -259,7 +273,8 @@ class CreateDatasetReq(Base):
@field_validator("embedding_model", mode="after") @field_validator("embedding_model", mode="after")
@classmethod @classmethod
def validate_embedding_model(cls, v: str) -> str: def validate_embedding_model(cls, v: str) -> str:
"""Validates embedding model identifier format compliance. """
Validates embedding model identifier format compliance.
Validation pipeline: Validation pipeline:
1. Structural format verification 1. Structural format verification
@ -298,11 +313,12 @@ class CreateDatasetReq(Base):
@field_validator("permission", mode="before") @field_validator("permission", mode="before")
@classmethod @classmethod
def permission_auto_lowercase(cls, v: str) -> str: def permission_auto_lowercase(cls, v: Any) -> Any:
"""Normalize permission input to lowercase for consistent PermissionEnum matching. """
Normalize permission input to lowercase for consistent PermissionEnum matching.
Args: Args:
v (str): Raw input value for the permission field v (Any): Raw input value for the permission field
Returns: Returns:
Lowercase string if input is string type, otherwise returns original value Lowercase string if input is string type, otherwise returns original value
@ -316,13 +332,13 @@ class CreateDatasetReq(Base):
@field_validator("parser_config", mode="after") @field_validator("parser_config", mode="after")
@classmethod @classmethod
def validate_parser_config_json_length(cls, v: ParserConfig | None) -> ParserConfig | None: def validate_parser_config_json_length(cls, v: ParserConfig) -> ParserConfig:
"""Validates serialized JSON length constraints for parser configuration. """
Validates serialized JSON length constraints for parser configuration.
Implements a three-stage validation workflow: Implements a two-stage validation workflow:
1. Null check - bypass validation for empty configurations 1. Model serialization - convert Pydantic model to JSON string
2. Model serialization - convert Pydantic model to JSON string 2. Size verification - enforce maximum allowed payload size
3. Size verification - enforce maximum allowed payload size
Args: Args:
v (ParserConfig | None): Raw parser configuration object v (ParserConfig | None): Raw parser configuration object
@ -333,9 +349,15 @@ class CreateDatasetReq(Base):
Raises: Raises:
ValueError: When serialized JSON exceeds 65,535 characters ValueError: When serialized JSON exceeds 65,535 characters
""" """
if v is None:
return v
if (json_str := v.model_dump_json()) and len(json_str) > 65535: if (json_str := v.model_dump_json()) and len(json_str) > 65535:
raise ValueError(f"Parser config exceeds size limit (max 65,535 characters). Current size: {len(json_str):,}") raise ValueError(f"Parser config exceeds size limit (max 65,535 characters). Current size: {len(json_str):,}")
return v return v
class UpdateDatasetReq(CreateDatasetReq):
dataset_id: UUID1 = Field(...)
name: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1, max_length=DATASET_NAME_LIMIT), Field(default="")]
@field_serializer("dataset_id")
def serialize_uuid_to_hex(self, v: uuid.UUID) -> str:
return v.hex

View File

@ -385,7 +385,7 @@ curl --request POST \
- `"team"`: All team members can manage the dataset. - `"team"`: All team members can manage the dataset.
- `"pagerank"`: (*Body parameter*), `int` - `"pagerank"`: (*Body parameter*), `int`
Set page rank: refer to [Set page rank](https://ragflow.io/docs/dev/set_page_rank) refer to [Set page rank](https://ragflow.io/docs/dev/set_page_rank)
- Default: `0` - Default: `0`
- Minimum: `0` - Minimum: `0`
- Maximum: `100` - Maximum: `100`
@ -562,8 +562,13 @@ Updates configurations for a specified dataset.
- `'Authorization: Bearer <YOUR_API_KEY>'` - `'Authorization: Bearer <YOUR_API_KEY>'`
- Body: - Body:
- `"name"`: `string` - `"name"`: `string`
- `"avatar"`: `string`
- `"description"`: `string`
- `"embedding_model"`: `string` - `"embedding_model"`: `string`
- `"chunk_method"`: `enum<string>` - `"permission"`: `string`
- `"chunk_method"`: `string`
- `"pagerank"`: `int`
- `"parser_config"`: `object`
##### Request example ##### Request example
@ -584,22 +589,74 @@ curl --request PUT \
The ID of the dataset to update. The ID of the dataset to update.
- `"name"`: (*Body parameter*), `string` - `"name"`: (*Body parameter*), `string`
The revised name of the dataset. The revised name of the dataset.
- Basic Multilingual Plane (BMP) only
- Maximum 128 characters
- Case-insensitive
- `"avatar"`: (*Body parameter*), `string`
The updated base64 encoding of the avatar.
- Maximum 65535 characters
- `"embedding_model"`: (*Body parameter*), `string` - `"embedding_model"`: (*Body parameter*), `string`
The updated embedding model name. The updated embedding model name.
- Ensure that `"chunk_count"` is `0` before updating `"embedding_model"`. - Ensure that `"chunk_count"` is `0` before updating `"embedding_model"`.
- Maximum 255 characters
- Must follow `model_name@model_factory` format
- `"permission"`: (*Body parameter*), `string`
The updated dataset permission. Available options:
- `"me"`: (Default) Only you can manage the dataset.
- `"team"`: All team members can manage the dataset.
- `"pagerank"`: (*Body parameter*), `int`
refer to [Set page rank](https://ragflow.io/docs/dev/set_page_rank)
- Default: `0`
- Minimum: `0`
- Maximum: `100`
- `"chunk_method"`: (*Body parameter*), `enum<string>` - `"chunk_method"`: (*Body parameter*), `enum<string>`
The chunking method for the dataset. Available options: The chunking method for the dataset. Available options:
- `"naive"`: General - `"naive"`: General (default)
- `"manual`: Manual - `"book"`: Book
- `"email"`: Email
- `"laws"`: Laws
- `"manual"`: Manual
- `"one"`: One
- `"paper"`: Paper
- `"picture"`: Picture
- `"presentation"`: Presentation
- `"qa"`: Q&A - `"qa"`: Q&A
- `"table"`: Table - `"table"`: Table
- `"paper"`: Paper - `"tag"`: Tag
- `"book"`: Book - `"parser_config"`: (*Body parameter*), `object`
- `"laws"`: Laws The configuration settings for the dataset parser. The attributes in this JSON object vary with the selected `"chunk_method"`:
- `"presentation"`: Presentation - If `"chunk_method"` is `"naive"`, the `"parser_config"` object contains the following attributes:
- `"picture"`: Picture - `"auto_keywords"`: `int`
- `"one"`:One - Defaults to `0`
- `"email"`: Email - Minimum: `0`
- Maximum: `32`
- `"auto_questions"`: `int`
- Defaults to `0`
- Minimum: `0`
- Maximum: `10`
- `"chunk_token_num"`: `int`
- Defaults to `128`
- Minimum: `1`
- Maximum: `2048`
- `"delimiter"`: `string`
- Defaults to `"\n"`.
- `"html4excel"`: `bool` Indicates whether to convert Excel documents into HTML format.
- Defaults to `false`
- `"layout_recognize"`: `string`
- Defaults to `DeepDOC`
- `"tag_kb_ids"`: `array<string>` refer to [Use tag set](https://ragflow.io/docs/dev/use_tag_sets)
- Must include a list of dataset IDs, where each dataset is parsed using the Tag Chunk Method
- `"task_page_size"`: `int` For PDF only.
- Defaults to `12`
- Minimum: `1`
- `"raptor"`: `object` RAPTOR-specific settings.
- Defaults to: `{"use_raptor": false}`
- `"graphrag"`: `object` GRAPHRAG-specific settings.
- Defaults to: `{"use_graphrag": false}`
- If `"chunk_method"` is `"qa"`, `"manuel"`, `"paper"`, `"book"`, `"laws"`, or `"presentation"`, the `"parser_config"` object contains the following attribute:
- `"raptor"`: `object` RAPTOR-specific settings.
- Defaults to: `{"use_raptor": false}`.
- If `"chunk_method"` is `"table"`, `"picture"`, `"one"`, or `"email"`, `"parser_config"` is an empty JSON object.
#### Response #### Response

View File

@ -306,20 +306,40 @@ Updates configurations for the current dataset.
A dictionary representing the attributes to update, with the following keys: A dictionary representing the attributes to update, with the following keys:
- `"name"`: `str` The revised name of the dataset. - `"name"`: `str` The revised name of the dataset.
- `"embedding_model"`: `str` The updated embedding model name. - Basic Multilingual Plane (BMP) only
- Maximum 128 characters
- Case-insensitive
- `"avatar"`: (*Body parameter*), `string`
The updated base64 encoding of the avatar.
- Maximum 65535 characters
- `"embedding_model"`: (*Body parameter*), `string`
The updated embedding model name.
- Ensure that `"chunk_count"` is `0` before updating `"embedding_model"`. - Ensure that `"chunk_count"` is `0` before updating `"embedding_model"`.
- `"chunk_method"`: `str` The chunking method for the dataset. Available options: - Maximum 255 characters
- `"naive"`: General - Must follow `model_name@model_factory` format
- `"manual`: Manual - `"permission"`: (*Body parameter*), `string`
The updated dataset permission. Available options:
- `"me"`: (Default) Only you can manage the dataset.
- `"team"`: All team members can manage the dataset.
- `"pagerank"`: (*Body parameter*), `int`
refer to [Set page rank](https://ragflow.io/docs/dev/set_page_rank)
- Default: `0`
- Minimum: `0`
- Maximum: `100`
- `"chunk_method"`: (*Body parameter*), `enum<string>`
The chunking method for the dataset. Available options:
- `"naive"`: General (default)
- `"book"`: Book
- `"email"`: Email
- `"laws"`: Laws
- `"manual"`: Manual
- `"one"`: One
- `"paper"`: Paper
- `"picture"`: Picture
- `"presentation"`: Presentation
- `"qa"`: Q&A - `"qa"`: Q&A
- `"table"`: Table - `"table"`: Table
- `"paper"`: Paper - `"tag"`: Tag
- `"book"`: Book
- `"laws"`: Laws
- `"presentation"`: Presentation
- `"picture"`: Picture
- `"one"`: One
- `"email"`: Email
#### Returns #### Returns

View File

@ -59,21 +59,19 @@ class RAGFlow:
pagerank: int = 0, pagerank: int = 0,
parser_config: DataSet.ParserConfig = None, parser_config: DataSet.ParserConfig = None,
) -> DataSet: ) -> DataSet:
if parser_config: payload = {
parser_config = parser_config.to_json() "name": name,
res = self.post( "avatar": avatar,
"/datasets", "description": description,
{ "embedding_model": embedding_model,
"name": name, "permission": permission,
"avatar": avatar, "chunk_method": chunk_method,
"description": description, "pagerank": pagerank,
"embedding_model": embedding_model, }
"permission": permission, if parser_config is not None:
"chunk_method": chunk_method, payload["parser_config"] = parser_config.to_json()
"pagerank": pagerank,
"parser_config": parser_config, res = self.post("/datasets", payload)
},
)
res = res.json() res = res.json()
if res.get("code") == 0: if res.get("code") == 0:
return DataSet(self, res["data"]) return DataSet(self, res["data"])

View File

@ -0,0 +1,28 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import hypothesis.strategies as st
@st.composite
def valid_names(draw):
base_chars = "abcdefghijklmnopqrstuvwxyz_"
first_char = draw(st.sampled_from([c for c in base_chars if c.isalpha() or c == "_"]))
remaining = draw(st.text(alphabet=st.sampled_from(base_chars), min_size=0, max_size=128 - 2))
name = (first_char + remaining)[:128]
return name.encode("utf-8").decode("utf-8")

View File

@ -39,23 +39,23 @@ SESSION_WITH_CHAT_NAME_LIMIT = 255
# DATASET MANAGEMENT # DATASET MANAGEMENT
def create_dataset(auth, payload=None, headers=HEADERS, data=None): def create_dataset(auth, payload=None, *, headers=HEADERS, data=None):
res = requests.post(url=f"{HOST_ADDRESS}{DATASETS_API_URL}", headers=headers, auth=auth, json=payload, data=data) res = requests.post(url=f"{HOST_ADDRESS}{DATASETS_API_URL}", headers=headers, auth=auth, json=payload, data=data)
return res.json() return res.json()
def list_datasets(auth, params=None, headers=HEADERS): def list_datasets(auth, params=None, *, headers=HEADERS):
res = requests.get(url=f"{HOST_ADDRESS}{DATASETS_API_URL}", headers=headers, auth=auth, params=params) res = requests.get(url=f"{HOST_ADDRESS}{DATASETS_API_URL}", headers=headers, auth=auth, params=params)
return res.json() return res.json()
def update_dataset(auth, dataset_id, payload=None, headers=HEADERS): def update_dataset(auth, dataset_id, payload=None, *, headers=HEADERS, data=None):
res = requests.put(url=f"{HOST_ADDRESS}{DATASETS_API_URL}/{dataset_id}", headers=headers, auth=auth, json=payload) res = requests.put(url=f"{HOST_ADDRESS}{DATASETS_API_URL}/{dataset_id}", headers=headers, auth=auth, json=payload, data=data)
return res.json() return res.json()
def delete_datasets(auth, payload=None, headers=HEADERS): def delete_datasets(auth, payload=None, *, headers=HEADERS, data=None):
res = requests.delete(url=f"{HOST_ADDRESS}{DATASETS_API_URL}", headers=headers, auth=auth, json=payload) res = requests.delete(url=f"{HOST_ADDRESS}{DATASETS_API_URL}", headers=headers, auth=auth, json=payload, data=data)
return res.json() return res.json()

View File

@ -37,3 +37,13 @@ def add_datasets_func(get_http_api_auth, request):
request.addfinalizer(cleanup) request.addfinalizer(cleanup)
return batch_create_datasets(get_http_api_auth, 3) return batch_create_datasets(get_http_api_auth, 3)
@pytest.fixture(scope="function")
def add_dataset_func(get_http_api_auth, request):
def cleanup():
delete_datasets(get_http_api_auth)
request.addfinalizer(cleanup)
return batch_create_datasets(get_http_api_auth, 1)[0]

View File

@ -13,30 +13,20 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
from concurrent.futures import ThreadPoolExecutor
import hypothesis.strategies as st
import pytest import pytest
from common import DATASET_NAME_LIMIT, INVALID_API_TOKEN, create_dataset from common import DATASET_NAME_LIMIT, INVALID_API_TOKEN, create_dataset
from hypothesis import example, given, settings from hypothesis import example, given, settings
from libs.auth import RAGFlowHttpApiAuth from libs.auth import RAGFlowHttpApiAuth
from libs.utils import encode_avatar from libs.utils import encode_avatar
from libs.utils.file_utils import create_image_file from libs.utils.file_utils import create_image_file
from libs.utils.hypothesis_utils import valid_names
@st.composite
def valid_names(draw):
base_chars = "abcdefghijklmnopqrstuvwxyz_"
first_char = draw(st.sampled_from([c for c in base_chars if c.isalpha() or c == "_"]))
remaining = draw(st.text(alphabet=st.sampled_from(base_chars), min_size=0, max_size=DATASET_NAME_LIMIT - 2))
name = (first_char + remaining)[:128]
return name.encode("utf-8").decode("utf-8")
@pytest.mark.p1
@pytest.mark.usefixtures("clear_datasets") @pytest.mark.usefixtures("clear_datasets")
class TestAuthorization: class TestAuthorization:
@pytest.mark.p1
@pytest.mark.parametrize( @pytest.mark.parametrize(
"auth, expected_code, expected_message", "auth, expected_code, expected_message",
[ [
@ -49,64 +39,17 @@ class TestAuthorization:
], ],
ids=["empty_auth", "invalid_api_token"], ids=["empty_auth", "invalid_api_token"],
) )
def test_invalid_auth(self, auth, expected_code, expected_message): def test_auth_invalid(self, auth, expected_code, expected_message):
res = create_dataset(auth, {"name": "auth_test"}) res = create_dataset(auth, {"name": "auth_test"})
assert res["code"] == expected_code assert res["code"] == expected_code, res
assert res["message"] == expected_message assert res["message"] == expected_message, res
@pytest.mark.usefixtures("clear_datasets") class TestRquest:
class TestDatasetCreation:
@pytest.mark.p1
@given(name=valid_names())
@example("a" * 128)
@settings(max_examples=20)
def test_valid_name(self, get_http_api_auth, name):
res = create_dataset(get_http_api_auth, {"name": name})
assert res["code"] == 0, res
assert res["data"]["name"] == name, res
@pytest.mark.p1
@pytest.mark.parametrize(
"name, expected_message",
[
("", "String should have at least 1 character"),
(" ", "String should have at least 1 character"),
("a" * (DATASET_NAME_LIMIT + 1), "String should have at most 128 characters"),
(0, "Input should be a valid string"),
],
ids=["empty_name", "space_name", "too_long_name", "invalid_name"],
)
def test_invalid_name(self, get_http_api_auth, name, expected_message):
res = create_dataset(get_http_api_auth, {"name": name})
assert res["code"] == 101, res
assert expected_message in res["message"], res
@pytest.mark.p2
def test_duplicated_name(self, get_http_api_auth):
name = "duplicated_name"
payload = {"name": name}
res = create_dataset(get_http_api_auth, payload)
assert res["code"] == 0, res
res = create_dataset(get_http_api_auth, payload)
assert res["code"] == 101, res
assert res["message"] == f"Dataset name '{name}' already exists", res
@pytest.mark.p2
def test_case_insensitive(self, get_http_api_auth):
name = "CaseInsensitive"
res = create_dataset(get_http_api_auth, {"name": name.upper()})
assert res["code"] == 0, res
res = create_dataset(get_http_api_auth, {"name": name.lower()})
assert res["code"] == 101, res
assert res["message"] == f"Dataset name '{name.lower()}' already exists", res
@pytest.mark.p3 @pytest.mark.p3
def test_bad_content_type(self, get_http_api_auth): def test_content_type_bad(self, get_http_api_auth):
BAD_CONTENT_TYPE = "text/xml" BAD_CONTENT_TYPE = "text/xml"
res = create_dataset(get_http_api_auth, {"name": "name"}, {"Content-Type": BAD_CONTENT_TYPE}) res = create_dataset(get_http_api_auth, {"name": "bad_content_type"}, headers={"Content-Type": BAD_CONTENT_TYPE})
assert res["code"] == 101, res assert res["code"] == 101, res
assert res["message"] == f"Unsupported content type: Expected application/json, got {BAD_CONTENT_TYPE}", res assert res["message"] == f"Unsupported content type: Expected application/json, got {BAD_CONTENT_TYPE}", res
@ -115,15 +58,85 @@ class TestDatasetCreation:
"payload, expected_message", "payload, expected_message",
[ [
("a", "Malformed JSON syntax: Missing commas/brackets or invalid encoding"), ("a", "Malformed JSON syntax: Missing commas/brackets or invalid encoding"),
('"a"', "Invalid request payload: expected objec"), ('"a"', "Invalid request payload: expected object, got str"),
], ],
ids=["malformed_json_syntax", "invalid_request_payload_type"], ids=["malformed_json_syntax", "invalid_request_payload_type"],
) )
def test_bad_payload(self, get_http_api_auth, payload, expected_message): def test_payload_bad(self, get_http_api_auth, payload, expected_message):
res = create_dataset(get_http_api_auth, data=payload) res = create_dataset(get_http_api_auth, data=payload)
assert res["code"] == 101, res assert res["code"] == 101, res
assert res["message"] == expected_message, res
@pytest.mark.usefixtures("clear_datasets")
class TestCapability:
@pytest.mark.p3
def test_create_dataset_1k(self, get_http_api_auth):
for i in range(1_000):
payload = {"name": f"dataset_{i}"}
res = create_dataset(get_http_api_auth, payload)
assert res["code"] == 0, f"Failed to create dataset {i}"
@pytest.mark.p3
def test_create_dataset_concurrent(self, get_http_api_auth):
with ThreadPoolExecutor(max_workers=5) as executor:
futures = [executor.submit(create_dataset, get_http_api_auth, {"name": f"dataset_{i}"}) for i in range(100)]
responses = [f.result() for f in futures]
assert all(r["code"] == 0 for r in responses), responses
@pytest.mark.usefixtures("clear_datasets")
class TestDatasetCreate:
@pytest.mark.p1
@given(name=valid_names())
@example("a" * 128)
@settings(max_examples=20)
def test_name(self, get_http_api_auth, name):
res = create_dataset(get_http_api_auth, {"name": name})
assert res["code"] == 0, res
assert res["data"]["name"] == name, res
@pytest.mark.p2
@pytest.mark.parametrize(
"name, expected_message",
[
("", "String should have at least 1 character"),
(" ", "String should have at least 1 character"),
("a" * (DATASET_NAME_LIMIT + 1), "String should have at most 128 characters"),
(0, "Input should be a valid string"),
(None, "Input should be a valid string"),
],
ids=["empty_name", "space_name", "too_long_name", "invalid_name", "None_name"],
)
def test_name_invalid(self, get_http_api_auth, name, expected_message):
payload = {"name": name}
res = create_dataset(get_http_api_auth, payload)
assert res["code"] == 101, res
assert expected_message in res["message"], res assert expected_message in res["message"], res
@pytest.mark.p3
def test_name_duplicated(self, get_http_api_auth):
name = "duplicated_name"
payload = {"name": name}
res = create_dataset(get_http_api_auth, payload)
assert res["code"] == 0, res
res = create_dataset(get_http_api_auth, payload)
assert res["code"] == 102, res
assert res["message"] == f"Dataset name '{name}' already exists", res
@pytest.mark.p3
def test_name_case_insensitive(self, get_http_api_auth):
name = "CaseInsensitive"
payload = {"name": name.upper()}
res = create_dataset(get_http_api_auth, payload)
assert res["code"] == 0, res
payload = {"name": name.lower()}
res = create_dataset(get_http_api_auth, payload)
assert res["code"] == 102, res
assert res["message"] == f"Dataset name '{name.lower()}' already exists", res
@pytest.mark.p2 @pytest.mark.p2
def test_avatar(self, get_http_api_auth, tmp_path): def test_avatar(self, get_http_api_auth, tmp_path):
fn = create_image_file(tmp_path / "ragflow_test.png") fn = create_image_file(tmp_path / "ragflow_test.png")
@ -134,16 +147,10 @@ class TestDatasetCreation:
res = create_dataset(get_http_api_auth, payload) res = create_dataset(get_http_api_auth, payload)
assert res["code"] == 0, res assert res["code"] == 0, res
@pytest.mark.p3
def test_avatar_none(self, get_http_api_auth, tmp_path):
payload = {"name": "test_avatar_none", "avatar": None}
res = create_dataset(get_http_api_auth, payload)
assert res["code"] == 0, res
assert res["data"]["avatar"] is None, res
@pytest.mark.p2 @pytest.mark.p2
def test_avatar_exceeds_limit_length(self, get_http_api_auth): def test_avatar_exceeds_limit_length(self, get_http_api_auth):
res = create_dataset(get_http_api_auth, {"name": "exceeds_limit_length_avatar", "avatar": "a" * 65536}) payload = {"name": "exceeds_limit_length_avatar", "avatar": "a" * 65536}
res = create_dataset(get_http_api_auth, payload)
assert res["code"] == 101, res assert res["code"] == 101, res
assert "String should have at most 65535 characters" in res["message"], res assert "String should have at most 65535 characters" in res["message"], res
@ -158,7 +165,7 @@ class TestDatasetCreation:
], ],
ids=["empty_prefix", "missing_comma", "unsupported_mine_type", "invalid_mine_type"], ids=["empty_prefix", "missing_comma", "unsupported_mine_type", "invalid_mine_type"],
) )
def test_invalid_avatar_prefix(self, get_http_api_auth, tmp_path, name, avatar_prefix, expected_message): def test_avatar_invalid_prefix(self, get_http_api_auth, tmp_path, name, avatar_prefix, expected_message):
fn = create_image_file(tmp_path / "ragflow_test.png") fn = create_image_file(tmp_path / "ragflow_test.png")
payload = { payload = {
"name": name, "name": name,
@ -169,11 +176,25 @@ class TestDatasetCreation:
assert expected_message in res["message"], res assert expected_message in res["message"], res
@pytest.mark.p3 @pytest.mark.p3
def test_description_none(self, get_http_api_auth): def test_avatar_unset(self, get_http_api_auth):
payload = {"name": "test_description_none", "description": None} payload = {"name": "test_avatar_unset"}
res = create_dataset(get_http_api_auth, payload) res = create_dataset(get_http_api_auth, payload)
assert res["code"] == 0, res assert res["code"] == 0, res
assert res["data"]["description"] is None, res assert res["data"]["avatar"] is None, res
@pytest.mark.p3
def test_avatar_none(self, get_http_api_auth):
payload = {"name": "test_avatar_none", "avatar": None}
res = create_dataset(get_http_api_auth, payload)
assert res["code"] == 0, res
assert res["data"]["avatar"] is None, res
@pytest.mark.p2
def test_description(self, get_http_api_auth):
payload = {"name": "test_description", "description": "description"}
res = create_dataset(get_http_api_auth, payload)
assert res["code"] == 0, res
assert res["data"]["description"] == "description", res
@pytest.mark.p2 @pytest.mark.p2
def test_description_exceeds_limit_length(self, get_http_api_auth): def test_description_exceeds_limit_length(self, get_http_api_auth):
@ -182,6 +203,20 @@ class TestDatasetCreation:
assert res["code"] == 101, res assert res["code"] == 101, res
assert "String should have at most 65535 characters" in res["message"], res assert "String should have at most 65535 characters" in res["message"], res
@pytest.mark.p3
def test_description_unset(self, get_http_api_auth):
payload = {"name": "test_description_unset"}
res = create_dataset(get_http_api_auth, payload)
assert res["code"] == 0, res
assert res["data"]["description"] is None, res
@pytest.mark.p3
def test_description_none(self, get_http_api_auth):
payload = {"name": "test_description_none", "description": None}
res = create_dataset(get_http_api_auth, payload)
assert res["code"] == 0, res
assert res["data"]["description"] is None, res
@pytest.mark.p1 @pytest.mark.p1
@pytest.mark.parametrize( @pytest.mark.parametrize(
"name, embedding_model", "name, embedding_model",
@ -189,22 +224,14 @@ class TestDatasetCreation:
("BAAI/bge-large-zh-v1.5@BAAI", "BAAI/bge-large-zh-v1.5@BAAI"), ("BAAI/bge-large-zh-v1.5@BAAI", "BAAI/bge-large-zh-v1.5@BAAI"),
("maidalun1020/bce-embedding-base_v1@Youdao", "maidalun1020/bce-embedding-base_v1@Youdao"), ("maidalun1020/bce-embedding-base_v1@Youdao", "maidalun1020/bce-embedding-base_v1@Youdao"),
("embedding-3@ZHIPU-AI", "embedding-3@ZHIPU-AI"), ("embedding-3@ZHIPU-AI", "embedding-3@ZHIPU-AI"),
("embedding_model_default", None),
], ],
ids=["builtin_baai", "builtin_youdao", "tenant_zhipu", "default"], ids=["builtin_baai", "builtin_youdao", "tenant_zhipu"],
) )
def test_valid_embedding_model(self, get_http_api_auth, name, embedding_model): def test_embedding_model(self, get_http_api_auth, name, embedding_model):
if embedding_model is None: payload = {"name": name, "embedding_model": embedding_model}
payload = {"name": name}
else:
payload = {"name": name, "embedding_model": embedding_model}
res = create_dataset(get_http_api_auth, payload) res = create_dataset(get_http_api_auth, payload)
assert res["code"] == 0, res assert res["code"] == 0, res
if embedding_model is None: assert res["data"]["embedding_model"] == embedding_model, res
assert res["data"]["embedding_model"] == "BAAI/bge-large-zh-v1.5@BAAI", res
else:
assert res["data"]["embedding_model"] == embedding_model, res
@pytest.mark.p2 @pytest.mark.p2
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -217,7 +244,7 @@ class TestDatasetCreation:
], ],
ids=["unknown_llm_name", "unknown_llm_factory", "tenant_no_auth_default_tenant_llm", "tenant_no_auth"], ids=["unknown_llm_name", "unknown_llm_factory", "tenant_no_auth_default_tenant_llm", "tenant_no_auth"],
) )
def test_invalid_embedding_model(self, get_http_api_auth, name, embedding_model): def test_embedding_model_invalid(self, get_http_api_auth, name, embedding_model):
payload = {"name": name, "embedding_model": embedding_model} payload = {"name": name, "embedding_model": embedding_model}
res = create_dataset(get_http_api_auth, payload) res = create_dataset(get_http_api_auth, payload)
assert res["code"] == 101, res assert res["code"] == 101, res
@ -247,6 +274,20 @@ class TestDatasetCreation:
else: else:
assert "Both model_name and provider must be non-empty strings" in res["message"], res assert "Both model_name and provider must be non-empty strings" in res["message"], res
@pytest.mark.p2
def test_embedding_model_unset(self, get_http_api_auth):
payload = {"name": "embedding_model_unset"}
res = create_dataset(get_http_api_auth, payload)
assert res["code"] == 0, res
assert res["data"]["embedding_model"] == "BAAI/bge-large-zh-v1.5@BAAI", res
@pytest.mark.p2
def test_embedding_model_none(self, get_http_api_auth):
payload = {"name": "test_embedding_model_none", "embedding_model": None}
res = create_dataset(get_http_api_auth, payload)
assert res["code"] == 101, res
assert "Input should be a valid string" in res["message"], res
@pytest.mark.p1 @pytest.mark.p1
@pytest.mark.parametrize( @pytest.mark.parametrize(
"name, permission", "name, permission",
@ -255,21 +296,14 @@ class TestDatasetCreation:
("team", "team"), ("team", "team"),
("me_upercase", "ME"), ("me_upercase", "ME"),
("team_upercase", "TEAM"), ("team_upercase", "TEAM"),
("permission_default", None),
], ],
ids=["me", "team", "me_upercase", "team_upercase", "permission_default"], ids=["me", "team", "me_upercase", "team_upercase"],
) )
def test_valid_permission(self, get_http_api_auth, name, permission): def test_permission(self, get_http_api_auth, name, permission):
if permission is None: payload = {"name": name, "permission": permission}
payload = {"name": name}
else:
payload = {"name": name, "permission": permission}
res = create_dataset(get_http_api_auth, payload) res = create_dataset(get_http_api_auth, payload)
assert res["code"] == 0, res assert res["code"] == 0, res
if permission is None: assert res["data"]["permission"] == permission.lower(), res
assert res["data"]["permission"] == "me", res
else:
assert res["data"]["permission"] == permission.lower(), res
@pytest.mark.p2 @pytest.mark.p2
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -279,13 +313,28 @@ class TestDatasetCreation:
("unknown", "unknown"), ("unknown", "unknown"),
("type_error", list()), ("type_error", list()),
], ],
ids=["empty", "unknown", "type_error"],
) )
def test_invalid_permission(self, get_http_api_auth, name, permission): def test_permission_invalid(self, get_http_api_auth, name, permission):
payload = {"name": name, "permission": permission} payload = {"name": name, "permission": permission}
res = create_dataset(get_http_api_auth, payload) res = create_dataset(get_http_api_auth, payload)
assert res["code"] == 101 assert res["code"] == 101
assert "Input should be 'me' or 'team'" in res["message"] assert "Input should be 'me' or 'team'" in res["message"]
@pytest.mark.p2
def test_permission_unset(self, get_http_api_auth):
payload = {"name": "test_permission_unset"}
res = create_dataset(get_http_api_auth, payload)
assert res["code"] == 0, res
assert res["data"]["permission"] == "me", res
@pytest.mark.p3
def test_permission_none(self, get_http_api_auth):
payload = {"name": "test_permission_none", "permission": None}
res = create_dataset(get_http_api_auth, payload)
assert res["code"] == 101, res
assert "Input should be 'me' or 'team'" in res["message"], res
@pytest.mark.p1 @pytest.mark.p1
@pytest.mark.parametrize( @pytest.mark.parametrize(
"name, chunk_method", "name, chunk_method",
@ -302,20 +351,14 @@ class TestDatasetCreation:
("qa", "qa"), ("qa", "qa"),
("table", "table"), ("table", "table"),
("tag", "tag"), ("tag", "tag"),
("chunk_method_default", None),
], ],
ids=["naive", "book", "email", "laws", "manual", "one", "paper", "picture", "presentation", "qa", "table", "tag"],
) )
def test_valid_chunk_method(self, get_http_api_auth, name, chunk_method): def test_chunk_method(self, get_http_api_auth, name, chunk_method):
if chunk_method is None: payload = {"name": name, "chunk_method": chunk_method}
payload = {"name": name}
else:
payload = {"name": name, "chunk_method": chunk_method}
res = create_dataset(get_http_api_auth, payload) res = create_dataset(get_http_api_auth, payload)
assert res["code"] == 0, res assert res["code"] == 0, res
if chunk_method is None: assert res["data"]["chunk_method"] == chunk_method, res
assert res["data"]["chunk_method"] == "naive", res
else:
assert res["data"]["chunk_method"] == chunk_method, res
@pytest.mark.p2 @pytest.mark.p2
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -325,19 +368,77 @@ class TestDatasetCreation:
("unknown", "unknown"), ("unknown", "unknown"),
("type_error", list()), ("type_error", list()),
], ],
ids=["empty", "unknown", "type_error"],
) )
def test_invalid_chunk_method(self, get_http_api_auth, name, chunk_method): def test_chunk_method_invalid(self, get_http_api_auth, name, chunk_method):
payload = {"name": name, "chunk_method": chunk_method} payload = {"name": name, "chunk_method": chunk_method}
res = create_dataset(get_http_api_auth, payload) res = create_dataset(get_http_api_auth, payload)
assert res["code"] == 101, res assert res["code"] == 101, res
assert "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', 'picture', 'presentation', 'qa', 'table' or 'tag'" in res["message"], res assert "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', 'picture', 'presentation', 'qa', 'table' or 'tag'" in res["message"], res
@pytest.mark.p2
def test_chunk_method_unset(self, get_http_api_auth):
payload = {"name": "test_chunk_method_unset"}
res = create_dataset(get_http_api_auth, payload)
assert res["code"] == 0, res
assert res["data"]["chunk_method"] == "naive", res
@pytest.mark.p3
def test_chunk_method_none(self, get_http_api_auth):
payload = {"name": "chunk_method_none", "chunk_method": None}
res = create_dataset(get_http_api_auth, payload)
assert res["code"] == 101, res
assert "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', 'picture', 'presentation', 'qa', 'table' or 'tag'" in res["message"], res
@pytest.mark.p2
@pytest.mark.parametrize(
"name, pagerank",
[
("pagerank_min", 0),
("pagerank_mid", 50),
("pagerank_max", 100),
],
ids=["min", "mid", "max"],
)
def test_pagerank(self, get_http_api_auth, name, pagerank):
payload = {"name": name, "pagerank": pagerank}
res = create_dataset(get_http_api_auth, payload)
assert res["code"] == 0, res
assert res["data"]["pagerank"] == pagerank, res
@pytest.mark.p3
@pytest.mark.parametrize(
"name, pagerank, expected_message",
[
("pagerank_min_limit", -1, "Input should be greater than or equal to 0"),
("pagerank_max_limit", 101, "Input should be less than or equal to 100"),
],
ids=["min_limit", "max_limit"],
)
def test_pagerank_invalid(self, get_http_api_auth, name, pagerank, expected_message):
payload = {"name": name, "pagerank": pagerank}
res = create_dataset(get_http_api_auth, payload)
assert res["code"] == 101, res
assert expected_message in res["message"], res
@pytest.mark.p3
def test_pagerank_unset(self, get_http_api_auth):
payload = {"name": "pagerank_unset"}
res = create_dataset(get_http_api_auth, payload)
assert res["code"] == 0, res
assert res["data"]["pagerank"] == 0, res
@pytest.mark.p3
def test_pagerank_none(self, get_http_api_auth):
payload = {"name": "pagerank_unset", "pagerank": None}
res = create_dataset(get_http_api_auth, payload)
assert res["code"] == 101, res
assert "Input should be a valid integer" in res["message"], res
@pytest.mark.p1 @pytest.mark.p1
@pytest.mark.parametrize( @pytest.mark.parametrize(
"name, parser_config", "name, parser_config",
[ [
("default_none", None),
("default_empty", {}),
("auto_keywords_min", {"auto_keywords": 0}), ("auto_keywords_min", {"auto_keywords": 0}),
("auto_keywords_mid", {"auto_keywords": 16}), ("auto_keywords_mid", {"auto_keywords": 16}),
("auto_keywords_max", {"auto_keywords": 32}), ("auto_keywords_max", {"auto_keywords": 32}),
@ -363,7 +464,7 @@ class TestDatasetCreation:
("task_page_size_min", {"task_page_size": 1}), ("task_page_size_min", {"task_page_size": 1}),
("task_page_size_None", {"task_page_size": None}), ("task_page_size_None", {"task_page_size": None}),
("pages", {"pages": [[1, 100]]}), ("pages", {"pages": [[1, 100]]}),
("pages_none", None), ("pages_none", {"pages": None}),
("graphrag_true", {"graphrag": {"use_graphrag": True}}), ("graphrag_true", {"graphrag": {"use_graphrag": True}}),
("graphrag_false", {"graphrag": {"use_graphrag": False}}), ("graphrag_false", {"graphrag": {"use_graphrag": False}}),
("graphrag_entity_types", {"graphrag": {"entity_types": ["age", "sex", "height", "weight"]}}), ("graphrag_entity_types", {"graphrag": {"entity_types": ["age", "sex", "height", "weight"]}}),
@ -388,8 +489,6 @@ class TestDatasetCreation:
("raptor_random_seed_min", {"raptor": {"random_seed": 0}}), ("raptor_random_seed_min", {"raptor": {"random_seed": 0}}),
], ],
ids=[ ids=[
"default_none",
"default_empty",
"auto_keywords_min", "auto_keywords_min",
"auto_keywords_mid", "auto_keywords_mid",
"auto_keywords_max", "auto_keywords_max",
@ -440,44 +539,16 @@ class TestDatasetCreation:
"raptor_random_seed_min", "raptor_random_seed_min",
], ],
) )
def test_valid_parser_config(self, get_http_api_auth, name, parser_config): def test_parser_config(self, get_http_api_auth, name, parser_config):
if parser_config is None: payload = {"name": name, "parser_config": parser_config}
payload = {"name": name}
else:
payload = {"name": name, "parser_config": parser_config}
res = create_dataset(get_http_api_auth, payload) res = create_dataset(get_http_api_auth, payload)
assert res["code"] == 0, res assert res["code"] == 0, res
if parser_config is None: for k, v in parser_config.items():
assert res["data"]["parser_config"] == { if isinstance(v, dict):
"chunk_token_num": 128, for kk, vv in v.items():
"delimiter": r"\n", assert res["data"]["parser_config"][k][kk] == vv, res
"html4excel": False, else:
"layout_recognize": "DeepDOC", assert res["data"]["parser_config"][k] == v, res
"raptor": {"use_raptor": False},
}
elif parser_config == {}:
assert res["data"]["parser_config"] == {
"auto_keywords": 0,
"auto_questions": 0,
"chunk_token_num": 128,
"delimiter": r"\n",
"filename_embd_weight": None,
"graphrag": None,
"html4excel": False,
"layout_recognize": "DeepDOC",
"pages": None,
"raptor": None,
"tag_kb_ids": [],
"task_page_size": None,
"topn_tags": 1,
}
else:
for k, v in parser_config.items():
if isinstance(v, dict):
for kk, vv in v.items():
assert res["data"]["parser_config"][k][kk] == vv
else:
assert res["data"]["parser_config"][k] == v
@pytest.mark.p2 @pytest.mark.p2
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -595,15 +666,72 @@ class TestDatasetCreation:
"parser_config_type_invalid", "parser_config_type_invalid",
], ],
) )
def test_invalid_parser_config(self, get_http_api_auth, name, parser_config, expected_message): def test_parser_config_invalid(self, get_http_api_auth, name, parser_config, expected_message):
payload = {"name": name, "parser_config": parser_config} payload = {"name": name, "parser_config": parser_config}
res = create_dataset(get_http_api_auth, payload) res = create_dataset(get_http_api_auth, payload)
assert res["code"] == 101, res assert res["code"] == 101, res
assert expected_message in res["message"], res assert expected_message in res["message"], res
@pytest.mark.p2
def test_parser_config_empty(self, get_http_api_auth):
payload = {"name": "default_empty", "parser_config": {}}
res = create_dataset(get_http_api_auth, payload)
assert res["code"] == 0, res
assert res["data"]["parser_config"] == {
"auto_keywords": 0,
"auto_questions": 0,
"chunk_token_num": 128,
"delimiter": r"\n",
"filename_embd_weight": None,
"graphrag": None,
"html4excel": False,
"layout_recognize": "DeepDOC",
"pages": None,
"raptor": None,
"tag_kb_ids": [],
"task_page_size": None,
"topn_tags": 1,
}
@pytest.mark.p2
def test_parser_config_unset(self, get_http_api_auth):
payload = {"name": "default_unset"}
res = create_dataset(get_http_api_auth, payload)
assert res["code"] == 0, res
assert res["data"]["parser_config"] == {
"chunk_token_num": 128,
"delimiter": r"\n",
"html4excel": False,
"layout_recognize": "DeepDOC",
"raptor": {"use_raptor": False},
}, res
@pytest.mark.p3 @pytest.mark.p3
def test_dataset_10k(self, get_http_api_auth): def test_parser_config_none(self, get_http_api_auth):
for i in range(10_000): payload = {"name": "default_none", "parser_config": None}
payload = {"name": f"dataset_{i}"} res = create_dataset(get_http_api_auth, payload)
res = create_dataset(get_http_api_auth, payload) assert res["code"] == 101, res
assert res["code"] == 0, f"Failed to create dataset {i}" assert "Input should be a valid dictionary or instance of ParserConfig" in res["message"], res
@pytest.mark.p2
@pytest.mark.parametrize(
"payload",
[
{"name": "id", "id": "id"},
{"name": "tenant_id", "tenant_id": "e57c1966f99211efb41e9e45646e0111"},
{"name": "created_by", "created_by": "created_by"},
{"name": "create_date", "create_date": "Tue, 11 Mar 2025 13:37:23 GMT"},
{"name": "create_time", "create_time": 1741671443322},
{"name": "update_date", "update_date": "Tue, 11 Mar 2025 13:37:23 GMT"},
{"name": "update_time", "update_time": 1741671443339},
{"name": "document_count", "document_count": 1},
{"name": "chunk_count", "chunk_count": 1},
{"name": "token_num", "token_num": 1},
{"name": "status", "status": "1"},
{"name": "unknown_field", "unknown_field": "unknown_field"},
],
)
def test_unsupported_field(self, get_http_api_auth, payload):
res = create_dataset(get_http_api_auth, payload)
assert res["code"] == 101, res
assert "Extra inputs are not permitted" in res["message"], res

View File

@ -16,21 +16,18 @@
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
import pytest import pytest
from common import ( from common import DATASET_NAME_LIMIT, INVALID_API_TOKEN, list_datasets, update_dataset
DATASET_NAME_LIMIT, from hypothesis import HealthCheck, example, given, settings
INVALID_API_TOKEN,
list_datasets,
update_dataset,
)
from libs.auth import RAGFlowHttpApiAuth from libs.auth import RAGFlowHttpApiAuth
from libs.utils import encode_avatar from libs.utils import encode_avatar
from libs.utils.file_utils import create_image_file from libs.utils.file_utils import create_image_file
from libs.utils.hypothesis_utils import valid_names
# TODO: Missing scenario for updating embedding_model with chunk_count != 0 # TODO: Missing scenario for updating embedding_model with chunk_count != 0
@pytest.mark.p1
class TestAuthorization: class TestAuthorization:
@pytest.mark.p1
@pytest.mark.parametrize( @pytest.mark.parametrize(
"auth, expected_code, expected_message", "auth, expected_code, expected_message",
[ [
@ -41,111 +38,178 @@ class TestAuthorization:
"Authentication error: API key is invalid!", "Authentication error: API key is invalid!",
), ),
], ],
ids=["empty_auth", "invalid_api_token"],
) )
def test_invalid_auth(self, auth, expected_code, expected_message): def test_auth_invalid(self, auth, expected_code, expected_message):
res = update_dataset(auth, "dataset_id") res = update_dataset(auth, "dataset_id")
assert res["code"] == expected_code assert res["code"] == expected_code, res
assert res["message"] == expected_message assert res["message"] == expected_message, res
class TestRquest:
@pytest.mark.p3
def test_bad_content_type(self, get_http_api_auth, add_dataset_func):
dataset_id = add_dataset_func
BAD_CONTENT_TYPE = "text/xml"
res = update_dataset(get_http_api_auth, dataset_id, {"name": "bad_content_type"}, headers={"Content-Type": BAD_CONTENT_TYPE})
assert res["code"] == 101, res
assert res["message"] == f"Unsupported content type: Expected application/json, got {BAD_CONTENT_TYPE}", res
@pytest.mark.p3
@pytest.mark.parametrize(
"payload, expected_message",
[
("a", "Malformed JSON syntax: Missing commas/brackets or invalid encoding"),
('"a"', "Invalid request payload: expected object, got str"),
],
ids=["malformed_json_syntax", "invalid_request_payload_type"],
)
def test_payload_bad(self, get_http_api_auth, add_dataset_func, payload, expected_message):
dataset_id = add_dataset_func
res = update_dataset(get_http_api_auth, dataset_id, data=payload)
assert res["code"] == 101, res
assert res["message"] == expected_message, res
@pytest.mark.p2
def test_payload_empty(self, get_http_api_auth, add_dataset_func):
dataset_id = add_dataset_func
res = update_dataset(get_http_api_auth, dataset_id, {})
assert res["code"] == 101, res
assert res["message"] == "No properties were modified", res
class TestCapability:
@pytest.mark.p3
def test_update_dateset_concurrent(self, get_http_api_auth, add_dataset_func):
dataset_id = add_dataset_func
with ThreadPoolExecutor(max_workers=5) as executor:
futures = [executor.submit(update_dataset, get_http_api_auth, dataset_id, {"name": f"dataset_{i}"}) for i in range(100)]
responses = [f.result() for f in futures]
assert all(r["code"] == 0 for r in responses), responses
@pytest.mark.p1
class TestDatasetUpdate: class TestDatasetUpdate:
@pytest.mark.parametrize( @pytest.mark.p3
"name, expected_code, expected_message", def test_dataset_id_not_uuid(self, get_http_api_auth):
[ payload = {"name": "dataset_id_not_uuid"}
("valid_name", 0, ""), res = update_dataset(get_http_api_auth, "not_uuid", payload)
( assert res["code"] == 101, res
"a" * (DATASET_NAME_LIMIT + 1), assert "Input should be a valid UUID" in res["message"], res
102,
"Dataset name should not be longer than 128 characters.",
),
(0, 100, """AttributeError("\'int\' object has no attribute \'strip\'")"""),
(
None,
100,
"""AttributeError("\'NoneType\' object has no attribute \'strip\'")""",
),
pytest.param("", 102, "", marks=pytest.mark.skip(reason="issue/5915")),
("dataset_1", 102, "Duplicated dataset name in updating dataset."),
("DATASET_1", 102, "Duplicated dataset name in updating dataset."),
],
)
def test_name(self, get_http_api_auth, add_datasets_func, name, expected_code, expected_message):
dataset_ids = add_datasets_func
res = update_dataset(get_http_api_auth, dataset_ids[0], {"name": name})
assert res["code"] == expected_code
if expected_code == 0:
res = list_datasets(get_http_api_auth, {"id": dataset_ids[0]})
assert res["data"][0]["name"] == name
else:
assert res["message"] == expected_message
@pytest.mark.parametrize( @pytest.mark.p3
"embedding_model, expected_code, expected_message", def test_dataset_id_wrong_uuid(self, get_http_api_auth):
[ payload = {"name": "wrong_uuid"}
("BAAI/bge-large-zh-v1.5", 0, ""), res = update_dataset(get_http_api_auth, "d94a8dc02c9711f0930f7fbc369eab6d", payload)
("maidalun1020/bce-embedding-base_v1", 0, ""), assert res["code"] == 102, res
( assert "lacks permission for dataset" in res["message"], res
"other_embedding_model",
102, @pytest.mark.p1
"`embedding_model` other_embedding_model doesn't exist", @given(name=valid_names())
), @example("a" * 128)
(None, 102, "`embedding_model` can't be empty"), @settings(max_examples=20, suppress_health_check=[HealthCheck.function_scoped_fixture])
], def test_name(self, get_http_api_auth, add_dataset_func, name):
)
def test_embedding_model(self, get_http_api_auth, add_dataset_func, embedding_model, expected_code, expected_message):
dataset_id = add_dataset_func dataset_id = add_dataset_func
res = update_dataset(get_http_api_auth, dataset_id, {"embedding_model": embedding_model}) payload = {"name": name}
assert res["code"] == expected_code res = update_dataset(get_http_api_auth, dataset_id, payload)
if expected_code == 0: assert res["code"] == 0, res
res = list_datasets(get_http_api_auth, {"id": dataset_id})
assert res["data"][0]["embedding_model"] == embedding_model
else:
assert res["message"] == expected_message
res = list_datasets(get_http_api_auth)
assert res["code"] == 0, res
assert res["data"][0]["name"] == name, res
@pytest.mark.p2
@pytest.mark.parametrize( @pytest.mark.parametrize(
"chunk_method, expected_code, expected_message", "name, expected_message",
[ [
("naive", 0, ""), ("", "String should have at least 1 character"),
("manual", 0, ""), (" ", "String should have at least 1 character"),
("qa", 0, ""), ("a" * (DATASET_NAME_LIMIT + 1), "String should have at most 128 characters"),
("table", 0, ""), (0, "Input should be a valid string"),
("paper", 0, ""), (None, "Input should be a valid string"),
("book", 0, ""),
("laws", 0, ""),
("presentation", 0, ""),
("picture", 0, ""),
("one", 0, ""),
("email", 0, ""),
("tag", 0, ""),
("", 0, ""),
(
"other_chunk_method",
102,
"'other_chunk_method' is not in ['naive', 'manual', 'qa', 'table', 'paper', 'book', 'laws', 'presentation', 'picture', 'one', 'email', 'tag']",
),
], ],
ids=["empty_name", "space_name", "too_long_name", "invalid_name", "None_name"],
) )
def test_chunk_method(self, get_http_api_auth, add_dataset_func, chunk_method, expected_code, expected_message): def test_name_invalid(self, get_http_api_auth, add_dataset_func, name, expected_message):
dataset_id = add_dataset_func dataset_id = add_dataset_func
res = update_dataset(get_http_api_auth, dataset_id, {"chunk_method": chunk_method}) payload = {"name": name}
assert res["code"] == expected_code res = update_dataset(get_http_api_auth, dataset_id, payload)
if expected_code == 0: assert res["code"] == 101, res
res = list_datasets(get_http_api_auth, {"id": dataset_id}) assert expected_message in res["message"], res
if chunk_method != "":
assert res["data"][0]["chunk_method"] == chunk_method
else:
assert res["data"][0]["chunk_method"] == "naive"
else:
assert res["message"] == expected_message
@pytest.mark.p3
def test_name_duplicated(self, get_http_api_auth, add_datasets_func):
dataset_ids = add_datasets_func[0]
name = "dataset_1"
payload = {"name": name}
res = update_dataset(get_http_api_auth, dataset_ids, payload)
assert res["code"] == 102, res
assert res["message"] == f"Dataset name '{name}' already exists", res
@pytest.mark.p3
def test_name_case_insensitive(self, get_http_api_auth, add_datasets_func):
dataset_id = add_datasets_func[0]
name = "DATASET_1"
payload = {"name": name}
res = update_dataset(get_http_api_auth, dataset_id, payload)
assert res["code"] == 102, res
assert res["message"] == f"Dataset name '{name}' already exists", res
@pytest.mark.p2
def test_avatar(self, get_http_api_auth, add_dataset_func, tmp_path): def test_avatar(self, get_http_api_auth, add_dataset_func, tmp_path):
dataset_id = add_dataset_func dataset_id = add_dataset_func
fn = create_image_file(tmp_path / "ragflow_test.png") fn = create_image_file(tmp_path / "ragflow_test.png")
payload = {"avatar": encode_avatar(fn)} payload = {
"avatar": f"data:image/png;base64,{encode_avatar(fn)}",
}
res = update_dataset(get_http_api_auth, dataset_id, payload) res = update_dataset(get_http_api_auth, dataset_id, payload)
assert res["code"] == 0 assert res["code"] == 0, res
res = list_datasets(get_http_api_auth)
assert res["code"] == 0, res
assert res["data"][0]["avatar"] == f"data:image/png;base64,{encode_avatar(fn)}", res
@pytest.mark.p2
def test_avatar_exceeds_limit_length(self, get_http_api_auth, add_dataset_func):
dataset_id = add_dataset_func
payload = {"avatar": "a" * 65536}
res = update_dataset(get_http_api_auth, dataset_id, payload)
assert res["code"] == 101, res
assert "String should have at most 65535 characters" in res["message"], res
@pytest.mark.p3
@pytest.mark.parametrize(
"name, avatar_prefix, expected_message",
[
("empty_prefix", "", "Missing MIME prefix. Expected format: data:<mime>;base64,<data>"),
("missing_comma", "data:image/png;base64", "Missing MIME prefix. Expected format: data:<mime>;base64,<data>"),
("unsupported_mine_type", "invalid_mine_prefix:image/png;base64,", "Invalid MIME prefix format. Must start with 'data:'"),
("invalid_mine_type", "data:unsupported_mine_type;base64,", "Unsupported MIME type. Allowed: ['image/jpeg', 'image/png']"),
],
ids=["empty_prefix", "missing_comma", "unsupported_mine_type", "invalid_mine_type"],
)
def test_avatar_invalid_prefix(self, get_http_api_auth, add_dataset_func, tmp_path, name, avatar_prefix, expected_message):
dataset_id = add_dataset_func
fn = create_image_file(tmp_path / "ragflow_test.png")
payload = {
"name": name,
"avatar": f"{avatar_prefix}{encode_avatar(fn)}",
}
res = update_dataset(get_http_api_auth, dataset_id, payload)
assert res["code"] == 101, res
assert expected_message in res["message"], res
@pytest.mark.p3
def test_avatar_none(self, get_http_api_auth, add_dataset_func):
dataset_id = add_dataset_func
payload = {"avatar": None}
res = update_dataset(get_http_api_auth, dataset_id, payload)
assert res["code"] == 0, res
res = list_datasets(get_http_api_auth)
assert res["code"] == 0, res
assert res["data"][0]["avatar"] is None, res
@pytest.mark.p2
def test_description(self, get_http_api_auth, add_dataset_func): def test_description(self, get_http_api_auth, add_dataset_func):
dataset_id = add_dataset_func dataset_id = add_dataset_func
payload = {"description": "description"} payload = {"description": "description"}
@ -153,95 +217,533 @@ class TestDatasetUpdate:
assert res["code"] == 0 assert res["code"] == 0
res = list_datasets(get_http_api_auth, {"id": dataset_id}) res = list_datasets(get_http_api_auth, {"id": dataset_id})
assert res["code"] == 0, res
assert res["data"][0]["description"] == "description" assert res["data"][0]["description"] == "description"
def test_pagerank(self, get_http_api_auth, add_dataset_func): @pytest.mark.p2
def test_description_exceeds_limit_length(self, get_http_api_auth, add_dataset_func):
dataset_id = add_dataset_func dataset_id = add_dataset_func
payload = {"pagerank": 1} payload = {"description": "a" * 65536}
res = update_dataset(get_http_api_auth, dataset_id, payload) res = update_dataset(get_http_api_auth, dataset_id, payload)
assert res["code"] == 0 assert res["code"] == 101, res
assert "String should have at most 65535 characters" in res["message"], res
@pytest.mark.p3
def test_description_none(self, get_http_api_auth, add_dataset_func):
dataset_id = add_dataset_func
payload = {"description": None}
res = update_dataset(get_http_api_auth, dataset_id, payload)
assert res["code"] == 0, res
res = list_datasets(get_http_api_auth, {"id": dataset_id}) res = list_datasets(get_http_api_auth, {"id": dataset_id})
assert res["data"][0]["pagerank"] == 1 assert res["code"] == 0, res
assert res["data"][0]["description"] is None
def test_similarity_threshold(self, get_http_api_auth, add_dataset_func):
dataset_id = add_dataset_func
payload = {"similarity_threshold": 1}
res = update_dataset(get_http_api_auth, dataset_id, payload)
assert res["code"] == 0
res = list_datasets(get_http_api_auth, {"id": dataset_id})
assert res["data"][0]["similarity_threshold"] == 1
@pytest.mark.p1
@pytest.mark.parametrize( @pytest.mark.parametrize(
"permission, expected_code", "embedding_model",
[ [
("me", 0), "BAAI/bge-large-zh-v1.5@BAAI",
("team", 0), "maidalun1020/bce-embedding-base_v1@Youdao",
("", 0), "embedding-3@ZHIPU-AI",
("ME", 102),
("TEAM", 102),
("other_permission", 102),
], ],
ids=["builtin_baai", "builtin_youdao", "tenant_zhipu"],
) )
def test_permission(self, get_http_api_auth, add_dataset_func, permission, expected_code): def test_embedding_model(self, get_http_api_auth, add_dataset_func, embedding_model):
dataset_id = add_dataset_func
payload = {"embedding_model": embedding_model}
res = update_dataset(get_http_api_auth, dataset_id, payload)
assert res["code"] == 0, res
res = list_datasets(get_http_api_auth)
assert res["code"] == 0, res
assert res["data"][0]["embedding_model"] == embedding_model, res
@pytest.mark.p2
@pytest.mark.parametrize(
"name, embedding_model",
[
("unknown_llm_name", "unknown@ZHIPU-AI"),
("unknown_llm_factory", "embedding-3@unknown"),
("tenant_no_auth_default_tenant_llm", "text-embedding-v3@Tongyi-Qianwen"),
("tenant_no_auth", "text-embedding-3-small@OpenAI"),
],
ids=["unknown_llm_name", "unknown_llm_factory", "tenant_no_auth_default_tenant_llm", "tenant_no_auth"],
)
def test_embedding_model_invalid(self, get_http_api_auth, add_dataset_func, name, embedding_model):
dataset_id = add_dataset_func
payload = {"name": name, "embedding_model": embedding_model}
res = update_dataset(get_http_api_auth, dataset_id, payload)
assert res["code"] == 101, res
if "tenant_no_auth" in name:
assert res["message"] == f"Unauthorized model: <{embedding_model}>", res
else:
assert res["message"] == f"Unsupported model: <{embedding_model}>", res
@pytest.mark.p2
@pytest.mark.parametrize(
"name, embedding_model",
[
("missing_at", "BAAI/bge-large-zh-v1.5BAAI"),
("missing_model_name", "@BAAI"),
("missing_provider", "BAAI/bge-large-zh-v1.5@"),
("whitespace_only_model_name", " @BAAI"),
("whitespace_only_provider", "BAAI/bge-large-zh-v1.5@ "),
],
ids=["missing_at", "empty_model_name", "empty_provider", "whitespace_only_model_name", "whitespace_only_provider"],
)
def test_embedding_model_format(self, get_http_api_auth, add_dataset_func, name, embedding_model):
dataset_id = add_dataset_func
payload = {"name": name, "embedding_model": embedding_model}
res = update_dataset(get_http_api_auth, dataset_id, payload)
assert res["code"] == 101, res
if name == "missing_at":
assert "Embedding model identifier must follow <model_name>@<provider> format" in res["message"], res
else:
assert "Both model_name and provider must be non-empty strings" in res["message"], res
@pytest.mark.p2
def test_embedding_model_none(self, get_http_api_auth, add_dataset_func):
dataset_id = add_dataset_func
payload = {"embedding_model": None}
res = update_dataset(get_http_api_auth, dataset_id, payload)
assert res["code"] == 101, res
assert "Input should be a valid string" in res["message"], res
@pytest.mark.p1
@pytest.mark.parametrize(
"name, permission",
[
("me", "me"),
("team", "team"),
("me_upercase", "ME"),
("team_upercase", "TEAM"),
],
ids=["me", "team", "me_upercase", "team_upercase"],
)
def test_permission(self, get_http_api_auth, add_dataset_func, name, permission):
dataset_id = add_dataset_func
payload = {"name": name, "permission": permission}
res = update_dataset(get_http_api_auth, dataset_id, payload)
assert res["code"] == 0, res
res = list_datasets(get_http_api_auth)
assert res["code"] == 0, res
assert res["data"][0]["permission"] == permission.lower(), res
@pytest.mark.p2
@pytest.mark.parametrize(
"permission",
[
"",
"unknown",
list(),
],
ids=["empty", "unknown", "type_error"],
)
def test_permission_invalid(self, get_http_api_auth, add_dataset_func, permission):
dataset_id = add_dataset_func dataset_id = add_dataset_func
payload = {"permission": permission} payload = {"permission": permission}
res = update_dataset(get_http_api_auth, dataset_id, payload) res = update_dataset(get_http_api_auth, dataset_id, payload)
assert res["code"] == expected_code assert res["code"] == 101
assert "Input should be 'me' or 'team'" in res["message"]
res = list_datasets(get_http_api_auth, {"id": dataset_id}) @pytest.mark.p3
if expected_code == 0 and permission != "": def test_permission_none(self, get_http_api_auth, add_dataset_func):
assert res["data"][0]["permission"] == permission
if permission == "":
assert res["data"][0]["permission"] == "me"
def test_vector_similarity_weight(self, get_http_api_auth, add_dataset_func):
dataset_id = add_dataset_func dataset_id = add_dataset_func
payload = {"vector_similarity_weight": 1} payload = {"name": "test_permission_none", "permission": None}
res = update_dataset(get_http_api_auth, dataset_id, payload)
assert res["code"] == 101, res
assert "Input should be 'me' or 'team'" in res["message"], res
@pytest.mark.p1
@pytest.mark.parametrize(
"chunk_method",
[
"naive",
"book",
"email",
"laws",
"manual",
"one",
"paper",
"picture",
"presentation",
"qa",
"table",
"tag",
],
ids=["naive", "book", "email", "laws", "manual", "one", "paper", "picture", "presentation", "qa", "table", "tag"],
)
def test_chunk_method(self, get_http_api_auth, add_dataset_func, chunk_method):
dataset_id = add_dataset_func
payload = {"chunk_method": chunk_method}
res = update_dataset(get_http_api_auth, dataset_id, payload)
assert res["code"] == 0, res
res = list_datasets(get_http_api_auth)
assert res["code"] == 0, res
assert res["data"][0]["chunk_method"] == chunk_method, res
@pytest.mark.p2
@pytest.mark.parametrize(
"chunk_method",
[
"",
"unknown",
list(),
],
ids=["empty", "unknown", "type_error"],
)
def test_chunk_method_invalid(self, get_http_api_auth, add_dataset_func, chunk_method):
dataset_id = add_dataset_func
payload = {"chunk_method": chunk_method}
res = update_dataset(get_http_api_auth, dataset_id, payload)
assert res["code"] == 101, res
assert "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', 'picture', 'presentation', 'qa', 'table' or 'tag'" in res["message"], res
@pytest.mark.p3
def test_chunk_method_none(self, get_http_api_auth, add_dataset_func):
dataset_id = add_dataset_func
payload = {"chunk_method": None}
res = update_dataset(get_http_api_auth, dataset_id, payload)
assert res["code"] == 101, res
assert "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', 'picture', 'presentation', 'qa', 'table' or 'tag'" in res["message"], res
@pytest.mark.p2
@pytest.mark.parametrize("pagerank", [0, 50, 100], ids=["min", "mid", "max"])
def test_pagerank(self, get_http_api_auth, add_dataset_func, pagerank):
dataset_id = add_dataset_func
payload = {"pagerank": pagerank}
res = update_dataset(get_http_api_auth, dataset_id, payload) res = update_dataset(get_http_api_auth, dataset_id, payload)
assert res["code"] == 0 assert res["code"] == 0
res = list_datasets(get_http_api_auth, {"id": dataset_id}) res = list_datasets(get_http_api_auth, {"id": dataset_id})
assert res["data"][0]["vector_similarity_weight"] == 1 assert res["code"] == 0, res
assert res["data"][0]["pagerank"] == pagerank
def test_invalid_dataset_id(self, get_http_api_auth): @pytest.mark.p2
res = update_dataset(get_http_api_auth, "invalid_dataset_id", {"name": "invalid_dataset_id"}) @pytest.mark.parametrize(
assert res["code"] == 102 "pagerank, expected_message",
assert res["message"] == "You don't own the dataset" [
(-1, "Input should be greater than or equal to 0"),
(101, "Input should be less than or equal to 100"),
],
ids=["min_limit", "max_limit"],
)
def test_pagerank_invalid(self, get_http_api_auth, add_dataset_func, pagerank, expected_message):
dataset_id = add_dataset_func
payload = {"pagerank": pagerank}
res = update_dataset(get_http_api_auth, dataset_id, payload)
assert res["code"] == 101, res
assert expected_message in res["message"], res
@pytest.mark.p3
def test_pagerank_none(self, get_http_api_auth, add_dataset_func):
dataset_id = add_dataset_func
payload = {"pagerank": None}
res = update_dataset(get_http_api_auth, dataset_id, payload)
assert res["code"] == 101, res
assert "Input should be a valid integer" in res["message"], res
@pytest.mark.p1
@pytest.mark.parametrize(
"parser_config",
[
{"auto_keywords": 0},
{"auto_keywords": 16},
{"auto_keywords": 32},
{"auto_questions": 0},
{"auto_questions": 5},
{"auto_questions": 10},
{"chunk_token_num": 1},
{"chunk_token_num": 1024},
{"chunk_token_num": 2048},
{"delimiter": "\n"},
{"delimiter": " "},
{"html4excel": True},
{"html4excel": False},
{"layout_recognize": "DeepDOC"},
{"layout_recognize": "Plain Text"},
{"tag_kb_ids": ["1", "2"]},
{"topn_tags": 1},
{"topn_tags": 5},
{"topn_tags": 10},
{"filename_embd_weight": 0.1},
{"filename_embd_weight": 0.5},
{"filename_embd_weight": 1.0},
{"task_page_size": 1},
{"task_page_size": None},
{"pages": [[1, 100]]},
{"pages": None},
{"graphrag": {"use_graphrag": True}},
{"graphrag": {"use_graphrag": False}},
{"graphrag": {"entity_types": ["age", "sex", "height", "weight"]}},
{"graphrag": {"method": "general"}},
{"graphrag": {"method": "light"}},
{"graphrag": {"community": True}},
{"graphrag": {"community": False}},
{"graphrag": {"resolution": True}},
{"graphrag": {"resolution": False}},
{"raptor": {"use_raptor": True}},
{"raptor": {"use_raptor": False}},
{"raptor": {"prompt": "Who are you?"}},
{"raptor": {"max_token": 1}},
{"raptor": {"max_token": 1024}},
{"raptor": {"max_token": 2048}},
{"raptor": {"threshold": 0.0}},
{"raptor": {"threshold": 0.5}},
{"raptor": {"threshold": 1.0}},
{"raptor": {"max_cluster": 1}},
{"raptor": {"max_cluster": 512}},
{"raptor": {"max_cluster": 1024}},
{"raptor": {"random_seed": 0}},
],
ids=[
"auto_keywords_min",
"auto_keywords_mid",
"auto_keywords_max",
"auto_questions_min",
"auto_questions_mid",
"auto_questions_max",
"chunk_token_num_min",
"chunk_token_num_mid",
"chunk_token_num_max",
"delimiter",
"delimiter_space",
"html4excel_true",
"html4excel_false",
"layout_recognize_DeepDOC",
"layout_recognize_navie",
"tag_kb_ids",
"topn_tags_min",
"topn_tags_mid",
"topn_tags_max",
"filename_embd_weight_min",
"filename_embd_weight_mid",
"filename_embd_weight_max",
"task_page_size_min",
"task_page_size_None",
"pages",
"pages_none",
"graphrag_true",
"graphrag_false",
"graphrag_entity_types",
"graphrag_method_general",
"graphrag_method_light",
"graphrag_community_true",
"graphrag_community_false",
"graphrag_resolution_true",
"graphrag_resolution_false",
"raptor_true",
"raptor_false",
"raptor_prompt",
"raptor_max_token_min",
"raptor_max_token_mid",
"raptor_max_token_max",
"raptor_threshold_min",
"raptor_threshold_mid",
"raptor_threshold_max",
"raptor_max_cluster_min",
"raptor_max_cluster_mid",
"raptor_max_cluster_max",
"raptor_random_seed_min",
],
)
def test_parser_config(self, get_http_api_auth, add_dataset_func, parser_config):
dataset_id = add_dataset_func
payload = {"parser_config": parser_config}
res = update_dataset(get_http_api_auth, dataset_id, payload)
assert res["code"] == 0, res
res = list_datasets(get_http_api_auth)
assert res["code"] == 0, res
for k, v in parser_config.items():
if isinstance(v, dict):
for kk, vv in v.items():
assert res["data"][0]["parser_config"][k][kk] == vv, res
else:
assert res["data"][0]["parser_config"][k] == v, res
@pytest.mark.p2
@pytest.mark.parametrize(
"parser_config, expected_message",
[
({"auto_keywords": -1}, "Input should be greater than or equal to 0"),
({"auto_keywords": 33}, "Input should be less than or equal to 32"),
({"auto_keywords": 3.14}, "Input should be a valid integer, got a number with a fractional part"),
({"auto_keywords": "string"}, "Input should be a valid integer, unable to parse string as an integer"),
({"auto_questions": -1}, "Input should be greater than or equal to 0"),
({"auto_questions": 11}, "Input should be less than or equal to 10"),
({"auto_questions": 3.14}, "Input should be a valid integer, got a number with a fractional part"),
({"auto_questions": "string"}, "Input should be a valid integer, unable to parse string as an integer"),
({"chunk_token_num": 0}, "Input should be greater than or equal to 1"),
({"chunk_token_num": 2049}, "Input should be less than or equal to 2048"),
({"chunk_token_num": 3.14}, "Input should be a valid integer, got a number with a fractional part"),
({"chunk_token_num": "string"}, "Input should be a valid integer, unable to parse string as an integer"),
({"delimiter": ""}, "String should have at least 1 character"),
({"html4excel": "string"}, "Input should be a valid boolean, unable to interpret input"),
({"tag_kb_ids": "1,2"}, "Input should be a valid list"),
({"tag_kb_ids": [1, 2]}, "Input should be a valid string"),
({"topn_tags": 0}, "Input should be greater than or equal to 1"),
({"topn_tags": 11}, "Input should be less than or equal to 10"),
({"topn_tags": 3.14}, "Input should be a valid integer, got a number with a fractional part"),
({"topn_tags": "string"}, "Input should be a valid integer, unable to parse string as an integer"),
({"filename_embd_weight": -1}, "Input should be greater than or equal to 0"),
({"filename_embd_weight": 1.1}, "Input should be less than or equal to 1"),
({"filename_embd_weight": "string"}, "Input should be a valid number, unable to parse string as a number"),
({"task_page_size": 0}, "Input should be greater than or equal to 1"),
({"task_page_size": 3.14}, "Input should be a valid integer, got a number with a fractional part"),
({"task_page_size": "string"}, "Input should be a valid integer, unable to parse string as an integer"),
({"pages": "1,2"}, "Input should be a valid list"),
({"pages": ["1,2"]}, "Input should be a valid list"),
({"pages": [["string1", "string2"]]}, "Input should be a valid integer, unable to parse string as an integer"),
({"graphrag": {"use_graphrag": "string"}}, "Input should be a valid boolean, unable to interpret input"),
({"graphrag": {"entity_types": "1,2"}}, "Input should be a valid list"),
({"graphrag": {"entity_types": [1, 2]}}, "nput should be a valid string"),
({"graphrag": {"method": "unknown"}}, "Input should be 'light' or 'general'"),
({"graphrag": {"method": None}}, "Input should be 'light' or 'general'"),
({"graphrag": {"community": "string"}}, "Input should be a valid boolean, unable to interpret input"),
({"graphrag": {"resolution": "string"}}, "Input should be a valid boolean, unable to interpret input"),
({"raptor": {"use_raptor": "string"}}, "Input should be a valid boolean, unable to interpret input"),
({"raptor": {"prompt": ""}}, "String should have at least 1 character"),
({"raptor": {"prompt": " "}}, "String should have at least 1 character"),
({"raptor": {"max_token": 0}}, "Input should be greater than or equal to 1"),
({"raptor": {"max_token": 2049}}, "Input should be less than or equal to 2048"),
({"raptor": {"max_token": 3.14}}, "Input should be a valid integer, got a number with a fractional part"),
({"raptor": {"max_token": "string"}}, "Input should be a valid integer, unable to parse string as an integer"),
({"raptor": {"threshold": -0.1}}, "Input should be greater than or equal to 0"),
({"raptor": {"threshold": 1.1}}, "Input should be less than or equal to 1"),
({"raptor": {"threshold": "string"}}, "Input should be a valid number, unable to parse string as a number"),
({"raptor": {"max_cluster": 0}}, "Input should be greater than or equal to 1"),
({"raptor": {"max_cluster": 1025}}, "Input should be less than or equal to 1024"),
({"raptor": {"max_cluster": 3.14}}, "Input should be a valid integer, got a number with a fractional par"),
({"raptor": {"max_cluster": "string"}}, "Input should be a valid integer, unable to parse string as an integer"),
({"raptor": {"random_seed": -1}}, "Input should be greater than or equal to 0"),
({"raptor": {"random_seed": 3.14}}, "Input should be a valid integer, got a number with a fractional part"),
({"raptor": {"random_seed": "string"}}, "Input should be a valid integer, unable to parse string as an integer"),
({"delimiter": "a" * 65536}, "Parser config exceeds size limit (max 65,535 characters)"),
],
ids=[
"auto_keywords_min_limit",
"auto_keywords_max_limit",
"auto_keywords_float_not_allowed",
"auto_keywords_type_invalid",
"auto_questions_min_limit",
"auto_questions_max_limit",
"auto_questions_float_not_allowed",
"auto_questions_type_invalid",
"chunk_token_num_min_limit",
"chunk_token_num_max_limit",
"chunk_token_num_float_not_allowed",
"chunk_token_num_type_invalid",
"delimiter_empty",
"html4excel_type_invalid",
"tag_kb_ids_not_list",
"tag_kb_ids_int_in_list",
"topn_tags_min_limit",
"topn_tags_max_limit",
"topn_tags_float_not_allowed",
"topn_tags_type_invalid",
"filename_embd_weight_min_limit",
"filename_embd_weight_max_limit",
"filename_embd_weight_type_invalid",
"task_page_size_min_limit",
"task_page_size_float_not_allowed",
"task_page_size_type_invalid",
"pages_not_list",
"pages_not_list_in_list",
"pages_not_int_list",
"graphrag_type_invalid",
"graphrag_entity_types_not_list",
"graphrag_entity_types_not_str_in_list",
"graphrag_method_unknown",
"graphrag_method_none",
"graphrag_community_type_invalid",
"graphrag_resolution_type_invalid",
"raptor_type_invalid",
"raptor_prompt_empty",
"raptor_prompt_space",
"raptor_max_token_min_limit",
"raptor_max_token_max_limit",
"raptor_max_token_float_not_allowed",
"raptor_max_token_type_invalid",
"raptor_threshold_min_limit",
"raptor_threshold_max_limit",
"raptor_threshold_type_invalid",
"raptor_max_cluster_min_limit",
"raptor_max_cluster_max_limit",
"raptor_max_cluster_float_not_allowed",
"raptor_max_cluster_type_invalid",
"raptor_random_seed_min_limit",
"raptor_random_seed_float_not_allowed",
"raptor_random_seed_type_invalid",
"parser_config_type_invalid",
],
)
def test_parser_config_invalid(self, get_http_api_auth, add_dataset_func, parser_config, expected_message):
dataset_id = add_dataset_func
payload = {"parser_config": parser_config}
res = update_dataset(get_http_api_auth, dataset_id, payload)
assert res["code"] == 101, res
assert expected_message in res["message"], res
@pytest.mark.p2
def test_parser_config_empty(self, get_http_api_auth, add_dataset_func):
dataset_id = add_dataset_func
payload = {"parser_config": {}}
res = update_dataset(get_http_api_auth, dataset_id, payload)
assert res["code"] == 0, res
res = list_datasets(get_http_api_auth)
assert res["code"] == 0, res
assert res["data"][0]["parser_config"] == {}
# @pytest.mark.p2
# def test_parser_config_unset(self, get_http_api_auth, add_dataset_func):
# dataset_id = add_dataset_func
# payload = {"name": "default_unset"}
# res = update_dataset(get_http_api_auth, dataset_id, payload)
# assert res["code"] == 0, res
# res = list_datasets(get_http_api_auth)
# assert res["code"] == 0, res
# assert res["data"][0]["parser_config"] == {
# "chunk_token_num": 128,
# "delimiter": r"\n",
# "html4excel": False,
# "layout_recognize": "DeepDOC",
# "raptor": {"use_raptor": False},
# }, res
@pytest.mark.p3
def test_parser_config_none(self, get_http_api_auth, add_dataset_func):
dataset_id = add_dataset_func
payload = {"parser_config": None}
res = update_dataset(get_http_api_auth, dataset_id, payload)
assert res["code"] == 101, res
assert "Input should be a valid dictionary or instance of ParserConfig" in res["message"], res
@pytest.mark.p2
@pytest.mark.parametrize( @pytest.mark.parametrize(
"payload", "payload",
[ [
{"chunk_count": 1}, {"id": "id"},
{"tenant_id": "e57c1966f99211efb41e9e45646e0111"},
{"created_by": "created_by"},
{"create_date": "Tue, 11 Mar 2025 13:37:23 GMT"}, {"create_date": "Tue, 11 Mar 2025 13:37:23 GMT"},
{"create_time": 1741671443322}, {"create_time": 1741671443322},
{"created_by": "aa"},
{"document_count": 1},
{"id": "id"},
{"status": "1"},
{"tenant_id": "e57c1966f99211efb41e9e45646e0111"},
{"token_num": 1},
{"update_date": "Tue, 11 Mar 2025 13:37:23 GMT"}, {"update_date": "Tue, 11 Mar 2025 13:37:23 GMT"},
{"update_time": 1741671443339}, {"update_time": 1741671443339},
{"document_count": 1},
{"chunk_count": 1},
{"token_num": 1},
{"status": "1"},
{"unknown_field": "unknown_field"},
], ],
) )
def test_modify_read_only_field(self, get_http_api_auth, add_dataset_func, payload): def test_unsupported_field(self, get_http_api_auth, add_dataset_func, payload):
dataset_id = add_dataset_func dataset_id = add_dataset_func
res = update_dataset(get_http_api_auth, dataset_id, payload) res = update_dataset(get_http_api_auth, dataset_id, payload)
assert res["code"] == 101 assert res["code"] == 101, res
assert "is readonly" in res["message"] assert "Extra inputs are not permitted" in res["message"], res
def test_modify_unknown_field(self, get_http_api_auth, add_dataset_func):
dataset_id = add_dataset_func
res = update_dataset(get_http_api_auth, dataset_id, {"unknown_field": 0})
assert res["code"] == 100
@pytest.mark.p3
def test_concurrent_update(self, get_http_api_auth, add_dataset_func):
dataset_id = add_dataset_func
with ThreadPoolExecutor(max_workers=5) as executor:
futures = [executor.submit(update_dataset, get_http_api_auth, dataset_id, {"name": f"dataset_{i}"}) for i in range(100)]
responses = [f.result() for f in futures]
assert all(r["code"] == 0 for r in responses)

View File

@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
# #
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from time import sleep
import pytest import pytest
from common import INVALID_API_TOKEN, bulk_upload_documents, list_documnets, parse_documnets, stop_parse_documnets from common import INVALID_API_TOKEN, bulk_upload_documents, list_documnets, parse_documnets, stop_parse_documnets
@ -173,6 +174,7 @@ def test_stop_parse_100_files(get_http_api_auth, add_dataset_func, tmp_path):
dataset_id = add_dataset_func dataset_id = add_dataset_func
document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, document_num, tmp_path) document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, document_num, tmp_path)
parse_documnets(get_http_api_auth, dataset_id, {"document_ids": document_ids}) parse_documnets(get_http_api_auth, dataset_id, {"document_ids": document_ids})
sleep(1)
res = stop_parse_documnets(get_http_api_auth, dataset_id, {"document_ids": document_ids}) res = stop_parse_documnets(get_http_api_auth, dataset_id, {"document_ids": document_ids})
assert res["code"] == 0 assert res["code"] == 0
validate_document_parse_cancel(get_http_api_auth, dataset_id, document_ids) validate_document_parse_cancel(get_http_api_auth, dataset_id, document_ids)