mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-08-10 19:08:58 +08:00
Refa: HTTP API update dataset / test cases / docs (#7564)
### What problem does this PR solve? This PR introduces Pydantic-based validation for the update dataset HTTP API, improving code clarity and robustness. Key changes include: 1. Pydantic Validation 2. Error Handling 3. Test Updates 4. Documentation Updates 5. fix bug: #5915 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) - [x] Documentation Update - [x] Refactoring
This commit is contained in:
parent
31718581b5
commit
35e36cb945
@ -27,22 +27,19 @@ from api.db.services.document_service import DocumentService
|
||||
from api.db.services.file2document_service import File2DocumentService
|
||||
from api.db.services.file_service import FileService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.llm_service import LLMService, TenantLLMService
|
||||
from api.db.services.user_service import TenantService
|
||||
from api.utils import get_uuid
|
||||
from api.utils.api_utils import (
|
||||
check_duplicate_ids,
|
||||
dataset_readonly_fields,
|
||||
deep_merge,
|
||||
get_error_argument_result,
|
||||
get_error_data_result,
|
||||
get_parser_config,
|
||||
get_result,
|
||||
token_required,
|
||||
valid,
|
||||
valid_parser_config,
|
||||
verify_embedding_availability,
|
||||
)
|
||||
from api.utils.validation_utils import CreateDatasetReq, validate_and_parse_json_request
|
||||
from api.utils.validation_utils import CreateDatasetReq, UpdateDatasetReq, validate_and_parse_json_request
|
||||
|
||||
|
||||
@manager.route("/datasets", methods=["POST"]) # noqa: F821
|
||||
@ -117,8 +114,8 @@ def create(tenant_id):
|
||||
return get_error_argument_result(err)
|
||||
|
||||
try:
|
||||
if KnowledgebaseService.query(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value):
|
||||
return get_error_argument_result(message=f"Dataset name '{req['name']}' already exists")
|
||||
if KnowledgebaseService.get_or_none(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value):
|
||||
return get_error_data_result(message=f"Dataset name '{req['name']}' already exists")
|
||||
except OperationalError as e:
|
||||
logging.exception(e)
|
||||
return get_error_data_result(message="Database operation failed")
|
||||
@ -145,7 +142,7 @@ def create(tenant_id):
|
||||
|
||||
try:
|
||||
if not KnowledgebaseService.save(**req):
|
||||
return get_error_data_result(message="Database operation failed")
|
||||
return get_error_data_result(message="Create dataset error.(Database error)")
|
||||
except OperationalError as e:
|
||||
logging.exception(e)
|
||||
return get_error_data_result(message="Database operation failed")
|
||||
@ -205,6 +202,7 @@ def delete(tenant_id):
|
||||
schema:
|
||||
type: object
|
||||
"""
|
||||
|
||||
errors = []
|
||||
success_count = 0
|
||||
req = request.json
|
||||
@ -291,16 +289,28 @@ def update(tenant_id, dataset_id):
|
||||
name:
|
||||
type: string
|
||||
description: New name of the dataset.
|
||||
avatar:
|
||||
type: string
|
||||
description: Updated base64 encoding of the avatar.
|
||||
description:
|
||||
type: string
|
||||
description: Updated description of the dataset.
|
||||
embedding_model:
|
||||
type: string
|
||||
description: Updated embedding model Name.
|
||||
permission:
|
||||
type: string
|
||||
enum: ['me', 'team']
|
||||
description: Updated permission.
|
||||
description: Updated dataset permission.
|
||||
chunk_method:
|
||||
type: string
|
||||
enum: ["naive", "manual", "qa", "table", "paper", "book", "laws",
|
||||
"presentation", "picture", "one", "email", "tag"
|
||||
enum: ["naive", "book", "email", "laws", "manual", "one", "paper",
|
||||
"picture", "presentation", "qa", "table", "tag"
|
||||
]
|
||||
description: Updated chunking method.
|
||||
pagerank:
|
||||
type: integer
|
||||
description: Updated page rank.
|
||||
parser_config:
|
||||
type: object
|
||||
description: Updated parser configuration.
|
||||
@ -310,98 +320,56 @@ def update(tenant_id, dataset_id):
|
||||
schema:
|
||||
type: object
|
||||
"""
|
||||
if not KnowledgebaseService.query(id=dataset_id, tenant_id=tenant_id):
|
||||
return get_error_data_result(message="You don't own the dataset")
|
||||
req = request.json
|
||||
for k in req.keys():
|
||||
if dataset_readonly_fields(k):
|
||||
return get_result(code=settings.RetCode.ARGUMENT_ERROR, message=f"'{k}' is readonly.")
|
||||
e, t = TenantService.get_by_id(tenant_id)
|
||||
invalid_keys = {"id", "embd_id", "chunk_num", "doc_num", "parser_id", "create_date", "create_time", "created_by", "status", "token_num", "update_date", "update_time"}
|
||||
if any(key in req for key in invalid_keys):
|
||||
return get_error_data_result(message="The input parameters are invalid.")
|
||||
permission = req.get("permission")
|
||||
chunk_method = req.get("chunk_method")
|
||||
parser_config = req.get("parser_config")
|
||||
valid_parser_config(parser_config)
|
||||
valid_permission = ["me", "team"]
|
||||
valid_chunk_method = ["naive", "manual", "qa", "table", "paper", "book", "laws", "presentation", "picture", "one", "email", "tag"]
|
||||
check_validation = valid(
|
||||
permission,
|
||||
valid_permission,
|
||||
chunk_method,
|
||||
valid_chunk_method,
|
||||
)
|
||||
if check_validation:
|
||||
return check_validation
|
||||
if "tenant_id" in req:
|
||||
if req["tenant_id"] != tenant_id:
|
||||
return get_error_data_result(message="Can't change `tenant_id`.")
|
||||
e, kb = KnowledgebaseService.get_by_id(dataset_id)
|
||||
if "parser_config" in req:
|
||||
temp_dict = kb.parser_config
|
||||
temp_dict.update(req["parser_config"])
|
||||
req["parser_config"] = temp_dict
|
||||
if "chunk_count" in req:
|
||||
if req["chunk_count"] != kb.chunk_num:
|
||||
return get_error_data_result(message="Can't change `chunk_count`.")
|
||||
req.pop("chunk_count")
|
||||
if "document_count" in req:
|
||||
if req["document_count"] != kb.doc_num:
|
||||
return get_error_data_result(message="Can't change `document_count`.")
|
||||
req.pop("document_count")
|
||||
if req.get("chunk_method"):
|
||||
if kb.chunk_num != 0 and req["chunk_method"] != kb.parser_id:
|
||||
return get_error_data_result(message="If `chunk_count` is not 0, `chunk_method` is not changeable.")
|
||||
req["parser_id"] = req.pop("chunk_method")
|
||||
if req["parser_id"] != kb.parser_id:
|
||||
if not req.get("parser_config"):
|
||||
req["parser_config"] = get_parser_config(chunk_method, parser_config)
|
||||
if "embedding_model" in req:
|
||||
if kb.chunk_num != 0 and req["embedding_model"] != kb.embd_id:
|
||||
return get_error_data_result(message="If `chunk_count` is not 0, `embedding_model` is not changeable.")
|
||||
if not req.get("embedding_model"):
|
||||
return get_error_data_result("`embedding_model` can't be empty")
|
||||
valid_embedding_models = [
|
||||
"BAAI/bge-large-zh-v1.5",
|
||||
"BAAI/bge-base-en-v1.5",
|
||||
"BAAI/bge-large-en-v1.5",
|
||||
"BAAI/bge-small-en-v1.5",
|
||||
"BAAI/bge-small-zh-v1.5",
|
||||
"jinaai/jina-embeddings-v2-base-en",
|
||||
"jinaai/jina-embeddings-v2-small-en",
|
||||
"nomic-ai/nomic-embed-text-v1.5",
|
||||
"sentence-transformers/all-MiniLM-L6-v2",
|
||||
"text-embedding-v2",
|
||||
"text-embedding-v3",
|
||||
"maidalun1020/bce-embedding-base_v1",
|
||||
]
|
||||
embd_model = LLMService.query(llm_name=req["embedding_model"], model_type="embedding")
|
||||
if embd_model:
|
||||
if req["embedding_model"] not in valid_embedding_models and not TenantLLMService.query(
|
||||
tenant_id=tenant_id,
|
||||
model_type="embedding",
|
||||
llm_name=req.get("embedding_model"),
|
||||
):
|
||||
return get_error_data_result(f"`embedding_model` {req.get('embedding_model')} doesn't exist")
|
||||
if not embd_model:
|
||||
embd_model = TenantLLMService.query(tenant_id=tenant_id, model_type="embedding", llm_name=req.get("embedding_model"))
|
||||
# Field name transformations during model dump:
|
||||
# | Original | Dump Output |
|
||||
# |----------------|-------------|
|
||||
# | embedding_model| embd_id |
|
||||
# | chunk_method | parser_id |
|
||||
extras = {"dataset_id": dataset_id}
|
||||
req, err = validate_and_parse_json_request(request, UpdateDatasetReq, extras=extras, exclude_unset=True)
|
||||
if err is not None:
|
||||
return get_error_argument_result(err)
|
||||
|
||||
if not req:
|
||||
return get_error_argument_result(message="No properties were modified")
|
||||
|
||||
try:
|
||||
kb = KnowledgebaseService.get_or_none(id=dataset_id, tenant_id=tenant_id)
|
||||
if kb is None:
|
||||
return get_error_data_result(message=f"User '{tenant_id}' lacks permission for dataset '{dataset_id}'")
|
||||
except OperationalError as e:
|
||||
logging.exception(e)
|
||||
return get_error_data_result(message="Database operation failed")
|
||||
|
||||
if req.get("parser_config"):
|
||||
req["parser_config"] = deep_merge(kb.parser_config, req["parser_config"])
|
||||
|
||||
if (chunk_method := req.get("parser_id")) and chunk_method != kb.parser_id and req.get("parser_config") is None:
|
||||
req["parser_config"] = get_parser_config(chunk_method, None)
|
||||
|
||||
if "name" in req and req["name"].lower() != kb.name.lower():
|
||||
try:
|
||||
exists = KnowledgebaseService.get_or_none(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value)
|
||||
if exists:
|
||||
return get_error_data_result(message=f"Dataset name '{req['name']}' already exists")
|
||||
except OperationalError as e:
|
||||
logging.exception(e)
|
||||
return get_error_data_result(message="Database operation failed")
|
||||
|
||||
if "embd_id" in req:
|
||||
if kb.chunk_num != 0 and req["embd_id"] != kb.embd_id:
|
||||
return get_error_data_result(message=f"When chunk_num ({kb.chunk_num}) > 0, embedding_model must remain {kb.embd_id}")
|
||||
ok, err = verify_embedding_availability(req["embd_id"], tenant_id)
|
||||
if not ok:
|
||||
return err
|
||||
|
||||
try:
|
||||
if not KnowledgebaseService.update_by_id(kb.id, req):
|
||||
return get_error_data_result(message="Update dataset error.(Database error)")
|
||||
except OperationalError as e:
|
||||
logging.exception(e)
|
||||
return get_error_data_result(message="Database operation failed")
|
||||
|
||||
if not embd_model:
|
||||
return get_error_data_result(f"`embedding_model` {req.get('embedding_model')} doesn't exist")
|
||||
req["embd_id"] = req.pop("embedding_model")
|
||||
if "name" in req:
|
||||
req["name"] = req["name"].strip()
|
||||
if len(req["name"]) >= 128:
|
||||
return get_error_data_result(message="Dataset name should not be longer than 128 characters.")
|
||||
if req["name"].lower() != kb.name.lower() and len(KnowledgebaseService.query(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value)) > 0:
|
||||
return get_error_data_result(message="Duplicated dataset name in updating dataset.")
|
||||
flds = list(req.keys())
|
||||
for f in flds:
|
||||
if req[f] == "" and f in ["permission", "parser_id", "chunk_method"]:
|
||||
del req[f]
|
||||
if not KnowledgebaseService.update_by_id(kb.id, req):
|
||||
return get_error_data_result(message="Update dataset error.(Database error)")
|
||||
return get_result(code=settings.RetCode.SUCCESS)
|
||||
|
||||
|
||||
|
@ -19,6 +19,7 @@ import logging
|
||||
import random
|
||||
import time
|
||||
from base64 import b64encode
|
||||
from copy import deepcopy
|
||||
from functools import wraps
|
||||
from hmac import HMAC
|
||||
from io import BytesIO
|
||||
@ -333,22 +334,6 @@ def generate_confirmation_token(tenant_id):
|
||||
return "ragflow-" + serializer.dumps(get_uuid(), salt=tenant_id)[2:34]
|
||||
|
||||
|
||||
def valid(permission, valid_permission, chunk_method, valid_chunk_method):
|
||||
if valid_parameter(permission, valid_permission):
|
||||
return valid_parameter(permission, valid_permission)
|
||||
if valid_parameter(chunk_method, valid_chunk_method):
|
||||
return valid_parameter(chunk_method, valid_chunk_method)
|
||||
|
||||
|
||||
def valid_parameter(parameter, valid_values):
|
||||
if parameter and parameter not in valid_values:
|
||||
return get_error_data_result(f"'{parameter}' is not in {valid_values}")
|
||||
|
||||
|
||||
def dataset_readonly_fields(field_name):
|
||||
return field_name in ["chunk_count", "create_date", "create_time", "update_date", "update_time", "created_by", "document_count", "token_num", "status", "tenant_id", "id"]
|
||||
|
||||
|
||||
def get_parser_config(chunk_method, parser_config):
|
||||
if parser_config:
|
||||
return parser_config
|
||||
@ -402,43 +387,6 @@ def get_data_openai(
|
||||
}
|
||||
|
||||
|
||||
def valid_parser_config(parser_config):
|
||||
if not parser_config:
|
||||
return
|
||||
scopes = set(
|
||||
[
|
||||
"chunk_token_num",
|
||||
"delimiter",
|
||||
"raptor",
|
||||
"graphrag",
|
||||
"layout_recognize",
|
||||
"task_page_size",
|
||||
"pages",
|
||||
"html4excel",
|
||||
"auto_keywords",
|
||||
"auto_questions",
|
||||
"tag_kb_ids",
|
||||
"topn_tags",
|
||||
"filename_embd_weight",
|
||||
]
|
||||
)
|
||||
for k in parser_config.keys():
|
||||
assert k in scopes, f"Abnormal 'parser_config'. Invalid key: {k}"
|
||||
|
||||
assert isinstance(parser_config.get("chunk_token_num", 1), int), "chunk_token_num should be int"
|
||||
assert 1 <= parser_config.get("chunk_token_num", 1) < 100000000, "chunk_token_num should be in range from 1 to 100000000"
|
||||
assert isinstance(parser_config.get("task_page_size", 1), int), "task_page_size should be int"
|
||||
assert 1 <= parser_config.get("task_page_size", 1) < 100000000, "task_page_size should be in range from 1 to 100000000"
|
||||
assert isinstance(parser_config.get("auto_keywords", 1), int), "auto_keywords should be int"
|
||||
assert 0 <= parser_config.get("auto_keywords", 0) < 32, "auto_keywords should be in range from 0 to 32"
|
||||
assert isinstance(parser_config.get("auto_questions", 1), int), "auto_questions should be int"
|
||||
assert 0 <= parser_config.get("auto_questions", 0) < 10, "auto_questions should be in range from 0 to 10"
|
||||
assert isinstance(parser_config.get("topn_tags", 1), int), "topn_tags should be int"
|
||||
assert 0 <= parser_config.get("topn_tags", 0) < 10, "topn_tags should be in range from 0 to 10"
|
||||
assert isinstance(parser_config.get("html4excel", False), bool), "html4excel should be True or False"
|
||||
assert isinstance(parser_config.get("delimiter", ""), str), "delimiter should be str"
|
||||
|
||||
|
||||
def check_duplicate_ids(ids, id_type="item"):
|
||||
"""
|
||||
Check for duplicate IDs in a list and return unique IDs and error messages.
|
||||
@ -469,7 +417,8 @@ def check_duplicate_ids(ids, id_type="item"):
|
||||
|
||||
|
||||
def verify_embedding_availability(embd_id: str, tenant_id: str) -> tuple[bool, Response | None]:
|
||||
"""Verifies availability of an embedding model for a specific tenant.
|
||||
"""
|
||||
Verifies availability of an embedding model for a specific tenant.
|
||||
|
||||
Implements a four-stage validation process:
|
||||
1. Model identifier parsing and validation
|
||||
@ -518,3 +467,50 @@ def verify_embedding_availability(embd_id: str, tenant_id: str) -> tuple[bool, R
|
||||
return False, get_error_data_result(message="Database operation failed")
|
||||
|
||||
return True, None
|
||||
|
||||
|
||||
def deep_merge(default: dict, custom: dict) -> dict:
|
||||
"""
|
||||
Recursively merges two dictionaries with priority given to `custom` values.
|
||||
|
||||
Creates a deep copy of the `default` dictionary and iteratively merges nested
|
||||
dictionaries using a stack-based approach. Non-dict values in `custom` will
|
||||
completely override corresponding entries in `default`.
|
||||
|
||||
Args:
|
||||
default (dict): Base dictionary containing default values.
|
||||
custom (dict): Dictionary containing overriding values.
|
||||
|
||||
Returns:
|
||||
dict: New merged dictionary combining values from both inputs.
|
||||
|
||||
Example:
|
||||
>>> from copy import deepcopy
|
||||
>>> default = {"a": 1, "nested": {"x": 10, "y": 20}}
|
||||
>>> custom = {"b": 2, "nested": {"y": 99, "z": 30}}
|
||||
>>> deep_merge(default, custom)
|
||||
{'a': 1, 'b': 2, 'nested': {'x': 10, 'y': 99, 'z': 30}}
|
||||
|
||||
>>> deep_merge({"config": {"mode": "auto"}}, {"config": "manual"})
|
||||
{'config': 'manual'}
|
||||
|
||||
Notes:
|
||||
1. Merge priority is always given to `custom` values at all nesting levels
|
||||
2. Non-dict values (e.g. list, str) in `custom` will replace entire values
|
||||
in `default`, even if the original value was a dictionary
|
||||
3. Time complexity: O(N) where N is total key-value pairs in `custom`
|
||||
4. Recommended for configuration merging and nested data updates
|
||||
"""
|
||||
merged = deepcopy(default)
|
||||
stack = [(merged, custom)]
|
||||
|
||||
while stack:
|
||||
base_dict, override_dict = stack.pop()
|
||||
|
||||
for key, val in override_dict.items():
|
||||
if key in base_dict and isinstance(val, dict) and isinstance(base_dict[key], dict):
|
||||
stack.append((base_dict[key], val))
|
||||
else:
|
||||
base_dict[key] = val
|
||||
|
||||
return merged
|
||||
|
@ -13,21 +13,23 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import uuid
|
||||
from enum import auto
|
||||
from typing import Annotated, Any
|
||||
|
||||
from flask import Request
|
||||
from pydantic import BaseModel, Field, StringConstraints, ValidationError, field_validator
|
||||
from pydantic import UUID1, BaseModel, Field, StringConstraints, ValidationError, field_serializer, field_validator
|
||||
from strenum import StrEnum
|
||||
from werkzeug.exceptions import BadRequest, UnsupportedMediaType
|
||||
|
||||
from api.constants import DATASET_NAME_LIMIT
|
||||
|
||||
|
||||
def validate_and_parse_json_request(request: Request, validator: type[BaseModel]) -> tuple[dict[str, Any] | None, str | None]:
|
||||
"""Validates and parses JSON requests through a multi-stage validation pipeline.
|
||||
def validate_and_parse_json_request(request: Request, validator: type[BaseModel], *, extras: dict[str, Any] | None = None, exclude_unset: bool = False) -> tuple[dict[str, Any] | None, str | None]:
|
||||
"""
|
||||
Validates and parses JSON requests through a multi-stage validation pipeline.
|
||||
|
||||
Implements a robust four-stage validation process:
|
||||
Implements a four-stage validation process:
|
||||
1. Content-Type verification (must be application/json)
|
||||
2. JSON syntax validation
|
||||
3. Payload structure type checking
|
||||
@ -35,6 +37,10 @@ def validate_and_parse_json_request(request: Request, validator: type[BaseModel]
|
||||
|
||||
Args:
|
||||
request (Request): Flask request object containing HTTP payload
|
||||
validator (type[BaseModel]): Pydantic model class for data validation
|
||||
extras (dict[str, Any] | None): Additional fields to merge into payload
|
||||
before validation. These fields will be removed from the final output
|
||||
exclude_unset (bool): Whether to exclude fields that have not been explicitly set
|
||||
|
||||
Returns:
|
||||
tuple[Dict[str, Any] | None, str | None]:
|
||||
@ -46,26 +52,26 @@ def validate_and_parse_json_request(request: Request, validator: type[BaseModel]
|
||||
- Diagnostic error message on failure
|
||||
|
||||
Raises:
|
||||
UnsupportedMediaType: When Content-Type ≠ application/json
|
||||
UnsupportedMediaType: When Content-Type header is not application/json
|
||||
BadRequest: For structural JSON syntax errors
|
||||
ValidationError: When payload violates Pydantic schema rules
|
||||
|
||||
Examples:
|
||||
Successful validation:
|
||||
```python
|
||||
# Input: {"name": "Dataset1", "format": "csv"}
|
||||
# Returns: ({"name": "Dataset1", "format": "csv"}, None)
|
||||
```
|
||||
>>> validate_and_parse_json_request(valid_request, DatasetSchema)
|
||||
({"name": "Dataset1", "format": "csv"}, None)
|
||||
|
||||
Invalid Content-Type:
|
||||
```python
|
||||
# Returns: (None, "Unsupported content type: Expected application/json, got text/xml")
|
||||
```
|
||||
>>> validate_and_parse_json_request(xml_request, DatasetSchema)
|
||||
(None, "Unsupported content type: Expected application/json, got text/xml")
|
||||
|
||||
Malformed JSON:
|
||||
```python
|
||||
# Returns: (None, "Malformed JSON syntax: Missing commas/brackets or invalid encoding")
|
||||
```
|
||||
>>> validate_and_parse_json_request(bad_json_request, DatasetSchema)
|
||||
(None, "Malformed JSON syntax: Missing commas/brackets or invalid encoding")
|
||||
|
||||
Notes:
|
||||
1. Validation Priority:
|
||||
- Content-Type verification precedes JSON parsing
|
||||
- Structural validation occurs before schema validation
|
||||
2. Extra fields added via `extras` parameter are automatically removed
|
||||
from the final output after validation
|
||||
"""
|
||||
try:
|
||||
payload = request.get_json() or {}
|
||||
@ -78,17 +84,25 @@ def validate_and_parse_json_request(request: Request, validator: type[BaseModel]
|
||||
return None, f"Invalid request payload: expected object, got {type(payload).__name__}"
|
||||
|
||||
try:
|
||||
if extras is not None:
|
||||
payload.update(extras)
|
||||
validated_request = validator(**payload)
|
||||
except ValidationError as e:
|
||||
return None, format_validation_error_message(e)
|
||||
|
||||
parsed_payload = validated_request.model_dump(by_alias=True)
|
||||
parsed_payload = validated_request.model_dump(by_alias=True, exclude_unset=exclude_unset)
|
||||
|
||||
if extras is not None:
|
||||
for key in list(parsed_payload.keys()):
|
||||
if key in extras:
|
||||
del parsed_payload[key]
|
||||
|
||||
return parsed_payload, None
|
||||
|
||||
|
||||
def format_validation_error_message(e: ValidationError) -> str:
|
||||
"""Formats validation errors into a standardized string format.
|
||||
"""
|
||||
Formats validation errors into a standardized string format.
|
||||
|
||||
Processes pydantic ValidationError objects to create human-readable error messages
|
||||
containing field locations, error descriptions, and input values.
|
||||
@ -155,7 +169,6 @@ class GraphragMethodEnum(StrEnum):
|
||||
class Base(BaseModel):
|
||||
class Config:
|
||||
extra = "forbid"
|
||||
json_schema_extra = {"charset": "utf8mb4", "collation": "utf8mb4_0900_ai_ci"}
|
||||
|
||||
|
||||
class RaptorConfig(Base):
|
||||
@ -201,16 +214,17 @@ class CreateDatasetReq(Base):
|
||||
name: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1, max_length=DATASET_NAME_LIMIT), Field(...)]
|
||||
avatar: str | None = Field(default=None, max_length=65535)
|
||||
description: str | None = Field(default=None, max_length=65535)
|
||||
embedding_model: Annotated[str | None, StringConstraints(strip_whitespace=True, max_length=255), Field(default=None, serialization_alias="embd_id")]
|
||||
embedding_model: Annotated[str, StringConstraints(strip_whitespace=True, max_length=255), Field(default="", serialization_alias="embd_id")]
|
||||
permission: Annotated[PermissionEnum, StringConstraints(strip_whitespace=True, min_length=1, max_length=16), Field(default=PermissionEnum.me)]
|
||||
chunk_method: Annotated[ChunkMethodnEnum, StringConstraints(strip_whitespace=True, min_length=1, max_length=32), Field(default=ChunkMethodnEnum.naive, serialization_alias="parser_id")]
|
||||
pagerank: int = Field(default=0, ge=0, le=100)
|
||||
parser_config: ParserConfig | None = Field(default=None)
|
||||
parser_config: ParserConfig = Field(default_factory=dict)
|
||||
|
||||
@field_validator("avatar")
|
||||
@classmethod
|
||||
def validate_avatar_base64(cls, v: str) -> str:
|
||||
"""Validates Base64-encoded avatar string format and MIME type compliance.
|
||||
def validate_avatar_base64(cls, v: str | None) -> str | None:
|
||||
"""
|
||||
Validates Base64-encoded avatar string format and MIME type compliance.
|
||||
|
||||
Implements a three-stage validation workflow:
|
||||
1. MIME prefix existence check
|
||||
@ -259,7 +273,8 @@ class CreateDatasetReq(Base):
|
||||
@field_validator("embedding_model", mode="after")
|
||||
@classmethod
|
||||
def validate_embedding_model(cls, v: str) -> str:
|
||||
"""Validates embedding model identifier format compliance.
|
||||
"""
|
||||
Validates embedding model identifier format compliance.
|
||||
|
||||
Validation pipeline:
|
||||
1. Structural format verification
|
||||
@ -298,11 +313,12 @@ class CreateDatasetReq(Base):
|
||||
|
||||
@field_validator("permission", mode="before")
|
||||
@classmethod
|
||||
def permission_auto_lowercase(cls, v: str) -> str:
|
||||
"""Normalize permission input to lowercase for consistent PermissionEnum matching.
|
||||
def permission_auto_lowercase(cls, v: Any) -> Any:
|
||||
"""
|
||||
Normalize permission input to lowercase for consistent PermissionEnum matching.
|
||||
|
||||
Args:
|
||||
v (str): Raw input value for the permission field
|
||||
v (Any): Raw input value for the permission field
|
||||
|
||||
Returns:
|
||||
Lowercase string if input is string type, otherwise returns original value
|
||||
@ -316,13 +332,13 @@ class CreateDatasetReq(Base):
|
||||
|
||||
@field_validator("parser_config", mode="after")
|
||||
@classmethod
|
||||
def validate_parser_config_json_length(cls, v: ParserConfig | None) -> ParserConfig | None:
|
||||
"""Validates serialized JSON length constraints for parser configuration.
|
||||
def validate_parser_config_json_length(cls, v: ParserConfig) -> ParserConfig:
|
||||
"""
|
||||
Validates serialized JSON length constraints for parser configuration.
|
||||
|
||||
Implements a three-stage validation workflow:
|
||||
1. Null check - bypass validation for empty configurations
|
||||
2. Model serialization - convert Pydantic model to JSON string
|
||||
3. Size verification - enforce maximum allowed payload size
|
||||
Implements a two-stage validation workflow:
|
||||
1. Model serialization - convert Pydantic model to JSON string
|
||||
2. Size verification - enforce maximum allowed payload size
|
||||
|
||||
Args:
|
||||
v (ParserConfig | None): Raw parser configuration object
|
||||
@ -333,9 +349,15 @@ class CreateDatasetReq(Base):
|
||||
Raises:
|
||||
ValueError: When serialized JSON exceeds 65,535 characters
|
||||
"""
|
||||
if v is None:
|
||||
return v
|
||||
|
||||
if (json_str := v.model_dump_json()) and len(json_str) > 65535:
|
||||
raise ValueError(f"Parser config exceeds size limit (max 65,535 characters). Current size: {len(json_str):,}")
|
||||
return v
|
||||
|
||||
|
||||
class UpdateDatasetReq(CreateDatasetReq):
|
||||
dataset_id: UUID1 = Field(...)
|
||||
name: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1, max_length=DATASET_NAME_LIMIT), Field(default="")]
|
||||
|
||||
@field_serializer("dataset_id")
|
||||
def serialize_uuid_to_hex(self, v: uuid.UUID) -> str:
|
||||
return v.hex
|
||||
|
@ -385,7 +385,7 @@ curl --request POST \
|
||||
- `"team"`: All team members can manage the dataset.
|
||||
|
||||
- `"pagerank"`: (*Body parameter*), `int`
|
||||
Set page rank: refer to [Set page rank](https://ragflow.io/docs/dev/set_page_rank)
|
||||
refer to [Set page rank](https://ragflow.io/docs/dev/set_page_rank)
|
||||
- Default: `0`
|
||||
- Minimum: `0`
|
||||
- Maximum: `100`
|
||||
@ -562,8 +562,13 @@ Updates configurations for a specified dataset.
|
||||
- `'Authorization: Bearer <YOUR_API_KEY>'`
|
||||
- Body:
|
||||
- `"name"`: `string`
|
||||
- `"avatar"`: `string`
|
||||
- `"description"`: `string`
|
||||
- `"embedding_model"`: `string`
|
||||
- `"chunk_method"`: `enum<string>`
|
||||
- `"permission"`: `string`
|
||||
- `"chunk_method"`: `string`
|
||||
- `"pagerank"`: `int`
|
||||
- `"parser_config"`: `object`
|
||||
|
||||
##### Request example
|
||||
|
||||
@ -584,22 +589,74 @@ curl --request PUT \
|
||||
The ID of the dataset to update.
|
||||
- `"name"`: (*Body parameter*), `string`
|
||||
The revised name of the dataset.
|
||||
- Basic Multilingual Plane (BMP) only
|
||||
- Maximum 128 characters
|
||||
- Case-insensitive
|
||||
- `"avatar"`: (*Body parameter*), `string`
|
||||
The updated base64 encoding of the avatar.
|
||||
- Maximum 65535 characters
|
||||
- `"embedding_model"`: (*Body parameter*), `string`
|
||||
The updated embedding model name.
|
||||
- Ensure that `"chunk_count"` is `0` before updating `"embedding_model"`.
|
||||
- Maximum 255 characters
|
||||
- Must follow `model_name@model_factory` format
|
||||
- `"permission"`: (*Body parameter*), `string`
|
||||
The updated dataset permission. Available options:
|
||||
- `"me"`: (Default) Only you can manage the dataset.
|
||||
- `"team"`: All team members can manage the dataset.
|
||||
- `"pagerank"`: (*Body parameter*), `int`
|
||||
refer to [Set page rank](https://ragflow.io/docs/dev/set_page_rank)
|
||||
- Default: `0`
|
||||
- Minimum: `0`
|
||||
- Maximum: `100`
|
||||
- `"chunk_method"`: (*Body parameter*), `enum<string>`
|
||||
The chunking method for the dataset. Available options:
|
||||
- `"naive"`: General
|
||||
- `"manual`: Manual
|
||||
- `"naive"`: General (default)
|
||||
- `"book"`: Book
|
||||
- `"email"`: Email
|
||||
- `"laws"`: Laws
|
||||
- `"manual"`: Manual
|
||||
- `"one"`: One
|
||||
- `"paper"`: Paper
|
||||
- `"picture"`: Picture
|
||||
- `"presentation"`: Presentation
|
||||
- `"qa"`: Q&A
|
||||
- `"table"`: Table
|
||||
- `"paper"`: Paper
|
||||
- `"book"`: Book
|
||||
- `"laws"`: Laws
|
||||
- `"presentation"`: Presentation
|
||||
- `"picture"`: Picture
|
||||
- `"one"`:One
|
||||
- `"email"`: Email
|
||||
- `"tag"`: Tag
|
||||
- `"parser_config"`: (*Body parameter*), `object`
|
||||
The configuration settings for the dataset parser. The attributes in this JSON object vary with the selected `"chunk_method"`:
|
||||
- If `"chunk_method"` is `"naive"`, the `"parser_config"` object contains the following attributes:
|
||||
- `"auto_keywords"`: `int`
|
||||
- Defaults to `0`
|
||||
- Minimum: `0`
|
||||
- Maximum: `32`
|
||||
- `"auto_questions"`: `int`
|
||||
- Defaults to `0`
|
||||
- Minimum: `0`
|
||||
- Maximum: `10`
|
||||
- `"chunk_token_num"`: `int`
|
||||
- Defaults to `128`
|
||||
- Minimum: `1`
|
||||
- Maximum: `2048`
|
||||
- `"delimiter"`: `string`
|
||||
- Defaults to `"\n"`.
|
||||
- `"html4excel"`: `bool` Indicates whether to convert Excel documents into HTML format.
|
||||
- Defaults to `false`
|
||||
- `"layout_recognize"`: `string`
|
||||
- Defaults to `DeepDOC`
|
||||
- `"tag_kb_ids"`: `array<string>` refer to [Use tag set](https://ragflow.io/docs/dev/use_tag_sets)
|
||||
- Must include a list of dataset IDs, where each dataset is parsed using the Tag Chunk Method
|
||||
- `"task_page_size"`: `int` For PDF only.
|
||||
- Defaults to `12`
|
||||
- Minimum: `1`
|
||||
- `"raptor"`: `object` RAPTOR-specific settings.
|
||||
- Defaults to: `{"use_raptor": false}`
|
||||
- `"graphrag"`: `object` GRAPHRAG-specific settings.
|
||||
- Defaults to: `{"use_graphrag": false}`
|
||||
- If `"chunk_method"` is `"qa"`, `"manuel"`, `"paper"`, `"book"`, `"laws"`, or `"presentation"`, the `"parser_config"` object contains the following attribute:
|
||||
- `"raptor"`: `object` RAPTOR-specific settings.
|
||||
- Defaults to: `{"use_raptor": false}`.
|
||||
- If `"chunk_method"` is `"table"`, `"picture"`, `"one"`, or `"email"`, `"parser_config"` is an empty JSON object.
|
||||
|
||||
#### Response
|
||||
|
||||
|
@ -306,20 +306,40 @@ Updates configurations for the current dataset.
|
||||
A dictionary representing the attributes to update, with the following keys:
|
||||
|
||||
- `"name"`: `str` The revised name of the dataset.
|
||||
- `"embedding_model"`: `str` The updated embedding model name.
|
||||
- Basic Multilingual Plane (BMP) only
|
||||
- Maximum 128 characters
|
||||
- Case-insensitive
|
||||
- `"avatar"`: (*Body parameter*), `string`
|
||||
The updated base64 encoding of the avatar.
|
||||
- Maximum 65535 characters
|
||||
- `"embedding_model"`: (*Body parameter*), `string`
|
||||
The updated embedding model name.
|
||||
- Ensure that `"chunk_count"` is `0` before updating `"embedding_model"`.
|
||||
- `"chunk_method"`: `str` The chunking method for the dataset. Available options:
|
||||
- `"naive"`: General
|
||||
- `"manual`: Manual
|
||||
- Maximum 255 characters
|
||||
- Must follow `model_name@model_factory` format
|
||||
- `"permission"`: (*Body parameter*), `string`
|
||||
The updated dataset permission. Available options:
|
||||
- `"me"`: (Default) Only you can manage the dataset.
|
||||
- `"team"`: All team members can manage the dataset.
|
||||
- `"pagerank"`: (*Body parameter*), `int`
|
||||
refer to [Set page rank](https://ragflow.io/docs/dev/set_page_rank)
|
||||
- Default: `0`
|
||||
- Minimum: `0`
|
||||
- Maximum: `100`
|
||||
- `"chunk_method"`: (*Body parameter*), `enum<string>`
|
||||
The chunking method for the dataset. Available options:
|
||||
- `"naive"`: General (default)
|
||||
- `"book"`: Book
|
||||
- `"email"`: Email
|
||||
- `"laws"`: Laws
|
||||
- `"manual"`: Manual
|
||||
- `"one"`: One
|
||||
- `"paper"`: Paper
|
||||
- `"picture"`: Picture
|
||||
- `"presentation"`: Presentation
|
||||
- `"qa"`: Q&A
|
||||
- `"table"`: Table
|
||||
- `"paper"`: Paper
|
||||
- `"book"`: Book
|
||||
- `"laws"`: Laws
|
||||
- `"presentation"`: Presentation
|
||||
- `"picture"`: Picture
|
||||
- `"one"`: One
|
||||
- `"email"`: Email
|
||||
- `"tag"`: Tag
|
||||
|
||||
#### Returns
|
||||
|
||||
|
@ -59,21 +59,19 @@ class RAGFlow:
|
||||
pagerank: int = 0,
|
||||
parser_config: DataSet.ParserConfig = None,
|
||||
) -> DataSet:
|
||||
if parser_config:
|
||||
parser_config = parser_config.to_json()
|
||||
res = self.post(
|
||||
"/datasets",
|
||||
{
|
||||
"name": name,
|
||||
"avatar": avatar,
|
||||
"description": description,
|
||||
"embedding_model": embedding_model,
|
||||
"permission": permission,
|
||||
"chunk_method": chunk_method,
|
||||
"pagerank": pagerank,
|
||||
"parser_config": parser_config,
|
||||
},
|
||||
)
|
||||
payload = {
|
||||
"name": name,
|
||||
"avatar": avatar,
|
||||
"description": description,
|
||||
"embedding_model": embedding_model,
|
||||
"permission": permission,
|
||||
"chunk_method": chunk_method,
|
||||
"pagerank": pagerank,
|
||||
}
|
||||
if parser_config is not None:
|
||||
payload["parser_config"] = parser_config.to_json()
|
||||
|
||||
res = self.post("/datasets", payload)
|
||||
res = res.json()
|
||||
if res.get("code") == 0:
|
||||
return DataSet(self, res["data"])
|
||||
|
28
sdk/python/test/libs/utils/hypothesis_utils.py
Normal file
28
sdk/python/test/libs/utils/hypothesis_utils.py
Normal file
@ -0,0 +1,28 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
|
||||
import hypothesis.strategies as st
|
||||
|
||||
|
||||
@st.composite
|
||||
def valid_names(draw):
|
||||
base_chars = "abcdefghijklmnopqrstuvwxyz_"
|
||||
first_char = draw(st.sampled_from([c for c in base_chars if c.isalpha() or c == "_"]))
|
||||
remaining = draw(st.text(alphabet=st.sampled_from(base_chars), min_size=0, max_size=128 - 2))
|
||||
|
||||
name = (first_char + remaining)[:128]
|
||||
return name.encode("utf-8").decode("utf-8")
|
@ -39,23 +39,23 @@ SESSION_WITH_CHAT_NAME_LIMIT = 255
|
||||
|
||||
|
||||
# DATASET MANAGEMENT
|
||||
def create_dataset(auth, payload=None, headers=HEADERS, data=None):
|
||||
def create_dataset(auth, payload=None, *, headers=HEADERS, data=None):
|
||||
res = requests.post(url=f"{HOST_ADDRESS}{DATASETS_API_URL}", headers=headers, auth=auth, json=payload, data=data)
|
||||
return res.json()
|
||||
|
||||
|
||||
def list_datasets(auth, params=None, headers=HEADERS):
|
||||
def list_datasets(auth, params=None, *, headers=HEADERS):
|
||||
res = requests.get(url=f"{HOST_ADDRESS}{DATASETS_API_URL}", headers=headers, auth=auth, params=params)
|
||||
return res.json()
|
||||
|
||||
|
||||
def update_dataset(auth, dataset_id, payload=None, headers=HEADERS):
|
||||
res = requests.put(url=f"{HOST_ADDRESS}{DATASETS_API_URL}/{dataset_id}", headers=headers, auth=auth, json=payload)
|
||||
def update_dataset(auth, dataset_id, payload=None, *, headers=HEADERS, data=None):
|
||||
res = requests.put(url=f"{HOST_ADDRESS}{DATASETS_API_URL}/{dataset_id}", headers=headers, auth=auth, json=payload, data=data)
|
||||
return res.json()
|
||||
|
||||
|
||||
def delete_datasets(auth, payload=None, headers=HEADERS):
|
||||
res = requests.delete(url=f"{HOST_ADDRESS}{DATASETS_API_URL}", headers=headers, auth=auth, json=payload)
|
||||
def delete_datasets(auth, payload=None, *, headers=HEADERS, data=None):
|
||||
res = requests.delete(url=f"{HOST_ADDRESS}{DATASETS_API_URL}", headers=headers, auth=auth, json=payload, data=data)
|
||||
return res.json()
|
||||
|
||||
|
||||
|
@ -37,3 +37,13 @@ def add_datasets_func(get_http_api_auth, request):
|
||||
request.addfinalizer(cleanup)
|
||||
|
||||
return batch_create_datasets(get_http_api_auth, 3)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def add_dataset_func(get_http_api_auth, request):
|
||||
def cleanup():
|
||||
delete_datasets(get_http_api_auth)
|
||||
|
||||
request.addfinalizer(cleanup)
|
||||
|
||||
return batch_create_datasets(get_http_api_auth, 1)[0]
|
||||
|
@ -13,30 +13,20 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
|
||||
import hypothesis.strategies as st
|
||||
import pytest
|
||||
from common import DATASET_NAME_LIMIT, INVALID_API_TOKEN, create_dataset
|
||||
from hypothesis import example, given, settings
|
||||
from libs.auth import RAGFlowHttpApiAuth
|
||||
from libs.utils import encode_avatar
|
||||
from libs.utils.file_utils import create_image_file
|
||||
from libs.utils.hypothesis_utils import valid_names
|
||||
|
||||
|
||||
@st.composite
|
||||
def valid_names(draw):
|
||||
base_chars = "abcdefghijklmnopqrstuvwxyz_"
|
||||
first_char = draw(st.sampled_from([c for c in base_chars if c.isalpha() or c == "_"]))
|
||||
remaining = draw(st.text(alphabet=st.sampled_from(base_chars), min_size=0, max_size=DATASET_NAME_LIMIT - 2))
|
||||
|
||||
name = (first_char + remaining)[:128]
|
||||
return name.encode("utf-8").decode("utf-8")
|
||||
|
||||
|
||||
@pytest.mark.p1
|
||||
@pytest.mark.usefixtures("clear_datasets")
|
||||
class TestAuthorization:
|
||||
@pytest.mark.p1
|
||||
@pytest.mark.parametrize(
|
||||
"auth, expected_code, expected_message",
|
||||
[
|
||||
@ -49,64 +39,17 @@ class TestAuthorization:
|
||||
],
|
||||
ids=["empty_auth", "invalid_api_token"],
|
||||
)
|
||||
def test_invalid_auth(self, auth, expected_code, expected_message):
|
||||
def test_auth_invalid(self, auth, expected_code, expected_message):
|
||||
res = create_dataset(auth, {"name": "auth_test"})
|
||||
assert res["code"] == expected_code
|
||||
assert res["message"] == expected_message
|
||||
assert res["code"] == expected_code, res
|
||||
assert res["message"] == expected_message, res
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("clear_datasets")
|
||||
class TestDatasetCreation:
|
||||
@pytest.mark.p1
|
||||
@given(name=valid_names())
|
||||
@example("a" * 128)
|
||||
@settings(max_examples=20)
|
||||
def test_valid_name(self, get_http_api_auth, name):
|
||||
res = create_dataset(get_http_api_auth, {"name": name})
|
||||
assert res["code"] == 0, res
|
||||
assert res["data"]["name"] == name, res
|
||||
|
||||
@pytest.mark.p1
|
||||
@pytest.mark.parametrize(
|
||||
"name, expected_message",
|
||||
[
|
||||
("", "String should have at least 1 character"),
|
||||
(" ", "String should have at least 1 character"),
|
||||
("a" * (DATASET_NAME_LIMIT + 1), "String should have at most 128 characters"),
|
||||
(0, "Input should be a valid string"),
|
||||
],
|
||||
ids=["empty_name", "space_name", "too_long_name", "invalid_name"],
|
||||
)
|
||||
def test_invalid_name(self, get_http_api_auth, name, expected_message):
|
||||
res = create_dataset(get_http_api_auth, {"name": name})
|
||||
assert res["code"] == 101, res
|
||||
assert expected_message in res["message"], res
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_duplicated_name(self, get_http_api_auth):
|
||||
name = "duplicated_name"
|
||||
payload = {"name": name}
|
||||
res = create_dataset(get_http_api_auth, payload)
|
||||
assert res["code"] == 0, res
|
||||
|
||||
res = create_dataset(get_http_api_auth, payload)
|
||||
assert res["code"] == 101, res
|
||||
assert res["message"] == f"Dataset name '{name}' already exists", res
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_case_insensitive(self, get_http_api_auth):
|
||||
name = "CaseInsensitive"
|
||||
res = create_dataset(get_http_api_auth, {"name": name.upper()})
|
||||
assert res["code"] == 0, res
|
||||
|
||||
res = create_dataset(get_http_api_auth, {"name": name.lower()})
|
||||
assert res["code"] == 101, res
|
||||
assert res["message"] == f"Dataset name '{name.lower()}' already exists", res
|
||||
|
||||
class TestRquest:
|
||||
@pytest.mark.p3
|
||||
def test_bad_content_type(self, get_http_api_auth):
|
||||
def test_content_type_bad(self, get_http_api_auth):
|
||||
BAD_CONTENT_TYPE = "text/xml"
|
||||
res = create_dataset(get_http_api_auth, {"name": "name"}, {"Content-Type": BAD_CONTENT_TYPE})
|
||||
res = create_dataset(get_http_api_auth, {"name": "bad_content_type"}, headers={"Content-Type": BAD_CONTENT_TYPE})
|
||||
assert res["code"] == 101, res
|
||||
assert res["message"] == f"Unsupported content type: Expected application/json, got {BAD_CONTENT_TYPE}", res
|
||||
|
||||
@ -115,15 +58,85 @@ class TestDatasetCreation:
|
||||
"payload, expected_message",
|
||||
[
|
||||
("a", "Malformed JSON syntax: Missing commas/brackets or invalid encoding"),
|
||||
('"a"', "Invalid request payload: expected objec"),
|
||||
('"a"', "Invalid request payload: expected object, got str"),
|
||||
],
|
||||
ids=["malformed_json_syntax", "invalid_request_payload_type"],
|
||||
)
|
||||
def test_bad_payload(self, get_http_api_auth, payload, expected_message):
|
||||
def test_payload_bad(self, get_http_api_auth, payload, expected_message):
|
||||
res = create_dataset(get_http_api_auth, data=payload)
|
||||
assert res["code"] == 101, res
|
||||
assert res["message"] == expected_message, res
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("clear_datasets")
|
||||
class TestCapability:
|
||||
@pytest.mark.p3
|
||||
def test_create_dataset_1k(self, get_http_api_auth):
|
||||
for i in range(1_000):
|
||||
payload = {"name": f"dataset_{i}"}
|
||||
res = create_dataset(get_http_api_auth, payload)
|
||||
assert res["code"] == 0, f"Failed to create dataset {i}"
|
||||
|
||||
@pytest.mark.p3
|
||||
def test_create_dataset_concurrent(self, get_http_api_auth):
|
||||
with ThreadPoolExecutor(max_workers=5) as executor:
|
||||
futures = [executor.submit(create_dataset, get_http_api_auth, {"name": f"dataset_{i}"}) for i in range(100)]
|
||||
responses = [f.result() for f in futures]
|
||||
assert all(r["code"] == 0 for r in responses), responses
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("clear_datasets")
|
||||
class TestDatasetCreate:
|
||||
@pytest.mark.p1
|
||||
@given(name=valid_names())
|
||||
@example("a" * 128)
|
||||
@settings(max_examples=20)
|
||||
def test_name(self, get_http_api_auth, name):
|
||||
res = create_dataset(get_http_api_auth, {"name": name})
|
||||
assert res["code"] == 0, res
|
||||
assert res["data"]["name"] == name, res
|
||||
|
||||
@pytest.mark.p2
|
||||
@pytest.mark.parametrize(
|
||||
"name, expected_message",
|
||||
[
|
||||
("", "String should have at least 1 character"),
|
||||
(" ", "String should have at least 1 character"),
|
||||
("a" * (DATASET_NAME_LIMIT + 1), "String should have at most 128 characters"),
|
||||
(0, "Input should be a valid string"),
|
||||
(None, "Input should be a valid string"),
|
||||
],
|
||||
ids=["empty_name", "space_name", "too_long_name", "invalid_name", "None_name"],
|
||||
)
|
||||
def test_name_invalid(self, get_http_api_auth, name, expected_message):
|
||||
payload = {"name": name}
|
||||
res = create_dataset(get_http_api_auth, payload)
|
||||
assert res["code"] == 101, res
|
||||
assert expected_message in res["message"], res
|
||||
|
||||
@pytest.mark.p3
|
||||
def test_name_duplicated(self, get_http_api_auth):
|
||||
name = "duplicated_name"
|
||||
payload = {"name": name}
|
||||
res = create_dataset(get_http_api_auth, payload)
|
||||
assert res["code"] == 0, res
|
||||
|
||||
res = create_dataset(get_http_api_auth, payload)
|
||||
assert res["code"] == 102, res
|
||||
assert res["message"] == f"Dataset name '{name}' already exists", res
|
||||
|
||||
@pytest.mark.p3
|
||||
def test_name_case_insensitive(self, get_http_api_auth):
|
||||
name = "CaseInsensitive"
|
||||
payload = {"name": name.upper()}
|
||||
res = create_dataset(get_http_api_auth, payload)
|
||||
assert res["code"] == 0, res
|
||||
|
||||
payload = {"name": name.lower()}
|
||||
res = create_dataset(get_http_api_auth, payload)
|
||||
assert res["code"] == 102, res
|
||||
assert res["message"] == f"Dataset name '{name.lower()}' already exists", res
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_avatar(self, get_http_api_auth, tmp_path):
|
||||
fn = create_image_file(tmp_path / "ragflow_test.png")
|
||||
@ -134,16 +147,10 @@ class TestDatasetCreation:
|
||||
res = create_dataset(get_http_api_auth, payload)
|
||||
assert res["code"] == 0, res
|
||||
|
||||
@pytest.mark.p3
|
||||
def test_avatar_none(self, get_http_api_auth, tmp_path):
|
||||
payload = {"name": "test_avatar_none", "avatar": None}
|
||||
res = create_dataset(get_http_api_auth, payload)
|
||||
assert res["code"] == 0, res
|
||||
assert res["data"]["avatar"] is None, res
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_avatar_exceeds_limit_length(self, get_http_api_auth):
|
||||
res = create_dataset(get_http_api_auth, {"name": "exceeds_limit_length_avatar", "avatar": "a" * 65536})
|
||||
payload = {"name": "exceeds_limit_length_avatar", "avatar": "a" * 65536}
|
||||
res = create_dataset(get_http_api_auth, payload)
|
||||
assert res["code"] == 101, res
|
||||
assert "String should have at most 65535 characters" in res["message"], res
|
||||
|
||||
@ -158,7 +165,7 @@ class TestDatasetCreation:
|
||||
],
|
||||
ids=["empty_prefix", "missing_comma", "unsupported_mine_type", "invalid_mine_type"],
|
||||
)
|
||||
def test_invalid_avatar_prefix(self, get_http_api_auth, tmp_path, name, avatar_prefix, expected_message):
|
||||
def test_avatar_invalid_prefix(self, get_http_api_auth, tmp_path, name, avatar_prefix, expected_message):
|
||||
fn = create_image_file(tmp_path / "ragflow_test.png")
|
||||
payload = {
|
||||
"name": name,
|
||||
@ -169,11 +176,25 @@ class TestDatasetCreation:
|
||||
assert expected_message in res["message"], res
|
||||
|
||||
@pytest.mark.p3
|
||||
def test_description_none(self, get_http_api_auth):
|
||||
payload = {"name": "test_description_none", "description": None}
|
||||
def test_avatar_unset(self, get_http_api_auth):
|
||||
payload = {"name": "test_avatar_unset"}
|
||||
res = create_dataset(get_http_api_auth, payload)
|
||||
assert res["code"] == 0, res
|
||||
assert res["data"]["description"] is None, res
|
||||
assert res["data"]["avatar"] is None, res
|
||||
|
||||
@pytest.mark.p3
|
||||
def test_avatar_none(self, get_http_api_auth):
|
||||
payload = {"name": "test_avatar_none", "avatar": None}
|
||||
res = create_dataset(get_http_api_auth, payload)
|
||||
assert res["code"] == 0, res
|
||||
assert res["data"]["avatar"] is None, res
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_description(self, get_http_api_auth):
|
||||
payload = {"name": "test_description", "description": "description"}
|
||||
res = create_dataset(get_http_api_auth, payload)
|
||||
assert res["code"] == 0, res
|
||||
assert res["data"]["description"] == "description", res
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_description_exceeds_limit_length(self, get_http_api_auth):
|
||||
@ -182,6 +203,20 @@ class TestDatasetCreation:
|
||||
assert res["code"] == 101, res
|
||||
assert "String should have at most 65535 characters" in res["message"], res
|
||||
|
||||
@pytest.mark.p3
|
||||
def test_description_unset(self, get_http_api_auth):
|
||||
payload = {"name": "test_description_unset"}
|
||||
res = create_dataset(get_http_api_auth, payload)
|
||||
assert res["code"] == 0, res
|
||||
assert res["data"]["description"] is None, res
|
||||
|
||||
@pytest.mark.p3
|
||||
def test_description_none(self, get_http_api_auth):
|
||||
payload = {"name": "test_description_none", "description": None}
|
||||
res = create_dataset(get_http_api_auth, payload)
|
||||
assert res["code"] == 0, res
|
||||
assert res["data"]["description"] is None, res
|
||||
|
||||
@pytest.mark.p1
|
||||
@pytest.mark.parametrize(
|
||||
"name, embedding_model",
|
||||
@ -189,22 +224,14 @@ class TestDatasetCreation:
|
||||
("BAAI/bge-large-zh-v1.5@BAAI", "BAAI/bge-large-zh-v1.5@BAAI"),
|
||||
("maidalun1020/bce-embedding-base_v1@Youdao", "maidalun1020/bce-embedding-base_v1@Youdao"),
|
||||
("embedding-3@ZHIPU-AI", "embedding-3@ZHIPU-AI"),
|
||||
("embedding_model_default", None),
|
||||
],
|
||||
ids=["builtin_baai", "builtin_youdao", "tenant_zhipu", "default"],
|
||||
ids=["builtin_baai", "builtin_youdao", "tenant_zhipu"],
|
||||
)
|
||||
def test_valid_embedding_model(self, get_http_api_auth, name, embedding_model):
|
||||
if embedding_model is None:
|
||||
payload = {"name": name}
|
||||
else:
|
||||
payload = {"name": name, "embedding_model": embedding_model}
|
||||
|
||||
def test_embedding_model(self, get_http_api_auth, name, embedding_model):
|
||||
payload = {"name": name, "embedding_model": embedding_model}
|
||||
res = create_dataset(get_http_api_auth, payload)
|
||||
assert res["code"] == 0, res
|
||||
if embedding_model is None:
|
||||
assert res["data"]["embedding_model"] == "BAAI/bge-large-zh-v1.5@BAAI", res
|
||||
else:
|
||||
assert res["data"]["embedding_model"] == embedding_model, res
|
||||
assert res["data"]["embedding_model"] == embedding_model, res
|
||||
|
||||
@pytest.mark.p2
|
||||
@pytest.mark.parametrize(
|
||||
@ -217,7 +244,7 @@ class TestDatasetCreation:
|
||||
],
|
||||
ids=["unknown_llm_name", "unknown_llm_factory", "tenant_no_auth_default_tenant_llm", "tenant_no_auth"],
|
||||
)
|
||||
def test_invalid_embedding_model(self, get_http_api_auth, name, embedding_model):
|
||||
def test_embedding_model_invalid(self, get_http_api_auth, name, embedding_model):
|
||||
payload = {"name": name, "embedding_model": embedding_model}
|
||||
res = create_dataset(get_http_api_auth, payload)
|
||||
assert res["code"] == 101, res
|
||||
@ -247,6 +274,20 @@ class TestDatasetCreation:
|
||||
else:
|
||||
assert "Both model_name and provider must be non-empty strings" in res["message"], res
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_embedding_model_unset(self, get_http_api_auth):
|
||||
payload = {"name": "embedding_model_unset"}
|
||||
res = create_dataset(get_http_api_auth, payload)
|
||||
assert res["code"] == 0, res
|
||||
assert res["data"]["embedding_model"] == "BAAI/bge-large-zh-v1.5@BAAI", res
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_embedding_model_none(self, get_http_api_auth):
|
||||
payload = {"name": "test_embedding_model_none", "embedding_model": None}
|
||||
res = create_dataset(get_http_api_auth, payload)
|
||||
assert res["code"] == 101, res
|
||||
assert "Input should be a valid string" in res["message"], res
|
||||
|
||||
@pytest.mark.p1
|
||||
@pytest.mark.parametrize(
|
||||
"name, permission",
|
||||
@ -255,21 +296,14 @@ class TestDatasetCreation:
|
||||
("team", "team"),
|
||||
("me_upercase", "ME"),
|
||||
("team_upercase", "TEAM"),
|
||||
("permission_default", None),
|
||||
],
|
||||
ids=["me", "team", "me_upercase", "team_upercase", "permission_default"],
|
||||
ids=["me", "team", "me_upercase", "team_upercase"],
|
||||
)
|
||||
def test_valid_permission(self, get_http_api_auth, name, permission):
|
||||
if permission is None:
|
||||
payload = {"name": name}
|
||||
else:
|
||||
payload = {"name": name, "permission": permission}
|
||||
def test_permission(self, get_http_api_auth, name, permission):
|
||||
payload = {"name": name, "permission": permission}
|
||||
res = create_dataset(get_http_api_auth, payload)
|
||||
assert res["code"] == 0, res
|
||||
if permission is None:
|
||||
assert res["data"]["permission"] == "me", res
|
||||
else:
|
||||
assert res["data"]["permission"] == permission.lower(), res
|
||||
assert res["data"]["permission"] == permission.lower(), res
|
||||
|
||||
@pytest.mark.p2
|
||||
@pytest.mark.parametrize(
|
||||
@ -279,13 +313,28 @@ class TestDatasetCreation:
|
||||
("unknown", "unknown"),
|
||||
("type_error", list()),
|
||||
],
|
||||
ids=["empty", "unknown", "type_error"],
|
||||
)
|
||||
def test_invalid_permission(self, get_http_api_auth, name, permission):
|
||||
def test_permission_invalid(self, get_http_api_auth, name, permission):
|
||||
payload = {"name": name, "permission": permission}
|
||||
res = create_dataset(get_http_api_auth, payload)
|
||||
assert res["code"] == 101
|
||||
assert "Input should be 'me' or 'team'" in res["message"]
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_permission_unset(self, get_http_api_auth):
|
||||
payload = {"name": "test_permission_unset"}
|
||||
res = create_dataset(get_http_api_auth, payload)
|
||||
assert res["code"] == 0, res
|
||||
assert res["data"]["permission"] == "me", res
|
||||
|
||||
@pytest.mark.p3
|
||||
def test_permission_none(self, get_http_api_auth):
|
||||
payload = {"name": "test_permission_none", "permission": None}
|
||||
res = create_dataset(get_http_api_auth, payload)
|
||||
assert res["code"] == 101, res
|
||||
assert "Input should be 'me' or 'team'" in res["message"], res
|
||||
|
||||
@pytest.mark.p1
|
||||
@pytest.mark.parametrize(
|
||||
"name, chunk_method",
|
||||
@ -302,20 +351,14 @@ class TestDatasetCreation:
|
||||
("qa", "qa"),
|
||||
("table", "table"),
|
||||
("tag", "tag"),
|
||||
("chunk_method_default", None),
|
||||
],
|
||||
ids=["naive", "book", "email", "laws", "manual", "one", "paper", "picture", "presentation", "qa", "table", "tag"],
|
||||
)
|
||||
def test_valid_chunk_method(self, get_http_api_auth, name, chunk_method):
|
||||
if chunk_method is None:
|
||||
payload = {"name": name}
|
||||
else:
|
||||
payload = {"name": name, "chunk_method": chunk_method}
|
||||
def test_chunk_method(self, get_http_api_auth, name, chunk_method):
|
||||
payload = {"name": name, "chunk_method": chunk_method}
|
||||
res = create_dataset(get_http_api_auth, payload)
|
||||
assert res["code"] == 0, res
|
||||
if chunk_method is None:
|
||||
assert res["data"]["chunk_method"] == "naive", res
|
||||
else:
|
||||
assert res["data"]["chunk_method"] == chunk_method, res
|
||||
assert res["data"]["chunk_method"] == chunk_method, res
|
||||
|
||||
@pytest.mark.p2
|
||||
@pytest.mark.parametrize(
|
||||
@ -325,19 +368,77 @@ class TestDatasetCreation:
|
||||
("unknown", "unknown"),
|
||||
("type_error", list()),
|
||||
],
|
||||
ids=["empty", "unknown", "type_error"],
|
||||
)
|
||||
def test_invalid_chunk_method(self, get_http_api_auth, name, chunk_method):
|
||||
def test_chunk_method_invalid(self, get_http_api_auth, name, chunk_method):
|
||||
payload = {"name": name, "chunk_method": chunk_method}
|
||||
res = create_dataset(get_http_api_auth, payload)
|
||||
assert res["code"] == 101, res
|
||||
assert "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', 'picture', 'presentation', 'qa', 'table' or 'tag'" in res["message"], res
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_chunk_method_unset(self, get_http_api_auth):
|
||||
payload = {"name": "test_chunk_method_unset"}
|
||||
res = create_dataset(get_http_api_auth, payload)
|
||||
assert res["code"] == 0, res
|
||||
assert res["data"]["chunk_method"] == "naive", res
|
||||
|
||||
@pytest.mark.p3
|
||||
def test_chunk_method_none(self, get_http_api_auth):
|
||||
payload = {"name": "chunk_method_none", "chunk_method": None}
|
||||
res = create_dataset(get_http_api_auth, payload)
|
||||
assert res["code"] == 101, res
|
||||
assert "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', 'picture', 'presentation', 'qa', 'table' or 'tag'" in res["message"], res
|
||||
|
||||
@pytest.mark.p2
|
||||
@pytest.mark.parametrize(
|
||||
"name, pagerank",
|
||||
[
|
||||
("pagerank_min", 0),
|
||||
("pagerank_mid", 50),
|
||||
("pagerank_max", 100),
|
||||
],
|
||||
ids=["min", "mid", "max"],
|
||||
)
|
||||
def test_pagerank(self, get_http_api_auth, name, pagerank):
|
||||
payload = {"name": name, "pagerank": pagerank}
|
||||
res = create_dataset(get_http_api_auth, payload)
|
||||
assert res["code"] == 0, res
|
||||
assert res["data"]["pagerank"] == pagerank, res
|
||||
|
||||
@pytest.mark.p3
|
||||
@pytest.mark.parametrize(
|
||||
"name, pagerank, expected_message",
|
||||
[
|
||||
("pagerank_min_limit", -1, "Input should be greater than or equal to 0"),
|
||||
("pagerank_max_limit", 101, "Input should be less than or equal to 100"),
|
||||
],
|
||||
ids=["min_limit", "max_limit"],
|
||||
)
|
||||
def test_pagerank_invalid(self, get_http_api_auth, name, pagerank, expected_message):
|
||||
payload = {"name": name, "pagerank": pagerank}
|
||||
res = create_dataset(get_http_api_auth, payload)
|
||||
assert res["code"] == 101, res
|
||||
assert expected_message in res["message"], res
|
||||
|
||||
@pytest.mark.p3
|
||||
def test_pagerank_unset(self, get_http_api_auth):
|
||||
payload = {"name": "pagerank_unset"}
|
||||
res = create_dataset(get_http_api_auth, payload)
|
||||
assert res["code"] == 0, res
|
||||
assert res["data"]["pagerank"] == 0, res
|
||||
|
||||
@pytest.mark.p3
|
||||
def test_pagerank_none(self, get_http_api_auth):
|
||||
payload = {"name": "pagerank_unset", "pagerank": None}
|
||||
res = create_dataset(get_http_api_auth, payload)
|
||||
assert res["code"] == 101, res
|
||||
assert "Input should be a valid integer" in res["message"], res
|
||||
|
||||
@pytest.mark.p1
|
||||
@pytest.mark.parametrize(
|
||||
"name, parser_config",
|
||||
[
|
||||
("default_none", None),
|
||||
("default_empty", {}),
|
||||
("auto_keywords_min", {"auto_keywords": 0}),
|
||||
("auto_keywords_mid", {"auto_keywords": 16}),
|
||||
("auto_keywords_max", {"auto_keywords": 32}),
|
||||
@ -363,7 +464,7 @@ class TestDatasetCreation:
|
||||
("task_page_size_min", {"task_page_size": 1}),
|
||||
("task_page_size_None", {"task_page_size": None}),
|
||||
("pages", {"pages": [[1, 100]]}),
|
||||
("pages_none", None),
|
||||
("pages_none", {"pages": None}),
|
||||
("graphrag_true", {"graphrag": {"use_graphrag": True}}),
|
||||
("graphrag_false", {"graphrag": {"use_graphrag": False}}),
|
||||
("graphrag_entity_types", {"graphrag": {"entity_types": ["age", "sex", "height", "weight"]}}),
|
||||
@ -388,8 +489,6 @@ class TestDatasetCreation:
|
||||
("raptor_random_seed_min", {"raptor": {"random_seed": 0}}),
|
||||
],
|
||||
ids=[
|
||||
"default_none",
|
||||
"default_empty",
|
||||
"auto_keywords_min",
|
||||
"auto_keywords_mid",
|
||||
"auto_keywords_max",
|
||||
@ -440,44 +539,16 @@ class TestDatasetCreation:
|
||||
"raptor_random_seed_min",
|
||||
],
|
||||
)
|
||||
def test_valid_parser_config(self, get_http_api_auth, name, parser_config):
|
||||
if parser_config is None:
|
||||
payload = {"name": name}
|
||||
else:
|
||||
payload = {"name": name, "parser_config": parser_config}
|
||||
def test_parser_config(self, get_http_api_auth, name, parser_config):
|
||||
payload = {"name": name, "parser_config": parser_config}
|
||||
res = create_dataset(get_http_api_auth, payload)
|
||||
assert res["code"] == 0, res
|
||||
if parser_config is None:
|
||||
assert res["data"]["parser_config"] == {
|
||||
"chunk_token_num": 128,
|
||||
"delimiter": r"\n",
|
||||
"html4excel": False,
|
||||
"layout_recognize": "DeepDOC",
|
||||
"raptor": {"use_raptor": False},
|
||||
}
|
||||
elif parser_config == {}:
|
||||
assert res["data"]["parser_config"] == {
|
||||
"auto_keywords": 0,
|
||||
"auto_questions": 0,
|
||||
"chunk_token_num": 128,
|
||||
"delimiter": r"\n",
|
||||
"filename_embd_weight": None,
|
||||
"graphrag": None,
|
||||
"html4excel": False,
|
||||
"layout_recognize": "DeepDOC",
|
||||
"pages": None,
|
||||
"raptor": None,
|
||||
"tag_kb_ids": [],
|
||||
"task_page_size": None,
|
||||
"topn_tags": 1,
|
||||
}
|
||||
else:
|
||||
for k, v in parser_config.items():
|
||||
if isinstance(v, dict):
|
||||
for kk, vv in v.items():
|
||||
assert res["data"]["parser_config"][k][kk] == vv
|
||||
else:
|
||||
assert res["data"]["parser_config"][k] == v
|
||||
for k, v in parser_config.items():
|
||||
if isinstance(v, dict):
|
||||
for kk, vv in v.items():
|
||||
assert res["data"]["parser_config"][k][kk] == vv, res
|
||||
else:
|
||||
assert res["data"]["parser_config"][k] == v, res
|
||||
|
||||
@pytest.mark.p2
|
||||
@pytest.mark.parametrize(
|
||||
@ -595,15 +666,72 @@ class TestDatasetCreation:
|
||||
"parser_config_type_invalid",
|
||||
],
|
||||
)
|
||||
def test_invalid_parser_config(self, get_http_api_auth, name, parser_config, expected_message):
|
||||
def test_parser_config_invalid(self, get_http_api_auth, name, parser_config, expected_message):
|
||||
payload = {"name": name, "parser_config": parser_config}
|
||||
res = create_dataset(get_http_api_auth, payload)
|
||||
assert res["code"] == 101, res
|
||||
assert expected_message in res["message"], res
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_parser_config_empty(self, get_http_api_auth):
|
||||
payload = {"name": "default_empty", "parser_config": {}}
|
||||
res = create_dataset(get_http_api_auth, payload)
|
||||
assert res["code"] == 0, res
|
||||
assert res["data"]["parser_config"] == {
|
||||
"auto_keywords": 0,
|
||||
"auto_questions": 0,
|
||||
"chunk_token_num": 128,
|
||||
"delimiter": r"\n",
|
||||
"filename_embd_weight": None,
|
||||
"graphrag": None,
|
||||
"html4excel": False,
|
||||
"layout_recognize": "DeepDOC",
|
||||
"pages": None,
|
||||
"raptor": None,
|
||||
"tag_kb_ids": [],
|
||||
"task_page_size": None,
|
||||
"topn_tags": 1,
|
||||
}
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_parser_config_unset(self, get_http_api_auth):
|
||||
payload = {"name": "default_unset"}
|
||||
res = create_dataset(get_http_api_auth, payload)
|
||||
assert res["code"] == 0, res
|
||||
assert res["data"]["parser_config"] == {
|
||||
"chunk_token_num": 128,
|
||||
"delimiter": r"\n",
|
||||
"html4excel": False,
|
||||
"layout_recognize": "DeepDOC",
|
||||
"raptor": {"use_raptor": False},
|
||||
}, res
|
||||
|
||||
@pytest.mark.p3
|
||||
def test_dataset_10k(self, get_http_api_auth):
|
||||
for i in range(10_000):
|
||||
payload = {"name": f"dataset_{i}"}
|
||||
res = create_dataset(get_http_api_auth, payload)
|
||||
assert res["code"] == 0, f"Failed to create dataset {i}"
|
||||
def test_parser_config_none(self, get_http_api_auth):
|
||||
payload = {"name": "default_none", "parser_config": None}
|
||||
res = create_dataset(get_http_api_auth, payload)
|
||||
assert res["code"] == 101, res
|
||||
assert "Input should be a valid dictionary or instance of ParserConfig" in res["message"], res
|
||||
|
||||
@pytest.mark.p2
|
||||
@pytest.mark.parametrize(
|
||||
"payload",
|
||||
[
|
||||
{"name": "id", "id": "id"},
|
||||
{"name": "tenant_id", "tenant_id": "e57c1966f99211efb41e9e45646e0111"},
|
||||
{"name": "created_by", "created_by": "created_by"},
|
||||
{"name": "create_date", "create_date": "Tue, 11 Mar 2025 13:37:23 GMT"},
|
||||
{"name": "create_time", "create_time": 1741671443322},
|
||||
{"name": "update_date", "update_date": "Tue, 11 Mar 2025 13:37:23 GMT"},
|
||||
{"name": "update_time", "update_time": 1741671443339},
|
||||
{"name": "document_count", "document_count": 1},
|
||||
{"name": "chunk_count", "chunk_count": 1},
|
||||
{"name": "token_num", "token_num": 1},
|
||||
{"name": "status", "status": "1"},
|
||||
{"name": "unknown_field", "unknown_field": "unknown_field"},
|
||||
],
|
||||
)
|
||||
def test_unsupported_field(self, get_http_api_auth, payload):
|
||||
res = create_dataset(get_http_api_auth, payload)
|
||||
assert res["code"] == 101, res
|
||||
assert "Extra inputs are not permitted" in res["message"], res
|
||||
|
@ -16,21 +16,18 @@
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import pytest
|
||||
from common import (
|
||||
DATASET_NAME_LIMIT,
|
||||
INVALID_API_TOKEN,
|
||||
list_datasets,
|
||||
update_dataset,
|
||||
)
|
||||
from common import DATASET_NAME_LIMIT, INVALID_API_TOKEN, list_datasets, update_dataset
|
||||
from hypothesis import HealthCheck, example, given, settings
|
||||
from libs.auth import RAGFlowHttpApiAuth
|
||||
from libs.utils import encode_avatar
|
||||
from libs.utils.file_utils import create_image_file
|
||||
from libs.utils.hypothesis_utils import valid_names
|
||||
|
||||
# TODO: Missing scenario for updating embedding_model with chunk_count != 0
|
||||
|
||||
|
||||
@pytest.mark.p1
|
||||
class TestAuthorization:
|
||||
@pytest.mark.p1
|
||||
@pytest.mark.parametrize(
|
||||
"auth, expected_code, expected_message",
|
||||
[
|
||||
@ -41,111 +38,178 @@ class TestAuthorization:
|
||||
"Authentication error: API key is invalid!",
|
||||
),
|
||||
],
|
||||
ids=["empty_auth", "invalid_api_token"],
|
||||
)
|
||||
def test_invalid_auth(self, auth, expected_code, expected_message):
|
||||
def test_auth_invalid(self, auth, expected_code, expected_message):
|
||||
res = update_dataset(auth, "dataset_id")
|
||||
assert res["code"] == expected_code
|
||||
assert res["message"] == expected_message
|
||||
assert res["code"] == expected_code, res
|
||||
assert res["message"] == expected_message, res
|
||||
|
||||
|
||||
class TestRquest:
|
||||
@pytest.mark.p3
|
||||
def test_bad_content_type(self, get_http_api_auth, add_dataset_func):
|
||||
dataset_id = add_dataset_func
|
||||
BAD_CONTENT_TYPE = "text/xml"
|
||||
res = update_dataset(get_http_api_auth, dataset_id, {"name": "bad_content_type"}, headers={"Content-Type": BAD_CONTENT_TYPE})
|
||||
assert res["code"] == 101, res
|
||||
assert res["message"] == f"Unsupported content type: Expected application/json, got {BAD_CONTENT_TYPE}", res
|
||||
|
||||
@pytest.mark.p3
|
||||
@pytest.mark.parametrize(
|
||||
"payload, expected_message",
|
||||
[
|
||||
("a", "Malformed JSON syntax: Missing commas/brackets or invalid encoding"),
|
||||
('"a"', "Invalid request payload: expected object, got str"),
|
||||
],
|
||||
ids=["malformed_json_syntax", "invalid_request_payload_type"],
|
||||
)
|
||||
def test_payload_bad(self, get_http_api_auth, add_dataset_func, payload, expected_message):
|
||||
dataset_id = add_dataset_func
|
||||
res = update_dataset(get_http_api_auth, dataset_id, data=payload)
|
||||
assert res["code"] == 101, res
|
||||
assert res["message"] == expected_message, res
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_payload_empty(self, get_http_api_auth, add_dataset_func):
|
||||
dataset_id = add_dataset_func
|
||||
res = update_dataset(get_http_api_auth, dataset_id, {})
|
||||
assert res["code"] == 101, res
|
||||
assert res["message"] == "No properties were modified", res
|
||||
|
||||
|
||||
class TestCapability:
|
||||
@pytest.mark.p3
|
||||
def test_update_dateset_concurrent(self, get_http_api_auth, add_dataset_func):
|
||||
dataset_id = add_dataset_func
|
||||
with ThreadPoolExecutor(max_workers=5) as executor:
|
||||
futures = [executor.submit(update_dataset, get_http_api_auth, dataset_id, {"name": f"dataset_{i}"}) for i in range(100)]
|
||||
responses = [f.result() for f in futures]
|
||||
assert all(r["code"] == 0 for r in responses), responses
|
||||
|
||||
|
||||
@pytest.mark.p1
|
||||
class TestDatasetUpdate:
|
||||
@pytest.mark.parametrize(
|
||||
"name, expected_code, expected_message",
|
||||
[
|
||||
("valid_name", 0, ""),
|
||||
(
|
||||
"a" * (DATASET_NAME_LIMIT + 1),
|
||||
102,
|
||||
"Dataset name should not be longer than 128 characters.",
|
||||
),
|
||||
(0, 100, """AttributeError("\'int\' object has no attribute \'strip\'")"""),
|
||||
(
|
||||
None,
|
||||
100,
|
||||
"""AttributeError("\'NoneType\' object has no attribute \'strip\'")""",
|
||||
),
|
||||
pytest.param("", 102, "", marks=pytest.mark.skip(reason="issue/5915")),
|
||||
("dataset_1", 102, "Duplicated dataset name in updating dataset."),
|
||||
("DATASET_1", 102, "Duplicated dataset name in updating dataset."),
|
||||
],
|
||||
)
|
||||
def test_name(self, get_http_api_auth, add_datasets_func, name, expected_code, expected_message):
|
||||
dataset_ids = add_datasets_func
|
||||
res = update_dataset(get_http_api_auth, dataset_ids[0], {"name": name})
|
||||
assert res["code"] == expected_code
|
||||
if expected_code == 0:
|
||||
res = list_datasets(get_http_api_auth, {"id": dataset_ids[0]})
|
||||
assert res["data"][0]["name"] == name
|
||||
else:
|
||||
assert res["message"] == expected_message
|
||||
@pytest.mark.p3
|
||||
def test_dataset_id_not_uuid(self, get_http_api_auth):
|
||||
payload = {"name": "dataset_id_not_uuid"}
|
||||
res = update_dataset(get_http_api_auth, "not_uuid", payload)
|
||||
assert res["code"] == 101, res
|
||||
assert "Input should be a valid UUID" in res["message"], res
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"embedding_model, expected_code, expected_message",
|
||||
[
|
||||
("BAAI/bge-large-zh-v1.5", 0, ""),
|
||||
("maidalun1020/bce-embedding-base_v1", 0, ""),
|
||||
(
|
||||
"other_embedding_model",
|
||||
102,
|
||||
"`embedding_model` other_embedding_model doesn't exist",
|
||||
),
|
||||
(None, 102, "`embedding_model` can't be empty"),
|
||||
],
|
||||
)
|
||||
def test_embedding_model(self, get_http_api_auth, add_dataset_func, embedding_model, expected_code, expected_message):
|
||||
@pytest.mark.p3
|
||||
def test_dataset_id_wrong_uuid(self, get_http_api_auth):
|
||||
payload = {"name": "wrong_uuid"}
|
||||
res = update_dataset(get_http_api_auth, "d94a8dc02c9711f0930f7fbc369eab6d", payload)
|
||||
assert res["code"] == 102, res
|
||||
assert "lacks permission for dataset" in res["message"], res
|
||||
|
||||
@pytest.mark.p1
|
||||
@given(name=valid_names())
|
||||
@example("a" * 128)
|
||||
@settings(max_examples=20, suppress_health_check=[HealthCheck.function_scoped_fixture])
|
||||
def test_name(self, get_http_api_auth, add_dataset_func, name):
|
||||
dataset_id = add_dataset_func
|
||||
res = update_dataset(get_http_api_auth, dataset_id, {"embedding_model": embedding_model})
|
||||
assert res["code"] == expected_code
|
||||
if expected_code == 0:
|
||||
res = list_datasets(get_http_api_auth, {"id": dataset_id})
|
||||
assert res["data"][0]["embedding_model"] == embedding_model
|
||||
else:
|
||||
assert res["message"] == expected_message
|
||||
payload = {"name": name}
|
||||
res = update_dataset(get_http_api_auth, dataset_id, payload)
|
||||
assert res["code"] == 0, res
|
||||
|
||||
res = list_datasets(get_http_api_auth)
|
||||
assert res["code"] == 0, res
|
||||
assert res["data"][0]["name"] == name, res
|
||||
|
||||
@pytest.mark.p2
|
||||
@pytest.mark.parametrize(
|
||||
"chunk_method, expected_code, expected_message",
|
||||
"name, expected_message",
|
||||
[
|
||||
("naive", 0, ""),
|
||||
("manual", 0, ""),
|
||||
("qa", 0, ""),
|
||||
("table", 0, ""),
|
||||
("paper", 0, ""),
|
||||
("book", 0, ""),
|
||||
("laws", 0, ""),
|
||||
("presentation", 0, ""),
|
||||
("picture", 0, ""),
|
||||
("one", 0, ""),
|
||||
("email", 0, ""),
|
||||
("tag", 0, ""),
|
||||
("", 0, ""),
|
||||
(
|
||||
"other_chunk_method",
|
||||
102,
|
||||
"'other_chunk_method' is not in ['naive', 'manual', 'qa', 'table', 'paper', 'book', 'laws', 'presentation', 'picture', 'one', 'email', 'tag']",
|
||||
),
|
||||
("", "String should have at least 1 character"),
|
||||
(" ", "String should have at least 1 character"),
|
||||
("a" * (DATASET_NAME_LIMIT + 1), "String should have at most 128 characters"),
|
||||
(0, "Input should be a valid string"),
|
||||
(None, "Input should be a valid string"),
|
||||
],
|
||||
ids=["empty_name", "space_name", "too_long_name", "invalid_name", "None_name"],
|
||||
)
|
||||
def test_chunk_method(self, get_http_api_auth, add_dataset_func, chunk_method, expected_code, expected_message):
|
||||
def test_name_invalid(self, get_http_api_auth, add_dataset_func, name, expected_message):
|
||||
dataset_id = add_dataset_func
|
||||
res = update_dataset(get_http_api_auth, dataset_id, {"chunk_method": chunk_method})
|
||||
assert res["code"] == expected_code
|
||||
if expected_code == 0:
|
||||
res = list_datasets(get_http_api_auth, {"id": dataset_id})
|
||||
if chunk_method != "":
|
||||
assert res["data"][0]["chunk_method"] == chunk_method
|
||||
else:
|
||||
assert res["data"][0]["chunk_method"] == "naive"
|
||||
else:
|
||||
assert res["message"] == expected_message
|
||||
payload = {"name": name}
|
||||
res = update_dataset(get_http_api_auth, dataset_id, payload)
|
||||
assert res["code"] == 101, res
|
||||
assert expected_message in res["message"], res
|
||||
|
||||
@pytest.mark.p3
|
||||
def test_name_duplicated(self, get_http_api_auth, add_datasets_func):
|
||||
dataset_ids = add_datasets_func[0]
|
||||
name = "dataset_1"
|
||||
payload = {"name": name}
|
||||
res = update_dataset(get_http_api_auth, dataset_ids, payload)
|
||||
assert res["code"] == 102, res
|
||||
assert res["message"] == f"Dataset name '{name}' already exists", res
|
||||
|
||||
@pytest.mark.p3
|
||||
def test_name_case_insensitive(self, get_http_api_auth, add_datasets_func):
|
||||
dataset_id = add_datasets_func[0]
|
||||
name = "DATASET_1"
|
||||
payload = {"name": name}
|
||||
res = update_dataset(get_http_api_auth, dataset_id, payload)
|
||||
assert res["code"] == 102, res
|
||||
assert res["message"] == f"Dataset name '{name}' already exists", res
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_avatar(self, get_http_api_auth, add_dataset_func, tmp_path):
|
||||
dataset_id = add_dataset_func
|
||||
fn = create_image_file(tmp_path / "ragflow_test.png")
|
||||
payload = {"avatar": encode_avatar(fn)}
|
||||
payload = {
|
||||
"avatar": f"data:image/png;base64,{encode_avatar(fn)}",
|
||||
}
|
||||
res = update_dataset(get_http_api_auth, dataset_id, payload)
|
||||
assert res["code"] == 0
|
||||
assert res["code"] == 0, res
|
||||
|
||||
res = list_datasets(get_http_api_auth)
|
||||
assert res["code"] == 0, res
|
||||
assert res["data"][0]["avatar"] == f"data:image/png;base64,{encode_avatar(fn)}", res
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_avatar_exceeds_limit_length(self, get_http_api_auth, add_dataset_func):
|
||||
dataset_id = add_dataset_func
|
||||
payload = {"avatar": "a" * 65536}
|
||||
res = update_dataset(get_http_api_auth, dataset_id, payload)
|
||||
assert res["code"] == 101, res
|
||||
assert "String should have at most 65535 characters" in res["message"], res
|
||||
|
||||
@pytest.mark.p3
|
||||
@pytest.mark.parametrize(
|
||||
"name, avatar_prefix, expected_message",
|
||||
[
|
||||
("empty_prefix", "", "Missing MIME prefix. Expected format: data:<mime>;base64,<data>"),
|
||||
("missing_comma", "data:image/png;base64", "Missing MIME prefix. Expected format: data:<mime>;base64,<data>"),
|
||||
("unsupported_mine_type", "invalid_mine_prefix:image/png;base64,", "Invalid MIME prefix format. Must start with 'data:'"),
|
||||
("invalid_mine_type", "data:unsupported_mine_type;base64,", "Unsupported MIME type. Allowed: ['image/jpeg', 'image/png']"),
|
||||
],
|
||||
ids=["empty_prefix", "missing_comma", "unsupported_mine_type", "invalid_mine_type"],
|
||||
)
|
||||
def test_avatar_invalid_prefix(self, get_http_api_auth, add_dataset_func, tmp_path, name, avatar_prefix, expected_message):
|
||||
dataset_id = add_dataset_func
|
||||
fn = create_image_file(tmp_path / "ragflow_test.png")
|
||||
payload = {
|
||||
"name": name,
|
||||
"avatar": f"{avatar_prefix}{encode_avatar(fn)}",
|
||||
}
|
||||
res = update_dataset(get_http_api_auth, dataset_id, payload)
|
||||
assert res["code"] == 101, res
|
||||
assert expected_message in res["message"], res
|
||||
|
||||
@pytest.mark.p3
|
||||
def test_avatar_none(self, get_http_api_auth, add_dataset_func):
|
||||
dataset_id = add_dataset_func
|
||||
payload = {"avatar": None}
|
||||
res = update_dataset(get_http_api_auth, dataset_id, payload)
|
||||
assert res["code"] == 0, res
|
||||
|
||||
res = list_datasets(get_http_api_auth)
|
||||
assert res["code"] == 0, res
|
||||
assert res["data"][0]["avatar"] is None, res
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_description(self, get_http_api_auth, add_dataset_func):
|
||||
dataset_id = add_dataset_func
|
||||
payload = {"description": "description"}
|
||||
@ -153,95 +217,533 @@ class TestDatasetUpdate:
|
||||
assert res["code"] == 0
|
||||
|
||||
res = list_datasets(get_http_api_auth, {"id": dataset_id})
|
||||
assert res["code"] == 0, res
|
||||
assert res["data"][0]["description"] == "description"
|
||||
|
||||
def test_pagerank(self, get_http_api_auth, add_dataset_func):
|
||||
@pytest.mark.p2
|
||||
def test_description_exceeds_limit_length(self, get_http_api_auth, add_dataset_func):
|
||||
dataset_id = add_dataset_func
|
||||
payload = {"pagerank": 1}
|
||||
payload = {"description": "a" * 65536}
|
||||
res = update_dataset(get_http_api_auth, dataset_id, payload)
|
||||
assert res["code"] == 0
|
||||
assert res["code"] == 101, res
|
||||
assert "String should have at most 65535 characters" in res["message"], res
|
||||
|
||||
@pytest.mark.p3
|
||||
def test_description_none(self, get_http_api_auth, add_dataset_func):
|
||||
dataset_id = add_dataset_func
|
||||
payload = {"description": None}
|
||||
res = update_dataset(get_http_api_auth, dataset_id, payload)
|
||||
assert res["code"] == 0, res
|
||||
|
||||
res = list_datasets(get_http_api_auth, {"id": dataset_id})
|
||||
assert res["data"][0]["pagerank"] == 1
|
||||
|
||||
def test_similarity_threshold(self, get_http_api_auth, add_dataset_func):
|
||||
dataset_id = add_dataset_func
|
||||
payload = {"similarity_threshold": 1}
|
||||
res = update_dataset(get_http_api_auth, dataset_id, payload)
|
||||
assert res["code"] == 0
|
||||
|
||||
res = list_datasets(get_http_api_auth, {"id": dataset_id})
|
||||
assert res["data"][0]["similarity_threshold"] == 1
|
||||
assert res["code"] == 0, res
|
||||
assert res["data"][0]["description"] is None
|
||||
|
||||
@pytest.mark.p1
|
||||
@pytest.mark.parametrize(
|
||||
"permission, expected_code",
|
||||
"embedding_model",
|
||||
[
|
||||
("me", 0),
|
||||
("team", 0),
|
||||
("", 0),
|
||||
("ME", 102),
|
||||
("TEAM", 102),
|
||||
("other_permission", 102),
|
||||
"BAAI/bge-large-zh-v1.5@BAAI",
|
||||
"maidalun1020/bce-embedding-base_v1@Youdao",
|
||||
"embedding-3@ZHIPU-AI",
|
||||
],
|
||||
ids=["builtin_baai", "builtin_youdao", "tenant_zhipu"],
|
||||
)
|
||||
def test_permission(self, get_http_api_auth, add_dataset_func, permission, expected_code):
|
||||
def test_embedding_model(self, get_http_api_auth, add_dataset_func, embedding_model):
|
||||
dataset_id = add_dataset_func
|
||||
payload = {"embedding_model": embedding_model}
|
||||
res = update_dataset(get_http_api_auth, dataset_id, payload)
|
||||
assert res["code"] == 0, res
|
||||
|
||||
res = list_datasets(get_http_api_auth)
|
||||
assert res["code"] == 0, res
|
||||
assert res["data"][0]["embedding_model"] == embedding_model, res
|
||||
|
||||
@pytest.mark.p2
|
||||
@pytest.mark.parametrize(
|
||||
"name, embedding_model",
|
||||
[
|
||||
("unknown_llm_name", "unknown@ZHIPU-AI"),
|
||||
("unknown_llm_factory", "embedding-3@unknown"),
|
||||
("tenant_no_auth_default_tenant_llm", "text-embedding-v3@Tongyi-Qianwen"),
|
||||
("tenant_no_auth", "text-embedding-3-small@OpenAI"),
|
||||
],
|
||||
ids=["unknown_llm_name", "unknown_llm_factory", "tenant_no_auth_default_tenant_llm", "tenant_no_auth"],
|
||||
)
|
||||
def test_embedding_model_invalid(self, get_http_api_auth, add_dataset_func, name, embedding_model):
|
||||
dataset_id = add_dataset_func
|
||||
payload = {"name": name, "embedding_model": embedding_model}
|
||||
res = update_dataset(get_http_api_auth, dataset_id, payload)
|
||||
assert res["code"] == 101, res
|
||||
if "tenant_no_auth" in name:
|
||||
assert res["message"] == f"Unauthorized model: <{embedding_model}>", res
|
||||
else:
|
||||
assert res["message"] == f"Unsupported model: <{embedding_model}>", res
|
||||
|
||||
@pytest.mark.p2
|
||||
@pytest.mark.parametrize(
|
||||
"name, embedding_model",
|
||||
[
|
||||
("missing_at", "BAAI/bge-large-zh-v1.5BAAI"),
|
||||
("missing_model_name", "@BAAI"),
|
||||
("missing_provider", "BAAI/bge-large-zh-v1.5@"),
|
||||
("whitespace_only_model_name", " @BAAI"),
|
||||
("whitespace_only_provider", "BAAI/bge-large-zh-v1.5@ "),
|
||||
],
|
||||
ids=["missing_at", "empty_model_name", "empty_provider", "whitespace_only_model_name", "whitespace_only_provider"],
|
||||
)
|
||||
def test_embedding_model_format(self, get_http_api_auth, add_dataset_func, name, embedding_model):
|
||||
dataset_id = add_dataset_func
|
||||
payload = {"name": name, "embedding_model": embedding_model}
|
||||
res = update_dataset(get_http_api_auth, dataset_id, payload)
|
||||
assert res["code"] == 101, res
|
||||
if name == "missing_at":
|
||||
assert "Embedding model identifier must follow <model_name>@<provider> format" in res["message"], res
|
||||
else:
|
||||
assert "Both model_name and provider must be non-empty strings" in res["message"], res
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_embedding_model_none(self, get_http_api_auth, add_dataset_func):
|
||||
dataset_id = add_dataset_func
|
||||
payload = {"embedding_model": None}
|
||||
res = update_dataset(get_http_api_auth, dataset_id, payload)
|
||||
assert res["code"] == 101, res
|
||||
assert "Input should be a valid string" in res["message"], res
|
||||
|
||||
@pytest.mark.p1
|
||||
@pytest.mark.parametrize(
|
||||
"name, permission",
|
||||
[
|
||||
("me", "me"),
|
||||
("team", "team"),
|
||||
("me_upercase", "ME"),
|
||||
("team_upercase", "TEAM"),
|
||||
],
|
||||
ids=["me", "team", "me_upercase", "team_upercase"],
|
||||
)
|
||||
def test_permission(self, get_http_api_auth, add_dataset_func, name, permission):
|
||||
dataset_id = add_dataset_func
|
||||
payload = {"name": name, "permission": permission}
|
||||
res = update_dataset(get_http_api_auth, dataset_id, payload)
|
||||
assert res["code"] == 0, res
|
||||
|
||||
res = list_datasets(get_http_api_auth)
|
||||
assert res["code"] == 0, res
|
||||
assert res["data"][0]["permission"] == permission.lower(), res
|
||||
|
||||
@pytest.mark.p2
|
||||
@pytest.mark.parametrize(
|
||||
"permission",
|
||||
[
|
||||
"",
|
||||
"unknown",
|
||||
list(),
|
||||
],
|
||||
ids=["empty", "unknown", "type_error"],
|
||||
)
|
||||
def test_permission_invalid(self, get_http_api_auth, add_dataset_func, permission):
|
||||
dataset_id = add_dataset_func
|
||||
payload = {"permission": permission}
|
||||
res = update_dataset(get_http_api_auth, dataset_id, payload)
|
||||
assert res["code"] == expected_code
|
||||
assert res["code"] == 101
|
||||
assert "Input should be 'me' or 'team'" in res["message"]
|
||||
|
||||
res = list_datasets(get_http_api_auth, {"id": dataset_id})
|
||||
if expected_code == 0 and permission != "":
|
||||
assert res["data"][0]["permission"] == permission
|
||||
if permission == "":
|
||||
assert res["data"][0]["permission"] == "me"
|
||||
|
||||
def test_vector_similarity_weight(self, get_http_api_auth, add_dataset_func):
|
||||
@pytest.mark.p3
|
||||
def test_permission_none(self, get_http_api_auth, add_dataset_func):
|
||||
dataset_id = add_dataset_func
|
||||
payload = {"vector_similarity_weight": 1}
|
||||
payload = {"name": "test_permission_none", "permission": None}
|
||||
res = update_dataset(get_http_api_auth, dataset_id, payload)
|
||||
assert res["code"] == 101, res
|
||||
assert "Input should be 'me' or 'team'" in res["message"], res
|
||||
|
||||
@pytest.mark.p1
|
||||
@pytest.mark.parametrize(
|
||||
"chunk_method",
|
||||
[
|
||||
"naive",
|
||||
"book",
|
||||
"email",
|
||||
"laws",
|
||||
"manual",
|
||||
"one",
|
||||
"paper",
|
||||
"picture",
|
||||
"presentation",
|
||||
"qa",
|
||||
"table",
|
||||
"tag",
|
||||
],
|
||||
ids=["naive", "book", "email", "laws", "manual", "one", "paper", "picture", "presentation", "qa", "table", "tag"],
|
||||
)
|
||||
def test_chunk_method(self, get_http_api_auth, add_dataset_func, chunk_method):
|
||||
dataset_id = add_dataset_func
|
||||
payload = {"chunk_method": chunk_method}
|
||||
res = update_dataset(get_http_api_auth, dataset_id, payload)
|
||||
assert res["code"] == 0, res
|
||||
|
||||
res = list_datasets(get_http_api_auth)
|
||||
assert res["code"] == 0, res
|
||||
assert res["data"][0]["chunk_method"] == chunk_method, res
|
||||
|
||||
@pytest.mark.p2
|
||||
@pytest.mark.parametrize(
|
||||
"chunk_method",
|
||||
[
|
||||
"",
|
||||
"unknown",
|
||||
list(),
|
||||
],
|
||||
ids=["empty", "unknown", "type_error"],
|
||||
)
|
||||
def test_chunk_method_invalid(self, get_http_api_auth, add_dataset_func, chunk_method):
|
||||
dataset_id = add_dataset_func
|
||||
payload = {"chunk_method": chunk_method}
|
||||
res = update_dataset(get_http_api_auth, dataset_id, payload)
|
||||
assert res["code"] == 101, res
|
||||
assert "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', 'picture', 'presentation', 'qa', 'table' or 'tag'" in res["message"], res
|
||||
|
||||
@pytest.mark.p3
|
||||
def test_chunk_method_none(self, get_http_api_auth, add_dataset_func):
|
||||
dataset_id = add_dataset_func
|
||||
payload = {"chunk_method": None}
|
||||
res = update_dataset(get_http_api_auth, dataset_id, payload)
|
||||
assert res["code"] == 101, res
|
||||
assert "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', 'picture', 'presentation', 'qa', 'table' or 'tag'" in res["message"], res
|
||||
|
||||
@pytest.mark.p2
|
||||
@pytest.mark.parametrize("pagerank", [0, 50, 100], ids=["min", "mid", "max"])
|
||||
def test_pagerank(self, get_http_api_auth, add_dataset_func, pagerank):
|
||||
dataset_id = add_dataset_func
|
||||
payload = {"pagerank": pagerank}
|
||||
res = update_dataset(get_http_api_auth, dataset_id, payload)
|
||||
assert res["code"] == 0
|
||||
|
||||
res = list_datasets(get_http_api_auth, {"id": dataset_id})
|
||||
assert res["data"][0]["vector_similarity_weight"] == 1
|
||||
assert res["code"] == 0, res
|
||||
assert res["data"][0]["pagerank"] == pagerank
|
||||
|
||||
def test_invalid_dataset_id(self, get_http_api_auth):
|
||||
res = update_dataset(get_http_api_auth, "invalid_dataset_id", {"name": "invalid_dataset_id"})
|
||||
assert res["code"] == 102
|
||||
assert res["message"] == "You don't own the dataset"
|
||||
@pytest.mark.p2
|
||||
@pytest.mark.parametrize(
|
||||
"pagerank, expected_message",
|
||||
[
|
||||
(-1, "Input should be greater than or equal to 0"),
|
||||
(101, "Input should be less than or equal to 100"),
|
||||
],
|
||||
ids=["min_limit", "max_limit"],
|
||||
)
|
||||
def test_pagerank_invalid(self, get_http_api_auth, add_dataset_func, pagerank, expected_message):
|
||||
dataset_id = add_dataset_func
|
||||
payload = {"pagerank": pagerank}
|
||||
res = update_dataset(get_http_api_auth, dataset_id, payload)
|
||||
assert res["code"] == 101, res
|
||||
assert expected_message in res["message"], res
|
||||
|
||||
@pytest.mark.p3
|
||||
def test_pagerank_none(self, get_http_api_auth, add_dataset_func):
|
||||
dataset_id = add_dataset_func
|
||||
payload = {"pagerank": None}
|
||||
res = update_dataset(get_http_api_auth, dataset_id, payload)
|
||||
assert res["code"] == 101, res
|
||||
assert "Input should be a valid integer" in res["message"], res
|
||||
|
||||
@pytest.mark.p1
|
||||
@pytest.mark.parametrize(
|
||||
"parser_config",
|
||||
[
|
||||
{"auto_keywords": 0},
|
||||
{"auto_keywords": 16},
|
||||
{"auto_keywords": 32},
|
||||
{"auto_questions": 0},
|
||||
{"auto_questions": 5},
|
||||
{"auto_questions": 10},
|
||||
{"chunk_token_num": 1},
|
||||
{"chunk_token_num": 1024},
|
||||
{"chunk_token_num": 2048},
|
||||
{"delimiter": "\n"},
|
||||
{"delimiter": " "},
|
||||
{"html4excel": True},
|
||||
{"html4excel": False},
|
||||
{"layout_recognize": "DeepDOC"},
|
||||
{"layout_recognize": "Plain Text"},
|
||||
{"tag_kb_ids": ["1", "2"]},
|
||||
{"topn_tags": 1},
|
||||
{"topn_tags": 5},
|
||||
{"topn_tags": 10},
|
||||
{"filename_embd_weight": 0.1},
|
||||
{"filename_embd_weight": 0.5},
|
||||
{"filename_embd_weight": 1.0},
|
||||
{"task_page_size": 1},
|
||||
{"task_page_size": None},
|
||||
{"pages": [[1, 100]]},
|
||||
{"pages": None},
|
||||
{"graphrag": {"use_graphrag": True}},
|
||||
{"graphrag": {"use_graphrag": False}},
|
||||
{"graphrag": {"entity_types": ["age", "sex", "height", "weight"]}},
|
||||
{"graphrag": {"method": "general"}},
|
||||
{"graphrag": {"method": "light"}},
|
||||
{"graphrag": {"community": True}},
|
||||
{"graphrag": {"community": False}},
|
||||
{"graphrag": {"resolution": True}},
|
||||
{"graphrag": {"resolution": False}},
|
||||
{"raptor": {"use_raptor": True}},
|
||||
{"raptor": {"use_raptor": False}},
|
||||
{"raptor": {"prompt": "Who are you?"}},
|
||||
{"raptor": {"max_token": 1}},
|
||||
{"raptor": {"max_token": 1024}},
|
||||
{"raptor": {"max_token": 2048}},
|
||||
{"raptor": {"threshold": 0.0}},
|
||||
{"raptor": {"threshold": 0.5}},
|
||||
{"raptor": {"threshold": 1.0}},
|
||||
{"raptor": {"max_cluster": 1}},
|
||||
{"raptor": {"max_cluster": 512}},
|
||||
{"raptor": {"max_cluster": 1024}},
|
||||
{"raptor": {"random_seed": 0}},
|
||||
],
|
||||
ids=[
|
||||
"auto_keywords_min",
|
||||
"auto_keywords_mid",
|
||||
"auto_keywords_max",
|
||||
"auto_questions_min",
|
||||
"auto_questions_mid",
|
||||
"auto_questions_max",
|
||||
"chunk_token_num_min",
|
||||
"chunk_token_num_mid",
|
||||
"chunk_token_num_max",
|
||||
"delimiter",
|
||||
"delimiter_space",
|
||||
"html4excel_true",
|
||||
"html4excel_false",
|
||||
"layout_recognize_DeepDOC",
|
||||
"layout_recognize_navie",
|
||||
"tag_kb_ids",
|
||||
"topn_tags_min",
|
||||
"topn_tags_mid",
|
||||
"topn_tags_max",
|
||||
"filename_embd_weight_min",
|
||||
"filename_embd_weight_mid",
|
||||
"filename_embd_weight_max",
|
||||
"task_page_size_min",
|
||||
"task_page_size_None",
|
||||
"pages",
|
||||
"pages_none",
|
||||
"graphrag_true",
|
||||
"graphrag_false",
|
||||
"graphrag_entity_types",
|
||||
"graphrag_method_general",
|
||||
"graphrag_method_light",
|
||||
"graphrag_community_true",
|
||||
"graphrag_community_false",
|
||||
"graphrag_resolution_true",
|
||||
"graphrag_resolution_false",
|
||||
"raptor_true",
|
||||
"raptor_false",
|
||||
"raptor_prompt",
|
||||
"raptor_max_token_min",
|
||||
"raptor_max_token_mid",
|
||||
"raptor_max_token_max",
|
||||
"raptor_threshold_min",
|
||||
"raptor_threshold_mid",
|
||||
"raptor_threshold_max",
|
||||
"raptor_max_cluster_min",
|
||||
"raptor_max_cluster_mid",
|
||||
"raptor_max_cluster_max",
|
||||
"raptor_random_seed_min",
|
||||
],
|
||||
)
|
||||
def test_parser_config(self, get_http_api_auth, add_dataset_func, parser_config):
|
||||
dataset_id = add_dataset_func
|
||||
payload = {"parser_config": parser_config}
|
||||
res = update_dataset(get_http_api_auth, dataset_id, payload)
|
||||
assert res["code"] == 0, res
|
||||
|
||||
res = list_datasets(get_http_api_auth)
|
||||
assert res["code"] == 0, res
|
||||
for k, v in parser_config.items():
|
||||
if isinstance(v, dict):
|
||||
for kk, vv in v.items():
|
||||
assert res["data"][0]["parser_config"][k][kk] == vv, res
|
||||
else:
|
||||
assert res["data"][0]["parser_config"][k] == v, res
|
||||
|
||||
@pytest.mark.p2
|
||||
@pytest.mark.parametrize(
|
||||
"parser_config, expected_message",
|
||||
[
|
||||
({"auto_keywords": -1}, "Input should be greater than or equal to 0"),
|
||||
({"auto_keywords": 33}, "Input should be less than or equal to 32"),
|
||||
({"auto_keywords": 3.14}, "Input should be a valid integer, got a number with a fractional part"),
|
||||
({"auto_keywords": "string"}, "Input should be a valid integer, unable to parse string as an integer"),
|
||||
({"auto_questions": -1}, "Input should be greater than or equal to 0"),
|
||||
({"auto_questions": 11}, "Input should be less than or equal to 10"),
|
||||
({"auto_questions": 3.14}, "Input should be a valid integer, got a number with a fractional part"),
|
||||
({"auto_questions": "string"}, "Input should be a valid integer, unable to parse string as an integer"),
|
||||
({"chunk_token_num": 0}, "Input should be greater than or equal to 1"),
|
||||
({"chunk_token_num": 2049}, "Input should be less than or equal to 2048"),
|
||||
({"chunk_token_num": 3.14}, "Input should be a valid integer, got a number with a fractional part"),
|
||||
({"chunk_token_num": "string"}, "Input should be a valid integer, unable to parse string as an integer"),
|
||||
({"delimiter": ""}, "String should have at least 1 character"),
|
||||
({"html4excel": "string"}, "Input should be a valid boolean, unable to interpret input"),
|
||||
({"tag_kb_ids": "1,2"}, "Input should be a valid list"),
|
||||
({"tag_kb_ids": [1, 2]}, "Input should be a valid string"),
|
||||
({"topn_tags": 0}, "Input should be greater than or equal to 1"),
|
||||
({"topn_tags": 11}, "Input should be less than or equal to 10"),
|
||||
({"topn_tags": 3.14}, "Input should be a valid integer, got a number with a fractional part"),
|
||||
({"topn_tags": "string"}, "Input should be a valid integer, unable to parse string as an integer"),
|
||||
({"filename_embd_weight": -1}, "Input should be greater than or equal to 0"),
|
||||
({"filename_embd_weight": 1.1}, "Input should be less than or equal to 1"),
|
||||
({"filename_embd_weight": "string"}, "Input should be a valid number, unable to parse string as a number"),
|
||||
({"task_page_size": 0}, "Input should be greater than or equal to 1"),
|
||||
({"task_page_size": 3.14}, "Input should be a valid integer, got a number with a fractional part"),
|
||||
({"task_page_size": "string"}, "Input should be a valid integer, unable to parse string as an integer"),
|
||||
({"pages": "1,2"}, "Input should be a valid list"),
|
||||
({"pages": ["1,2"]}, "Input should be a valid list"),
|
||||
({"pages": [["string1", "string2"]]}, "Input should be a valid integer, unable to parse string as an integer"),
|
||||
({"graphrag": {"use_graphrag": "string"}}, "Input should be a valid boolean, unable to interpret input"),
|
||||
({"graphrag": {"entity_types": "1,2"}}, "Input should be a valid list"),
|
||||
({"graphrag": {"entity_types": [1, 2]}}, "nput should be a valid string"),
|
||||
({"graphrag": {"method": "unknown"}}, "Input should be 'light' or 'general'"),
|
||||
({"graphrag": {"method": None}}, "Input should be 'light' or 'general'"),
|
||||
({"graphrag": {"community": "string"}}, "Input should be a valid boolean, unable to interpret input"),
|
||||
({"graphrag": {"resolution": "string"}}, "Input should be a valid boolean, unable to interpret input"),
|
||||
({"raptor": {"use_raptor": "string"}}, "Input should be a valid boolean, unable to interpret input"),
|
||||
({"raptor": {"prompt": ""}}, "String should have at least 1 character"),
|
||||
({"raptor": {"prompt": " "}}, "String should have at least 1 character"),
|
||||
({"raptor": {"max_token": 0}}, "Input should be greater than or equal to 1"),
|
||||
({"raptor": {"max_token": 2049}}, "Input should be less than or equal to 2048"),
|
||||
({"raptor": {"max_token": 3.14}}, "Input should be a valid integer, got a number with a fractional part"),
|
||||
({"raptor": {"max_token": "string"}}, "Input should be a valid integer, unable to parse string as an integer"),
|
||||
({"raptor": {"threshold": -0.1}}, "Input should be greater than or equal to 0"),
|
||||
({"raptor": {"threshold": 1.1}}, "Input should be less than or equal to 1"),
|
||||
({"raptor": {"threshold": "string"}}, "Input should be a valid number, unable to parse string as a number"),
|
||||
({"raptor": {"max_cluster": 0}}, "Input should be greater than or equal to 1"),
|
||||
({"raptor": {"max_cluster": 1025}}, "Input should be less than or equal to 1024"),
|
||||
({"raptor": {"max_cluster": 3.14}}, "Input should be a valid integer, got a number with a fractional par"),
|
||||
({"raptor": {"max_cluster": "string"}}, "Input should be a valid integer, unable to parse string as an integer"),
|
||||
({"raptor": {"random_seed": -1}}, "Input should be greater than or equal to 0"),
|
||||
({"raptor": {"random_seed": 3.14}}, "Input should be a valid integer, got a number with a fractional part"),
|
||||
({"raptor": {"random_seed": "string"}}, "Input should be a valid integer, unable to parse string as an integer"),
|
||||
({"delimiter": "a" * 65536}, "Parser config exceeds size limit (max 65,535 characters)"),
|
||||
],
|
||||
ids=[
|
||||
"auto_keywords_min_limit",
|
||||
"auto_keywords_max_limit",
|
||||
"auto_keywords_float_not_allowed",
|
||||
"auto_keywords_type_invalid",
|
||||
"auto_questions_min_limit",
|
||||
"auto_questions_max_limit",
|
||||
"auto_questions_float_not_allowed",
|
||||
"auto_questions_type_invalid",
|
||||
"chunk_token_num_min_limit",
|
||||
"chunk_token_num_max_limit",
|
||||
"chunk_token_num_float_not_allowed",
|
||||
"chunk_token_num_type_invalid",
|
||||
"delimiter_empty",
|
||||
"html4excel_type_invalid",
|
||||
"tag_kb_ids_not_list",
|
||||
"tag_kb_ids_int_in_list",
|
||||
"topn_tags_min_limit",
|
||||
"topn_tags_max_limit",
|
||||
"topn_tags_float_not_allowed",
|
||||
"topn_tags_type_invalid",
|
||||
"filename_embd_weight_min_limit",
|
||||
"filename_embd_weight_max_limit",
|
||||
"filename_embd_weight_type_invalid",
|
||||
"task_page_size_min_limit",
|
||||
"task_page_size_float_not_allowed",
|
||||
"task_page_size_type_invalid",
|
||||
"pages_not_list",
|
||||
"pages_not_list_in_list",
|
||||
"pages_not_int_list",
|
||||
"graphrag_type_invalid",
|
||||
"graphrag_entity_types_not_list",
|
||||
"graphrag_entity_types_not_str_in_list",
|
||||
"graphrag_method_unknown",
|
||||
"graphrag_method_none",
|
||||
"graphrag_community_type_invalid",
|
||||
"graphrag_resolution_type_invalid",
|
||||
"raptor_type_invalid",
|
||||
"raptor_prompt_empty",
|
||||
"raptor_prompt_space",
|
||||
"raptor_max_token_min_limit",
|
||||
"raptor_max_token_max_limit",
|
||||
"raptor_max_token_float_not_allowed",
|
||||
"raptor_max_token_type_invalid",
|
||||
"raptor_threshold_min_limit",
|
||||
"raptor_threshold_max_limit",
|
||||
"raptor_threshold_type_invalid",
|
||||
"raptor_max_cluster_min_limit",
|
||||
"raptor_max_cluster_max_limit",
|
||||
"raptor_max_cluster_float_not_allowed",
|
||||
"raptor_max_cluster_type_invalid",
|
||||
"raptor_random_seed_min_limit",
|
||||
"raptor_random_seed_float_not_allowed",
|
||||
"raptor_random_seed_type_invalid",
|
||||
"parser_config_type_invalid",
|
||||
],
|
||||
)
|
||||
def test_parser_config_invalid(self, get_http_api_auth, add_dataset_func, parser_config, expected_message):
|
||||
dataset_id = add_dataset_func
|
||||
payload = {"parser_config": parser_config}
|
||||
res = update_dataset(get_http_api_auth, dataset_id, payload)
|
||||
assert res["code"] == 101, res
|
||||
assert expected_message in res["message"], res
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_parser_config_empty(self, get_http_api_auth, add_dataset_func):
|
||||
dataset_id = add_dataset_func
|
||||
payload = {"parser_config": {}}
|
||||
res = update_dataset(get_http_api_auth, dataset_id, payload)
|
||||
assert res["code"] == 0, res
|
||||
|
||||
res = list_datasets(get_http_api_auth)
|
||||
assert res["code"] == 0, res
|
||||
assert res["data"][0]["parser_config"] == {}
|
||||
|
||||
# @pytest.mark.p2
|
||||
# def test_parser_config_unset(self, get_http_api_auth, add_dataset_func):
|
||||
# dataset_id = add_dataset_func
|
||||
# payload = {"name": "default_unset"}
|
||||
# res = update_dataset(get_http_api_auth, dataset_id, payload)
|
||||
# assert res["code"] == 0, res
|
||||
|
||||
# res = list_datasets(get_http_api_auth)
|
||||
# assert res["code"] == 0, res
|
||||
# assert res["data"][0]["parser_config"] == {
|
||||
# "chunk_token_num": 128,
|
||||
# "delimiter": r"\n",
|
||||
# "html4excel": False,
|
||||
# "layout_recognize": "DeepDOC",
|
||||
# "raptor": {"use_raptor": False},
|
||||
# }, res
|
||||
|
||||
@pytest.mark.p3
|
||||
def test_parser_config_none(self, get_http_api_auth, add_dataset_func):
|
||||
dataset_id = add_dataset_func
|
||||
payload = {"parser_config": None}
|
||||
res = update_dataset(get_http_api_auth, dataset_id, payload)
|
||||
assert res["code"] == 101, res
|
||||
assert "Input should be a valid dictionary or instance of ParserConfig" in res["message"], res
|
||||
|
||||
@pytest.mark.p2
|
||||
@pytest.mark.parametrize(
|
||||
"payload",
|
||||
[
|
||||
{"chunk_count": 1},
|
||||
{"id": "id"},
|
||||
{"tenant_id": "e57c1966f99211efb41e9e45646e0111"},
|
||||
{"created_by": "created_by"},
|
||||
{"create_date": "Tue, 11 Mar 2025 13:37:23 GMT"},
|
||||
{"create_time": 1741671443322},
|
||||
{"created_by": "aa"},
|
||||
{"document_count": 1},
|
||||
{"id": "id"},
|
||||
{"status": "1"},
|
||||
{"tenant_id": "e57c1966f99211efb41e9e45646e0111"},
|
||||
{"token_num": 1},
|
||||
{"update_date": "Tue, 11 Mar 2025 13:37:23 GMT"},
|
||||
{"update_time": 1741671443339},
|
||||
{"document_count": 1},
|
||||
{"chunk_count": 1},
|
||||
{"token_num": 1},
|
||||
{"status": "1"},
|
||||
{"unknown_field": "unknown_field"},
|
||||
],
|
||||
)
|
||||
def test_modify_read_only_field(self, get_http_api_auth, add_dataset_func, payload):
|
||||
def test_unsupported_field(self, get_http_api_auth, add_dataset_func, payload):
|
||||
dataset_id = add_dataset_func
|
||||
res = update_dataset(get_http_api_auth, dataset_id, payload)
|
||||
assert res["code"] == 101
|
||||
assert "is readonly" in res["message"]
|
||||
|
||||
def test_modify_unknown_field(self, get_http_api_auth, add_dataset_func):
|
||||
dataset_id = add_dataset_func
|
||||
res = update_dataset(get_http_api_auth, dataset_id, {"unknown_field": 0})
|
||||
assert res["code"] == 100
|
||||
|
||||
@pytest.mark.p3
|
||||
def test_concurrent_update(self, get_http_api_auth, add_dataset_func):
|
||||
dataset_id = add_dataset_func
|
||||
|
||||
with ThreadPoolExecutor(max_workers=5) as executor:
|
||||
futures = [executor.submit(update_dataset, get_http_api_auth, dataset_id, {"name": f"dataset_{i}"}) for i in range(100)]
|
||||
responses = [f.result() for f in futures]
|
||||
assert all(r["code"] == 0 for r in responses)
|
||||
assert res["code"] == 101, res
|
||||
assert "Extra inputs are not permitted" in res["message"], res
|
||||
|
@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from time import sleep
|
||||
|
||||
import pytest
|
||||
from common import INVALID_API_TOKEN, bulk_upload_documents, list_documnets, parse_documnets, stop_parse_documnets
|
||||
@ -173,6 +174,7 @@ def test_stop_parse_100_files(get_http_api_auth, add_dataset_func, tmp_path):
|
||||
dataset_id = add_dataset_func
|
||||
document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, document_num, tmp_path)
|
||||
parse_documnets(get_http_api_auth, dataset_id, {"document_ids": document_ids})
|
||||
sleep(1)
|
||||
res = stop_parse_documnets(get_http_api_auth, dataset_id, {"document_ids": document_ids})
|
||||
assert res["code"] == 0
|
||||
validate_document_parse_cancel(get_http_api_auth, dataset_id, document_ids)
|
||||
|
Loading…
x
Reference in New Issue
Block a user