refa: Optimize create dataset validation (#7451)

### What problem does this PR solve?

Optimize dataset validation and add function docs

### Type of change

- [x] Refactoring
This commit is contained in:
liu an 2025-05-06 17:38:06 +08:00 committed by GitHub
parent 2f768b96e8
commit c98933499a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 333 additions and 93 deletions

View File

@ -19,7 +19,6 @@ import logging
from flask import request from flask import request
from peewee import OperationalError from peewee import OperationalError
from pydantic import ValidationError
from api import settings from api import settings
from api.db import FileSource, StatusEnum from api.db import FileSource, StatusEnum
@ -41,8 +40,9 @@ from api.utils.api_utils import (
token_required, token_required,
valid, valid,
valid_parser_config, valid_parser_config,
verify_embedding_availability,
) )
from api.utils.validation_utils import CreateDatasetReq, format_validation_error_message from api.utils.validation_utils import CreateDatasetReq, validate_and_parse_json_request
@manager.route("/datasets", methods=["POST"]) # noqa: F821 @manager.route("/datasets", methods=["POST"]) # noqa: F821
@ -107,21 +107,14 @@ def create(tenant_id):
data: data:
type: object type: object
""" """
req_i = request.json
if not isinstance(req_i, dict):
return get_error_argument_result(f"Invalid request payload: expected object, got {type(req_i).__name__}")
try:
req_v = CreateDatasetReq(**req_i)
except ValidationError as e:
return get_error_argument_result(format_validation_error_message(e))
# Field name transformations during model dump: # Field name transformations during model dump:
# | Original | Dump Output | # | Original | Dump Output |
# |----------------|-------------| # |----------------|-------------|
# | embedding_model| embd_id | # | embedding_model| embd_id |
# | chunk_method | parser_id | # | chunk_method | parser_id |
req = req_v.model_dump(by_alias=True) req, err = validate_and_parse_json_request(request, CreateDatasetReq)
if err is not None:
return get_error_argument_result(err)
try: try:
if KnowledgebaseService.query(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value): if KnowledgebaseService.query(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value):
@ -146,21 +139,9 @@ def create(tenant_id):
if not req.get("embd_id"): if not req.get("embd_id"):
req["embd_id"] = t.embd_id req["embd_id"] = t.embd_id
else: else:
builtin_embedding_models = [ ok, err = verify_embedding_availability(req["embd_id"], tenant_id)
"BAAI/bge-large-zh-v1.5@BAAI", if not ok:
"maidalun1020/bce-embedding-base_v1@Youdao", return err
]
is_builtin_model = req["embd_id"] in builtin_embedding_models
try:
# model name must be model_name@model_factory
llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(req["embd_id"])
is_tenant_model = TenantLLMService.query(tenant_id=tenant_id, llm_name=llm_name, llm_factory=llm_factory, model_type="embedding")
is_supported_model = LLMService.query(llm_name=llm_name, fid=llm_factory, model_type="embedding")
if not (is_supported_model and (is_builtin_model or is_tenant_model)):
return get_error_argument_result(f"The embedding_model '{req['embd_id']}' is not supported")
except OperationalError as e:
logging.exception(e)
return get_error_data_result(message="Database operation failed")
try: try:
if not KnowledgebaseService.save(**req): if not KnowledgebaseService.save(**req):

View File

