mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-08-10 23:49:02 +08:00
refa: Optimize create dataset validation (#7451)
### What problem does this PR solve? Optimize dataset validation and add function docs ### Type of change - [x] Refactoring
This commit is contained in:
parent
2f768b96e8
commit
c98933499a
@ -19,7 +19,6 @@ import logging
|
|||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from peewee import OperationalError
|
from peewee import OperationalError
|
||||||
from pydantic import ValidationError
|
|
||||||
|
|
||||||
from api import settings
|
from api import settings
|
||||||
from api.db import FileSource, StatusEnum
|
from api.db import FileSource, StatusEnum
|
||||||
@ -41,8 +40,9 @@ from api.utils.api_utils import (
|
|||||||
token_required,
|
token_required,
|
||||||
valid,
|
valid,
|
||||||
valid_parser_config,
|
valid_parser_config,
|
||||||
|
verify_embedding_availability,
|
||||||
)
|
)
|
||||||
from api.utils.validation_utils import CreateDatasetReq, format_validation_error_message
|
from api.utils.validation_utils import CreateDatasetReq, validate_and_parse_json_request
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/datasets", methods=["POST"]) # noqa: F821
|
@manager.route("/datasets", methods=["POST"]) # noqa: F821
|
||||||
@ -107,21 +107,14 @@ def create(tenant_id):
|
|||||||
data:
|
data:
|
||||||
type: object
|
type: object
|
||||||
"""
|
"""
|
||||||
req_i = request.json
|
|
||||||
if not isinstance(req_i, dict):
|
|
||||||
return get_error_argument_result(f"Invalid request payload: expected object, got {type(req_i).__name__}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
req_v = CreateDatasetReq(**req_i)
|
|
||||||
except ValidationError as e:
|
|
||||||
return get_error_argument_result(format_validation_error_message(e))
|
|
||||||
|
|
||||||
# Field name transformations during model dump:
|
# Field name transformations during model dump:
|
||||||
# | Original | Dump Output |
|
# | Original | Dump Output |
|
||||||
# |----------------|-------------|
|
# |----------------|-------------|
|
||||||
# | embedding_model| embd_id |
|
# | embedding_model| embd_id |
|
||||||
# | chunk_method | parser_id |
|
# | chunk_method | parser_id |
|
||||||
req = req_v.model_dump(by_alias=True)
|
req, err = validate_and_parse_json_request(request, CreateDatasetReq)
|
||||||
|
if err is not None:
|
||||||
|
return get_error_argument_result(err)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if KnowledgebaseService.query(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value):
|
if KnowledgebaseService.query(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value):
|
||||||
@ -146,21 +139,9 @@ def create(tenant_id):
|
|||||||
if not req.get("embd_id"):
|
if not req.get("embd_id"):
|
||||||
req["embd_id"] = t.embd_id
|
req["embd_id"] = t.embd_id
|
||||||
else:
|
else:
|
||||||
builtin_embedding_models = [
|
ok, err = verify_embedding_availability(req["embd_id"], tenant_id)
|
||||||
"BAAI/bge-large-zh-v1.5@BAAI",
|
if not ok:
|
||||||
"maidalun1020/bce-embedding-base_v1@Youdao",
|
return err
|
||||||
]
|
|
||||||
is_builtin_model = req["embd_id"] in builtin_embedding_models
|
|
||||||
try:
|
|
||||||
# model name must be model_name@model_factory
|
|
||||||
llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(req["embd_id"])
|
|
||||||
is_tenant_model = TenantLLMService.query(tenant_id=tenant_id, llm_name=llm_name, llm_factory=llm_factory, model_type="embedding")
|
|
||||||
is_supported_model = LLMService.query(llm_name=llm_name, fid=llm_factory, model_type="embedding")
|
|
||||||
if not (is_supported_model and (is_builtin_model or is_tenant_model)):
|
|
||||||
return get_error_argument_result(f"The embedding_model '{req['embd_id']}' is not supported")
|
|
||||||
except OperationalError as e:
|
|
||||||
logging.exception(e)
|
|
||||||
return get_error_data_result(message="Database operation failed")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not KnowledgebaseService.save(**req):
|
if not KnowledgebaseService.save(**req):
|
||||||
|
@ -13,22 +13,22 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
from datetime import date
|
from datetime import date
|
||||||
from enum import IntEnum, Enum
|
from enum import Enum, IntEnum
|
||||||
import json
|
|
||||||
|
import rag.utils
|
||||||
import rag.utils.es_conn
|
import rag.utils.es_conn
|
||||||
import rag.utils.infinity_conn
|
import rag.utils.infinity_conn
|
||||||
import rag.utils.opensearch_coon
|
import rag.utils.opensearch_coon
|
||||||
|
|
||||||
import rag.utils
|
|
||||||
from rag.nlp import search
|
|
||||||
from graphrag import search as kg_search
|
|
||||||
from api.utils import get_base_config, decrypt_database_config
|
|
||||||
from api.constants import RAG_FLOW_SERVICE_NAME
|
from api.constants import RAG_FLOW_SERVICE_NAME
|
||||||
|
from api.utils import decrypt_database_config, get_base_config
|
||||||
from api.utils.file_utils import get_project_base_directory
|
from api.utils.file_utils import get_project_base_directory
|
||||||
|
from graphrag import search as kg_search
|
||||||
|
from rag.nlp import search
|
||||||
|
|
||||||
LIGHTEN = int(os.environ.get('LIGHTEN', "0"))
|
LIGHTEN = int(os.environ.get("LIGHTEN", "0"))
|
||||||
|
|
||||||
LLM = None
|
LLM = None
|
||||||
LLM_FACTORY = None
|
LLM_FACTORY = None
|
||||||
@ -45,7 +45,7 @@ HOST_PORT = None
|
|||||||
SECRET_KEY = None
|
SECRET_KEY = None
|
||||||
FACTORY_LLM_INFOS = None
|
FACTORY_LLM_INFOS = None
|
||||||
|
|
||||||
DATABASE_TYPE = os.getenv("DB_TYPE", 'mysql')
|
DATABASE_TYPE = os.getenv("DB_TYPE", "mysql")
|
||||||
DATABASE = decrypt_database_config(name=DATABASE_TYPE)
|
DATABASE = decrypt_database_config(name=DATABASE_TYPE)
|
||||||
|
|
||||||
# authentication
|
# authentication
|
||||||
@ -66,11 +66,13 @@ kg_retrievaler = None
|
|||||||
# user registration switch
|
# user registration switch
|
||||||
REGISTER_ENABLED = 1
|
REGISTER_ENABLED = 1
|
||||||
|
|
||||||
|
BUILTIN_EMBEDDING_MODELS = ["BAAI/bge-large-zh-v1.5@BAAI", "maidalun1020/bce-embedding-base_v1@Youdao"]
|
||||||
|
|
||||||
|
|
||||||
def init_settings():
|
def init_settings():
|
||||||
global LLM, LLM_FACTORY, LLM_BASE_URL, LIGHTEN, DATABASE_TYPE, DATABASE, FACTORY_LLM_INFOS, REGISTER_ENABLED
|
global LLM, LLM_FACTORY, LLM_BASE_URL, LIGHTEN, DATABASE_TYPE, DATABASE, FACTORY_LLM_INFOS, REGISTER_ENABLED
|
||||||
LIGHTEN = int(os.environ.get('LIGHTEN', "0"))
|
LIGHTEN = int(os.environ.get("LIGHTEN", "0"))
|
||||||
DATABASE_TYPE = os.getenv("DB_TYPE", 'mysql')
|
DATABASE_TYPE = os.getenv("DB_TYPE", "mysql")
|
||||||
DATABASE = decrypt_database_config(name=DATABASE_TYPE)
|
DATABASE = decrypt_database_config(name=DATABASE_TYPE)
|
||||||
LLM = get_base_config("user_default_llm", {})
|
LLM = get_base_config("user_default_llm", {})
|
||||||
LLM_DEFAULT_MODELS = LLM.get("default_models", {})
|
LLM_DEFAULT_MODELS = LLM.get("default_models", {})
|
||||||
@ -89,7 +91,7 @@ def init_settings():
|
|||||||
|
|
||||||
global CHAT_MDL, EMBEDDING_MDL, RERANK_MDL, ASR_MDL, IMAGE2TEXT_MDL
|
global CHAT_MDL, EMBEDDING_MDL, RERANK_MDL, ASR_MDL, IMAGE2TEXT_MDL
|
||||||
if not LIGHTEN:
|
if not LIGHTEN:
|
||||||
EMBEDDING_MDL = "BAAI/bge-large-zh-v1.5@BAAI"
|
EMBEDDING_MDL = BUILTIN_EMBEDDING_MODELS[0]
|
||||||
|
|
||||||
if LLM_DEFAULT_MODELS:
|
if LLM_DEFAULT_MODELS:
|
||||||
CHAT_MDL = LLM_DEFAULT_MODELS.get("chat_model", CHAT_MDL)
|
CHAT_MDL = LLM_DEFAULT_MODELS.get("chat_model", CHAT_MDL)
|
||||||
@ -103,30 +105,25 @@ def init_settings():
|
|||||||
EMBEDDING_MDL = EMBEDDING_MDL + (f"@{LLM_FACTORY}" if "@" not in EMBEDDING_MDL and EMBEDDING_MDL != "" else "")
|
EMBEDDING_MDL = EMBEDDING_MDL + (f"@{LLM_FACTORY}" if "@" not in EMBEDDING_MDL and EMBEDDING_MDL != "" else "")
|
||||||
RERANK_MDL = RERANK_MDL + (f"@{LLM_FACTORY}" if "@" not in RERANK_MDL and RERANK_MDL != "" else "")
|
RERANK_MDL = RERANK_MDL + (f"@{LLM_FACTORY}" if "@" not in RERANK_MDL and RERANK_MDL != "" else "")
|
||||||
ASR_MDL = ASR_MDL + (f"@{LLM_FACTORY}" if "@" not in ASR_MDL and ASR_MDL != "" else "")
|
ASR_MDL = ASR_MDL + (f"@{LLM_FACTORY}" if "@" not in ASR_MDL and ASR_MDL != "" else "")
|
||||||
IMAGE2TEXT_MDL = IMAGE2TEXT_MDL + (
|
IMAGE2TEXT_MDL = IMAGE2TEXT_MDL + (f"@{LLM_FACTORY}" if "@" not in IMAGE2TEXT_MDL and IMAGE2TEXT_MDL != "" else "")
|
||||||
f"@{LLM_FACTORY}" if "@" not in IMAGE2TEXT_MDL and IMAGE2TEXT_MDL != "" else "")
|
|
||||||
|
|
||||||
global API_KEY, PARSERS, HOST_IP, HOST_PORT, SECRET_KEY
|
global API_KEY, PARSERS, HOST_IP, HOST_PORT, SECRET_KEY
|
||||||
API_KEY = LLM.get("api_key", "")
|
API_KEY = LLM.get("api_key")
|
||||||
PARSERS = LLM.get(
|
PARSERS = LLM.get(
|
||||||
"parsers",
|
"parsers", "naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,email:Email,tag:Tag"
|
||||||
"naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,email:Email,tag:Tag")
|
)
|
||||||
|
|
||||||
HOST_IP = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("host", "127.0.0.1")
|
HOST_IP = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("host", "127.0.0.1")
|
||||||
HOST_PORT = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("http_port")
|
HOST_PORT = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("http_port")
|
||||||
|
|
||||||
SECRET_KEY = get_base_config(
|
SECRET_KEY = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("secret_key", str(date.today()))
|
||||||
RAG_FLOW_SERVICE_NAME,
|
|
||||||
{}).get("secret_key", str(date.today()))
|
|
||||||
|
|
||||||
global AUTHENTICATION_CONF, CLIENT_AUTHENTICATION, HTTP_APP_KEY, GITHUB_OAUTH, FEISHU_OAUTH, OAUTH_CONFIG
|
global AUTHENTICATION_CONF, CLIENT_AUTHENTICATION, HTTP_APP_KEY, GITHUB_OAUTH, FEISHU_OAUTH, OAUTH_CONFIG
|
||||||
# authentication
|
# authentication
|
||||||
AUTHENTICATION_CONF = get_base_config("authentication", {})
|
AUTHENTICATION_CONF = get_base_config("authentication", {})
|
||||||
|
|
||||||
# client
|
# client
|
||||||
CLIENT_AUTHENTICATION = AUTHENTICATION_CONF.get(
|
CLIENT_AUTHENTICATION = AUTHENTICATION_CONF.get("client", {}).get("switch", False)
|
||||||
"client", {}).get(
|
|
||||||
"switch", False)
|
|
||||||
HTTP_APP_KEY = AUTHENTICATION_CONF.get("client", {}).get("http_app_key")
|
HTTP_APP_KEY = AUTHENTICATION_CONF.get("client", {}).get("http_app_key")
|
||||||
GITHUB_OAUTH = get_base_config("oauth", {}).get("github")
|
GITHUB_OAUTH = get_base_config("oauth", {}).get("github")
|
||||||
FEISHU_OAUTH = get_base_config("oauth", {}).get("feishu")
|
FEISHU_OAUTH = get_base_config("oauth", {}).get("feishu")
|
||||||
@ -134,7 +131,7 @@ def init_settings():
|
|||||||
OAUTH_CONFIG = get_base_config("oauth", {})
|
OAUTH_CONFIG = get_base_config("oauth", {})
|
||||||
|
|
||||||
global DOC_ENGINE, docStoreConn, retrievaler, kg_retrievaler
|
global DOC_ENGINE, docStoreConn, retrievaler, kg_retrievaler
|
||||||
DOC_ENGINE = os.environ.get('DOC_ENGINE', "elasticsearch")
|
DOC_ENGINE = os.environ.get("DOC_ENGINE", "elasticsearch")
|
||||||
# DOC_ENGINE = os.environ.get('DOC_ENGINE', "opensearch")
|
# DOC_ENGINE = os.environ.get('DOC_ENGINE', "opensearch")
|
||||||
lower_case_doc_engine = DOC_ENGINE.lower()
|
lower_case_doc_engine = DOC_ENGINE.lower()
|
||||||
if lower_case_doc_engine == "elasticsearch":
|
if lower_case_doc_engine == "elasticsearch":
|
||||||
|
@ -36,11 +36,13 @@ from flask import (
|
|||||||
request as flask_request,
|
request as flask_request,
|
||||||
)
|
)
|
||||||
from itsdangerous import URLSafeTimedSerializer
|
from itsdangerous import URLSafeTimedSerializer
|
||||||
|
from peewee import OperationalError
|
||||||
from werkzeug.http import HTTP_STATUS_CODES
|
from werkzeug.http import HTTP_STATUS_CODES
|
||||||
|
|
||||||
from api import settings
|
from api import settings
|
||||||
from api.constants import REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC
|
from api.constants import REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC
|
||||||
from api.db.db_models import APIToken
|
from api.db.db_models import APIToken
|
||||||
|
from api.db.services.llm_service import LLMService, TenantLLMService
|
||||||
from api.utils import CustomJSONEncoder, get_uuid, json_dumps
|
from api.utils import CustomJSONEncoder, get_uuid, json_dumps
|
||||||
|
|
||||||
requests.models.complexjson.dumps = functools.partial(json.dumps, cls=CustomJSONEncoder)
|
requests.models.complexjson.dumps = functools.partial(json.dumps, cls=CustomJSONEncoder)
|
||||||
@ -464,3 +466,55 @@ def check_duplicate_ids(ids, id_type="item"):
|
|||||||
|
|
||||||
# Return unique IDs and error messages
|
# Return unique IDs and error messages
|
||||||
return list(set(ids)), duplicate_messages
|
return list(set(ids)), duplicate_messages
|
||||||
|
|
||||||
|
|
||||||
|
def verify_embedding_availability(embd_id: str, tenant_id: str) -> tuple[bool, Response | None]:
|
||||||
|
"""Verifies availability of an embedding model for a specific tenant.
|
||||||
|
|
||||||
|
Implements a four-stage validation process:
|
||||||
|
1. Model identifier parsing and validation
|
||||||
|
2. System support verification
|
||||||
|
3. Tenant authorization check
|
||||||
|
4. Database operation error handling
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embd_id (str): Unique identifier for the embedding model in format "model_name@factory"
|
||||||
|
tenant_id (str): Tenant identifier for access control
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[bool, Response | None]:
|
||||||
|
- First element (bool):
|
||||||
|
- True: Model is available and authorized
|
||||||
|
- False: Validation failed
|
||||||
|
- Second element contains:
|
||||||
|
- None on success
|
||||||
|
- Error detail dict on failure
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: When model identifier format is invalid
|
||||||
|
OperationalError: When database connection fails (auto-handled)
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> verify_embedding_availability("text-embedding@openai", "tenant_123")
|
||||||
|
(True, None)
|
||||||
|
|
||||||
|
>>> verify_embedding_availability("invalid_model", "tenant_123")
|
||||||
|
(False, {'code': 101, 'message': "Unsupported model: <invalid_model>"})
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(embd_id)
|
||||||
|
if not LLMService.query(llm_name=llm_name, fid=llm_factory, model_type="embedding"):
|
||||||
|
return False, get_error_argument_result(f"Unsupported model: <{embd_id}>")
|
||||||
|
|
||||||
|
# Tongyi-Qianwen is added to TenantLLM by default, but remains unusable with empty api_key
|
||||||
|
tenant_llms = TenantLLMService.get_my_llms(tenant_id=tenant_id)
|
||||||
|
is_tenant_model = any(llm["llm_name"] == llm_name and llm["llm_factory"] == llm_factory and llm["model_type"] == "embedding" for llm in tenant_llms)
|
||||||
|
|
||||||
|
is_builtin_model = embd_id in settings.BUILTIN_EMBEDDING_MODELS
|
||||||
|
if not (is_builtin_model or is_tenant_model):
|
||||||
|
return False, get_error_argument_result(f"Unauthorized model: <{embd_id}>")
|
||||||
|
except OperationalError as e:
|
||||||
|
logging.exception(e)
|
||||||
|
return False, get_error_data_result(message="Database operation failed")
|
||||||
|
|
||||||
|
return True, None
|
||||||
|
@ -14,13 +14,102 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
from enum import auto
|
from enum import auto
|
||||||
from typing import Annotated, List, Optional
|
from typing import Annotated, Any
|
||||||
|
|
||||||
|
from flask import Request
|
||||||
from pydantic import BaseModel, Field, StringConstraints, ValidationError, field_validator
|
from pydantic import BaseModel, Field, StringConstraints, ValidationError, field_validator
|
||||||
from strenum import StrEnum
|
from strenum import StrEnum
|
||||||
|
from werkzeug.exceptions import BadRequest, UnsupportedMediaType
|
||||||
|
|
||||||
|
from api.constants import DATASET_NAME_LIMIT
|
||||||
|
|
||||||
|
|
||||||
|
def validate_and_parse_json_request(request: Request, validator: type[BaseModel]) -> tuple[dict[str, Any] | None, str | None]:
|
||||||
|
"""Validates and parses JSON requests through a multi-stage validation pipeline.
|
||||||
|
|
||||||
|
Implements a robust four-stage validation process:
|
||||||
|
1. Content-Type verification (must be application/json)
|
||||||
|
2. JSON syntax validation
|
||||||
|
3. Payload structure type checking
|
||||||
|
4. Pydantic model validation with error formatting
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request (Request): Flask request object containing HTTP payload
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[Dict[str, Any] | None, str | None]:
|
||||||
|
- First element:
|
||||||
|
- Validated dictionary on success
|
||||||
|
- None on validation failure
|
||||||
|
- Second element:
|
||||||
|
- None on success
|
||||||
|
- Diagnostic error message on failure
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
UnsupportedMediaType: When Content-Type ≠ application/json
|
||||||
|
BadRequest: For structural JSON syntax errors
|
||||||
|
ValidationError: When payload violates Pydantic schema rules
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
Successful validation:
|
||||||
|
```python
|
||||||
|
# Input: {"name": "Dataset1", "format": "csv"}
|
||||||
|
# Returns: ({"name": "Dataset1", "format": "csv"}, None)
|
||||||
|
```
|
||||||
|
|
||||||
|
Invalid Content-Type:
|
||||||
|
```python
|
||||||
|
# Returns: (None, "Unsupported content type: Expected application/json, got text/xml")
|
||||||
|
```
|
||||||
|
|
||||||
|
Malformed JSON:
|
||||||
|
```python
|
||||||
|
# Returns: (None, "Malformed JSON syntax: Missing commas/brackets or invalid encoding")
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
payload = request.get_json() or {}
|
||||||
|
except UnsupportedMediaType:
|
||||||
|
return None, f"Unsupported content type: Expected application/json, got {request.content_type}"
|
||||||
|
except BadRequest:
|
||||||
|
return None, "Malformed JSON syntax: Missing commas/brackets or invalid encoding"
|
||||||
|
|
||||||
|
if not isinstance(payload, dict):
|
||||||
|
return None, f"Invalid request payload: expected object, got {type(payload).__name__}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
validated_request = validator(**payload)
|
||||||
|
except ValidationError as e:
|
||||||
|
return None, format_validation_error_message(e)
|
||||||
|
|
||||||
|
parsed_payload = validated_request.model_dump(by_alias=True)
|
||||||
|
|
||||||
|
return parsed_payload, None
|
||||||
|
|
||||||
|
|
||||||
def format_validation_error_message(e: ValidationError) -> str:
|
def format_validation_error_message(e: ValidationError) -> str:
|
||||||
|
"""Formats validation errors into a standardized string format.
|
||||||
|
|
||||||
|
Processes pydantic ValidationError objects to create human-readable error messages
|
||||||
|
containing field locations, error descriptions, and input values.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
e (ValidationError): The validation error instance containing error details
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Formatted error messages joined by newlines. Each line contains:
|
||||||
|
- Field path (dot-separated)
|
||||||
|
- Error message
|
||||||
|
- Truncated input value (max 128 chars)
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> try:
|
||||||
|
... UserModel(name=123, email="invalid")
|
||||||
|
... except ValidationError as e:
|
||||||
|
... print(format_validation_error_message(e))
|
||||||
|
Field: <name> - Message: <Input should be a valid string> - Value: <123>
|
||||||
|
Field: <email> - Message: <value is not a valid email address> - Value: <invalid>
|
||||||
|
"""
|
||||||
error_messages = []
|
error_messages = []
|
||||||
|
|
||||||
for error in e.errors():
|
for error in e.errors():
|
||||||
@ -86,7 +175,7 @@ class RaptorConfig(Base):
|
|||||||
|
|
||||||
class GraphragConfig(Base):
|
class GraphragConfig(Base):
|
||||||
use_graphrag: bool = Field(default=False)
|
use_graphrag: bool = Field(default=False)
|
||||||
entity_types: List[str] = Field(default_factory=lambda: ["organization", "person", "geo", "event", "category"])
|
entity_types: list[str] = Field(default_factory=lambda: ["organization", "person", "geo", "event", "category"])
|
||||||
method: GraphragMethodEnum = Field(default=GraphragMethodEnum.light)
|
method: GraphragMethodEnum = Field(default=GraphragMethodEnum.light)
|
||||||
community: bool = Field(default=False)
|
community: bool = Field(default=False)
|
||||||
resolution: bool = Field(default=False)
|
resolution: bool = Field(default=False)
|
||||||
@ -97,30 +186,59 @@ class ParserConfig(Base):
|
|||||||
auto_questions: int = Field(default=0, ge=0, le=10)
|
auto_questions: int = Field(default=0, ge=0, le=10)
|
||||||
chunk_token_num: int = Field(default=128, ge=1, le=2048)
|
chunk_token_num: int = Field(default=128, ge=1, le=2048)
|
||||||
delimiter: str = Field(default=r"\n", min_length=1)
|
delimiter: str = Field(default=r"\n", min_length=1)
|
||||||
graphrag: Optional[GraphragConfig] = None
|
graphrag: GraphragConfig | None = None
|
||||||
html4excel: bool = False
|
html4excel: bool = False
|
||||||
layout_recognize: str = "DeepDOC"
|
layout_recognize: str = "DeepDOC"
|
||||||
raptor: Optional[RaptorConfig] = None
|
raptor: RaptorConfig | None = None
|
||||||
tag_kb_ids: List[str] = Field(default_factory=list)
|
tag_kb_ids: list[str] = Field(default_factory=list)
|
||||||
topn_tags: int = Field(default=1, ge=1, le=10)
|
topn_tags: int = Field(default=1, ge=1, le=10)
|
||||||
filename_embd_weight: Optional[float] = Field(default=None, ge=0.0, le=1.0)
|
filename_embd_weight: float | None = Field(default=None, ge=0.0, le=1.0)
|
||||||
task_page_size: Optional[int] = Field(default=None, ge=1)
|
task_page_size: int | None = Field(default=None, ge=1)
|
||||||
pages: Optional[List[List[int]]] = None
|
pages: list[list[int]] | None = None
|
||||||
|
|
||||||
|
|
||||||
class CreateDatasetReq(Base):
|
class CreateDatasetReq(Base):
|
||||||
name: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1, max_length=128), Field(...)]
|
name: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1, max_length=DATASET_NAME_LIMIT), Field(...)]
|
||||||
avatar: Optional[str] = Field(default=None, max_length=65535)
|
avatar: str | None = Field(default=None, max_length=65535)
|
||||||
description: Optional[str] = Field(default=None, max_length=65535)
|
description: str | None = Field(default=None, max_length=65535)
|
||||||
embedding_model: Annotated[Optional[str], StringConstraints(strip_whitespace=True, max_length=255), Field(default=None, serialization_alias="embd_id")]
|
embedding_model: Annotated[str | None, StringConstraints(strip_whitespace=True, max_length=255), Field(default=None, serialization_alias="embd_id")]
|
||||||
permission: Annotated[PermissionEnum, StringConstraints(strip_whitespace=True, min_length=1, max_length=16), Field(default=PermissionEnum.me)]
|
permission: Annotated[PermissionEnum, StringConstraints(strip_whitespace=True, min_length=1, max_length=16), Field(default=PermissionEnum.me)]
|
||||||
chunk_method: Annotated[ChunkMethodnEnum, StringConstraints(strip_whitespace=True, min_length=1, max_length=32), Field(default=ChunkMethodnEnum.naive, serialization_alias="parser_id")]
|
chunk_method: Annotated[ChunkMethodnEnum, StringConstraints(strip_whitespace=True, min_length=1, max_length=32), Field(default=ChunkMethodnEnum.naive, serialization_alias="parser_id")]
|
||||||
pagerank: int = Field(default=0, ge=0, le=100)
|
pagerank: int = Field(default=0, ge=0, le=100)
|
||||||
parser_config: Optional[ParserConfig] = Field(default=None)
|
parser_config: ParserConfig | None = Field(default=None)
|
||||||
|
|
||||||
@field_validator("avatar")
|
@field_validator("avatar")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_avatar_base64(cls, v: str) -> str:
|
def validate_avatar_base64(cls, v: str) -> str:
|
||||||
|
"""Validates Base64-encoded avatar string format and MIME type compliance.
|
||||||
|
|
||||||
|
Implements a three-stage validation workflow:
|
||||||
|
1. MIME prefix existence check
|
||||||
|
2. MIME type format validation
|
||||||
|
3. Supported type verification
|
||||||
|
|
||||||
|
Args:
|
||||||
|
v (str): Raw avatar field value
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Validated Base64 string
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: For structural errors in these cases:
|
||||||
|
- Missing MIME prefix header
|
||||||
|
- Invalid MIME prefix format
|
||||||
|
- Unsupported image MIME type
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
# Valid case
|
||||||
|
CreateDatasetReq(avatar="...")
|
||||||
|
|
||||||
|
# Invalid cases
|
||||||
|
CreateDatasetReq(avatar="image/jpeg;base64,...") # Missing 'data:' prefix
|
||||||
|
CreateDatasetReq(avatar="data:video/mp4;base64,...") # Unsupported MIME type
|
||||||
|
```
|
||||||
|
"""
|
||||||
if v is None:
|
if v is None:
|
||||||
return v
|
return v
|
||||||
|
|
||||||
@ -141,22 +259,83 @@ class CreateDatasetReq(Base):
|
|||||||
@field_validator("embedding_model", mode="after")
|
@field_validator("embedding_model", mode="after")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_embedding_model(cls, v: str) -> str:
|
def validate_embedding_model(cls, v: str) -> str:
|
||||||
|
"""Validates embedding model identifier format compliance.
|
||||||
|
|
||||||
|
Validation pipeline:
|
||||||
|
1. Structural format verification
|
||||||
|
2. Component non-empty check
|
||||||
|
3. Value normalization
|
||||||
|
|
||||||
|
Args:
|
||||||
|
v (str): Raw model identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Validated <model_name>@<provider> format
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: For these violations:
|
||||||
|
- Missing @ separator
|
||||||
|
- Empty model_name/provider
|
||||||
|
- Invalid component structure
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
Valid: "text-embedding-3-large@openai"
|
||||||
|
Invalid: "invalid_model" (no @)
|
||||||
|
Invalid: "@openai" (empty model_name)
|
||||||
|
Invalid: "text-embedding-3-large@" (empty provider)
|
||||||
|
"""
|
||||||
if "@" not in v:
|
if "@" not in v:
|
||||||
raise ValueError("Embedding model must be xxx@yyy")
|
raise ValueError("Embedding model identifier must follow <model_name>@<provider> format")
|
||||||
|
|
||||||
|
components = v.split("@", 1)
|
||||||
|
if len(components) != 2 or not all(components):
|
||||||
|
raise ValueError("Both model_name and provider must be non-empty strings")
|
||||||
|
|
||||||
|
model_name, provider = components
|
||||||
|
if not model_name.strip() or not provider.strip():
|
||||||
|
raise ValueError("Model name and provider cannot be whitespace-only strings")
|
||||||
return v
|
return v
|
||||||
|
|
||||||
@field_validator("permission", mode="before")
|
@field_validator("permission", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def permission_auto_lowercase(cls, v: str) -> str:
|
def permission_auto_lowercase(cls, v: str) -> str:
|
||||||
if isinstance(v, str):
|
"""Normalize permission input to lowercase for consistent PermissionEnum matching.
|
||||||
return v.lower()
|
|
||||||
return v
|
Args:
|
||||||
|
v (str): Raw input value for the permission field
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Lowercase string if input is string type, otherwise returns original value
|
||||||
|
|
||||||
|
Behavior:
|
||||||
|
- Converts string inputs to lowercase (e.g., "ME" → "me")
|
||||||
|
- Non-string values pass through unchanged
|
||||||
|
- Works in validation pre-processing stage (before enum conversion)
|
||||||
|
"""
|
||||||
|
return v.lower() if isinstance(v, str) else v
|
||||||
|
|
||||||
@field_validator("parser_config", mode="after")
|
@field_validator("parser_config", mode="after")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_parser_config_json_length(cls, v: Optional[ParserConfig]) -> Optional[ParserConfig]:
|
def validate_parser_config_json_length(cls, v: ParserConfig | None) -> ParserConfig | None:
|
||||||
if v is not None:
|
"""Validates serialized JSON length constraints for parser configuration.
|
||||||
json_str = v.model_dump_json()
|
|
||||||
if len(json_str) > 65535:
|
Implements a three-stage validation workflow:
|
||||||
raise ValueError("Parser config have at most 65535 characters")
|
1. Null check - bypass validation for empty configurations
|
||||||
|
2. Model serialization - convert Pydantic model to JSON string
|
||||||
|
3. Size verification - enforce maximum allowed payload size
|
||||||
|
|
||||||
|
Args:
|
||||||
|
v (ParserConfig | None): Raw parser configuration object
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ParserConfig | None: Validated configuration object
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: When serialized JSON exceeds 65,535 characters
|
||||||
|
"""
|
||||||
|
if v is None:
|
||||||
|
return v
|
||||||
|
|
||||||
|
if (json_str := v.model_dump_json()) and len(json_str) > 65535:
|
||||||
|
raise ValueError(f"Parser config exceeds size limit (max 65,535 characters). Current size: {len(json_str):,}")
|
||||||
return v
|
return v
|
||||||
|
@ -39,23 +39,23 @@ SESSION_WITH_CHAT_NAME_LIMIT = 255
|
|||||||
|
|
||||||
|
|
||||||
# DATASET MANAGEMENT
|
# DATASET MANAGEMENT
|
||||||
def create_dataset(auth, payload=None):
|
def create_dataset(auth, payload=None, headers=HEADERS, data=None):
|
||||||
res = requests.post(url=f"{HOST_ADDRESS}{DATASETS_API_URL}", headers=HEADERS, auth=auth, json=payload)
|
res = requests.post(url=f"{HOST_ADDRESS}{DATASETS_API_URL}", headers=headers, auth=auth, json=payload, data=data)
|
||||||
return res.json()
|
return res.json()
|
||||||
|
|
||||||
|
|
||||||
def list_datasets(auth, params=None):
|
def list_datasets(auth, params=None, headers=HEADERS):
|
||||||
res = requests.get(url=f"{HOST_ADDRESS}{DATASETS_API_URL}", headers=HEADERS, auth=auth, params=params)
|
res = requests.get(url=f"{HOST_ADDRESS}{DATASETS_API_URL}", headers=headers, auth=auth, params=params)
|
||||||
return res.json()
|
return res.json()
|
||||||
|
|
||||||
|
|
||||||
def update_dataset(auth, dataset_id, payload=None):
|
def update_dataset(auth, dataset_id, payload=None, headers=HEADERS):
|
||||||
res = requests.put(url=f"{HOST_ADDRESS}{DATASETS_API_URL}/{dataset_id}", headers=HEADERS, auth=auth, json=payload)
|
res = requests.put(url=f"{HOST_ADDRESS}{DATASETS_API_URL}/{dataset_id}", headers=headers, auth=auth, json=payload)
|
||||||
return res.json()
|
return res.json()
|
||||||
|
|
||||||
|
|
||||||
def delete_datasets(auth, payload=None):
|
def delete_datasets(auth, payload=None, headers=HEADERS):
|
||||||
res = requests.delete(url=f"{HOST_ADDRESS}{DATASETS_API_URL}", headers=HEADERS, auth=auth, json=payload)
|
res = requests.delete(url=f"{HOST_ADDRESS}{DATASETS_API_URL}", headers=headers, auth=auth, json=payload)
|
||||||
return res.json()
|
return res.json()
|
||||||
|
|
||||||
|
|
||||||
|
@ -98,6 +98,25 @@ class TestDatasetCreation:
|
|||||||
assert res["code"] == 101, res
|
assert res["code"] == 101, res
|
||||||
assert res["message"] == f"Dataset name '{name.lower()}' already exists", res
|
assert res["message"] == f"Dataset name '{name.lower()}' already exists", res
|
||||||
|
|
||||||
|
def test_bad_content_type(self, get_http_api_auth):
|
||||||
|
BAD_CONTENT_TYPE = "text/xml"
|
||||||
|
res = create_dataset(get_http_api_auth, {"name": "name"}, {"Content-Type": BAD_CONTENT_TYPE})
|
||||||
|
assert res["code"] == 101, res
|
||||||
|
assert res["message"] == f"Unsupported content type: Expected application/json, got {BAD_CONTENT_TYPE}", res
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"payload, expected_message",
|
||||||
|
[
|
||||||
|
("a", "Malformed JSON syntax: Missing commas/brackets or invalid encoding"),
|
||||||
|
('"a"', "Invalid request payload: expected objec"),
|
||||||
|
],
|
||||||
|
ids=["malformed_json_syntax", "invalid_request_payload_type"],
|
||||||
|
)
|
||||||
|
def test_bad_payload(self, get_http_api_auth, payload, expected_message):
|
||||||
|
res = create_dataset(get_http_api_auth, data=payload)
|
||||||
|
assert res["code"] == 101, res
|
||||||
|
assert expected_message in res["message"], res
|
||||||
|
|
||||||
def test_avatar(self, get_http_api_auth, tmp_path):
|
def test_avatar(self, get_http_api_auth, tmp_path):
|
||||||
fn = create_image_file(tmp_path / "ragflow_test.png")
|
fn = create_image_file(tmp_path / "ragflow_test.png")
|
||||||
payload = {
|
payload = {
|
||||||
@ -158,7 +177,7 @@ class TestDatasetCreation:
|
|||||||
("embedding-3@ZHIPU-AI", "embedding-3@ZHIPU-AI"),
|
("embedding-3@ZHIPU-AI", "embedding-3@ZHIPU-AI"),
|
||||||
("embedding_model_default", None),
|
("embedding_model_default", None),
|
||||||
],
|
],
|
||||||
ids=["builtin_baai", "builtin_youdao", "tenant__zhipu", "default"],
|
ids=["builtin_baai", "builtin_youdao", "tenant_zhipu", "default"],
|
||||||
)
|
)
|
||||||
def test_valid_embedding_model(self, get_http_api_auth, name, embedding_model):
|
def test_valid_embedding_model(self, get_http_api_auth, name, embedding_model):
|
||||||
if embedding_model is None:
|
if embedding_model is None:
|
||||||
@ -178,29 +197,39 @@ class TestDatasetCreation:
|
|||||||
[
|
[
|
||||||
("unknown_llm_name", "unknown@ZHIPU-AI"),
|
("unknown_llm_name", "unknown@ZHIPU-AI"),
|
||||||
("unknown_llm_factory", "embedding-3@unknown"),
|
("unknown_llm_factory", "embedding-3@unknown"),
|
||||||
("tenant_no_auth", "deepseek-chat@DeepSeek"),
|
("tenant_no_auth_default_tenant_llm", "text-embedding-v3@Tongyi-Qianwen"),
|
||||||
|
("tenant_no_auth", "text-embedding-3-small@OpenAI"),
|
||||||
],
|
],
|
||||||
ids=["unknown_llm_name", "unknown_llm_factory", "tenant_no_auth"],
|
ids=["unknown_llm_name", "unknown_llm_factory", "tenant_no_auth_default_tenant_llm", "tenant_no_auth"],
|
||||||
)
|
)
|
||||||
def test_invalid_embedding_model(self, get_http_api_auth, name, embedding_model):
|
def test_invalid_embedding_model(self, get_http_api_auth, name, embedding_model):
|
||||||
payload = {"name": name, "embedding_model": embedding_model}
|
payload = {"name": name, "embedding_model": embedding_model}
|
||||||
res = create_dataset(get_http_api_auth, payload)
|
res = create_dataset(get_http_api_auth, payload)
|
||||||
assert res["code"] == 101, res
|
assert res["code"] == 101, res
|
||||||
assert res["message"] == f"The embedding_model '{embedding_model}' is not supported", res
|
if "tenant_no_auth" in name:
|
||||||
|
assert res["message"] == f"Unauthorized model: <{embedding_model}>", res
|
||||||
|
else:
|
||||||
|
assert res["message"] == f"Unsupported model: <{embedding_model}>", res
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"name, embedding_model",
|
"name, embedding_model",
|
||||||
[
|
[
|
||||||
("builtin_missing_at", "BAAI/bge-large-zh-v1.5"),
|
("missing_at", "BAAI/bge-large-zh-v1.5BAAI"),
|
||||||
("tenant_missing_at", "embedding-3ZHIPU-AI"),
|
("missing_model_name", "@BAAI"),
|
||||||
|
("missing_provider", "BAAI/bge-large-zh-v1.5@"),
|
||||||
|
("whitespace_only_model_name", " @BAAI"),
|
||||||
|
("whitespace_only_provider", "BAAI/bge-large-zh-v1.5@ "),
|
||||||
],
|
],
|
||||||
ids=["builtin_missing_at", "tenant_missing_at"],
|
ids=["missing_at", "empty_model_name", "empty_provider", "whitespace_only_model_name", "whitespace_only_provider"],
|
||||||
)
|
)
|
||||||
def test_embedding_model_missing_at(self, get_http_api_auth, name, embedding_model):
|
def test_embedding_model_format(self, get_http_api_auth, name, embedding_model):
|
||||||
payload = {"name": name, "embedding_model": embedding_model}
|
payload = {"name": name, "embedding_model": embedding_model}
|
||||||
res = create_dataset(get_http_api_auth, payload)
|
res = create_dataset(get_http_api_auth, payload)
|
||||||
assert res["code"] == 101, res
|
assert res["code"] == 101, res
|
||||||
assert "Embedding model must be xxx@yyy" in res["message"], res
|
if name == "missing_at":
|
||||||
|
assert "Embedding model identifier must follow <model_name>@<provider> format" in res["message"], res
|
||||||
|
else:
|
||||||
|
assert "Both model_name and provider must be non-empty strings" in res["message"], res
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"name, permission",
|
"name, permission",
|
||||||
@ -485,7 +514,7 @@ class TestDatasetCreation:
|
|||||||
("raptor_random_seed_min_limit", {"raptor": {"random_seed": -1}}, "Input should be greater than or equal to 0"),
|
("raptor_random_seed_min_limit", {"raptor": {"random_seed": -1}}, "Input should be greater than or equal to 0"),
|
||||||
("raptor_random_seed_float_not_allowed", {"raptor": {"random_seed": 3.14}}, "Input should be a valid integer, got a number with a fractional part"),
|
("raptor_random_seed_float_not_allowed", {"raptor": {"random_seed": 3.14}}, "Input should be a valid integer, got a number with a fractional part"),
|
||||||
("raptor_random_seed_type_invalid", {"raptor": {"random_seed": "string"}}, "Input should be a valid integer, unable to parse string as an integer"),
|
("raptor_random_seed_type_invalid", {"raptor": {"random_seed": "string"}}, "Input should be a valid integer, unable to parse string as an integer"),
|
||||||
("parser_config_type_invalid", {"delimiter": "a" * 65536}, "Parser config have at most 65535 characters"),
|
("parser_config_type_invalid", {"delimiter": "a" * 65536}, "Parser config exceeds size limit (max 65,535 characters)"),
|
||||||
],
|
],
|
||||||
ids=[
|
ids=[
|
||||||
"auto_keywords_min_limit",
|
"auto_keywords_min_limit",
|
||||||
|
Loading…
x
Reference in New Issue
Block a user