feat(file_factory): Standardize custom file type into known types (#11028)

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN- 2024-11-24 15:29:43 +08:00 committed by GitHub
parent 03ba4bc760
commit 8565c18e84
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,10 +1,11 @@
import mimetypes import mimetypes
from collections.abc import Callable, Mapping, Sequence from collections.abc import Callable, Mapping, Sequence
from typing import Any from typing import Any, cast
import httpx import httpx
from sqlalchemy import select from sqlalchemy import select
from constants import AUDIO_EXTENSIONS, DOCUMENT_EXTENSIONS, IMAGE_EXTENSIONS, VIDEO_EXTENSIONS
from core.file import File, FileBelongsTo, FileTransferMethod, FileType, FileUploadConfig from core.file import File, FileBelongsTo, FileTransferMethod, FileType, FileUploadConfig
from core.helper import ssrf_proxy from core.helper import ssrf_proxy
from extensions.ext_database import db from extensions.ext_database import db
@ -71,7 +72,12 @@ def build_from_mapping(
transfer_method=transfer_method, transfer_method=transfer_method,
) )
if not _is_file_valid_with_config(file=file, config=config): if not _is_file_valid_with_config(
input_file_type=mapping.get("type", FileType.CUSTOM),
file_extension=file.extension,
file_transfer_method=file.transfer_method,
config=config,
):
raise ValueError(f"File validation failed for file: {file.filename}") raise ValueError(f"File validation failed for file: {file.filename}")
return file return file
@ -114,17 +120,18 @@ def _build_from_local_file(
tenant_id: str, tenant_id: str,
transfer_method: FileTransferMethod, transfer_method: FileTransferMethod,
) -> File: ) -> File:
file_type = FileType.value_of(mapping.get("type"))
stmt = select(UploadFile).where( stmt = select(UploadFile).where(
UploadFile.id == mapping.get("upload_file_id"), UploadFile.id == mapping.get("upload_file_id"),
UploadFile.tenant_id == tenant_id, UploadFile.tenant_id == tenant_id,
) )
row = db.session.scalar(stmt) row = db.session.scalar(stmt)
if row is None: if row is None:
raise ValueError("Invalid upload file") raise ValueError("Invalid upload file")
file_type = FileType(mapping.get("type"))
file_type = _standardize_file_type(file_type, extension="." + row.extension, mime_type=row.mime_type)
return File( return File(
id=mapping.get("id"), id=mapping.get("id"),
filename=row.name, filename=row.name,
@ -152,11 +159,14 @@ def _build_from_remote_url(
mime_type, filename, file_size = _get_remote_file_info(url) mime_type, filename, file_size = _get_remote_file_info(url)
extension = mimetypes.guess_extension(mime_type) or "." + filename.split(".")[-1] if "." in filename else ".bin" extension = mimetypes.guess_extension(mime_type) or "." + filename.split(".")[-1] if "." in filename else ".bin"
file_type = FileType(mapping.get("type"))
file_type = _standardize_file_type(file_type, extension=extension, mime_type=mime_type)
return File( return File(
id=mapping.get("id"), id=mapping.get("id"),
filename=filename, filename=filename,
tenant_id=tenant_id, tenant_id=tenant_id,
type=FileType.value_of(mapping.get("type")), type=file_type,
transfer_method=transfer_method, transfer_method=transfer_method,
remote_url=url, remote_url=url,
mime_type=mime_type, mime_type=mime_type,
@ -171,6 +181,7 @@ def _get_remote_file_info(url: str):
mime_type = mimetypes.guess_type(filename)[0] or "" mime_type = mimetypes.guess_type(filename)[0] or ""
resp = ssrf_proxy.head(url, follow_redirects=True) resp = ssrf_proxy.head(url, follow_redirects=True)
resp = cast(httpx.Response, resp)
if resp.status_code == httpx.codes.OK: if resp.status_code == httpx.codes.OK:
if content_disposition := resp.headers.get("Content-Disposition"): if content_disposition := resp.headers.get("Content-Disposition"):
filename = str(content_disposition.split("filename=")[-1].strip('"')) filename = str(content_disposition.split("filename=")[-1].strip('"'))
@ -180,20 +191,6 @@ def _get_remote_file_info(url: str):
return mime_type, filename, file_size return mime_type, filename, file_size
def _get_file_type_by_mimetype(mime_type: str) -> FileType:
if "image" in mime_type:
file_type = FileType.IMAGE
elif "video" in mime_type:
file_type = FileType.VIDEO
elif "audio" in mime_type:
file_type = FileType.AUDIO
elif "text" in mime_type or "pdf" in mime_type:
file_type = FileType.DOCUMENT
else:
file_type = FileType.CUSTOM
return file_type
def _build_from_tool_file( def _build_from_tool_file(
*, *,
mapping: Mapping[str, Any], mapping: Mapping[str, Any],
@ -213,7 +210,8 @@ def _build_from_tool_file(
raise ValueError(f"ToolFile {mapping.get('tool_file_id')} not found") raise ValueError(f"ToolFile {mapping.get('tool_file_id')} not found")
extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin" extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin"
file_type = mapping.get("type", _get_file_type_by_mimetype(tool_file.mimetype)) file_type = FileType(mapping.get("type"))
file_type = _standardize_file_type(file_type, extension=extension, mime_type=tool_file.mimetype)
return File( return File(
id=mapping.get("id"), id=mapping.get("id"),
@ -229,18 +227,68 @@ def _build_from_tool_file(
) )
def _is_file_valid_with_config(*, file: File, config: FileUploadConfig) -> bool: def _is_file_valid_with_config(
if config.allowed_file_types and file.type not in config.allowed_file_types and file.type != FileType.CUSTOM: *,
input_file_type: str,
file_extension: str,
file_transfer_method: FileTransferMethod,
config: FileUploadConfig,
) -> bool:
if (
config.allowed_file_types
and input_file_type not in config.allowed_file_types
and input_file_type != FileType.CUSTOM
):
return False return False
if config.allowed_file_extensions and file.extension not in config.allowed_file_extensions: if config.allowed_file_extensions and file_extension not in config.allowed_file_extensions:
return False return False
if config.allowed_file_upload_methods and file.transfer_method not in config.allowed_file_upload_methods: if config.allowed_file_upload_methods and file_transfer_method not in config.allowed_file_upload_methods:
return False return False
if file.type == FileType.IMAGE and config.image_config: if input_file_type == FileType.IMAGE and config.image_config:
if config.image_config.transfer_methods and file.transfer_method not in config.image_config.transfer_methods: if config.image_config.transfer_methods and file_transfer_method not in config.image_config.transfer_methods:
return False return False
return True return True
def _standardize_file_type(file_type: FileType, /, *, extension: str = "", mime_type: str = "") -> FileType:
"""
If custom type, try to guess the file type by extension and mime_type.
"""
if file_type != FileType.CUSTOM:
return FileType(file_type)
guessed_type = None
if extension:
guessed_type = _get_file_type_by_extension(extension)
if guessed_type is None and mime_type:
guessed_type = _get_file_type_by_mimetype(mime_type)
return guessed_type or FileType.CUSTOM
def _get_file_type_by_extension(extension: str) -> FileType | None:
extension = extension.lstrip(".")
if extension in IMAGE_EXTENSIONS:
return FileType.IMAGE
elif extension in VIDEO_EXTENSIONS:
return FileType.VIDEO
elif extension in AUDIO_EXTENSIONS:
return FileType.AUDIO
elif extension in DOCUMENT_EXTENSIONS:
return FileType.DOCUMENT
def _get_file_type_by_mimetype(mime_type: str) -> FileType | None:
if "image" in mime_type:
file_type = FileType.IMAGE
elif "video" in mime_type:
file_type = FileType.VIDEO
elif "audio" in mime_type:
file_type = FileType.AUDIO
elif "text" in mime_type or "pdf" in mime_type:
file_type = FileType.DOCUMENT
else:
file_type = FileType.CUSTOM
return file_type