@ -13,22 +13,22 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import json
import os import os
from datetime import date from datetime import date
from enum import IntEnum, Enum from enum import Enum, IntEnum
import json
import rag.utils
import rag.utils.es_conn import rag.utils.es_conn
import rag.utils.infinity_conn import rag.utils.infinity_conn
import rag.utils.opensearch_coon import rag.utils.opensearch_coon
import rag.utils
from rag.nlp import search
from graphrag import search as kg_search
from api.utils import get_base_config, decrypt_database_config
from api.constants import RAG_FLOW_SERVICE_NAME from api.constants import RAG_FLOW_SERVICE_NAME
from api.utils import decrypt_database_config, get_base_config
from api.utils.file_utils import get_project_base_directory from api.utils.file_utils import get_project_base_directory
from graphrag import search as kg_search
from rag.nlp import search
LIGHTEN = int(os.environ.get('LIGHTEN', "0")) LIGHTEN = int(os.environ.get("LIGHTEN", "0"))
LLM = None LLM = None
LLM_FACTORY = None LLM_FACTORY = None
@ -45,7 +45,7 @@ HOST_PORT = None
SECRET_KEY = None SECRET_KEY = None
FACTORY_LLM_INFOS = None FACTORY_LLM_INFOS = None
DATABASE_TYPE = os.getenv("DB_TYPE", 'mysql') DATABASE_TYPE = os.getenv("DB_TYPE", "mysql")
DATABASE = decrypt_database_config(name=DATABASE_TYPE) DATABASE = decrypt_database_config(name=DATABASE_TYPE)
# authentication # authentication
@ -66,11 +66,13 @@ kg_retrievaler = None
# user registration switch # user registration switch
REGISTER_ENABLED = 1 REGISTER_ENABLED = 1
BUILTIN_EMBEDDING_MODELS = ["BAAI/bge-large-zh-v1.5@BAAI", "maidalun1020/bce-embedding-base_v1@Youdao"]
def init_settings(): def init_settings():
global LLM, LLM_FACTORY, LLM_BASE_URL, LIGHTEN, DATABASE_TYPE, DATABASE, FACTORY_LLM_INFOS, REGISTER_ENABLED global LLM, LLM_FACTORY, LLM_BASE_URL, LIGHTEN, DATABASE_TYPE, DATABASE, FACTORY_LLM_INFOS, REGISTER_ENABLED
LIGHTEN = int(os.environ.get('LIGHTEN', "0")) LIGHTEN = int(os.environ.get("LIGHTEN", "0"))
DATABASE_TYPE = os.getenv("DB_TYPE", 'mysql') DATABASE_TYPE = os.getenv("DB_TYPE", "mysql")
DATABASE = decrypt_database_config(name=DATABASE_TYPE) DATABASE = decrypt_database_config(name=DATABASE_TYPE)
LLM = get_base_config("user_default_llm", {}) LLM = get_base_config("user_default_llm", {})
LLM_DEFAULT_MODELS = LLM.get("default_models", {}) LLM_DEFAULT_MODELS = LLM.get("default_models", {})
@ -79,8 +81,8 @@ def init_settings():
try: try:
REGISTER_ENABLED = int(os.environ.get("REGISTER_ENABLED", "1")) REGISTER_ENABLED = int(os.environ.get("REGISTER_ENABLED", "1"))
except Exception: except Exception:
pass pass
try: try:
with open(os.path.join(get_project_base_directory(), "conf", "llm_factories.json"), "r") as f: with open(os.path.join(get_project_base_directory(), "conf", "llm_factories.json"), "r") as f:
FACTORY_LLM_INFOS = json.load(f)["factory_llm_infos"] FACTORY_LLM_INFOS = json.load(f)["factory_llm_infos"]
@ -89,7 +91,7 @@ def init_settings():
global CHAT_MDL, EMBEDDING_MDL, RERANK_MDL, ASR_MDL, IMAGE2TEXT_MDL global CHAT_MDL, EMBEDDING_MDL, RERANK_MDL, ASR_MDL, IMAGE2TEXT_MDL
if not LIGHTEN: if not LIGHTEN:
EMBEDDING_MDL = "BAAI/bge-large-zh-v1.5@BAAI" EMBEDDING_MDL = BUILTIN_EMBEDDING_MODELS[0]
if LLM_DEFAULT_MODELS: if LLM_DEFAULT_MODELS:
CHAT_MDL = LLM_DEFAULT_MODELS.get("chat_model", CHAT_MDL) CHAT_MDL = LLM_DEFAULT_MODELS.get("chat_model", CHAT_MDL)
@ -103,30 +105,25 @@ def init_settings():
EMBEDDING_MDL = EMBEDDING_MDL + (f"@{LLM_FACTORY}" if "@" not in EMBEDDING_MDL and EMBEDDING_MDL != "" else "") EMBEDDING_MDL = EMBEDDING_MDL + (f"@{LLM_FACTORY}" if "@" not in EMBEDDING_MDL and EMBEDDING_MDL != "" else "")
RERANK_MDL = RERANK_MDL + (f"@{LLM_FACTORY}" if "@" not in RERANK_MDL and RERANK_MDL != "" else "") RERANK_MDL = RERANK_MDL + (f"@{LLM_FACTORY}" if "@" not in RERANK_MDL and RERANK_MDL != "" else "")
ASR_MDL = ASR_MDL + (f"@{LLM_FACTORY}" if "@" not in ASR_MDL and ASR_MDL != "" else "") ASR_MDL = ASR_MDL + (f"@{LLM_FACTORY}" if "@" not in ASR_MDL and ASR_MDL != "" else "")
IMAGE2TEXT_MDL = IMAGE2TEXT_MDL + ( IMAGE2TEXT_MDL = IMAGE2TEXT_MDL + (f"@{LLM_FACTORY}" if "@" not in IMAGE2TEXT_MDL and IMAGE2TEXT_MDL != "" else "")
f"@{LLM_FACTORY}" if "@" not in IMAGE2TEXT_MDL and IMAGE2TEXT_MDL != "" else "")
global API_KEY, PARSERS, HOST_IP, HOST_PORT, SECRET_KEY global API_KEY, PARSERS, HOST_IP, HOST_PORT, SECRET_KEY
API_KEY = LLM.get("api_key", "") API_KEY = LLM.get("api_key")
PARSERS = LLM.get( PARSERS = LLM.get(
"parsers", "parsers", "naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,email:Email,tag:Tag"
"naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,email:Email,tag:Tag") )
HOST_IP = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("host", "127.0.0.1") HOST_IP = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("host", "127.0.0.1")
HOST_PORT = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("http_port") HOST_PORT = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("http_port")
SECRET_KEY = get_base_config( SECRET_KEY = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("secret_key", str(date.today()))
RAG_FLOW_SERVICE_NAME,
{}).get("secret_key", str(date.today()))
global AUTHENTICATION_CONF, CLIENT_AUTHENTICATION, HTTP_APP_KEY, GITHUB_OAUTH, FEISHU_OAUTH, OAUTH_CONFIG global AUTHENTICATION_CONF, CLIENT_AUTHENTICATION, HTTP_APP_KEY, GITHUB_OAUTH, FEISHU_OAUTH, OAUTH_CONFIG
# authentication # authentication
AUTHENTICATION_CONF = get_base_config("authentication", {}) AUTHENTICATION_CONF = get_base_config("authentication", {})
# client # client
CLIENT_AUTHENTICATION = AUTHENTICATION_CONF.get( CLIENT_AUTHENTICATION = AUTHENTICATION_CONF.get("client", {}).get("switch", False)
"client", {}).get(
"switch", False)
HTTP_APP_KEY = AUTHENTICATION_CONF.get("client", {}).get("http_app_key") HTTP_APP_KEY = AUTHENTICATION_CONF.get("client", {}).get("http_app_key")
GITHUB_OAUTH = get_base_config("oauth", {}).get("github") GITHUB_OAUTH = get_base_config("oauth", {}).get("github")
FEISHU_OAUTH = get_base_config("oauth", {}).get("feishu") FEISHU_OAUTH = get_base_config("oauth", {}).get("feishu")
@ -134,7 +131,7 @@ def init_settings():
OAUTH_CONFIG = get_base_config("oauth", {}) OAUTH_CONFIG = get_base_config("oauth", {})
global DOC_ENGINE, docStoreConn, retrievaler, kg_retrievaler global DOC_ENGINE, docStoreConn, retrievaler, kg_retrievaler
DOC_ENGINE = os.environ.get('DOC_ENGINE', "elasticsearch") DOC_ENGINE = os.environ.get("DOC_ENGINE", "elasticsearch")
# DOC_ENGINE = os.environ.get('DOC_ENGINE', "opensearch") # DOC_ENGINE = os.environ.get('DOC_ENGINE', "opensearch")
lower_case_doc_engine = DOC_ENGINE.lower() lower_case_doc_engine = DOC_ENGINE.lower()
if lower_case_doc_engine == "elasticsearch": if lower_case_doc_engine == "elasticsearch":

