diff --git a/api/apps/sdk/dataset.py b/api/apps/sdk/dataset.py index c721743c4..b52f584d3 100644 --- a/api/apps/sdk/dataset.py +++ b/api/apps/sdk/dataset.py @@ -19,7 +19,6 @@ import logging from flask import request from peewee import OperationalError -from pydantic import ValidationError from api import settings from api.db import FileSource, StatusEnum @@ -41,8 +40,9 @@ from api.utils.api_utils import ( token_required, valid, valid_parser_config, + verify_embedding_availability, ) -from api.utils.validation_utils import CreateDatasetReq, format_validation_error_message +from api.utils.validation_utils import CreateDatasetReq, validate_and_parse_json_request @manager.route("/datasets", methods=["POST"]) # noqa: F821 @@ -107,21 +107,14 @@ def create(tenant_id): data: type: object """ - req_i = request.json - if not isinstance(req_i, dict): - return get_error_argument_result(f"Invalid request payload: expected object, got {type(req_i).__name__}") - - try: - req_v = CreateDatasetReq(**req_i) - except ValidationError as e: - return get_error_argument_result(format_validation_error_message(e)) - # Field name transformations during model dump: # | Original | Dump Output | # |----------------|-------------| # | embedding_model| embd_id | # | chunk_method | parser_id | - req = req_v.model_dump(by_alias=True) + req, err = validate_and_parse_json_request(request, CreateDatasetReq) + if err is not None: + return get_error_argument_result(err) try: if KnowledgebaseService.query(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value): @@ -146,21 +139,9 @@ def create(tenant_id): if not req.get("embd_id"): req["embd_id"] = t.embd_id else: - builtin_embedding_models = [ - "BAAI/bge-large-zh-v1.5@BAAI", - "maidalun1020/bce-embedding-base_v1@Youdao", - ] - is_builtin_model = req["embd_id"] in builtin_embedding_models - try: - # model name must be model_name@model_factory - llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(req["embd_id"]) - is_tenant_model = TenantLLMService.query(tenant_id=tenant_id, llm_name=llm_name, llm_factory=llm_factory, model_type="embedding") - is_supported_model = LLMService.query(llm_name=llm_name, fid=llm_factory, model_type="embedding") - if not (is_supported_model and (is_builtin_model or is_tenant_model)): - return get_error_argument_result(f"The embedding_model '{req['embd_id']}' is not supported") - except OperationalError as e: - logging.exception(e) - return get_error_data_result(message="Database operation failed") + ok, err = verify_embedding_availability(req["embd_id"], tenant_id) + if not ok: + return err try: if not KnowledgebaseService.save(**req): diff --git a/api/settings.py b/api/settings.py index f92e96bae..c1507dc3b 100644 --- a/api/settings.py +++ b/api/settings.py @@ -13,22 +13,22 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import json import os from datetime import date -from enum import IntEnum, Enum -import json +from enum import Enum, IntEnum + +import rag.utils import rag.utils.es_conn import rag.utils.infinity_conn import rag.utils.opensearch_coon - -import rag.utils -from rag.nlp import search -from graphrag import search as kg_search -from api.utils import get_base_config, decrypt_database_config from api.constants import RAG_FLOW_SERVICE_NAME +from api.utils import decrypt_database_config, get_base_config from api.utils.file_utils import get_project_base_directory +from graphrag import search as kg_search +from rag.nlp import search -LIGHTEN = int(os.environ.get('LIGHTEN', "0")) +LIGHTEN = int(os.environ.get("LIGHTEN", "0")) LLM = None LLM_FACTORY = None @@ -45,7 +45,7 @@ HOST_PORT = None SECRET_KEY = None FACTORY_LLM_INFOS = None -DATABASE_TYPE = os.getenv("DB_TYPE", 'mysql') +DATABASE_TYPE = os.getenv("DB_TYPE", "mysql") DATABASE = decrypt_database_config(name=DATABASE_TYPE) # authentication @@ -66,11 +66,13 @@ kg_retrievaler = None # user registration switch REGISTER_ENABLED = 1 +BUILTIN_EMBEDDING_MODELS = ["BAAI/bge-large-zh-v1.5@BAAI", "maidalun1020/bce-embedding-base_v1@Youdao"] + def init_settings(): global LLM, LLM_FACTORY, LLM_BASE_URL, LIGHTEN, DATABASE_TYPE, DATABASE, FACTORY_LLM_INFOS, REGISTER_ENABLED - LIGHTEN = int(os.environ.get('LIGHTEN', "0")) - DATABASE_TYPE = os.getenv("DB_TYPE", 'mysql') + LIGHTEN = int(os.environ.get("LIGHTEN", "0")) + DATABASE_TYPE = os.getenv("DB_TYPE", "mysql") DATABASE = decrypt_database_config(name=DATABASE_TYPE) LLM = get_base_config("user_default_llm", {}) LLM_DEFAULT_MODELS = LLM.get("default_models", {}) @@ -79,8 +81,8 @@ def init_settings(): try: REGISTER_ENABLED = int(os.environ.get("REGISTER_ENABLED", "1")) except Exception: - pass - + pass + try: with open(os.path.join(get_project_base_directory(), "conf", "llm_factories.json"), "r") as f: FACTORY_LLM_INFOS = json.load(f)["factory_llm_infos"] @@ -89,7 +91,7 @@ def init_settings(): global CHAT_MDL, EMBEDDING_MDL, RERANK_MDL, ASR_MDL, IMAGE2TEXT_MDL if not LIGHTEN: - EMBEDDING_MDL = "BAAI/bge-large-zh-v1.5@BAAI" + EMBEDDING_MDL = BUILTIN_EMBEDDING_MODELS[0] if LLM_DEFAULT_MODELS: CHAT_MDL = LLM_DEFAULT_MODELS.get("chat_model", CHAT_MDL) @@ -103,30 +105,25 @@ def init_settings(): EMBEDDING_MDL = EMBEDDING_MDL + (f"@{LLM_FACTORY}" if "@" not in EMBEDDING_MDL and EMBEDDING_MDL != "" else "") RERANK_MDL = RERANK_MDL + (f"@{LLM_FACTORY}" if "@" not in RERANK_MDL and RERANK_MDL != "" else "") ASR_MDL = ASR_MDL + (f"@{LLM_FACTORY}" if "@" not in ASR_MDL and ASR_MDL != "" else "") - IMAGE2TEXT_MDL = IMAGE2TEXT_MDL + ( - f"@{LLM_FACTORY}" if "@" not in IMAGE2TEXT_MDL and IMAGE2TEXT_MDL != "" else "") + IMAGE2TEXT_MDL = IMAGE2TEXT_MDL + (f"@{LLM_FACTORY}" if "@" not in IMAGE2TEXT_MDL and IMAGE2TEXT_MDL != "" else "") global API_KEY, PARSERS, HOST_IP, HOST_PORT, SECRET_KEY - API_KEY = LLM.get("api_key", "") + API_KEY = LLM.get("api_key") PARSERS = LLM.get( - "parsers", - "naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,email:Email,tag:Tag") + "parsers", "naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,email:Email,tag:Tag" + ) HOST_IP = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("host", "127.0.0.1") HOST_PORT = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("http_port") - SECRET_KEY = get_base_config( - RAG_FLOW_SERVICE_NAME, - {}).get("secret_key", str(date.today())) + SECRET_KEY = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("secret_key", str(date.today())) global AUTHENTICATION_CONF, CLIENT_AUTHENTICATION, HTTP_APP_KEY, GITHUB_OAUTH, FEISHU_OAUTH, OAUTH_CONFIG # authentication AUTHENTICATION_CONF = get_base_config("authentication", {}) # client - CLIENT_AUTHENTICATION = AUTHENTICATION_CONF.get( - "client", {}).get( - "switch", False) + CLIENT_AUTHENTICATION = AUTHENTICATION_CONF.get("client", {}).get("switch", False) HTTP_APP_KEY = AUTHENTICATION_CONF.get("client", {}).get("http_app_key") GITHUB_OAUTH = get_base_config("oauth", {}).get("github") FEISHU_OAUTH = get_base_config("oauth", {}).get("feishu") @@ -134,7 +131,7 @@ def init_settings(): OAUTH_CONFIG = get_base_config("oauth", {}) global DOC_ENGINE, docStoreConn, retrievaler, kg_retrievaler - DOC_ENGINE = os.environ.get('DOC_ENGINE', "elasticsearch") + DOC_ENGINE = os.environ.get("DOC_ENGINE", "elasticsearch") # DOC_ENGINE = os.environ.get('DOC_ENGINE', "opensearch") lower_case_doc_engine = DOC_ENGINE.lower() if lower_case_doc_engine == "elasticsearch": diff --git a/api/utils/api_utils.py b/api/utils/api_utils.py index e17667e4f..a9b9b76ae 100644 --- a/api/utils/api_utils.py +++ b/api/utils/api_utils.py @@ -36,11 +36,13 @@ from flask import ( request as flask_request, ) from itsdangerous import URLSafeTimedSerializer +from peewee import OperationalError from werkzeug.http import HTTP_STATUS_CODES from api import settings from api.constants import REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC from api.db.db_models import APIToken +from api.db.services.llm_service import LLMService, TenantLLMService from api.utils import CustomJSONEncoder, get_uuid, json_dumps requests.models.complexjson.dumps = functools.partial(json.dumps, cls=CustomJSONEncoder) @@ -464,3 +466,55 @@ def check_duplicate_ids(ids, id_type="item"): # Return unique IDs and error messages return list(set(ids)), duplicate_messages + + +def verify_embedding_availability(embd_id: str, tenant_id: str) -> tuple[bool, Response | None]: + """Verifies availability of an embedding model for a specific tenant. + + Implements a four-stage validation process: + 1. Model identifier parsing and validation + 2. System support verification + 3. Tenant authorization check + 4. Database operation error handling + + Args: + embd_id (str): Unique identifier for the embedding model in format "model_name@factory" + tenant_id (str): Tenant identifier for access control + + Returns: + tuple[bool, Response | None]: + - First element (bool): + - True: Model is available and authorized + - False: Validation failed + - Second element contains: + - None on success + - Error detail dict on failure + + Raises: + ValueError: When model identifier format is invalid + OperationalError: When database connection fails (auto-handled) + + Examples: + >>> verify_embedding_availability("text-embedding@openai", "tenant_123") + (True, None) + + >>> verify_embedding_availability("invalid_model", "tenant_123") + (False, {'code': 101, 'message': "Unsupported model: "}) + """ + try: + llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(embd_id) + if not LLMService.query(llm_name=llm_name, fid=llm_factory, model_type="embedding"): + return False, get_error_argument_result(f"Unsupported model: <{embd_id}>") + + # Tongyi-Qianwen is added to TenantLLM by default, but remains unusable with empty api_key + tenant_llms = TenantLLMService.get_my_llms(tenant_id=tenant_id) + is_tenant_model = any(llm["llm_name"] == llm_name and llm["llm_factory"] == llm_factory and llm["model_type"] == "embedding" for llm in tenant_llms) + + is_builtin_model = embd_id in settings.BUILTIN_EMBEDDING_MODELS + if not (is_builtin_model or is_tenant_model): + return False, get_error_argument_result(f"Unauthorized model: <{embd_id}>") + except OperationalError as e: + logging.exception(e) + return False, get_error_data_result(message="Database operation failed") + + return True, None diff --git a/api/utils/validation_utils.py b/api/utils/validation_utils.py index dea0ddb87..d16e80149 100644 --- a/api/utils/validation_utils.py +++ b/api/utils/validation_utils.py @@ -14,13 +14,102 @@ # limitations under the License. # from enum import auto -from typing import Annotated, List, Optional +from typing import Annotated, Any +from flask import Request from pydantic import BaseModel, Field, StringConstraints, ValidationError, field_validator from strenum import StrEnum +from werkzeug.exceptions import BadRequest, UnsupportedMediaType + +from api.constants import DATASET_NAME_LIMIT + + +def validate_and_parse_json_request(request: Request, validator: type[BaseModel]) -> tuple[dict[str, Any] | None, str | None]: + """Validates and parses JSON requests through a multi-stage validation pipeline. + + Implements a robust four-stage validation process: + 1. Content-Type verification (must be application/json) + 2. JSON syntax validation + 3. Payload structure type checking + 4. Pydantic model validation with error formatting + + Args: + request (Request): Flask request object containing HTTP payload + + Returns: + tuple[Dict[str, Any] | None, str | None]: + - First element: + - Validated dictionary on success + - None on validation failure + - Second element: + - None on success + - Diagnostic error message on failure + + Raises: + UnsupportedMediaType: When Content-Type ≠ application/json + BadRequest: For structural JSON syntax errors + ValidationError: When payload violates Pydantic schema rules + + Examples: + Successful validation: + ```python + # Input: {"name": "Dataset1", "format": "csv"} + # Returns: ({"name": "Dataset1", "format": "csv"}, None) + ``` + + Invalid Content-Type: + ```python + # Returns: (None, "Unsupported content type: Expected application/json, got text/xml") + ``` + + Malformed JSON: + ```python + # Returns: (None, "Malformed JSON syntax: Missing commas/brackets or invalid encoding") + ``` + """ + try: + payload = request.get_json() or {} + except UnsupportedMediaType: + return None, f"Unsupported content type: Expected application/json, got {request.content_type}" + except BadRequest: + return None, "Malformed JSON syntax: Missing commas/brackets or invalid encoding" + + if not isinstance(payload, dict): + return None, f"Invalid request payload: expected object, got {type(payload).__name__}" + + try: + validated_request = validator(**payload) + except ValidationError as e: + return None, format_validation_error_message(e) + + parsed_payload = validated_request.model_dump(by_alias=True) + + return parsed_payload, None def format_validation_error_message(e: ValidationError) -> str: + """Formats validation errors into a standardized string format. + + Processes pydantic ValidationError objects to create human-readable error messages + containing field locations, error descriptions, and input values. + + Args: + e (ValidationError): The validation error instance containing error details + + Returns: + str: Formatted error messages joined by newlines. Each line contains: + - Field path (dot-separated) + - Error message + - Truncated input value (max 128 chars) + + Example: + >>> try: + ... UserModel(name=123, email="invalid") + ... except ValidationError as e: + ... print(format_validation_error_message(e)) + Field: - Message: - Value: <123> + Field: - Message: - Value: + """ error_messages = [] for error in e.errors(): @@ -86,7 +175,7 @@ class RaptorConfig(Base): class GraphragConfig(Base): use_graphrag: bool = Field(default=False) - entity_types: List[str] = Field(default_factory=lambda: ["organization", "person", "geo", "event", "category"]) + entity_types: list[str] = Field(default_factory=lambda: ["organization", "person", "geo", "event", "category"]) method: GraphragMethodEnum = Field(default=GraphragMethodEnum.light) community: bool = Field(default=False) resolution: bool = Field(default=False) @@ -97,30 +186,59 @@ class ParserConfig(Base): auto_questions: int = Field(default=0, ge=0, le=10) chunk_token_num: int = Field(default=128, ge=1, le=2048) delimiter: str = Field(default=r"\n", min_length=1) - graphrag: Optional[GraphragConfig] = None + graphrag: GraphragConfig | None = None html4excel: bool = False layout_recognize: str = "DeepDOC" - raptor: Optional[RaptorConfig] = None - tag_kb_ids: List[str] = Field(default_factory=list) + raptor: RaptorConfig | None = None + tag_kb_ids: list[str] = Field(default_factory=list) topn_tags: int = Field(default=1, ge=1, le=10) - filename_embd_weight: Optional[float] = Field(default=None, ge=0.0, le=1.0) - task_page_size: Optional[int] = Field(default=None, ge=1) - pages: Optional[List[List[int]]] = None + filename_embd_weight: float | None = Field(default=None, ge=0.0, le=1.0) + task_page_size: int | None = Field(default=None, ge=1) + pages: list[list[int]] | None = None class CreateDatasetReq(Base): - name: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1, max_length=128), Field(...)] - avatar: Optional[str] = Field(default=None, max_length=65535) - description: Optional[str] = Field(default=None, max_length=65535) - embedding_model: Annotated[Optional[str], StringConstraints(strip_whitespace=True, max_length=255), Field(default=None, serialization_alias="embd_id")] + name: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1, max_length=DATASET_NAME_LIMIT), Field(...)] + avatar: str | None = Field(default=None, max_length=65535) + description: str | None = Field(default=None, max_length=65535) + embedding_model: Annotated[str | None, StringConstraints(strip_whitespace=True, max_length=255), Field(default=None, serialization_alias="embd_id")] permission: Annotated[PermissionEnum, StringConstraints(strip_whitespace=True, min_length=1, max_length=16), Field(default=PermissionEnum.me)] chunk_method: Annotated[ChunkMethodnEnum, StringConstraints(strip_whitespace=True, min_length=1, max_length=32), Field(default=ChunkMethodnEnum.naive, serialization_alias="parser_id")] pagerank: int = Field(default=0, ge=0, le=100) - parser_config: Optional[ParserConfig] = Field(default=None) + parser_config: ParserConfig | None = Field(default=None) @field_validator("avatar") @classmethod def validate_avatar_base64(cls, v: str) -> str: + """Validates Base64-encoded avatar string format and MIME type compliance. + + Implements a three-stage validation workflow: + 1. MIME prefix existence check + 2. MIME type format validation + 3. Supported type verification + + Args: + v (str): Raw avatar field value + + Returns: + str: Validated Base64 string + + Raises: + ValueError: For structural errors in these cases: + - Missing MIME prefix header + - Invalid MIME prefix format + - Unsupported image MIME type + + Example: + ```python + # Valid case + CreateDatasetReq(avatar="...") + + # Invalid cases + CreateDatasetReq(avatar="image/jpeg;base64,...") # Missing 'data:' prefix + CreateDatasetReq(avatar="data:video/mp4;base64,...") # Unsupported MIME type + ``` + """ if v is None: return v @@ -141,22 +259,83 @@ class CreateDatasetReq(Base): @field_validator("embedding_model", mode="after") @classmethod def validate_embedding_model(cls, v: str) -> str: + """Validates embedding model identifier format compliance. + + Validation pipeline: + 1. Structural format verification + 2. Component non-empty check + 3. Value normalization + + Args: + v (str): Raw model identifier + + Returns: + str: Validated @ format + + Raises: + ValueError: For these violations: + - Missing @ separator + - Empty model_name/provider + - Invalid component structure + + Examples: + Valid: "text-embedding-3-large@openai" + Invalid: "invalid_model" (no @) + Invalid: "@openai" (empty model_name) + Invalid: "text-embedding-3-large@" (empty provider) + """ if "@" not in v: - raise ValueError("Embedding model must be xxx@yyy") + raise ValueError("Embedding model identifier must follow @ format") + + components = v.split("@", 1) + if len(components) != 2 or not all(components): + raise ValueError("Both model_name and provider must be non-empty strings") + + model_name, provider = components + if not model_name.strip() or not provider.strip(): + raise ValueError("Model name and provider cannot be whitespace-only strings") return v @field_validator("permission", mode="before") @classmethod def permission_auto_lowercase(cls, v: str) -> str: - if isinstance(v, str): - return v.lower() - return v + """Normalize permission input to lowercase for consistent PermissionEnum matching. + + Args: + v (str): Raw input value for the permission field + + Returns: + Lowercase string if input is string type, otherwise returns original value + + Behavior: + - Converts string inputs to lowercase (e.g., "ME" → "me") + - Non-string values pass through unchanged + - Works in validation pre-processing stage (before enum conversion) + """ + return v.lower() if isinstance(v, str) else v @field_validator("parser_config", mode="after") @classmethod - def validate_parser_config_json_length(cls, v: Optional[ParserConfig]) -> Optional[ParserConfig]: - if v is not None: - json_str = v.model_dump_json() - if len(json_str) > 65535: - raise ValueError("Parser config have at most 65535 characters") + def validate_parser_config_json_length(cls, v: ParserConfig | None) -> ParserConfig | None: + """Validates serialized JSON length constraints for parser configuration. + + Implements a three-stage validation workflow: + 1. Null check - bypass validation for empty configurations + 2. Model serialization - convert Pydantic model to JSON string + 3. Size verification - enforce maximum allowed payload size + + Args: + v (ParserConfig | None): Raw parser configuration object + + Returns: + ParserConfig | None: Validated configuration object + + Raises: + ValueError: When serialized JSON exceeds 65,535 characters + """ + if v is None: + return v + + if (json_str := v.model_dump_json()) and len(json_str) > 65535: + raise ValueError(f"Parser config exceeds size limit (max 65,535 characters). Current size: {len(json_str):,}") return v diff --git a/sdk/python/test/test_http_api/common.py b/sdk/python/test/test_http_api/common.py index 424624f28..770f716e4 100644 --- a/sdk/python/test/test_http_api/common.py +++ b/sdk/python/test/test_http_api/common.py @@ -39,23 +39,23 @@ SESSION_WITH_CHAT_NAME_LIMIT = 255 # DATASET MANAGEMENT -def create_dataset(auth, payload=None): - res = requests.post(url=f"{HOST_ADDRESS}{DATASETS_API_URL}", headers=HEADERS, auth=auth, json=payload) +def create_dataset(auth, payload=None, headers=HEADERS, data=None): + res = requests.post(url=f"{HOST_ADDRESS}{DATASETS_API_URL}", headers=headers, auth=auth, json=payload, data=data) return res.json() -def list_datasets(auth, params=None): - res = requests.get(url=f"{HOST_ADDRESS}{DATASETS_API_URL}", headers=HEADERS, auth=auth, params=params) +def list_datasets(auth, params=None, headers=HEADERS): + res = requests.get(url=f"{HOST_ADDRESS}{DATASETS_API_URL}", headers=headers, auth=auth, params=params) return res.json() -def update_dataset(auth, dataset_id, payload=None): - res = requests.put(url=f"{HOST_ADDRESS}{DATASETS_API_URL}/{dataset_id}", headers=HEADERS, auth=auth, json=payload) +def update_dataset(auth, dataset_id, payload=None, headers=HEADERS): + res = requests.put(url=f"{HOST_ADDRESS}{DATASETS_API_URL}/{dataset_id}", headers=headers, auth=auth, json=payload) return res.json() -def delete_datasets(auth, payload=None): - res = requests.delete(url=f"{HOST_ADDRESS}{DATASETS_API_URL}", headers=HEADERS, auth=auth, json=payload) +def delete_datasets(auth, payload=None, headers=HEADERS): + res = requests.delete(url=f"{HOST_ADDRESS}{DATASETS_API_URL}", headers=headers, auth=auth, json=payload) return res.json() diff --git a/sdk/python/test/test_http_api/test_dataset_mangement/test_create_dataset.py b/sdk/python/test/test_http_api/test_dataset_mangement/test_create_dataset.py index 42db8ca99..25e175a74 100644 --- a/sdk/python/test/test_http_api/test_dataset_mangement/test_create_dataset.py +++ b/sdk/python/test/test_http_api/test_dataset_mangement/test_create_dataset.py @@ -98,6 +98,25 @@ class TestDatasetCreation: assert res["code"] == 101, res assert res["message"] == f"Dataset name '{name.lower()}' already exists", res + def test_bad_content_type(self, get_http_api_auth): + BAD_CONTENT_TYPE = "text/xml" + res = create_dataset(get_http_api_auth, {"name": "name"}, {"Content-Type": BAD_CONTENT_TYPE}) + assert res["code"] == 101, res + assert res["message"] == f"Unsupported content type: Expected application/json, got {BAD_CONTENT_TYPE}", res + + @pytest.mark.parametrize( + "payload, expected_message", + [ + ("a", "Malformed JSON syntax: Missing commas/brackets or invalid encoding"), + ('"a"', "Invalid request payload: expected objec"), + ], + ids=["malformed_json_syntax", "invalid_request_payload_type"], + ) + def test_bad_payload(self, get_http_api_auth, payload, expected_message): + res = create_dataset(get_http_api_auth, data=payload) + assert res["code"] == 101, res + assert expected_message in res["message"], res + def test_avatar(self, get_http_api_auth, tmp_path): fn = create_image_file(tmp_path / "ragflow_test.png") payload = { @@ -158,7 +177,7 @@ class TestDatasetCreation: ("embedding-3@ZHIPU-AI", "embedding-3@ZHIPU-AI"), ("embedding_model_default", None), ], - ids=["builtin_baai", "builtin_youdao", "tenant__zhipu", "default"], + ids=["builtin_baai", "builtin_youdao", "tenant_zhipu", "default"], ) def test_valid_embedding_model(self, get_http_api_auth, name, embedding_model): if embedding_model is None: @@ -178,29 +197,39 @@ class TestDatasetCreation: [ ("unknown_llm_name", "unknown@ZHIPU-AI"), ("unknown_llm_factory", "embedding-3@unknown"), - ("tenant_no_auth", "deepseek-chat@DeepSeek"), + ("tenant_no_auth_default_tenant_llm", "text-embedding-v3@Tongyi-Qianwen"), + ("tenant_no_auth", "text-embedding-3-small@OpenAI"), ], - ids=["unknown_llm_name", "unknown_llm_factory", "tenant_no_auth"], + ids=["unknown_llm_name", "unknown_llm_factory", "tenant_no_auth_default_tenant_llm", "tenant_no_auth"], ) def test_invalid_embedding_model(self, get_http_api_auth, name, embedding_model): payload = {"name": name, "embedding_model": embedding_model} res = create_dataset(get_http_api_auth, payload) assert res["code"] == 101, res - assert res["message"] == f"The embedding_model '{embedding_model}' is not supported", res + if "tenant_no_auth" in name: + assert res["message"] == f"Unauthorized model: <{embedding_model}>", res + else: + assert res["message"] == f"Unsupported model: <{embedding_model}>", res @pytest.mark.parametrize( "name, embedding_model", [ - ("builtin_missing_at", "BAAI/bge-large-zh-v1.5"), - ("tenant_missing_at", "embedding-3ZHIPU-AI"), + ("missing_at", "BAAI/bge-large-zh-v1.5BAAI"), + ("missing_model_name", "@BAAI"), + ("missing_provider", "BAAI/bge-large-zh-v1.5@"), + ("whitespace_only_model_name", " @BAAI"), + ("whitespace_only_provider", "BAAI/bge-large-zh-v1.5@ "), ], - ids=["builtin_missing_at", "tenant_missing_at"], + ids=["missing_at", "empty_model_name", "empty_provider", "whitespace_only_model_name", "whitespace_only_provider"], ) - def test_embedding_model_missing_at(self, get_http_api_auth, name, embedding_model): + def test_embedding_model_format(self, get_http_api_auth, name, embedding_model): payload = {"name": name, "embedding_model": embedding_model} res = create_dataset(get_http_api_auth, payload) assert res["code"] == 101, res - assert "Embedding model must be xxx@yyy" in res["message"], res + if name == "missing_at": + assert "Embedding model identifier must follow @ format" in res["message"], res + else: + assert "Both model_name and provider must be non-empty strings" in res["message"], res @pytest.mark.parametrize( "name, permission", @@ -485,7 +514,7 @@ class TestDatasetCreation: ("raptor_random_seed_min_limit", {"raptor": {"random_seed": -1}}, "Input should be greater than or equal to 0"), ("raptor_random_seed_float_not_allowed", {"raptor": {"random_seed": 3.14}}, "Input should be a valid integer, got a number with a fractional part"), ("raptor_random_seed_type_invalid", {"raptor": {"random_seed": "string"}}, "Input should be a valid integer, unable to parse string as an integer"), - ("parser_config_type_invalid", {"delimiter": "a" * 65536}, "Parser config have at most 65535 characters"), + ("parser_config_type_invalid", {"delimiter": "a" * 65536}, "Parser config exceeds size limit (max 65,535 characters)"), ], ids=[ "auto_keywords_min_limit",