From 35e36cb9457cbb75ca939dae0ee19297e3a63721 Mon Sep 17 00:00:00 2001 From: liu an Date: Fri, 9 May 2025 19:17:08 +0800 Subject: [PATCH] Refa: HTTP API update dataset / test cases / docs (#7564) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What problem does this PR solve? This PR introduces Pydantic-based validation for the update dataset HTTP API, improving code clarity and robustness. Key changes include: 1. Pydantic Validation 2. ​​Error Handling 3. Test Updates 4. Documentation Updates 5. fix bug: #5915 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) - [x] Documentation Update - [x] Refactoring --- api/apps/sdk/dataset.py | 172 ++-- api/utils/api_utils.py | 104 ++- api/utils/validation_utils.py | 98 ++- docs/references/http_api_reference.md | 79 +- docs/references/python_api_reference.md | 42 +- sdk/python/ragflow_sdk/ragflow.py | 28 +- .../test/libs/utils/hypothesis_utils.py | 28 + sdk/python/test/test_http_api/common.py | 12 +- .../test_dataset_mangement/conftest.py | 10 + .../test_create_dataset.py | 454 ++++++---- .../test_update_dataset.py | 806 ++++++++++++++---- .../test_stop_parse_documents.py | 2 + 12 files changed, 1283 insertions(+), 552 deletions(-) create mode 100644 sdk/python/test/libs/utils/hypothesis_utils.py diff --git a/api/apps/sdk/dataset.py b/api/apps/sdk/dataset.py index b52f584d3..7dec0b9bc 100644 --- a/api/apps/sdk/dataset.py +++ b/api/apps/sdk/dataset.py @@ -27,22 +27,19 @@ from api.db.services.document_service import DocumentService from api.db.services.file2document_service import File2DocumentService from api.db.services.file_service import FileService from api.db.services.knowledgebase_service import KnowledgebaseService -from api.db.services.llm_service import LLMService, TenantLLMService from api.db.services.user_service import TenantService from api.utils import get_uuid from api.utils.api_utils import ( check_duplicate_ids, - dataset_readonly_fields, + deep_merge, get_error_argument_result, get_error_data_result, get_parser_config, get_result, token_required, - valid, - valid_parser_config, verify_embedding_availability, ) -from api.utils.validation_utils import CreateDatasetReq, validate_and_parse_json_request +from api.utils.validation_utils import CreateDatasetReq, UpdateDatasetReq, validate_and_parse_json_request @manager.route("/datasets", methods=["POST"]) # noqa: F821 @@ -117,8 +114,8 @@ def create(tenant_id): return get_error_argument_result(err) try: - if KnowledgebaseService.query(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value): - return get_error_argument_result(message=f"Dataset name '{req['name']}' already exists") + if KnowledgebaseService.get_or_none(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value): + return get_error_data_result(message=f"Dataset name '{req['name']}' already exists") except OperationalError as e: logging.exception(e) return get_error_data_result(message="Database operation failed") @@ -145,7 +142,7 @@ def create(tenant_id): try: if not KnowledgebaseService.save(**req): - return get_error_data_result(message="Database operation failed") + return get_error_data_result(message="Create dataset error.(Database error)") except OperationalError as e: logging.exception(e) return get_error_data_result(message="Database operation failed") @@ -205,6 +202,7 @@ def delete(tenant_id): schema: type: object """ + errors = [] success_count = 0 req = request.json @@ -291,16 +289,28 @@ def update(tenant_id, dataset_id): name: type: string description: New name of the dataset. + avatar: + type: string + description: Updated base64 encoding of the avatar. + description: + type: string + description: Updated description of the dataset. + embedding_model: + type: string + description: Updated embedding model Name. permission: type: string enum: ['me', 'team'] - description: Updated permission. + description: Updated dataset permission. chunk_method: type: string - enum: ["naive", "manual", "qa", "table", "paper", "book", "laws", - "presentation", "picture", "one", "email", "tag" + enum: ["naive", "book", "email", "laws", "manual", "one", "paper", + "picture", "presentation", "qa", "table", "tag" ] description: Updated chunking method. + pagerank: + type: integer + description: Updated page rank. parser_config: type: object description: Updated parser configuration. @@ -310,98 +320,56 @@ def update(tenant_id, dataset_id): schema: type: object """ - if not KnowledgebaseService.query(id=dataset_id, tenant_id=tenant_id): - return get_error_data_result(message="You don't own the dataset") - req = request.json - for k in req.keys(): - if dataset_readonly_fields(k): - return get_result(code=settings.RetCode.ARGUMENT_ERROR, message=f"'{k}' is readonly.") - e, t = TenantService.get_by_id(tenant_id) - invalid_keys = {"id", "embd_id", "chunk_num", "doc_num", "parser_id", "create_date", "create_time", "created_by", "status", "token_num", "update_date", "update_time"} - if any(key in req for key in invalid_keys): - return get_error_data_result(message="The input parameters are invalid.") - permission = req.get("permission") - chunk_method = req.get("chunk_method") - parser_config = req.get("parser_config") - valid_parser_config(parser_config) - valid_permission = ["me", "team"] - valid_chunk_method = ["naive", "manual", "qa", "table", "paper", "book", "laws", "presentation", "picture", "one", "email", "tag"] - check_validation = valid( - permission, - valid_permission, - chunk_method, - valid_chunk_method, - ) - if check_validation: - return check_validation - if "tenant_id" in req: - if req["tenant_id"] != tenant_id: - return get_error_data_result(message="Can't change `tenant_id`.") - e, kb = KnowledgebaseService.get_by_id(dataset_id) - if "parser_config" in req: - temp_dict = kb.parser_config - temp_dict.update(req["parser_config"]) - req["parser_config"] = temp_dict - if "chunk_count" in req: - if req["chunk_count"] != kb.chunk_num: - return get_error_data_result(message="Can't change `chunk_count`.") - req.pop("chunk_count") - if "document_count" in req: - if req["document_count"] != kb.doc_num: - return get_error_data_result(message="Can't change `document_count`.") - req.pop("document_count") - if req.get("chunk_method"): - if kb.chunk_num != 0 and req["chunk_method"] != kb.parser_id: - return get_error_data_result(message="If `chunk_count` is not 0, `chunk_method` is not changeable.") - req["parser_id"] = req.pop("chunk_method") - if req["parser_id"] != kb.parser_id: - if not req.get("parser_config"): - req["parser_config"] = get_parser_config(chunk_method, parser_config) - if "embedding_model" in req: - if kb.chunk_num != 0 and req["embedding_model"] != kb.embd_id: - return get_error_data_result(message="If `chunk_count` is not 0, `embedding_model` is not changeable.") - if not req.get("embedding_model"): - return get_error_data_result("`embedding_model` can't be empty") - valid_embedding_models = [ - "BAAI/bge-large-zh-v1.5", - "BAAI/bge-base-en-v1.5", - "BAAI/bge-large-en-v1.5", - "BAAI/bge-small-en-v1.5", - "BAAI/bge-small-zh-v1.5", - "jinaai/jina-embeddings-v2-base-en", - "jinaai/jina-embeddings-v2-small-en", - "nomic-ai/nomic-embed-text-v1.5", - "sentence-transformers/all-MiniLM-L6-v2", - "text-embedding-v2", - "text-embedding-v3", - "maidalun1020/bce-embedding-base_v1", - ] - embd_model = LLMService.query(llm_name=req["embedding_model"], model_type="embedding") - if embd_model: - if req["embedding_model"] not in valid_embedding_models and not TenantLLMService.query( - tenant_id=tenant_id, - model_type="embedding", - llm_name=req.get("embedding_model"), - ): - return get_error_data_result(f"`embedding_model` {req.get('embedding_model')} doesn't exist") - if not embd_model: - embd_model = TenantLLMService.query(tenant_id=tenant_id, model_type="embedding", llm_name=req.get("embedding_model")) + # Field name transformations during model dump: + # | Original | Dump Output | + # |----------------|-------------| + # | embedding_model| embd_id | + # | chunk_method | parser_id | + extras = {"dataset_id": dataset_id} + req, err = validate_and_parse_json_request(request, UpdateDatasetReq, extras=extras, exclude_unset=True) + if err is not None: + return get_error_argument_result(err) + + if not req: + return get_error_argument_result(message="No properties were modified") + + try: + kb = KnowledgebaseService.get_or_none(id=dataset_id, tenant_id=tenant_id) + if kb is None: + return get_error_data_result(message=f"User '{tenant_id}' lacks permission for dataset '{dataset_id}'") + except OperationalError as e: + logging.exception(e) + return get_error_data_result(message="Database operation failed") + + if req.get("parser_config"): + req["parser_config"] = deep_merge(kb.parser_config, req["parser_config"]) + + if (chunk_method := req.get("parser_id")) and chunk_method != kb.parser_id and req.get("parser_config") is None: + req["parser_config"] = get_parser_config(chunk_method, None) + + if "name" in req and req["name"].lower() != kb.name.lower(): + try: + exists = KnowledgebaseService.get_or_none(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value) + if exists: + return get_error_data_result(message=f"Dataset name '{req['name']}' already exists") + except OperationalError as e: + logging.exception(e) + return get_error_data_result(message="Database operation failed") + + if "embd_id" in req: + if kb.chunk_num != 0 and req["embd_id"] != kb.embd_id: + return get_error_data_result(message=f"When chunk_num ({kb.chunk_num}) > 0, embedding_model must remain {kb.embd_id}") + ok, err = verify_embedding_availability(req["embd_id"], tenant_id) + if not ok: + return err + + try: + if not KnowledgebaseService.update_by_id(kb.id, req): + return get_error_data_result(message="Update dataset error.(Database error)") + except OperationalError as e: + logging.exception(e) + return get_error_data_result(message="Database operation failed") - if not embd_model: - return get_error_data_result(f"`embedding_model` {req.get('embedding_model')} doesn't exist") - req["embd_id"] = req.pop("embedding_model") - if "name" in req: - req["name"] = req["name"].strip() - if len(req["name"]) >= 128: - return get_error_data_result(message="Dataset name should not be longer than 128 characters.") - if req["name"].lower() != kb.name.lower() and len(KnowledgebaseService.query(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value)) > 0: - return get_error_data_result(message="Duplicated dataset name in updating dataset.") - flds = list(req.keys()) - for f in flds: - if req[f] == "" and f in ["permission", "parser_id", "chunk_method"]: - del req[f] - if not KnowledgebaseService.update_by_id(kb.id, req): - return get_error_data_result(message="Update dataset error.(Database error)") return get_result(code=settings.RetCode.SUCCESS) diff --git a/api/utils/api_utils.py b/api/utils/api_utils.py index a9b9b76ae..70e1282bd 100644 --- a/api/utils/api_utils.py +++ b/api/utils/api_utils.py @@ -19,6 +19,7 @@ import logging import random import time from base64 import b64encode +from copy import deepcopy from functools import wraps from hmac import HMAC from io import BytesIO @@ -333,22 +334,6 @@ def generate_confirmation_token(tenant_id): return "ragflow-" + serializer.dumps(get_uuid(), salt=tenant_id)[2:34] -def valid(permission, valid_permission, chunk_method, valid_chunk_method): - if valid_parameter(permission, valid_permission): - return valid_parameter(permission, valid_permission) - if valid_parameter(chunk_method, valid_chunk_method): - return valid_parameter(chunk_method, valid_chunk_method) - - -def valid_parameter(parameter, valid_values): - if parameter and parameter not in valid_values: - return get_error_data_result(f"'{parameter}' is not in {valid_values}") - - -def dataset_readonly_fields(field_name): - return field_name in ["chunk_count", "create_date", "create_time", "update_date", "update_time", "created_by", "document_count", "token_num", "status", "tenant_id", "id"] - - def get_parser_config(chunk_method, parser_config): if parser_config: return parser_config @@ -402,43 +387,6 @@ def get_data_openai( } -def valid_parser_config(parser_config): - if not parser_config: - return - scopes = set( - [ - "chunk_token_num", - "delimiter", - "raptor", - "graphrag", - "layout_recognize", - "task_page_size", - "pages", - "html4excel", - "auto_keywords", - "auto_questions", - "tag_kb_ids", - "topn_tags", - "filename_embd_weight", - ] - ) - for k in parser_config.keys(): - assert k in scopes, f"Abnormal 'parser_config'. Invalid key: {k}" - - assert isinstance(parser_config.get("chunk_token_num", 1), int), "chunk_token_num should be int" - assert 1 <= parser_config.get("chunk_token_num", 1) < 100000000, "chunk_token_num should be in range from 1 to 100000000" - assert isinstance(parser_config.get("task_page_size", 1), int), "task_page_size should be int" - assert 1 <= parser_config.get("task_page_size", 1) < 100000000, "task_page_size should be in range from 1 to 100000000" - assert isinstance(parser_config.get("auto_keywords", 1), int), "auto_keywords should be int" - assert 0 <= parser_config.get("auto_keywords", 0) < 32, "auto_keywords should be in range from 0 to 32" - assert isinstance(parser_config.get("auto_questions", 1), int), "auto_questions should be int" - assert 0 <= parser_config.get("auto_questions", 0) < 10, "auto_questions should be in range from 0 to 10" - assert isinstance(parser_config.get("topn_tags", 1), int), "topn_tags should be int" - assert 0 <= parser_config.get("topn_tags", 0) < 10, "topn_tags should be in range from 0 to 10" - assert isinstance(parser_config.get("html4excel", False), bool), "html4excel should be True or False" - assert isinstance(parser_config.get("delimiter", ""), str), "delimiter should be str" - - def check_duplicate_ids(ids, id_type="item"): """ Check for duplicate IDs in a list and return unique IDs and error messages. @@ -469,7 +417,8 @@ def check_duplicate_ids(ids, id_type="item"): def verify_embedding_availability(embd_id: str, tenant_id: str) -> tuple[bool, Response | None]: - """Verifies availability of an embedding model for a specific tenant. + """ + Verifies availability of an embedding model for a specific tenant. Implements a four-stage validation process: 1. Model identifier parsing and validation @@ -518,3 +467,50 @@ def verify_embedding_availability(embd_id: str, tenant_id: str) -> tuple[bool, R return False, get_error_data_result(message="Database operation failed") return True, None + + +def deep_merge(default: dict, custom: dict) -> dict: + """ + Recursively merges two dictionaries with priority given to `custom` values. + + Creates a deep copy of the `default` dictionary and iteratively merges nested + dictionaries using a stack-based approach. Non-dict values in `custom` will + completely override corresponding entries in `default`. + + Args: + default (dict): Base dictionary containing default values. + custom (dict): Dictionary containing overriding values. + + Returns: + dict: New merged dictionary combining values from both inputs. + + Example: + >>> from copy import deepcopy + >>> default = {"a": 1, "nested": {"x": 10, "y": 20}} + >>> custom = {"b": 2, "nested": {"y": 99, "z": 30}} + >>> deep_merge(default, custom) + {'a': 1, 'b': 2, 'nested': {'x': 10, 'y': 99, 'z': 30}} + + >>> deep_merge({"config": {"mode": "auto"}}, {"config": "manual"}) + {'config': 'manual'} + + Notes: + 1. Merge priority is always given to `custom` values at all nesting levels + 2. Non-dict values (e.g. list, str) in `custom` will replace entire values + in `default`, even if the original value was a dictionary + 3. Time complexity: O(N) where N is total key-value pairs in `custom` + 4. Recommended for configuration merging and nested data updates + """ + merged = deepcopy(default) + stack = [(merged, custom)] + + while stack: + base_dict, override_dict = stack.pop() + + for key, val in override_dict.items(): + if key in base_dict and isinstance(val, dict) and isinstance(base_dict[key], dict): + stack.append((base_dict[key], val)) + else: + base_dict[key] = val + + return merged diff --git a/api/utils/validation_utils.py b/api/utils/validation_utils.py index d16e80149..c752fbcd9 100644 --- a/api/utils/validation_utils.py +++ b/api/utils/validation_utils.py @@ -13,21 +13,23 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import uuid from enum import auto from typing import Annotated, Any from flask import Request -from pydantic import BaseModel, Field, StringConstraints, ValidationError, field_validator +from pydantic import UUID1, BaseModel, Field, StringConstraints, ValidationError, field_serializer, field_validator from strenum import StrEnum from werkzeug.exceptions import BadRequest, UnsupportedMediaType from api.constants import DATASET_NAME_LIMIT -def validate_and_parse_json_request(request: Request, validator: type[BaseModel]) -> tuple[dict[str, Any] | None, str | None]: - """Validates and parses JSON requests through a multi-stage validation pipeline. +def validate_and_parse_json_request(request: Request, validator: type[BaseModel], *, extras: dict[str, Any] | None = None, exclude_unset: bool = False) -> tuple[dict[str, Any] | None, str | None]: + """ + Validates and parses JSON requests through a multi-stage validation pipeline. - Implements a robust four-stage validation process: + Implements a four-stage validation process: 1. Content-Type verification (must be application/json) 2. JSON syntax validation 3. Payload structure type checking @@ -35,6 +37,10 @@ def validate_and_parse_json_request(request: Request, validator: type[BaseModel] Args: request (Request): Flask request object containing HTTP payload + validator (type[BaseModel]): Pydantic model class for data validation + extras (dict[str, Any] | None): Additional fields to merge into payload + before validation. These fields will be removed from the final output + exclude_unset (bool): Whether to exclude fields that have not been explicitly set Returns: tuple[Dict[str, Any] | None, str | None]: @@ -46,26 +52,26 @@ def validate_and_parse_json_request(request: Request, validator: type[BaseModel] - Diagnostic error message on failure Raises: - UnsupportedMediaType: When Content-Type ≠ application/json + UnsupportedMediaType: When Content-Type header is not application/json BadRequest: For structural JSON syntax errors ValidationError: When payload violates Pydantic schema rules Examples: - Successful validation: - ```python - # Input: {"name": "Dataset1", "format": "csv"} - # Returns: ({"name": "Dataset1", "format": "csv"}, None) - ``` + >>> validate_and_parse_json_request(valid_request, DatasetSchema) + ({"name": "Dataset1", "format": "csv"}, None) - Invalid Content-Type: - ```python - # Returns: (None, "Unsupported content type: Expected application/json, got text/xml") - ``` + >>> validate_and_parse_json_request(xml_request, DatasetSchema) + (None, "Unsupported content type: Expected application/json, got text/xml") - Malformed JSON: - ```python - # Returns: (None, "Malformed JSON syntax: Missing commas/brackets or invalid encoding") - ``` + >>> validate_and_parse_json_request(bad_json_request, DatasetSchema) + (None, "Malformed JSON syntax: Missing commas/brackets or invalid encoding") + + Notes: + 1. Validation Priority: + - Content-Type verification precedes JSON parsing + - Structural validation occurs before schema validation + 2. Extra fields added via `extras` parameter are automatically removed + from the final output after validation """ try: payload = request.get_json() or {} @@ -78,17 +84,25 @@ def validate_and_parse_json_request(request: Request, validator: type[BaseModel] return None, f"Invalid request payload: expected object, got {type(payload).__name__}" try: + if extras is not None: + payload.update(extras) validated_request = validator(**payload) except ValidationError as e: return None, format_validation_error_message(e) - parsed_payload = validated_request.model_dump(by_alias=True) + parsed_payload = validated_request.model_dump(by_alias=True, exclude_unset=exclude_unset) + + if extras is not None: + for key in list(parsed_payload.keys()): + if key in extras: + del parsed_payload[key] return parsed_payload, None def format_validation_error_message(e: ValidationError) -> str: - """Formats validation errors into a standardized string format. + """ + Formats validation errors into a standardized string format. Processes pydantic ValidationError objects to create human-readable error messages containing field locations, error descriptions, and input values. @@ -155,7 +169,6 @@ class GraphragMethodEnum(StrEnum): class Base(BaseModel): class Config: extra = "forbid" - json_schema_extra = {"charset": "utf8mb4", "collation": "utf8mb4_0900_ai_ci"} class RaptorConfig(Base): @@ -201,16 +214,17 @@ class CreateDatasetReq(Base): name: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1, max_length=DATASET_NAME_LIMIT), Field(...)] avatar: str | None = Field(default=None, max_length=65535) description: str | None = Field(default=None, max_length=65535) - embedding_model: Annotated[str | None, StringConstraints(strip_whitespace=True, max_length=255), Field(default=None, serialization_alias="embd_id")] + embedding_model: Annotated[str, StringConstraints(strip_whitespace=True, max_length=255), Field(default="", serialization_alias="embd_id")] permission: Annotated[PermissionEnum, StringConstraints(strip_whitespace=True, min_length=1, max_length=16), Field(default=PermissionEnum.me)] chunk_method: Annotated[ChunkMethodnEnum, StringConstraints(strip_whitespace=True, min_length=1, max_length=32), Field(default=ChunkMethodnEnum.naive, serialization_alias="parser_id")] pagerank: int = Field(default=0, ge=0, le=100) - parser_config: ParserConfig | None = Field(default=None) + parser_config: ParserConfig = Field(default_factory=dict) @field_validator("avatar") @classmethod - def validate_avatar_base64(cls, v: str) -> str: - """Validates Base64-encoded avatar string format and MIME type compliance. + def validate_avatar_base64(cls, v: str | None) -> str | None: + """ + Validates Base64-encoded avatar string format and MIME type compliance. Implements a three-stage validation workflow: 1. MIME prefix existence check @@ -259,7 +273,8 @@ class CreateDatasetReq(Base): @field_validator("embedding_model", mode="after") @classmethod def validate_embedding_model(cls, v: str) -> str: - """Validates embedding model identifier format compliance. + """ + Validates embedding model identifier format compliance. Validation pipeline: 1. Structural format verification @@ -298,11 +313,12 @@ class CreateDatasetReq(Base): @field_validator("permission", mode="before") @classmethod - def permission_auto_lowercase(cls, v: str) -> str: - """Normalize permission input to lowercase for consistent PermissionEnum matching. + def permission_auto_lowercase(cls, v: Any) -> Any: + """ + Normalize permission input to lowercase for consistent PermissionEnum matching. Args: - v (str): Raw input value for the permission field + v (Any): Raw input value for the permission field Returns: Lowercase string if input is string type, otherwise returns original value @@ -316,13 +332,13 @@ class CreateDatasetReq(Base): @field_validator("parser_config", mode="after") @classmethod - def validate_parser_config_json_length(cls, v: ParserConfig | None) -> ParserConfig | None: - """Validates serialized JSON length constraints for parser configuration. + def validate_parser_config_json_length(cls, v: ParserConfig) -> ParserConfig: + """ + Validates serialized JSON length constraints for parser configuration. - Implements a three-stage validation workflow: - 1. Null check - bypass validation for empty configurations - 2. Model serialization - convert Pydantic model to JSON string - 3. Size verification - enforce maximum allowed payload size + Implements a two-stage validation workflow: + 1. Model serialization - convert Pydantic model to JSON string + 2. Size verification - enforce maximum allowed payload size Args: v (ParserConfig | None): Raw parser configuration object @@ -333,9 +349,15 @@ class CreateDatasetReq(Base): Raises: ValueError: When serialized JSON exceeds 65,535 characters """ - if v is None: - return v - if (json_str := v.model_dump_json()) and len(json_str) > 65535: raise ValueError(f"Parser config exceeds size limit (max 65,535 characters). Current size: {len(json_str):,}") return v + + +class UpdateDatasetReq(CreateDatasetReq): + dataset_id: UUID1 = Field(...) + name: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1, max_length=DATASET_NAME_LIMIT), Field(default="")] + + @field_serializer("dataset_id") + def serialize_uuid_to_hex(self, v: uuid.UUID) -> str: + return v.hex diff --git a/docs/references/http_api_reference.md b/docs/references/http_api_reference.md index 0843ab124..cc3634519 100644 --- a/docs/references/http_api_reference.md +++ b/docs/references/http_api_reference.md @@ -385,7 +385,7 @@ curl --request POST \ - `"team"`: All team members can manage the dataset. - `"pagerank"`: (*Body parameter*), `int` - Set page rank: refer to [Set page rank](https://ragflow.io/docs/dev/set_page_rank) + refer to [Set page rank](https://ragflow.io/docs/dev/set_page_rank) - Default: `0` - Minimum: `0` - Maximum: `100` @@ -562,8 +562,13 @@ Updates configurations for a specified dataset. - `'Authorization: Bearer '` - Body: - `"name"`: `string` + - `"avatar"`: `string` + - `"description"`: `string` - `"embedding_model"`: `string` - - `"chunk_method"`: `enum` + - `"permission"`: `string` + - `"chunk_method"`: `string` + - `"pagerank"`: `int` + - `"parser_config"`: `object` ##### Request example @@ -584,22 +589,74 @@ curl --request PUT \ The ID of the dataset to update. - `"name"`: (*Body parameter*), `string` The revised name of the dataset. + - Basic Multilingual Plane (BMP) only + - Maximum 128 characters + - Case-insensitive +- `"avatar"`: (*Body parameter*), `string` + The updated base64 encoding of the avatar. + - Maximum 65535 characters - `"embedding_model"`: (*Body parameter*), `string` The updated embedding model name. - Ensure that `"chunk_count"` is `0` before updating `"embedding_model"`. + - Maximum 255 characters + - Must follow `model_name@model_factory` format +- `"permission"`: (*Body parameter*), `string` + The updated dataset permission. Available options: + - `"me"`: (Default) Only you can manage the dataset. + - `"team"`: All team members can manage the dataset. +- `"pagerank"`: (*Body parameter*), `int` + refer to [Set page rank](https://ragflow.io/docs/dev/set_page_rank) + - Default: `0` + - Minimum: `0` + - Maximum: `100` - `"chunk_method"`: (*Body parameter*), `enum` The chunking method for the dataset. Available options: - - `"naive"`: General - - `"manual`: Manual + - `"naive"`: General (default) + - `"book"`: Book + - `"email"`: Email + - `"laws"`: Laws + - `"manual"`: Manual + - `"one"`: One + - `"paper"`: Paper + - `"picture"`: Picture + - `"presentation"`: Presentation - `"qa"`: Q&A - `"table"`: Table - - `"paper"`: Paper - - `"book"`: Book - - `"laws"`: Laws - - `"presentation"`: Presentation - - `"picture"`: Picture - - `"one"`:One - - `"email"`: Email + - `"tag"`: Tag +- `"parser_config"`: (*Body parameter*), `object` + The configuration settings for the dataset parser. The attributes in this JSON object vary with the selected `"chunk_method"`: + - If `"chunk_method"` is `"naive"`, the `"parser_config"` object contains the following attributes: + - `"auto_keywords"`: `int` + - Defaults to `0` + - Minimum: `0` + - Maximum: `32` + - `"auto_questions"`: `int` + - Defaults to `0` + - Minimum: `0` + - Maximum: `10` + - `"chunk_token_num"`: `int` + - Defaults to `128` + - Minimum: `1` + - Maximum: `2048` + - `"delimiter"`: `string` + - Defaults to `"\n"`. + - `"html4excel"`: `bool` Indicates whether to convert Excel documents into HTML format. + - Defaults to `false` + - `"layout_recognize"`: `string` + - Defaults to `DeepDOC` + - `"tag_kb_ids"`: `array` refer to [Use tag set](https://ragflow.io/docs/dev/use_tag_sets) + - Must include a list of dataset IDs, where each dataset is parsed using the ​​Tag Chunk Method + - `"task_page_size"`: `int` For PDF only. + - Defaults to `12` + - Minimum: `1` + - `"raptor"`: `object` RAPTOR-specific settings. + - Defaults to: `{"use_raptor": false}` + - `"graphrag"`: `object` GRAPHRAG-specific settings. + - Defaults to: `{"use_graphrag": false}` + - If `"chunk_method"` is `"qa"`, `"manuel"`, `"paper"`, `"book"`, `"laws"`, or `"presentation"`, the `"parser_config"` object contains the following attribute: + - `"raptor"`: `object` RAPTOR-specific settings. + - Defaults to: `{"use_raptor": false}`. + - If `"chunk_method"` is `"table"`, `"picture"`, `"one"`, or `"email"`, `"parser_config"` is an empty JSON object. #### Response diff --git a/docs/references/python_api_reference.md b/docs/references/python_api_reference.md index 5925ca6e1..be827f720 100644 --- a/docs/references/python_api_reference.md +++ b/docs/references/python_api_reference.md @@ -306,20 +306,40 @@ Updates configurations for the current dataset. A dictionary representing the attributes to update, with the following keys: - `"name"`: `str` The revised name of the dataset. -- `"embedding_model"`: `str` The updated embedding model name. + - Basic Multilingual Plane (BMP) only + - Maximum 128 characters + - Case-insensitive +- `"avatar"`: (*Body parameter*), `string` + The updated base64 encoding of the avatar. + - Maximum 65535 characters +- `"embedding_model"`: (*Body parameter*), `string` + The updated embedding model name. - Ensure that `"chunk_count"` is `0` before updating `"embedding_model"`. -- `"chunk_method"`: `str` The chunking method for the dataset. Available options: - - `"naive"`: General - - `"manual`: Manual + - Maximum 255 characters + - Must follow `model_name@model_factory` format +- `"permission"`: (*Body parameter*), `string` + The updated dataset permission. Available options: + - `"me"`: (Default) Only you can manage the dataset. + - `"team"`: All team members can manage the dataset. +- `"pagerank"`: (*Body parameter*), `int` + refer to [Set page rank](https://ragflow.io/docs/dev/set_page_rank) + - Default: `0` + - Minimum: `0` + - Maximum: `100` +- `"chunk_method"`: (*Body parameter*), `enum` + The chunking method for the dataset. Available options: + - `"naive"`: General (default) + - `"book"`: Book + - `"email"`: Email + - `"laws"`: Laws + - `"manual"`: Manual + - `"one"`: One + - `"paper"`: Paper + - `"picture"`: Picture + - `"presentation"`: Presentation - `"qa"`: Q&A - `"table"`: Table - - `"paper"`: Paper - - `"book"`: Book - - `"laws"`: Laws - - `"presentation"`: Presentation - - `"picture"`: Picture - - `"one"`: One - - `"email"`: Email + - `"tag"`: Tag #### Returns diff --git a/sdk/python/ragflow_sdk/ragflow.py b/sdk/python/ragflow_sdk/ragflow.py index 1c283651c..9b6267cdd 100644 --- a/sdk/python/ragflow_sdk/ragflow.py +++ b/sdk/python/ragflow_sdk/ragflow.py @@ -59,21 +59,19 @@ class RAGFlow: pagerank: int = 0, parser_config: DataSet.ParserConfig = None, ) -> DataSet: - if parser_config: - parser_config = parser_config.to_json() - res = self.post( - "/datasets", - { - "name": name, - "avatar": avatar, - "description": description, - "embedding_model": embedding_model, - "permission": permission, - "chunk_method": chunk_method, - "pagerank": pagerank, - "parser_config": parser_config, - }, - ) + payload = { + "name": name, + "avatar": avatar, + "description": description, + "embedding_model": embedding_model, + "permission": permission, + "chunk_method": chunk_method, + "pagerank": pagerank, + } + if parser_config is not None: + payload["parser_config"] = parser_config.to_json() + + res = self.post("/datasets", payload) res = res.json() if res.get("code") == 0: return DataSet(self, res["data"]) diff --git a/sdk/python/test/libs/utils/hypothesis_utils.py b/sdk/python/test/libs/utils/hypothesis_utils.py new file mode 100644 index 000000000..736e6cbdf --- /dev/null +++ b/sdk/python/test/libs/utils/hypothesis_utils.py @@ -0,0 +1,28 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +import hypothesis.strategies as st + + +@st.composite +def valid_names(draw): + base_chars = "abcdefghijklmnopqrstuvwxyz_" + first_char = draw(st.sampled_from([c for c in base_chars if c.isalpha() or c == "_"])) + remaining = draw(st.text(alphabet=st.sampled_from(base_chars), min_size=0, max_size=128 - 2)) + + name = (first_char + remaining)[:128] + return name.encode("utf-8").decode("utf-8") diff --git a/sdk/python/test/test_http_api/common.py b/sdk/python/test/test_http_api/common.py index 770f716e4..f3010b589 100644 --- a/sdk/python/test/test_http_api/common.py +++ b/sdk/python/test/test_http_api/common.py @@ -39,23 +39,23 @@ SESSION_WITH_CHAT_NAME_LIMIT = 255 # DATASET MANAGEMENT -def create_dataset(auth, payload=None, headers=HEADERS, data=None): +def create_dataset(auth, payload=None, *, headers=HEADERS, data=None): res = requests.post(url=f"{HOST_ADDRESS}{DATASETS_API_URL}", headers=headers, auth=auth, json=payload, data=data) return res.json() -def list_datasets(auth, params=None, headers=HEADERS): +def list_datasets(auth, params=None, *, headers=HEADERS): res = requests.get(url=f"{HOST_ADDRESS}{DATASETS_API_URL}", headers=headers, auth=auth, params=params) return res.json() -def update_dataset(auth, dataset_id, payload=None, headers=HEADERS): - res = requests.put(url=f"{HOST_ADDRESS}{DATASETS_API_URL}/{dataset_id}", headers=headers, auth=auth, json=payload) +def update_dataset(auth, dataset_id, payload=None, *, headers=HEADERS, data=None): + res = requests.put(url=f"{HOST_ADDRESS}{DATASETS_API_URL}/{dataset_id}", headers=headers, auth=auth, json=payload, data=data) return res.json() -def delete_datasets(auth, payload=None, headers=HEADERS): - res = requests.delete(url=f"{HOST_ADDRESS}{DATASETS_API_URL}", headers=headers, auth=auth, json=payload) +def delete_datasets(auth, payload=None, *, headers=HEADERS, data=None): + res = requests.delete(url=f"{HOST_ADDRESS}{DATASETS_API_URL}", headers=headers, auth=auth, json=payload, data=data) return res.json() diff --git a/sdk/python/test/test_http_api/test_dataset_mangement/conftest.py b/sdk/python/test/test_http_api/test_dataset_mangement/conftest.py index 6f7582912..8694ccead 100644 --- a/sdk/python/test/test_http_api/test_dataset_mangement/conftest.py +++ b/sdk/python/test/test_http_api/test_dataset_mangement/conftest.py @@ -37,3 +37,13 @@ def add_datasets_func(get_http_api_auth, request): request.addfinalizer(cleanup) return batch_create_datasets(get_http_api_auth, 3) + + +@pytest.fixture(scope="function") +def add_dataset_func(get_http_api_auth, request): + def cleanup(): + delete_datasets(get_http_api_auth) + + request.addfinalizer(cleanup) + + return batch_create_datasets(get_http_api_auth, 1)[0] diff --git a/sdk/python/test/test_http_api/test_dataset_mangement/test_create_dataset.py b/sdk/python/test/test_http_api/test_dataset_mangement/test_create_dataset.py index 2cbd26757..3a70d248d 100644 --- a/sdk/python/test/test_http_api/test_dataset_mangement/test_create_dataset.py +++ b/sdk/python/test/test_http_api/test_dataset_mangement/test_create_dataset.py @@ -13,30 +13,20 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from concurrent.futures import ThreadPoolExecutor - -import hypothesis.strategies as st import pytest from common import DATASET_NAME_LIMIT, INVALID_API_TOKEN, create_dataset from hypothesis import example, given, settings from libs.auth import RAGFlowHttpApiAuth from libs.utils import encode_avatar from libs.utils.file_utils import create_image_file +from libs.utils.hypothesis_utils import valid_names -@st.composite -def valid_names(draw): - base_chars = "abcdefghijklmnopqrstuvwxyz_" - first_char = draw(st.sampled_from([c for c in base_chars if c.isalpha() or c == "_"])) - remaining = draw(st.text(alphabet=st.sampled_from(base_chars), min_size=0, max_size=DATASET_NAME_LIMIT - 2)) - - name = (first_char + remaining)[:128] - return name.encode("utf-8").decode("utf-8") - - -@pytest.mark.p1 @pytest.mark.usefixtures("clear_datasets") class TestAuthorization: + @pytest.mark.p1 @pytest.mark.parametrize( "auth, expected_code, expected_message", [ @@ -49,64 +39,17 @@ class TestAuthorization: ], ids=["empty_auth", "invalid_api_token"], ) - def test_invalid_auth(self, auth, expected_code, expected_message): + def test_auth_invalid(self, auth, expected_code, expected_message): res = create_dataset(auth, {"name": "auth_test"}) - assert res["code"] == expected_code - assert res["message"] == expected_message + assert res["code"] == expected_code, res + assert res["message"] == expected_message, res -@pytest.mark.usefixtures("clear_datasets") -class TestDatasetCreation: - @pytest.mark.p1 - @given(name=valid_names()) - @example("a" * 128) - @settings(max_examples=20) - def test_valid_name(self, get_http_api_auth, name): - res = create_dataset(get_http_api_auth, {"name": name}) - assert res["code"] == 0, res - assert res["data"]["name"] == name, res - - @pytest.mark.p1 - @pytest.mark.parametrize( - "name, expected_message", - [ - ("", "String should have at least 1 character"), - (" ", "String should have at least 1 character"), - ("a" * (DATASET_NAME_LIMIT + 1), "String should have at most 128 characters"), - (0, "Input should be a valid string"), - ], - ids=["empty_name", "space_name", "too_long_name", "invalid_name"], - ) - def test_invalid_name(self, get_http_api_auth, name, expected_message): - res = create_dataset(get_http_api_auth, {"name": name}) - assert res["code"] == 101, res - assert expected_message in res["message"], res - - @pytest.mark.p2 - def test_duplicated_name(self, get_http_api_auth): - name = "duplicated_name" - payload = {"name": name} - res = create_dataset(get_http_api_auth, payload) - assert res["code"] == 0, res - - res = create_dataset(get_http_api_auth, payload) - assert res["code"] == 101, res - assert res["message"] == f"Dataset name '{name}' already exists", res - - @pytest.mark.p2 - def test_case_insensitive(self, get_http_api_auth): - name = "CaseInsensitive" - res = create_dataset(get_http_api_auth, {"name": name.upper()}) - assert res["code"] == 0, res - - res = create_dataset(get_http_api_auth, {"name": name.lower()}) - assert res["code"] == 101, res - assert res["message"] == f"Dataset name '{name.lower()}' already exists", res - +class TestRquest: @pytest.mark.p3 - def test_bad_content_type(self, get_http_api_auth): + def test_content_type_bad(self, get_http_api_auth): BAD_CONTENT_TYPE = "text/xml" - res = create_dataset(get_http_api_auth, {"name": "name"}, {"Content-Type": BAD_CONTENT_TYPE}) + res = create_dataset(get_http_api_auth, {"name": "bad_content_type"}, headers={"Content-Type": BAD_CONTENT_TYPE}) assert res["code"] == 101, res assert res["message"] == f"Unsupported content type: Expected application/json, got {BAD_CONTENT_TYPE}", res @@ -115,15 +58,85 @@ class TestDatasetCreation: "payload, expected_message", [ ("a", "Malformed JSON syntax: Missing commas/brackets or invalid encoding"), - ('"a"', "Invalid request payload: expected objec"), + ('"a"', "Invalid request payload: expected object, got str"), ], ids=["malformed_json_syntax", "invalid_request_payload_type"], ) - def test_bad_payload(self, get_http_api_auth, payload, expected_message): + def test_payload_bad(self, get_http_api_auth, payload, expected_message): res = create_dataset(get_http_api_auth, data=payload) assert res["code"] == 101, res + assert res["message"] == expected_message, res + + +@pytest.mark.usefixtures("clear_datasets") +class TestCapability: + @pytest.mark.p3 + def test_create_dataset_1k(self, get_http_api_auth): + for i in range(1_000): + payload = {"name": f"dataset_{i}"} + res = create_dataset(get_http_api_auth, payload) + assert res["code"] == 0, f"Failed to create dataset {i}" + + @pytest.mark.p3 + def test_create_dataset_concurrent(self, get_http_api_auth): + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [executor.submit(create_dataset, get_http_api_auth, {"name": f"dataset_{i}"}) for i in range(100)] + responses = [f.result() for f in futures] + assert all(r["code"] == 0 for r in responses), responses + + +@pytest.mark.usefixtures("clear_datasets") +class TestDatasetCreate: + @pytest.mark.p1 + @given(name=valid_names()) + @example("a" * 128) + @settings(max_examples=20) + def test_name(self, get_http_api_auth, name): + res = create_dataset(get_http_api_auth, {"name": name}) + assert res["code"] == 0, res + assert res["data"]["name"] == name, res + + @pytest.mark.p2 + @pytest.mark.parametrize( + "name, expected_message", + [ + ("", "String should have at least 1 character"), + (" ", "String should have at least 1 character"), + ("a" * (DATASET_NAME_LIMIT + 1), "String should have at most 128 characters"), + (0, "Input should be a valid string"), + (None, "Input should be a valid string"), + ], + ids=["empty_name", "space_name", "too_long_name", "invalid_name", "None_name"], + ) + def test_name_invalid(self, get_http_api_auth, name, expected_message): + payload = {"name": name} + res = create_dataset(get_http_api_auth, payload) + assert res["code"] == 101, res assert expected_message in res["message"], res + @pytest.mark.p3 + def test_name_duplicated(self, get_http_api_auth): + name = "duplicated_name" + payload = {"name": name} + res = create_dataset(get_http_api_auth, payload) + assert res["code"] == 0, res + + res = create_dataset(get_http_api_auth, payload) + assert res["code"] == 102, res + assert res["message"] == f"Dataset name '{name}' already exists", res + + @pytest.mark.p3 + def test_name_case_insensitive(self, get_http_api_auth): + name = "CaseInsensitive" + payload = {"name": name.upper()} + res = create_dataset(get_http_api_auth, payload) + assert res["code"] == 0, res + + payload = {"name": name.lower()} + res = create_dataset(get_http_api_auth, payload) + assert res["code"] == 102, res + assert res["message"] == f"Dataset name '{name.lower()}' already exists", res + @pytest.mark.p2 def test_avatar(self, get_http_api_auth, tmp_path): fn = create_image_file(tmp_path / "ragflow_test.png") @@ -134,16 +147,10 @@ class TestDatasetCreation: res = create_dataset(get_http_api_auth, payload) assert res["code"] == 0, res - @pytest.mark.p3 - def test_avatar_none(self, get_http_api_auth, tmp_path): - payload = {"name": "test_avatar_none", "avatar": None} - res = create_dataset(get_http_api_auth, payload) - assert res["code"] == 0, res - assert res["data"]["avatar"] is None, res - @pytest.mark.p2 def test_avatar_exceeds_limit_length(self, get_http_api_auth): - res = create_dataset(get_http_api_auth, {"name": "exceeds_limit_length_avatar", "avatar": "a" * 65536}) + payload = {"name": "exceeds_limit_length_avatar", "avatar": "a" * 65536} + res = create_dataset(get_http_api_auth, payload) assert res["code"] == 101, res assert "String should have at most 65535 characters" in res["message"], res @@ -158,7 +165,7 @@ class TestDatasetCreation: ], ids=["empty_prefix", "missing_comma", "unsupported_mine_type", "invalid_mine_type"], ) - def test_invalid_avatar_prefix(self, get_http_api_auth, tmp_path, name, avatar_prefix, expected_message): + def test_avatar_invalid_prefix(self, get_http_api_auth, tmp_path, name, avatar_prefix, expected_message): fn = create_image_file(tmp_path / "ragflow_test.png") payload = { "name": name, @@ -169,11 +176,25 @@ class TestDatasetCreation: assert expected_message in res["message"], res @pytest.mark.p3 - def test_description_none(self, get_http_api_auth): - payload = {"name": "test_description_none", "description": None} + def test_avatar_unset(self, get_http_api_auth): + payload = {"name": "test_avatar_unset"} res = create_dataset(get_http_api_auth, payload) assert res["code"] == 0, res - assert res["data"]["description"] is None, res + assert res["data"]["avatar"] is None, res + + @pytest.mark.p3 + def test_avatar_none(self, get_http_api_auth): + payload = {"name": "test_avatar_none", "avatar": None} + res = create_dataset(get_http_api_auth, payload) + assert res["code"] == 0, res + assert res["data"]["avatar"] is None, res + + @pytest.mark.p2 + def test_description(self, get_http_api_auth): + payload = {"name": "test_description", "description": "description"} + res = create_dataset(get_http_api_auth, payload) + assert res["code"] == 0, res + assert res["data"]["description"] == "description", res @pytest.mark.p2 def test_description_exceeds_limit_length(self, get_http_api_auth): @@ -182,6 +203,20 @@ class TestDatasetCreation: assert res["code"] == 101, res assert "String should have at most 65535 characters" in res["message"], res + @pytest.mark.p3 + def test_description_unset(self, get_http_api_auth): + payload = {"name": "test_description_unset"} + res = create_dataset(get_http_api_auth, payload) + assert res["code"] == 0, res + assert res["data"]["description"] is None, res + + @pytest.mark.p3 + def test_description_none(self, get_http_api_auth): + payload = {"name": "test_description_none", "description": None} + res = create_dataset(get_http_api_auth, payload) + assert res["code"] == 0, res + assert res["data"]["description"] is None, res + @pytest.mark.p1 @pytest.mark.parametrize( "name, embedding_model", @@ -189,22 +224,14 @@ class TestDatasetCreation: ("BAAI/bge-large-zh-v1.5@BAAI", "BAAI/bge-large-zh-v1.5@BAAI"), ("maidalun1020/bce-embedding-base_v1@Youdao", "maidalun1020/bce-embedding-base_v1@Youdao"), ("embedding-3@ZHIPU-AI", "embedding-3@ZHIPU-AI"), - ("embedding_model_default", None), ], - ids=["builtin_baai", "builtin_youdao", "tenant_zhipu", "default"], + ids=["builtin_baai", "builtin_youdao", "tenant_zhipu"], ) - def test_valid_embedding_model(self, get_http_api_auth, name, embedding_model): - if embedding_model is None: - payload = {"name": name} - else: - payload = {"name": name, "embedding_model": embedding_model} - + def test_embedding_model(self, get_http_api_auth, name, embedding_model): + payload = {"name": name, "embedding_model": embedding_model} res = create_dataset(get_http_api_auth, payload) assert res["code"] == 0, res - if embedding_model is None: - assert res["data"]["embedding_model"] == "BAAI/bge-large-zh-v1.5@BAAI", res - else: - assert res["data"]["embedding_model"] == embedding_model, res + assert res["data"]["embedding_model"] == embedding_model, res @pytest.mark.p2 @pytest.mark.parametrize( @@ -217,7 +244,7 @@ class TestDatasetCreation: ], ids=["unknown_llm_name", "unknown_llm_factory", "tenant_no_auth_default_tenant_llm", "tenant_no_auth"], ) - def test_invalid_embedding_model(self, get_http_api_auth, name, embedding_model): + def test_embedding_model_invalid(self, get_http_api_auth, name, embedding_model): payload = {"name": name, "embedding_model": embedding_model} res = create_dataset(get_http_api_auth, payload) assert res["code"] == 101, res @@ -247,6 +274,20 @@ class TestDatasetCreation: else: assert "Both model_name and provider must be non-empty strings" in res["message"], res + @pytest.mark.p2 + def test_embedding_model_unset(self, get_http_api_auth): + payload = {"name": "embedding_model_unset"} + res = create_dataset(get_http_api_auth, payload) + assert res["code"] == 0, res + assert res["data"]["embedding_model"] == "BAAI/bge-large-zh-v1.5@BAAI", res + + @pytest.mark.p2 + def test_embedding_model_none(self, get_http_api_auth): + payload = {"name": "test_embedding_model_none", "embedding_model": None} + res = create_dataset(get_http_api_auth, payload) + assert res["code"] == 101, res + assert "Input should be a valid string" in res["message"], res + @pytest.mark.p1 @pytest.mark.parametrize( "name, permission", @@ -255,21 +296,14 @@ class TestDatasetCreation: ("team", "team"), ("me_upercase", "ME"), ("team_upercase", "TEAM"), - ("permission_default", None), ], - ids=["me", "team", "me_upercase", "team_upercase", "permission_default"], + ids=["me", "team", "me_upercase", "team_upercase"], ) - def test_valid_permission(self, get_http_api_auth, name, permission): - if permission is None: - payload = {"name": name} - else: - payload = {"name": name, "permission": permission} + def test_permission(self, get_http_api_auth, name, permission): + payload = {"name": name, "permission": permission} res = create_dataset(get_http_api_auth, payload) assert res["code"] == 0, res - if permission is None: - assert res["data"]["permission"] == "me", res - else: - assert res["data"]["permission"] == permission.lower(), res + assert res["data"]["permission"] == permission.lower(), res @pytest.mark.p2 @pytest.mark.parametrize( @@ -279,13 +313,28 @@ class TestDatasetCreation: ("unknown", "unknown"), ("type_error", list()), ], + ids=["empty", "unknown", "type_error"], ) - def test_invalid_permission(self, get_http_api_auth, name, permission): + def test_permission_invalid(self, get_http_api_auth, name, permission): payload = {"name": name, "permission": permission} res = create_dataset(get_http_api_auth, payload) assert res["code"] == 101 assert "Input should be 'me' or 'team'" in res["message"] + @pytest.mark.p2 + def test_permission_unset(self, get_http_api_auth): + payload = {"name": "test_permission_unset"} + res = create_dataset(get_http_api_auth, payload) + assert res["code"] == 0, res + assert res["data"]["permission"] == "me", res + + @pytest.mark.p3 + def test_permission_none(self, get_http_api_auth): + payload = {"name": "test_permission_none", "permission": None} + res = create_dataset(get_http_api_auth, payload) + assert res["code"] == 101, res + assert "Input should be 'me' or 'team'" in res["message"], res + @pytest.mark.p1 @pytest.mark.parametrize( "name, chunk_method", @@ -302,20 +351,14 @@ class TestDatasetCreation: ("qa", "qa"), ("table", "table"), ("tag", "tag"), - ("chunk_method_default", None), ], + ids=["naive", "book", "email", "laws", "manual", "one", "paper", "picture", "presentation", "qa", "table", "tag"], ) - def test_valid_chunk_method(self, get_http_api_auth, name, chunk_method): - if chunk_method is None: - payload = {"name": name} - else: - payload = {"name": name, "chunk_method": chunk_method} + def test_chunk_method(self, get_http_api_auth, name, chunk_method): + payload = {"name": name, "chunk_method": chunk_method} res = create_dataset(get_http_api_auth, payload) assert res["code"] == 0, res - if chunk_method is None: - assert res["data"]["chunk_method"] == "naive", res - else: - assert res["data"]["chunk_method"] == chunk_method, res + assert res["data"]["chunk_method"] == chunk_method, res @pytest.mark.p2 @pytest.mark.parametrize( @@ -325,19 +368,77 @@ class TestDatasetCreation: ("unknown", "unknown"), ("type_error", list()), ], + ids=["empty", "unknown", "type_error"], ) - def test_invalid_chunk_method(self, get_http_api_auth, name, chunk_method): + def test_chunk_method_invalid(self, get_http_api_auth, name, chunk_method): payload = {"name": name, "chunk_method": chunk_method} res = create_dataset(get_http_api_auth, payload) assert res["code"] == 101, res assert "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', 'picture', 'presentation', 'qa', 'table' or 'tag'" in res["message"], res + @pytest.mark.p2 + def test_chunk_method_unset(self, get_http_api_auth): + payload = {"name": "test_chunk_method_unset"} + res = create_dataset(get_http_api_auth, payload) + assert res["code"] == 0, res + assert res["data"]["chunk_method"] == "naive", res + + @pytest.mark.p3 + def test_chunk_method_none(self, get_http_api_auth): + payload = {"name": "chunk_method_none", "chunk_method": None} + res = create_dataset(get_http_api_auth, payload) + assert res["code"] == 101, res + assert "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', 'picture', 'presentation', 'qa', 'table' or 'tag'" in res["message"], res + + @pytest.mark.p2 + @pytest.mark.parametrize( + "name, pagerank", + [ + ("pagerank_min", 0), + ("pagerank_mid", 50), + ("pagerank_max", 100), + ], + ids=["min", "mid", "max"], + ) + def test_pagerank(self, get_http_api_auth, name, pagerank): + payload = {"name": name, "pagerank": pagerank} + res = create_dataset(get_http_api_auth, payload) + assert res["code"] == 0, res + assert res["data"]["pagerank"] == pagerank, res + + @pytest.mark.p3 + @pytest.mark.parametrize( + "name, pagerank, expected_message", + [ + ("pagerank_min_limit", -1, "Input should be greater than or equal to 0"), + ("pagerank_max_limit", 101, "Input should be less than or equal to 100"), + ], + ids=["min_limit", "max_limit"], + ) + def test_pagerank_invalid(self, get_http_api_auth, name, pagerank, expected_message): + payload = {"name": name, "pagerank": pagerank} + res = create_dataset(get_http_api_auth, payload) + assert res["code"] == 101, res + assert expected_message in res["message"], res + + @pytest.mark.p3 + def test_pagerank_unset(self, get_http_api_auth): + payload = {"name": "pagerank_unset"} + res = create_dataset(get_http_api_auth, payload) + assert res["code"] == 0, res + assert res["data"]["pagerank"] == 0, res + + @pytest.mark.p3 + def test_pagerank_none(self, get_http_api_auth): + payload = {"name": "pagerank_unset", "pagerank": None} + res = create_dataset(get_http_api_auth, payload) + assert res["code"] == 101, res + assert "Input should be a valid integer" in res["message"], res + @pytest.mark.p1 @pytest.mark.parametrize( "name, parser_config", [ - ("default_none", None), - ("default_empty", {}), ("auto_keywords_min", {"auto_keywords": 0}), ("auto_keywords_mid", {"auto_keywords": 16}), ("auto_keywords_max", {"auto_keywords": 32}), @@ -363,7 +464,7 @@ class TestDatasetCreation: ("task_page_size_min", {"task_page_size": 1}), ("task_page_size_None", {"task_page_size": None}), ("pages", {"pages": [[1, 100]]}), - ("pages_none", None), + ("pages_none", {"pages": None}), ("graphrag_true", {"graphrag": {"use_graphrag": True}}), ("graphrag_false", {"graphrag": {"use_graphrag": False}}), ("graphrag_entity_types", {"graphrag": {"entity_types": ["age", "sex", "height", "weight"]}}), @@ -388,8 +489,6 @@ class TestDatasetCreation: ("raptor_random_seed_min", {"raptor": {"random_seed": 0}}), ], ids=[ - "default_none", - "default_empty", "auto_keywords_min", "auto_keywords_mid", "auto_keywords_max", @@ -440,44 +539,16 @@ class TestDatasetCreation: "raptor_random_seed_min", ], ) - def test_valid_parser_config(self, get_http_api_auth, name, parser_config): - if parser_config is None: - payload = {"name": name} - else: - payload = {"name": name, "parser_config": parser_config} + def test_parser_config(self, get_http_api_auth, name, parser_config): + payload = {"name": name, "parser_config": parser_config} res = create_dataset(get_http_api_auth, payload) assert res["code"] == 0, res - if parser_config is None: - assert res["data"]["parser_config"] == { - "chunk_token_num": 128, - "delimiter": r"\n", - "html4excel": False, - "layout_recognize": "DeepDOC", - "raptor": {"use_raptor": False}, - } - elif parser_config == {}: - assert res["data"]["parser_config"] == { - "auto_keywords": 0, - "auto_questions": 0, - "chunk_token_num": 128, - "delimiter": r"\n", - "filename_embd_weight": None, - "graphrag": None, - "html4excel": False, - "layout_recognize": "DeepDOC", - "pages": None, - "raptor": None, - "tag_kb_ids": [], - "task_page_size": None, - "topn_tags": 1, - } - else: - for k, v in parser_config.items(): - if isinstance(v, dict): - for kk, vv in v.items(): - assert res["data"]["parser_config"][k][kk] == vv - else: - assert res["data"]["parser_config"][k] == v + for k, v in parser_config.items(): + if isinstance(v, dict): + for kk, vv in v.items(): + assert res["data"]["parser_config"][k][kk] == vv, res + else: + assert res["data"]["parser_config"][k] == v, res @pytest.mark.p2 @pytest.mark.parametrize( @@ -595,15 +666,72 @@ class TestDatasetCreation: "parser_config_type_invalid", ], ) - def test_invalid_parser_config(self, get_http_api_auth, name, parser_config, expected_message): + def test_parser_config_invalid(self, get_http_api_auth, name, parser_config, expected_message): payload = {"name": name, "parser_config": parser_config} res = create_dataset(get_http_api_auth, payload) assert res["code"] == 101, res assert expected_message in res["message"], res + @pytest.mark.p2 + def test_parser_config_empty(self, get_http_api_auth): + payload = {"name": "default_empty", "parser_config": {}} + res = create_dataset(get_http_api_auth, payload) + assert res["code"] == 0, res + assert res["data"]["parser_config"] == { + "auto_keywords": 0, + "auto_questions": 0, + "chunk_token_num": 128, + "delimiter": r"\n", + "filename_embd_weight": None, + "graphrag": None, + "html4excel": False, + "layout_recognize": "DeepDOC", + "pages": None, + "raptor": None, + "tag_kb_ids": [], + "task_page_size": None, + "topn_tags": 1, + } + + @pytest.mark.p2 + def test_parser_config_unset(self, get_http_api_auth): + payload = {"name": "default_unset"} + res = create_dataset(get_http_api_auth, payload) + assert res["code"] == 0, res + assert res["data"]["parser_config"] == { + "chunk_token_num": 128, + "delimiter": r"\n", + "html4excel": False, + "layout_recognize": "DeepDOC", + "raptor": {"use_raptor": False}, + }, res + @pytest.mark.p3 - def test_dataset_10k(self, get_http_api_auth): - for i in range(10_000): - payload = {"name": f"dataset_{i}"} - res = create_dataset(get_http_api_auth, payload) - assert res["code"] == 0, f"Failed to create dataset {i}" + def test_parser_config_none(self, get_http_api_auth): + payload = {"name": "default_none", "parser_config": None} + res = create_dataset(get_http_api_auth, payload) + assert res["code"] == 101, res + assert "Input should be a valid dictionary or instance of ParserConfig" in res["message"], res + + @pytest.mark.p2 + @pytest.mark.parametrize( + "payload", + [ + {"name": "id", "id": "id"}, + {"name": "tenant_id", "tenant_id": "e57c1966f99211efb41e9e45646e0111"}, + {"name": "created_by", "created_by": "created_by"}, + {"name": "create_date", "create_date": "Tue, 11 Mar 2025 13:37:23 GMT"}, + {"name": "create_time", "create_time": 1741671443322}, + {"name": "update_date", "update_date": "Tue, 11 Mar 2025 13:37:23 GMT"}, + {"name": "update_time", "update_time": 1741671443339}, + {"name": "document_count", "document_count": 1}, + {"name": "chunk_count", "chunk_count": 1}, + {"name": "token_num", "token_num": 1}, + {"name": "status", "status": "1"}, + {"name": "unknown_field", "unknown_field": "unknown_field"}, + ], + ) + def test_unsupported_field(self, get_http_api_auth, payload): + res = create_dataset(get_http_api_auth, payload) + assert res["code"] == 101, res + assert "Extra inputs are not permitted" in res["message"], res diff --git a/sdk/python/test/test_http_api/test_dataset_mangement/test_update_dataset.py b/sdk/python/test/test_http_api/test_dataset_mangement/test_update_dataset.py index ea695de98..07485609d 100644 --- a/sdk/python/test/test_http_api/test_dataset_mangement/test_update_dataset.py +++ b/sdk/python/test/test_http_api/test_dataset_mangement/test_update_dataset.py @@ -16,21 +16,18 @@ from concurrent.futures import ThreadPoolExecutor import pytest -from common import ( - DATASET_NAME_LIMIT, - INVALID_API_TOKEN, - list_datasets, - update_dataset, -) +from common import DATASET_NAME_LIMIT, INVALID_API_TOKEN, list_datasets, update_dataset +from hypothesis import HealthCheck, example, given, settings from libs.auth import RAGFlowHttpApiAuth from libs.utils import encode_avatar from libs.utils.file_utils import create_image_file +from libs.utils.hypothesis_utils import valid_names # TODO: Missing scenario for updating embedding_model with chunk_count != 0 -@pytest.mark.p1 class TestAuthorization: + @pytest.mark.p1 @pytest.mark.parametrize( "auth, expected_code, expected_message", [ @@ -41,111 +38,178 @@ class TestAuthorization: "Authentication error: API key is invalid!", ), ], + ids=["empty_auth", "invalid_api_token"], ) - def test_invalid_auth(self, auth, expected_code, expected_message): + def test_auth_invalid(self, auth, expected_code, expected_message): res = update_dataset(auth, "dataset_id") - assert res["code"] == expected_code - assert res["message"] == expected_message + assert res["code"] == expected_code, res + assert res["message"] == expected_message, res + + +class TestRquest: + @pytest.mark.p3 + def test_bad_content_type(self, get_http_api_auth, add_dataset_func): + dataset_id = add_dataset_func + BAD_CONTENT_TYPE = "text/xml" + res = update_dataset(get_http_api_auth, dataset_id, {"name": "bad_content_type"}, headers={"Content-Type": BAD_CONTENT_TYPE}) + assert res["code"] == 101, res + assert res["message"] == f"Unsupported content type: Expected application/json, got {BAD_CONTENT_TYPE}", res + + @pytest.mark.p3 + @pytest.mark.parametrize( + "payload, expected_message", + [ + ("a", "Malformed JSON syntax: Missing commas/brackets or invalid encoding"), + ('"a"', "Invalid request payload: expected object, got str"), + ], + ids=["malformed_json_syntax", "invalid_request_payload_type"], + ) + def test_payload_bad(self, get_http_api_auth, add_dataset_func, payload, expected_message): + dataset_id = add_dataset_func + res = update_dataset(get_http_api_auth, dataset_id, data=payload) + assert res["code"] == 101, res + assert res["message"] == expected_message, res + + @pytest.mark.p2 + def test_payload_empty(self, get_http_api_auth, add_dataset_func): + dataset_id = add_dataset_func + res = update_dataset(get_http_api_auth, dataset_id, {}) + assert res["code"] == 101, res + assert res["message"] == "No properties were modified", res + + +class TestCapability: + @pytest.mark.p3 + def test_update_dateset_concurrent(self, get_http_api_auth, add_dataset_func): + dataset_id = add_dataset_func + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [executor.submit(update_dataset, get_http_api_auth, dataset_id, {"name": f"dataset_{i}"}) for i in range(100)] + responses = [f.result() for f in futures] + assert all(r["code"] == 0 for r in responses), responses -@pytest.mark.p1 class TestDatasetUpdate: - @pytest.mark.parametrize( - "name, expected_code, expected_message", - [ - ("valid_name", 0, ""), - ( - "a" * (DATASET_NAME_LIMIT + 1), - 102, - "Dataset name should not be longer than 128 characters.", - ), - (0, 100, """AttributeError("\'int\' object has no attribute \'strip\'")"""), - ( - None, - 100, - """AttributeError("\'NoneType\' object has no attribute \'strip\'")""", - ), - pytest.param("", 102, "", marks=pytest.mark.skip(reason="issue/5915")), - ("dataset_1", 102, "Duplicated dataset name in updating dataset."), - ("DATASET_1", 102, "Duplicated dataset name in updating dataset."), - ], - ) - def test_name(self, get_http_api_auth, add_datasets_func, name, expected_code, expected_message): - dataset_ids = add_datasets_func - res = update_dataset(get_http_api_auth, dataset_ids[0], {"name": name}) - assert res["code"] == expected_code - if expected_code == 0: - res = list_datasets(get_http_api_auth, {"id": dataset_ids[0]}) - assert res["data"][0]["name"] == name - else: - assert res["message"] == expected_message + @pytest.mark.p3 + def test_dataset_id_not_uuid(self, get_http_api_auth): + payload = {"name": "dataset_id_not_uuid"} + res = update_dataset(get_http_api_auth, "not_uuid", payload) + assert res["code"] == 101, res + assert "Input should be a valid UUID" in res["message"], res - @pytest.mark.parametrize( - "embedding_model, expected_code, expected_message", - [ - ("BAAI/bge-large-zh-v1.5", 0, ""), - ("maidalun1020/bce-embedding-base_v1", 0, ""), - ( - "other_embedding_model", - 102, - "`embedding_model` other_embedding_model doesn't exist", - ), - (None, 102, "`embedding_model` can't be empty"), - ], - ) - def test_embedding_model(self, get_http_api_auth, add_dataset_func, embedding_model, expected_code, expected_message): + @pytest.mark.p3 + def test_dataset_id_wrong_uuid(self, get_http_api_auth): + payload = {"name": "wrong_uuid"} + res = update_dataset(get_http_api_auth, "d94a8dc02c9711f0930f7fbc369eab6d", payload) + assert res["code"] == 102, res + assert "lacks permission for dataset" in res["message"], res + + @pytest.mark.p1 + @given(name=valid_names()) + @example("a" * 128) + @settings(max_examples=20, suppress_health_check=[HealthCheck.function_scoped_fixture]) + def test_name(self, get_http_api_auth, add_dataset_func, name): dataset_id = add_dataset_func - res = update_dataset(get_http_api_auth, dataset_id, {"embedding_model": embedding_model}) - assert res["code"] == expected_code - if expected_code == 0: - res = list_datasets(get_http_api_auth, {"id": dataset_id}) - assert res["data"][0]["embedding_model"] == embedding_model - else: - assert res["message"] == expected_message + payload = {"name": name} + res = update_dataset(get_http_api_auth, dataset_id, payload) + assert res["code"] == 0, res + res = list_datasets(get_http_api_auth) + assert res["code"] == 0, res + assert res["data"][0]["name"] == name, res + + @pytest.mark.p2 @pytest.mark.parametrize( - "chunk_method, expected_code, expected_message", + "name, expected_message", [ - ("naive", 0, ""), - ("manual", 0, ""), - ("qa", 0, ""), - ("table", 0, ""), - ("paper", 0, ""), - ("book", 0, ""), - ("laws", 0, ""), - ("presentation", 0, ""), - ("picture", 0, ""), - ("one", 0, ""), - ("email", 0, ""), - ("tag", 0, ""), - ("", 0, ""), - ( - "other_chunk_method", - 102, - "'other_chunk_method' is not in ['naive', 'manual', 'qa', 'table', 'paper', 'book', 'laws', 'presentation', 'picture', 'one', 'email', 'tag']", - ), + ("", "String should have at least 1 character"), + (" ", "String should have at least 1 character"), + ("a" * (DATASET_NAME_LIMIT + 1), "String should have at most 128 characters"), + (0, "Input should be a valid string"), + (None, "Input should be a valid string"), ], + ids=["empty_name", "space_name", "too_long_name", "invalid_name", "None_name"], ) - def test_chunk_method(self, get_http_api_auth, add_dataset_func, chunk_method, expected_code, expected_message): + def test_name_invalid(self, get_http_api_auth, add_dataset_func, name, expected_message): dataset_id = add_dataset_func - res = update_dataset(get_http_api_auth, dataset_id, {"chunk_method": chunk_method}) - assert res["code"] == expected_code - if expected_code == 0: - res = list_datasets(get_http_api_auth, {"id": dataset_id}) - if chunk_method != "": - assert res["data"][0]["chunk_method"] == chunk_method - else: - assert res["data"][0]["chunk_method"] == "naive" - else: - assert res["message"] == expected_message + payload = {"name": name} + res = update_dataset(get_http_api_auth, dataset_id, payload) + assert res["code"] == 101, res + assert expected_message in res["message"], res + @pytest.mark.p3 + def test_name_duplicated(self, get_http_api_auth, add_datasets_func): + dataset_ids = add_datasets_func[0] + name = "dataset_1" + payload = {"name": name} + res = update_dataset(get_http_api_auth, dataset_ids, payload) + assert res["code"] == 102, res + assert res["message"] == f"Dataset name '{name}' already exists", res + + @pytest.mark.p3 + def test_name_case_insensitive(self, get_http_api_auth, add_datasets_func): + dataset_id = add_datasets_func[0] + name = "DATASET_1" + payload = {"name": name} + res = update_dataset(get_http_api_auth, dataset_id, payload) + assert res["code"] == 102, res + assert res["message"] == f"Dataset name '{name}' already exists", res + + @pytest.mark.p2 def test_avatar(self, get_http_api_auth, add_dataset_func, tmp_path): dataset_id = add_dataset_func fn = create_image_file(tmp_path / "ragflow_test.png") - payload = {"avatar": encode_avatar(fn)} + payload = { + "avatar": f"data:image/png;base64,{encode_avatar(fn)}", + } res = update_dataset(get_http_api_auth, dataset_id, payload) - assert res["code"] == 0 + assert res["code"] == 0, res + res = list_datasets(get_http_api_auth) + assert res["code"] == 0, res + assert res["data"][0]["avatar"] == f"data:image/png;base64,{encode_avatar(fn)}", res + + @pytest.mark.p2 + def test_avatar_exceeds_limit_length(self, get_http_api_auth, add_dataset_func): + dataset_id = add_dataset_func + payload = {"avatar": "a" * 65536} + res = update_dataset(get_http_api_auth, dataset_id, payload) + assert res["code"] == 101, res + assert "String should have at most 65535 characters" in res["message"], res + + @pytest.mark.p3 + @pytest.mark.parametrize( + "name, avatar_prefix, expected_message", + [ + ("empty_prefix", "", "Missing MIME prefix. Expected format: data:;base64,"), + ("missing_comma", "data:image/png;base64", "Missing MIME prefix. Expected format: data:;base64,"), + ("unsupported_mine_type", "invalid_mine_prefix:image/png;base64,", "Invalid MIME prefix format. Must start with 'data:'"), + ("invalid_mine_type", "data:unsupported_mine_type;base64,", "Unsupported MIME type. Allowed: ['image/jpeg', 'image/png']"), + ], + ids=["empty_prefix", "missing_comma", "unsupported_mine_type", "invalid_mine_type"], + ) + def test_avatar_invalid_prefix(self, get_http_api_auth, add_dataset_func, tmp_path, name, avatar_prefix, expected_message): + dataset_id = add_dataset_func + fn = create_image_file(tmp_path / "ragflow_test.png") + payload = { + "name": name, + "avatar": f"{avatar_prefix}{encode_avatar(fn)}", + } + res = update_dataset(get_http_api_auth, dataset_id, payload) + assert res["code"] == 101, res + assert expected_message in res["message"], res + + @pytest.mark.p3 + def test_avatar_none(self, get_http_api_auth, add_dataset_func): + dataset_id = add_dataset_func + payload = {"avatar": None} + res = update_dataset(get_http_api_auth, dataset_id, payload) + assert res["code"] == 0, res + + res = list_datasets(get_http_api_auth) + assert res["code"] == 0, res + assert res["data"][0]["avatar"] is None, res + + @pytest.mark.p2 def test_description(self, get_http_api_auth, add_dataset_func): dataset_id = add_dataset_func payload = {"description": "description"} @@ -153,95 +217,533 @@ class TestDatasetUpdate: assert res["code"] == 0 res = list_datasets(get_http_api_auth, {"id": dataset_id}) + assert res["code"] == 0, res assert res["data"][0]["description"] == "description" - def test_pagerank(self, get_http_api_auth, add_dataset_func): + @pytest.mark.p2 + def test_description_exceeds_limit_length(self, get_http_api_auth, add_dataset_func): dataset_id = add_dataset_func - payload = {"pagerank": 1} + payload = {"description": "a" * 65536} res = update_dataset(get_http_api_auth, dataset_id, payload) - assert res["code"] == 0 + assert res["code"] == 101, res + assert "String should have at most 65535 characters" in res["message"], res + + @pytest.mark.p3 + def test_description_none(self, get_http_api_auth, add_dataset_func): + dataset_id = add_dataset_func + payload = {"description": None} + res = update_dataset(get_http_api_auth, dataset_id, payload) + assert res["code"] == 0, res res = list_datasets(get_http_api_auth, {"id": dataset_id}) - assert res["data"][0]["pagerank"] == 1 - - def test_similarity_threshold(self, get_http_api_auth, add_dataset_func): - dataset_id = add_dataset_func - payload = {"similarity_threshold": 1} - res = update_dataset(get_http_api_auth, dataset_id, payload) - assert res["code"] == 0 - - res = list_datasets(get_http_api_auth, {"id": dataset_id}) - assert res["data"][0]["similarity_threshold"] == 1 + assert res["code"] == 0, res + assert res["data"][0]["description"] is None + @pytest.mark.p1 @pytest.mark.parametrize( - "permission, expected_code", + "embedding_model", [ - ("me", 0), - ("team", 0), - ("", 0), - ("ME", 102), - ("TEAM", 102), - ("other_permission", 102), + "BAAI/bge-large-zh-v1.5@BAAI", + "maidalun1020/bce-embedding-base_v1@Youdao", + "embedding-3@ZHIPU-AI", ], + ids=["builtin_baai", "builtin_youdao", "tenant_zhipu"], ) - def test_permission(self, get_http_api_auth, add_dataset_func, permission, expected_code): + def test_embedding_model(self, get_http_api_auth, add_dataset_func, embedding_model): + dataset_id = add_dataset_func + payload = {"embedding_model": embedding_model} + res = update_dataset(get_http_api_auth, dataset_id, payload) + assert res["code"] == 0, res + + res = list_datasets(get_http_api_auth) + assert res["code"] == 0, res + assert res["data"][0]["embedding_model"] == embedding_model, res + + @pytest.mark.p2 + @pytest.mark.parametrize( + "name, embedding_model", + [ + ("unknown_llm_name", "unknown@ZHIPU-AI"), + ("unknown_llm_factory", "embedding-3@unknown"), + ("tenant_no_auth_default_tenant_llm", "text-embedding-v3@Tongyi-Qianwen"), + ("tenant_no_auth", "text-embedding-3-small@OpenAI"), + ], + ids=["unknown_llm_name", "unknown_llm_factory", "tenant_no_auth_default_tenant_llm", "tenant_no_auth"], + ) + def test_embedding_model_invalid(self, get_http_api_auth, add_dataset_func, name, embedding_model): + dataset_id = add_dataset_func + payload = {"name": name, "embedding_model": embedding_model} + res = update_dataset(get_http_api_auth, dataset_id, payload) + assert res["code"] == 101, res + if "tenant_no_auth" in name: + assert res["message"] == f"Unauthorized model: <{embedding_model}>", res + else: + assert res["message"] == f"Unsupported model: <{embedding_model}>", res + + @pytest.mark.p2 + @pytest.mark.parametrize( + "name, embedding_model", + [ + ("missing_at", "BAAI/bge-large-zh-v1.5BAAI"), + ("missing_model_name", "@BAAI"), + ("missing_provider", "BAAI/bge-large-zh-v1.5@"), + ("whitespace_only_model_name", " @BAAI"), + ("whitespace_only_provider", "BAAI/bge-large-zh-v1.5@ "), + ], + ids=["missing_at", "empty_model_name", "empty_provider", "whitespace_only_model_name", "whitespace_only_provider"], + ) + def test_embedding_model_format(self, get_http_api_auth, add_dataset_func, name, embedding_model): + dataset_id = add_dataset_func + payload = {"name": name, "embedding_model": embedding_model} + res = update_dataset(get_http_api_auth, dataset_id, payload) + assert res["code"] == 101, res + if name == "missing_at": + assert "Embedding model identifier must follow @ format" in res["message"], res + else: + assert "Both model_name and provider must be non-empty strings" in res["message"], res + + @pytest.mark.p2 + def test_embedding_model_none(self, get_http_api_auth, add_dataset_func): + dataset_id = add_dataset_func + payload = {"embedding_model": None} + res = update_dataset(get_http_api_auth, dataset_id, payload) + assert res["code"] == 101, res + assert "Input should be a valid string" in res["message"], res + + @pytest.mark.p1 + @pytest.mark.parametrize( + "name, permission", + [ + ("me", "me"), + ("team", "team"), + ("me_upercase", "ME"), + ("team_upercase", "TEAM"), + ], + ids=["me", "team", "me_upercase", "team_upercase"], + ) + def test_permission(self, get_http_api_auth, add_dataset_func, name, permission): + dataset_id = add_dataset_func + payload = {"name": name, "permission": permission} + res = update_dataset(get_http_api_auth, dataset_id, payload) + assert res["code"] == 0, res + + res = list_datasets(get_http_api_auth) + assert res["code"] == 0, res + assert res["data"][0]["permission"] == permission.lower(), res + + @pytest.mark.p2 + @pytest.mark.parametrize( + "permission", + [ + "", + "unknown", + list(), + ], + ids=["empty", "unknown", "type_error"], + ) + def test_permission_invalid(self, get_http_api_auth, add_dataset_func, permission): dataset_id = add_dataset_func payload = {"permission": permission} res = update_dataset(get_http_api_auth, dataset_id, payload) - assert res["code"] == expected_code + assert res["code"] == 101 + assert "Input should be 'me' or 'team'" in res["message"] - res = list_datasets(get_http_api_auth, {"id": dataset_id}) - if expected_code == 0 and permission != "": - assert res["data"][0]["permission"] == permission - if permission == "": - assert res["data"][0]["permission"] == "me" - - def test_vector_similarity_weight(self, get_http_api_auth, add_dataset_func): + @pytest.mark.p3 + def test_permission_none(self, get_http_api_auth, add_dataset_func): dataset_id = add_dataset_func - payload = {"vector_similarity_weight": 1} + payload = {"name": "test_permission_none", "permission": None} + res = update_dataset(get_http_api_auth, dataset_id, payload) + assert res["code"] == 101, res + assert "Input should be 'me' or 'team'" in res["message"], res + + @pytest.mark.p1 + @pytest.mark.parametrize( + "chunk_method", + [ + "naive", + "book", + "email", + "laws", + "manual", + "one", + "paper", + "picture", + "presentation", + "qa", + "table", + "tag", + ], + ids=["naive", "book", "email", "laws", "manual", "one", "paper", "picture", "presentation", "qa", "table", "tag"], + ) + def test_chunk_method(self, get_http_api_auth, add_dataset_func, chunk_method): + dataset_id = add_dataset_func + payload = {"chunk_method": chunk_method} + res = update_dataset(get_http_api_auth, dataset_id, payload) + assert res["code"] == 0, res + + res = list_datasets(get_http_api_auth) + assert res["code"] == 0, res + assert res["data"][0]["chunk_method"] == chunk_method, res + + @pytest.mark.p2 + @pytest.mark.parametrize( + "chunk_method", + [ + "", + "unknown", + list(), + ], + ids=["empty", "unknown", "type_error"], + ) + def test_chunk_method_invalid(self, get_http_api_auth, add_dataset_func, chunk_method): + dataset_id = add_dataset_func + payload = {"chunk_method": chunk_method} + res = update_dataset(get_http_api_auth, dataset_id, payload) + assert res["code"] == 101, res + assert "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', 'picture', 'presentation', 'qa', 'table' or 'tag'" in res["message"], res + + @pytest.mark.p3 + def test_chunk_method_none(self, get_http_api_auth, add_dataset_func): + dataset_id = add_dataset_func + payload = {"chunk_method": None} + res = update_dataset(get_http_api_auth, dataset_id, payload) + assert res["code"] == 101, res + assert "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', 'picture', 'presentation', 'qa', 'table' or 'tag'" in res["message"], res + + @pytest.mark.p2 + @pytest.mark.parametrize("pagerank", [0, 50, 100], ids=["min", "mid", "max"]) + def test_pagerank(self, get_http_api_auth, add_dataset_func, pagerank): + dataset_id = add_dataset_func + payload = {"pagerank": pagerank} res = update_dataset(get_http_api_auth, dataset_id, payload) assert res["code"] == 0 res = list_datasets(get_http_api_auth, {"id": dataset_id}) - assert res["data"][0]["vector_similarity_weight"] == 1 + assert res["code"] == 0, res + assert res["data"][0]["pagerank"] == pagerank - def test_invalid_dataset_id(self, get_http_api_auth): - res = update_dataset(get_http_api_auth, "invalid_dataset_id", {"name": "invalid_dataset_id"}) - assert res["code"] == 102 - assert res["message"] == "You don't own the dataset" + @pytest.mark.p2 + @pytest.mark.parametrize( + "pagerank, expected_message", + [ + (-1, "Input should be greater than or equal to 0"), + (101, "Input should be less than or equal to 100"), + ], + ids=["min_limit", "max_limit"], + ) + def test_pagerank_invalid(self, get_http_api_auth, add_dataset_func, pagerank, expected_message): + dataset_id = add_dataset_func + payload = {"pagerank": pagerank} + res = update_dataset(get_http_api_auth, dataset_id, payload) + assert res["code"] == 101, res + assert expected_message in res["message"], res + @pytest.mark.p3 + def test_pagerank_none(self, get_http_api_auth, add_dataset_func): + dataset_id = add_dataset_func + payload = {"pagerank": None} + res = update_dataset(get_http_api_auth, dataset_id, payload) + assert res["code"] == 101, res + assert "Input should be a valid integer" in res["message"], res + + @pytest.mark.p1 + @pytest.mark.parametrize( + "parser_config", + [ + {"auto_keywords": 0}, + {"auto_keywords": 16}, + {"auto_keywords": 32}, + {"auto_questions": 0}, + {"auto_questions": 5}, + {"auto_questions": 10}, + {"chunk_token_num": 1}, + {"chunk_token_num": 1024}, + {"chunk_token_num": 2048}, + {"delimiter": "\n"}, + {"delimiter": " "}, + {"html4excel": True}, + {"html4excel": False}, + {"layout_recognize": "DeepDOC"}, + {"layout_recognize": "Plain Text"}, + {"tag_kb_ids": ["1", "2"]}, + {"topn_tags": 1}, + {"topn_tags": 5}, + {"topn_tags": 10}, + {"filename_embd_weight": 0.1}, + {"filename_embd_weight": 0.5}, + {"filename_embd_weight": 1.0}, + {"task_page_size": 1}, + {"task_page_size": None}, + {"pages": [[1, 100]]}, + {"pages": None}, + {"graphrag": {"use_graphrag": True}}, + {"graphrag": {"use_graphrag": False}}, + {"graphrag": {"entity_types": ["age", "sex", "height", "weight"]}}, + {"graphrag": {"method": "general"}}, + {"graphrag": {"method": "light"}}, + {"graphrag": {"community": True}}, + {"graphrag": {"community": False}}, + {"graphrag": {"resolution": True}}, + {"graphrag": {"resolution": False}}, + {"raptor": {"use_raptor": True}}, + {"raptor": {"use_raptor": False}}, + {"raptor": {"prompt": "Who are you?"}}, + {"raptor": {"max_token": 1}}, + {"raptor": {"max_token": 1024}}, + {"raptor": {"max_token": 2048}}, + {"raptor": {"threshold": 0.0}}, + {"raptor": {"threshold": 0.5}}, + {"raptor": {"threshold": 1.0}}, + {"raptor": {"max_cluster": 1}}, + {"raptor": {"max_cluster": 512}}, + {"raptor": {"max_cluster": 1024}}, + {"raptor": {"random_seed": 0}}, + ], + ids=[ + "auto_keywords_min", + "auto_keywords_mid", + "auto_keywords_max", + "auto_questions_min", + "auto_questions_mid", + "auto_questions_max", + "chunk_token_num_min", + "chunk_token_num_mid", + "chunk_token_num_max", + "delimiter", + "delimiter_space", + "html4excel_true", + "html4excel_false", + "layout_recognize_DeepDOC", + "layout_recognize_navie", + "tag_kb_ids", + "topn_tags_min", + "topn_tags_mid", + "topn_tags_max", + "filename_embd_weight_min", + "filename_embd_weight_mid", + "filename_embd_weight_max", + "task_page_size_min", + "task_page_size_None", + "pages", + "pages_none", + "graphrag_true", + "graphrag_false", + "graphrag_entity_types", + "graphrag_method_general", + "graphrag_method_light", + "graphrag_community_true", + "graphrag_community_false", + "graphrag_resolution_true", + "graphrag_resolution_false", + "raptor_true", + "raptor_false", + "raptor_prompt", + "raptor_max_token_min", + "raptor_max_token_mid", + "raptor_max_token_max", + "raptor_threshold_min", + "raptor_threshold_mid", + "raptor_threshold_max", + "raptor_max_cluster_min", + "raptor_max_cluster_mid", + "raptor_max_cluster_max", + "raptor_random_seed_min", + ], + ) + def test_parser_config(self, get_http_api_auth, add_dataset_func, parser_config): + dataset_id = add_dataset_func + payload = {"parser_config": parser_config} + res = update_dataset(get_http_api_auth, dataset_id, payload) + assert res["code"] == 0, res + + res = list_datasets(get_http_api_auth) + assert res["code"] == 0, res + for k, v in parser_config.items(): + if isinstance(v, dict): + for kk, vv in v.items(): + assert res["data"][0]["parser_config"][k][kk] == vv, res + else: + assert res["data"][0]["parser_config"][k] == v, res + + @pytest.mark.p2 + @pytest.mark.parametrize( + "parser_config, expected_message", + [ + ({"auto_keywords": -1}, "Input should be greater than or equal to 0"), + ({"auto_keywords": 33}, "Input should be less than or equal to 32"), + ({"auto_keywords": 3.14}, "Input should be a valid integer, got a number with a fractional part"), + ({"auto_keywords": "string"}, "Input should be a valid integer, unable to parse string as an integer"), + ({"auto_questions": -1}, "Input should be greater than or equal to 0"), + ({"auto_questions": 11}, "Input should be less than or equal to 10"), + ({"auto_questions": 3.14}, "Input should be a valid integer, got a number with a fractional part"), + ({"auto_questions": "string"}, "Input should be a valid integer, unable to parse string as an integer"), + ({"chunk_token_num": 0}, "Input should be greater than or equal to 1"), + ({"chunk_token_num": 2049}, "Input should be less than or equal to 2048"), + ({"chunk_token_num": 3.14}, "Input should be a valid integer, got a number with a fractional part"), + ({"chunk_token_num": "string"}, "Input should be a valid integer, unable to parse string as an integer"), + ({"delimiter": ""}, "String should have at least 1 character"), + ({"html4excel": "string"}, "Input should be a valid boolean, unable to interpret input"), + ({"tag_kb_ids": "1,2"}, "Input should be a valid list"), + ({"tag_kb_ids": [1, 2]}, "Input should be a valid string"), + ({"topn_tags": 0}, "Input should be greater than or equal to 1"), + ({"topn_tags": 11}, "Input should be less than or equal to 10"), + ({"topn_tags": 3.14}, "Input should be a valid integer, got a number with a fractional part"), + ({"topn_tags": "string"}, "Input should be a valid integer, unable to parse string as an integer"), + ({"filename_embd_weight": -1}, "Input should be greater than or equal to 0"), + ({"filename_embd_weight": 1.1}, "Input should be less than or equal to 1"), + ({"filename_embd_weight": "string"}, "Input should be a valid number, unable to parse string as a number"), + ({"task_page_size": 0}, "Input should be greater than or equal to 1"), + ({"task_page_size": 3.14}, "Input should be a valid integer, got a number with a fractional part"), + ({"task_page_size": "string"}, "Input should be a valid integer, unable to parse string as an integer"), + ({"pages": "1,2"}, "Input should be a valid list"), + ({"pages": ["1,2"]}, "Input should be a valid list"), + ({"pages": [["string1", "string2"]]}, "Input should be a valid integer, unable to parse string as an integer"), + ({"graphrag": {"use_graphrag": "string"}}, "Input should be a valid boolean, unable to interpret input"), + ({"graphrag": {"entity_types": "1,2"}}, "Input should be a valid list"), + ({"graphrag": {"entity_types": [1, 2]}}, "nput should be a valid string"), + ({"graphrag": {"method": "unknown"}}, "Input should be 'light' or 'general'"), + ({"graphrag": {"method": None}}, "Input should be 'light' or 'general'"), + ({"graphrag": {"community": "string"}}, "Input should be a valid boolean, unable to interpret input"), + ({"graphrag": {"resolution": "string"}}, "Input should be a valid boolean, unable to interpret input"), + ({"raptor": {"use_raptor": "string"}}, "Input should be a valid boolean, unable to interpret input"), + ({"raptor": {"prompt": ""}}, "String should have at least 1 character"), + ({"raptor": {"prompt": " "}}, "String should have at least 1 character"), + ({"raptor": {"max_token": 0}}, "Input should be greater than or equal to 1"), + ({"raptor": {"max_token": 2049}}, "Input should be less than or equal to 2048"), + ({"raptor": {"max_token": 3.14}}, "Input should be a valid integer, got a number with a fractional part"), + ({"raptor": {"max_token": "string"}}, "Input should be a valid integer, unable to parse string as an integer"), + ({"raptor": {"threshold": -0.1}}, "Input should be greater than or equal to 0"), + ({"raptor": {"threshold": 1.1}}, "Input should be less than or equal to 1"), + ({"raptor": {"threshold": "string"}}, "Input should be a valid number, unable to parse string as a number"), + ({"raptor": {"max_cluster": 0}}, "Input should be greater than or equal to 1"), + ({"raptor": {"max_cluster": 1025}}, "Input should be less than or equal to 1024"), + ({"raptor": {"max_cluster": 3.14}}, "Input should be a valid integer, got a number with a fractional par"), + ({"raptor": {"max_cluster": "string"}}, "Input should be a valid integer, unable to parse string as an integer"), + ({"raptor": {"random_seed": -1}}, "Input should be greater than or equal to 0"), + ({"raptor": {"random_seed": 3.14}}, "Input should be a valid integer, got a number with a fractional part"), + ({"raptor": {"random_seed": "string"}}, "Input should be a valid integer, unable to parse string as an integer"), + ({"delimiter": "a" * 65536}, "Parser config exceeds size limit (max 65,535 characters)"), + ], + ids=[ + "auto_keywords_min_limit", + "auto_keywords_max_limit", + "auto_keywords_float_not_allowed", + "auto_keywords_type_invalid", + "auto_questions_min_limit", + "auto_questions_max_limit", + "auto_questions_float_not_allowed", + "auto_questions_type_invalid", + "chunk_token_num_min_limit", + "chunk_token_num_max_limit", + "chunk_token_num_float_not_allowed", + "chunk_token_num_type_invalid", + "delimiter_empty", + "html4excel_type_invalid", + "tag_kb_ids_not_list", + "tag_kb_ids_int_in_list", + "topn_tags_min_limit", + "topn_tags_max_limit", + "topn_tags_float_not_allowed", + "topn_tags_type_invalid", + "filename_embd_weight_min_limit", + "filename_embd_weight_max_limit", + "filename_embd_weight_type_invalid", + "task_page_size_min_limit", + "task_page_size_float_not_allowed", + "task_page_size_type_invalid", + "pages_not_list", + "pages_not_list_in_list", + "pages_not_int_list", + "graphrag_type_invalid", + "graphrag_entity_types_not_list", + "graphrag_entity_types_not_str_in_list", + "graphrag_method_unknown", + "graphrag_method_none", + "graphrag_community_type_invalid", + "graphrag_resolution_type_invalid", + "raptor_type_invalid", + "raptor_prompt_empty", + "raptor_prompt_space", + "raptor_max_token_min_limit", + "raptor_max_token_max_limit", + "raptor_max_token_float_not_allowed", + "raptor_max_token_type_invalid", + "raptor_threshold_min_limit", + "raptor_threshold_max_limit", + "raptor_threshold_type_invalid", + "raptor_max_cluster_min_limit", + "raptor_max_cluster_max_limit", + "raptor_max_cluster_float_not_allowed", + "raptor_max_cluster_type_invalid", + "raptor_random_seed_min_limit", + "raptor_random_seed_float_not_allowed", + "raptor_random_seed_type_invalid", + "parser_config_type_invalid", + ], + ) + def test_parser_config_invalid(self, get_http_api_auth, add_dataset_func, parser_config, expected_message): + dataset_id = add_dataset_func + payload = {"parser_config": parser_config} + res = update_dataset(get_http_api_auth, dataset_id, payload) + assert res["code"] == 101, res + assert expected_message in res["message"], res + + @pytest.mark.p2 + def test_parser_config_empty(self, get_http_api_auth, add_dataset_func): + dataset_id = add_dataset_func + payload = {"parser_config": {}} + res = update_dataset(get_http_api_auth, dataset_id, payload) + assert res["code"] == 0, res + + res = list_datasets(get_http_api_auth) + assert res["code"] == 0, res + assert res["data"][0]["parser_config"] == {} + + # @pytest.mark.p2 + # def test_parser_config_unset(self, get_http_api_auth, add_dataset_func): + # dataset_id = add_dataset_func + # payload = {"name": "default_unset"} + # res = update_dataset(get_http_api_auth, dataset_id, payload) + # assert res["code"] == 0, res + + # res = list_datasets(get_http_api_auth) + # assert res["code"] == 0, res + # assert res["data"][0]["parser_config"] == { + # "chunk_token_num": 128, + # "delimiter": r"\n", + # "html4excel": False, + # "layout_recognize": "DeepDOC", + # "raptor": {"use_raptor": False}, + # }, res + + @pytest.mark.p3 + def test_parser_config_none(self, get_http_api_auth, add_dataset_func): + dataset_id = add_dataset_func + payload = {"parser_config": None} + res = update_dataset(get_http_api_auth, dataset_id, payload) + assert res["code"] == 101, res + assert "Input should be a valid dictionary or instance of ParserConfig" in res["message"], res + + @pytest.mark.p2 @pytest.mark.parametrize( "payload", [ - {"chunk_count": 1}, + {"id": "id"}, + {"tenant_id": "e57c1966f99211efb41e9e45646e0111"}, + {"created_by": "created_by"}, {"create_date": "Tue, 11 Mar 2025 13:37:23 GMT"}, {"create_time": 1741671443322}, - {"created_by": "aa"}, - {"document_count": 1}, - {"id": "id"}, - {"status": "1"}, - {"tenant_id": "e57c1966f99211efb41e9e45646e0111"}, - {"token_num": 1}, {"update_date": "Tue, 11 Mar 2025 13:37:23 GMT"}, {"update_time": 1741671443339}, + {"document_count": 1}, + {"chunk_count": 1}, + {"token_num": 1}, + {"status": "1"}, + {"unknown_field": "unknown_field"}, ], ) - def test_modify_read_only_field(self, get_http_api_auth, add_dataset_func, payload): + def test_unsupported_field(self, get_http_api_auth, add_dataset_func, payload): dataset_id = add_dataset_func res = update_dataset(get_http_api_auth, dataset_id, payload) - assert res["code"] == 101 - assert "is readonly" in res["message"] - - def test_modify_unknown_field(self, get_http_api_auth, add_dataset_func): - dataset_id = add_dataset_func - res = update_dataset(get_http_api_auth, dataset_id, {"unknown_field": 0}) - assert res["code"] == 100 - - @pytest.mark.p3 - def test_concurrent_update(self, get_http_api_auth, add_dataset_func): - dataset_id = add_dataset_func - - with ThreadPoolExecutor(max_workers=5) as executor: - futures = [executor.submit(update_dataset, get_http_api_auth, dataset_id, {"name": f"dataset_{i}"}) for i in range(100)] - responses = [f.result() for f in futures] - assert all(r["code"] == 0 for r in responses) + assert res["code"] == 101, res + assert "Extra inputs are not permitted" in res["message"], res diff --git a/sdk/python/test/test_http_api/test_file_management_within_dataset/test_stop_parse_documents.py b/sdk/python/test/test_http_api/test_file_management_within_dataset/test_stop_parse_documents.py index adde4c67a..57177590a 100644 --- a/sdk/python/test/test_http_api/test_file_management_within_dataset/test_stop_parse_documents.py +++ b/sdk/python/test/test_http_api/test_file_management_within_dataset/test_stop_parse_documents.py @@ -14,6 +14,7 @@ # limitations under the License. # from concurrent.futures import ThreadPoolExecutor +from time import sleep import pytest from common import INVALID_API_TOKEN, bulk_upload_documents, list_documnets, parse_documnets, stop_parse_documnets @@ -173,6 +174,7 @@ def test_stop_parse_100_files(get_http_api_auth, add_dataset_func, tmp_path): dataset_id = add_dataset_func document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, document_num, tmp_path) parse_documnets(get_http_api_auth, dataset_id, {"document_ids": document_ids}) + sleep(1) res = stop_parse_documnets(get_http_api_auth, dataset_id, {"document_ids": document_ids}) assert res["code"] == 0 validate_document_parse_cancel(get_http_api_auth, dataset_id, document_ids)