View File

@ -36,11 +36,13 @@ from flask import (
request as flask_request, request as flask_request,
) )
from itsdangerous import URLSafeTimedSerializer from itsdangerous import URLSafeTimedSerializer
from peewee import OperationalError
from werkzeug.http import HTTP_STATUS_CODES from werkzeug.http import HTTP_STATUS_CODES
from api import settings from api import settings
from api.constants import REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC from api.constants import REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC
from api.db.db_models import APIToken from api.db.db_models import APIToken
from api.db.services.llm_service import LLMService, TenantLLMService
from api.utils import CustomJSONEncoder, get_uuid, json_dumps from api.utils import CustomJSONEncoder, get_uuid, json_dumps
requests.models.complexjson.dumps = functools.partial(json.dumps, cls=CustomJSONEncoder) requests.models.complexjson.dumps = functools.partial(json.dumps, cls=CustomJSONEncoder)
@ -464,3 +466,55 @@ def check_duplicate_ids(ids, id_type="item"):
# Return unique IDs and error messages # Return unique IDs and error messages
return list(set(ids)), duplicate_messages return list(set(ids)), duplicate_messages
def verify_embedding_availability(embd_id: str, tenant_id: str) -> tuple[bool, Response | None]:
"""Verifies availability of an embedding model for a specific tenant.
Implements a four-stage validation process:
1. Model identifier parsing and validation
2. System support verification
3. Tenant authorization check
4. Database operation error handling
Args:
embd_id (str): Unique identifier for the embedding model in format "model_name@factory"
tenant_id (str): Tenant identifier for access control
Returns:
tuple[bool, Response | None]:
- First element (bool):
- True: Model is available and authorized
- False: Validation failed
- Second element contains:
- None on success
- Error detail dict on failure
Raises:
ValueError: When model identifier format is invalid
OperationalError: When database connection fails (auto-handled)
Examples:
>>> verify_embedding_availability("text-embedding@openai", "tenant_123")
(True, None)
>>> verify_embedding_availability("invalid_model", "tenant_123")
(False, {'code': 101, 'message': "Unsupported model: <invalid_model>"})
"""
try:
llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(embd_id)
if not LLMService.query(llm_name=llm_name, fid=llm_factory, model_type="embedding"):
return False, get_error_argument_result(f"Unsupported model: <{embd_id}>")
# Tongyi-Qianwen is added to TenantLLM by default, but remains unusable with empty api_key
tenant_llms = TenantLLMService.get_my_llms(tenant_id=tenant_id)
is_tenant_model = any(llm["llm_name"] == llm_name and llm["llm_factory"] == llm_factory and llm["model_type"] == "embedding" for llm in tenant_llms)
is_builtin_model = embd_id in settings.BUILTIN_EMBEDDING_MODELS
if not (is_builtin_model or is_tenant_model):
return False, get_error_argument_result(f"Unauthorized model: <{embd_id}>")
except OperationalError as e:
logging.exception(e)
return False, get_error_data_result(message="Database operation failed")
return True, None

View File

@ -14,13 +14,102 @@
# limitations under the License. # limitations under the License.
# #
from enum import auto from enum import auto
from typing import Annotated, List, Optional from typing import Annotated, Any
from flask import Request
from pydantic import BaseModel, Field, StringConstraints, ValidationError, field_validator from pydantic import BaseModel, Field, StringConstraints, ValidationError, field_validator
from strenum import StrEnum from strenum import StrEnum
from werkzeug.exceptions import BadRequest, UnsupportedMediaType
from api.constants import DATASET_NAME_LIMIT
def validate_and_parse_json_request(request: Request, validator: type[BaseModel]) -> tuple[dict[str, Any] | None, str | None]:
"""Validates and parses JSON requests through a multi-stage validation pipeline.
Implements a robust four-stage validation process:
1. Content-Type verification (must be application/json)
2. JSON syntax validation
3. Payload structure type checking
4. Pydantic model validation with error formatting
Args:
request (Request): Flask request object containing HTTP payload
Returns:
tuple[Dict[str, Any] | None, str | None]:
- First element:
- Validated dictionary on success
- None on validation failure
- Second element:
- None on success
- Diagnostic error message on failure
Raises:
UnsupportedMediaType: When Content-Type application/json
BadRequest: For structural JSON syntax errors
ValidationError: When payload violates Pydantic schema rules
Examples:
Successful validation:
```python
# Input: {"name": "Dataset1", "format": "csv"}
# Returns: ({"name": "Dataset1", "format": "csv"}, None)
```
Invalid Content-Type:
```python
# Returns: (None, "Unsupported content type: Expected application/json, got text/xml")
```
Malformed JSON:
```python
# Returns: (None, "Malformed JSON syntax: Missing commas/brackets or invalid encoding")
```
"""
try:
payload = request.get_json() or {}
except UnsupportedMediaType:
return None, f"Unsupported content type: Expected application/json, got {request.content_type}"
except BadRequest:
return None, "Malformed JSON syntax: Missing commas/brackets or invalid encoding"
if not isinstance(payload, dict):
return None, f"Invalid request payload: expected object, got {type(payload).__name__}"
try:
validated_request = validator(**payload)
except ValidationError as e:
return None, format_validation_error_message(e)
parsed_payload = validated_request.model_dump(by_alias=True)
return parsed_payload, None
def format_validation_error_message(e: ValidationError) -> str: def format_validation_error_message(e: ValidationError) -> str:
"""Formats validation errors into a standardized string format.
Processes pydantic ValidationError objects to create human-readable error messages
containing field locations, error descriptions, and input values.
Args:
e (ValidationError): The validation error instance containing error details
Returns:
str: Formatted error messages joined by newlines. Each line contains:
- Field path (dot-separated)
- Error message
- Truncated input value (max 128 chars)
Example:
>>> try:
... UserModel(name=123, email="invalid")
... except ValidationError as e:
... print(format_validation_error_message(e))
Field: <name> - Message: <Input should be a valid string> - Value: <123>
Field: <email> - Message: <value is not a valid email address> - Value: <invalid>
"""
error_messages = [] error_messages = []
for error in e.errors(): for error in e.errors():
@ -86,7 +175,7 @@ class RaptorConfig(Base):
class GraphragConfig(Base): class GraphragConfig(Base):
use_graphrag: bool = Field(default=False) use_graphrag: bool = Field(default=False)
entity_types: List[str] = Field(default_factory=lambda: ["organization", "person", "geo", "event", "category"]) entity_types: list[str] = Field(default_factory=lambda: ["organization", "person", "geo", "event", "category"])
method: GraphragMethodEnum = Field(default=GraphragMethodEnum.light) method: GraphragMethodEnum = Field(default=GraphragMethodEnum.light)
community: bool = Field(default=False) community: bool = Field(default=False)
resolution: bool = Field(default=False) resolution: bool = Field(default=False)
@ -97,30 +186,59 @@ class ParserConfig(Base):
auto_questions: int = Field(default=0, ge=0, le=10) auto_questions: int = Field(default=0, ge=0, le=10)
chunk_token_num: int = Field(default=128, ge=1, le=2048) chunk_token_num: int = Field(default=128, ge=1, le=2048)
delimiter: str = Field(default=r"\n", min_length=1) delimiter: str = Field(default=r"\n", min_length=1)
graphrag: Optional[GraphragConfig] = None graphrag: GraphragConfig | None = None
html4excel: bool = False html4excel: bool = False
layout_recognize: str = "DeepDOC" layout_recognize: str = "DeepDOC"
raptor: Optional[RaptorConfig] = None raptor: RaptorConfig | None = None
tag_kb_ids: List[str] = Field(default_factory=list) tag_kb_ids: list[str] = Field(default_factory=list)
topn_tags: int = Field(default=1, ge=1, le=10) topn_tags: int = Field(default=1, ge=1, le=10)
filename_embd_weight: Optional[float] = Field(default=None, ge=0.0, le=1.0) filename_embd_weight: float | None = Field(default=None, ge=0.0, le=1.0)
task_page_size: Optional[int] = Field(default=None, ge=1) task_page_size: int | None = Field(default=None, ge=1)
pages: Optional[List[List[int]]] = None pages: list[list[int]] | None = None
class CreateDatasetReq(Base): class CreateDatasetReq(Base):
name: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1, max_length=128), Field(...)] name: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1, max_length=DATASET_NAME_LIMIT), Field(...)]
avatar: Optional[str] = Field(default=None, max_length=65535) avatar: str | None = Field(default=None, max_length=65535)
description: Optional[str] = Field(default=None, max_length=65535) description: str | None = Field(default=None, max_length=65535)
embedding_model: Annotated[Optional[str], StringConstraints(strip_whitespace=True, max_length=255), Field(default=None, serialization_alias="embd_id")] embedding_model: Annotated[str | None, StringConstraints(strip_whitespace=True, max_length=255), Field(default=None, serialization_alias="embd_id")]
permission: Annotated[PermissionEnum, StringConstraints(strip_whitespace=True, min_length=1, max_length=16), Field(default=PermissionEnum.me)] permission: Annotated[PermissionEnum, StringConstraints(strip_whitespace=True, min_length=1, max_length=16), Field(default=PermissionEnum.me)]
chunk_method: Annotated[ChunkMethodnEnum, StringConstraints(strip_whitespace=True, min_length=1, max_length=32), Field(default=ChunkMethodnEnum.naive, serialization_alias="parser_id")] chunk_method: Annotated[ChunkMethodnEnum, StringConstraints(strip_whitespace=True, min_length=1, max_length=32), Field(default=ChunkMethodnEnum.naive, serialization_alias="parser_id")]
pagerank: int = Field(default=0, ge=0, le=100) pagerank: int = Field(default=0, ge=0, le=100)
parser_config: Optional[ParserConfig] = Field(default=None) parser_config: ParserConfig | None = Field(default=None)
@field_validator("avatar") @field_validator("avatar")
@classmethod @classmethod
def validate_avatar_base64(cls, v: str) -> str: def validate_avatar_base64(cls, v: str) -> str:
"""Validates Base64-encoded avatar string format and MIME type compliance.
Implements a three-stage validation workflow:
1. MIME prefix existence check
2. MIME type format validation
3. Supported type verification
Args:
v (str): Raw avatar field value
Returns:
str: Validated Base64 string
Raises:
ValueError: For structural errors in these cases:
- Missing MIME prefix header
- Invalid MIME prefix format
- Unsupported image MIME type
Example:
```python
# Valid case
CreateDatasetReq(avatar="data:image/png;base64,iVBORw0KGg...")
# Invalid cases
CreateDatasetReq(avatar="image/jpeg;base64,...") # Missing 'data:' prefix
CreateDatasetReq(avatar="data:video/mp4;base64,...") # Unsupported MIME type
```
"""
if v is None: if v is None:
return v return v
@ -141,22 +259,83 @@ class CreateDatasetReq(Base):
@field_validator("embedding_model", mode="after") @field_validator("embedding_model", mode="after")
@classmethod @classmethod
def validate_embedding_model(cls, v: str) -> str: def validate_embedding_model(cls, v: str) -> str:
"""Validates embedding model identifier format compliance.
Validation pipeline:
1. Structural format verification
2. Component non-empty check
3. Value normalization
Args:
v (str): Raw model identifier
Returns:
str: Validated <model_name>@<provider> format
Raises:
ValueError: For these violations:
- Missing @ separator
- Empty model_name/provider
- Invalid component structure
Examples:
Valid: "text-embedding-3-large@openai"
Invalid: "invalid_model" (no @)
Invalid: "@openai" (empty model_name)
Invalid: "text-embedding-3-large@" (empty provider)
"""
if "@" not in v: if "@" not in v:
raise ValueError("Embedding model must be xxx@yyy") raise ValueError("Embedding model identifier must follow <model_name>@<provider> format")
components = v.split("@", 1)
if len(components) != 2 or not all(components):
raise ValueError("Both model_name and provider must be non-empty strings")
model_name, provider = components
if not model_name.strip() or not provider.strip():
raise ValueError("Model name and provider cannot be whitespace-only strings")
return v return v
@field_validator("permission", mode="before") @field_validator("permission", mode="before")
@classmethod @classmethod
def permission_auto_lowercase(cls, v: str) -> str: def permission_auto_lowercase(cls, v: str) -> str:
if isinstance(v, str): """Normalize permission input to lowercase for consistent PermissionEnum matching.
return v.lower()
return v Args:
v (str): Raw input value for the permission field
Returns:
Lowercase string if input is string type, otherwise returns original value
Behavior:
- Converts string inputs to lowercase (e.g., "ME" "me")
- Non-string values pass through unchanged
- Works in validation pre-processing stage (before enum conversion)
"""
return v.lower() if isinstance(v, str) else v
@field_validator("parser_config", mode="after") @field_validator("parser_config", mode="after")
@classmethod @classmethod
def validate_parser_config_json_length(cls, v: Optional[ParserConfig]) -> Optional[ParserConfig]: def validate_parser_config_json_length(cls, v: ParserConfig | None) -> ParserConfig | None:
if v is not None: """Validates serialized JSON length constraints for parser configuration.
json_str = v.model_dump_json()
if len(json_str) > 65535: Implements a three-stage validation workflow:
raise ValueError("Parser config have at most 65535 characters") 1. Null check - bypass validation for empty configurations
2. Model serialization - convert Pydantic model to JSON string
3. Size verification - enforce maximum allowed payload size
Args:
v (ParserConfig | None): Raw parser configuration object
Returns:
ParserConfig | None: Validated configuration object
Raises:
ValueError: When serialized JSON exceeds 65,535 characters
"""
if v is None:
return v
if (json_str := v.model_dump_json()) and len(json_str) > 65535:
raise ValueError(f"Parser config exceeds size limit (max 65,535 characters). Current size: {len(json_str):,}")
return v return v

View File

@ -39,23 +39,23 @@ SESSION_WITH_CHAT_NAME_LIMIT = 255
# DATASET MANAGEMENT # DATASET MANAGEMENT
def create_dataset(auth, payload=None): def create_dataset(auth, payload=None, headers=HEADERS, data=None):
res = requests.post(url=f"{HOST_ADDRESS}{DATASETS_API_URL}", headers=HEADERS, auth=auth, json=payload) res = requests.post(url=f"{HOST_ADDRESS}{DATASETS_API_URL}", headers=headers, auth=auth, json=payload, data=data)
return res.json() return res.json()
def list_datasets(auth, params=None): def list_datasets(auth, params=None, headers=HEADERS):
res = requests.get(url=f"{HOST_ADDRESS}{DATASETS_API_URL}", headers=HEADERS, auth=auth, params=params) res = requests.get(url=f"{HOST_ADDRESS}{DATASETS_API_URL}", headers=headers, auth=auth, params=params)
return res.json() return res.json()
def update_dataset(auth, dataset_id, payload=None): def update_dataset(auth, dataset_id, payload=None, headers=HEADERS):
res = requests.put(url=f"{HOST_ADDRESS}{DATASETS_API_URL}/{dataset_id}", headers=HEADERS, auth=auth, json=payload) res = requests.put(url=f"{HOST_ADDRESS}{DATASETS_API_URL}/{dataset_id}", headers=headers, auth=auth, json=payload)
return res.json() return res.json()
def delete_datasets(auth, payload=None): def delete_datasets(auth, payload=None, headers=HEADERS):
res = requests.delete(url=f"{HOST_ADDRESS}{DATASETS_API_URL}", headers=HEADERS, auth=auth, json=payload) res = requests.delete(url=f"{HOST_ADDRESS}{DATASETS_API_URL}", headers=headers, auth=auth, json=payload)
return res.json() return res.json()

View File

@ -98,6 +98,25 @@ class TestDatasetCreation:
assert res["code"] == 101, res assert res["code"] == 101, res
assert res["message"] == f"Dataset name '{name.lower()}' already exists", res assert res["message"] == f"Dataset name '{name.lower()}' already exists", res
def test_bad_content_type(self, get_http_api_auth):
BAD_CONTENT_TYPE = "text/xml"
res = create_dataset(get_http_api_auth, {"name": "name"}, {"Content-Type": BAD_CONTENT_TYPE})
assert res["code"] == 101, res
assert res["message"] == f"Unsupported content type: Expected application/json, got {BAD_CONTENT_TYPE}", res
@pytest.mark.parametrize(
"payload, expected_message",
[
("a", "Malformed JSON syntax: Missing commas/brackets or invalid encoding"),
('"a"', "Invalid request payload: expected objec"),
],
ids=["malformed_json_syntax", "invalid_request_payload_type"],
)
def test_bad_payload(self, get_http_api_auth, payload, expected_message):
res = create_dataset(get_http_api_auth, data=payload)
assert res["code"] == 101, res
assert expected_message in res["message"], res
def test_avatar(self, get_http_api_auth, tmp_path): def test_avatar(self, get_http_api_auth, tmp_path):
fn = create_image_file(tmp_path / "ragflow_test.png") fn = create_image_file(tmp_path / "ragflow_test.png")
payload = { payload = {
@ -158,7 +177,7 @@ class TestDatasetCreation:
("embedding-3@ZHIPU-AI", "embedding-3@ZHIPU-AI"), ("embedding-3@ZHIPU-AI", "embedding-3@ZHIPU-AI"),
("embedding_model_default", None), ("embedding_model_default", None),
], ],
ids=["builtin_baai", "builtin_youdao", "tenant__zhipu", "default"], ids=["builtin_baai", "builtin_youdao", "tenant_zhipu", "default"],
) )
def test_valid_embedding_model(self, get_http_api_auth, name, embedding_model): def test_valid_embedding_model(self, get_http_api_auth, name, embedding_model):
if embedding_model is None: if embedding_model is None:
@ -178,29 +197,39 @@ class TestDatasetCreation:
[ [
("unknown_llm_name", "unknown@ZHIPU-AI"), ("unknown_llm_name", "unknown@ZHIPU-AI"),
("unknown_llm_factory", "embedding-3@unknown"), ("unknown_llm_factory", "embedding-3@unknown"),
("tenant_no_auth", "deepseek-chat@DeepSeek"), ("tenant_no_auth_default_tenant_llm", "text-embedding-v3@Tongyi-Qianwen"),
("tenant_no_auth", "text-embedding-3-small@OpenAI"),
], ],
ids=["unknown_llm_name", "unknown_llm_factory", "tenant_no_auth"], ids=["unknown_llm_name", "unknown_llm_factory", "tenant_no_auth_default_tenant_llm", "tenant_no_auth"],
) )
def test_invalid_embedding_model(self, get_http_api_auth, name, embedding_model): def test_invalid_embedding_model(self, get_http_api_auth, name, embedding_model):
payload = {"name": name, "embedding_model": embedding_model} payload = {"name": name, "embedding_model": embedding_model}
res = create_dataset(get_http_api_auth, payload) res = create_dataset(get_http_api_auth, payload)
assert res["code"] == 101, res assert res["code"] == 101, res
assert res["message"] == f"The embedding_model '{embedding_model}' is not supported", res if "tenant_no_auth" in name:
assert res["message"] == f"Unauthorized model: <{embedding_model}>", res
else:
assert res["message"] == f"Unsupported model: <{embedding_model}>", res
@pytest.mark.parametrize( @pytest.mark.parametrize(
"name, embedding_model", "name, embedding_model",
[ [
("builtin_missing_at", "BAAI/bge-large-zh-v1.5"), ("missing_at", "BAAI/bge-large-zh-v1.5BAAI"),
("tenant_missing_at", "embedding-3ZHIPU-AI"), ("missing_model_name", "@BAAI"),
("missing_provider", "BAAI/bge-large-zh-v1.5@"),
("whitespace_only_model_name", " @BAAI"),
("whitespace_only_provider", "BAAI/bge-large-zh-v1.5@ "),
], ],
ids=["builtin_missing_at", "tenant_missing_at"], ids=["missing_at", "empty_model_name", "empty_provider", "whitespace_only_model_name", "whitespace_only_provider"],
) )
def test_embedding_model_missing_at(self, get_http_api_auth, name, embedding_model): def test_embedding_model_format(self, get_http_api_auth, name, embedding_model):
payload = {"name": name, "embedding_model": embedding_model} payload = {"name": name, "embedding_model": embedding_model}
res = create_dataset(get_http_api_auth, payload) res = create_dataset(get_http_api_auth, payload)
assert res["code"] == 101, res assert res["code"] == 101, res
assert "Embedding model must be xxx@yyy" in res["message"], res if name == "missing_at":
assert "Embedding model identifier must follow <model_name>@<provider> format" in res["message"], res
else:
assert "Both model_name and provider must be non-empty strings" in res["message"], res
@pytest.mark.parametrize( @pytest.mark.parametrize(
"name, permission", "name, permission",
@ -485,7 +514,7 @@ class TestDatasetCreation:
("raptor_random_seed_min_limit", {"raptor": {"random_seed": -1}}, "Input should be greater than or equal to 0"), ("raptor_random_seed_min_limit", {"raptor": {"random_seed": -1}}, "Input should be greater than or equal to 0"),
("raptor_random_seed_float_not_allowed", {"raptor": {"random_seed": 3.14}}, "Input should be a valid integer, got a number with a fractional part"), ("raptor_random_seed_float_not_allowed", {"raptor": {"random_seed": 3.14}}, "Input should be a valid integer, got a number with a fractional part"),
("raptor_random_seed_type_invalid", {"raptor": {"random_seed": "string"}}, "Input should be a valid integer, unable to parse string as an integer"), ("raptor_random_seed_type_invalid", {"raptor": {"random_seed": "string"}}, "Input should be a valid integer, unable to parse string as an integer"),
("parser_config_type_invalid", {"delimiter": "a" * 65536}, "Parser config have at most 65535 characters"), ("parser_config_type_invalid", {"delimiter": "a" * 65536}, "Parser config exceeds size limit (max 65,535 characters)"),
], ],
ids=[ ids=[
"auto_keywords_min_limit", "auto_keywords_min_limit",