From cac0d3c33e04790cc8874162f7d97827ef4783ba Mon Sep 17 00:00:00 2001 From: Arcaner <52057416+lrhan321@users.noreply.github.com> Date: Wed, 16 Apr 2025 19:21:50 +0800 Subject: [PATCH] fix: implement robust file type checks to align with existing logic (#17557) Co-authored-by: Bowen Liang --- api/core/app/apps/base_app_generator.py | 2 + api/core/app/apps/workflow/app_generator.py | 6 +- api/factories/file_factory.py | 43 +++- .../factories/test_build_from_mapping.py | 198 ++++++++++++++++++ 4 files changed, 243 insertions(+), 6 deletions(-) create mode 100644 api/tests/unit_tests/factories/test_build_from_mapping.py diff --git a/api/core/app/apps/base_app_generator.py b/api/core/app/apps/base_app_generator.py index 5d559b96d7..a83b75cc1a 100644 --- a/api/core/app/apps/base_app_generator.py +++ b/api/core/app/apps/base_app_generator.py @@ -17,6 +17,7 @@ class BaseAppGenerator: user_inputs: Optional[Mapping[str, Any]], variables: Sequence["VariableEntity"], tenant_id: str, + strict_type_validation: bool = False, ) -> Mapping[str, Any]: user_inputs = user_inputs or {} # Filter input variables from form configuration, handle required fields, default values, and option values @@ -37,6 +38,7 @@ class BaseAppGenerator: allowed_file_extensions=entity_dictionary[k].allowed_file_extensions, allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods, ), + strict_type_validation=strict_type_validation, ) for k, v in user_inputs.items() if isinstance(v, dict) and entity_dictionary[k].type == VariableEntityType.FILE diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index cc7bcdeee1..08986b16f0 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -92,6 +92,7 @@ class WorkflowAppGenerator(BaseAppGenerator): mappings=files, tenant_id=app_model.tenant_id, config=file_extra_config, + strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False, ) # convert to app config @@ -114,7 +115,10 @@ class WorkflowAppGenerator(BaseAppGenerator): app_config=app_config, file_upload_config=file_extra_config, inputs=self._prepare_user_inputs( - user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id + user_inputs=inputs, + variables=app_config.variables, + tenant_id=app_model.tenant_id, + strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False, ), files=list(system_files), user_id=user.id, diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py index b69621ba5b..52f119936f 100644 --- a/api/factories/file_factory.py +++ b/api/factories/file_factory.py @@ -52,6 +52,7 @@ def build_from_mapping( mapping: Mapping[str, Any], tenant_id: str, config: FileUploadConfig | None = None, + strict_type_validation: bool = False, ) -> File: transfer_method = FileTransferMethod.value_of(mapping.get("transfer_method")) @@ -69,6 +70,7 @@ def build_from_mapping( mapping=mapping, tenant_id=tenant_id, transfer_method=transfer_method, + strict_type_validation=strict_type_validation, ) if config and not _is_file_valid_with_config( @@ -87,12 +89,14 @@ def build_from_mappings( mappings: Sequence[Mapping[str, Any]], config: FileUploadConfig | None = None, tenant_id: str, + strict_type_validation: bool = False, ) -> Sequence[File]: files = [ build_from_mapping( mapping=mapping, tenant_id=tenant_id, config=config, + strict_type_validation=strict_type_validation, ) for mapping in mappings ] @@ -116,6 +120,7 @@ def _build_from_local_file( mapping: Mapping[str, Any], tenant_id: str, transfer_method: FileTransferMethod, + strict_type_validation: bool = False, ) -> File: upload_file_id = mapping.get("upload_file_id") if not upload_file_id: @@ -134,10 +139,16 @@ def _build_from_local_file( if row is None: raise ValueError("Invalid upload file") - file_type = _standardize_file_type(extension="." + row.extension, mime_type=row.mime_type) - if file_type.value != mapping.get("type", "custom"): + detected_file_type = _standardize_file_type(extension="." + row.extension, mime_type=row.mime_type) + specified_type = mapping.get("type", "custom") + + if strict_type_validation and detected_file_type.value != specified_type: raise ValueError("Detected file type does not match the specified type. Please verify the file.") + file_type = ( + FileType(specified_type) if specified_type and specified_type != FileType.CUSTOM.value else detected_file_type + ) + return File( id=mapping.get("id"), filename=row.name, @@ -158,6 +169,7 @@ def _build_from_remote_url( mapping: Mapping[str, Any], tenant_id: str, transfer_method: FileTransferMethod, + strict_type_validation: bool = False, ) -> File: upload_file_id = mapping.get("upload_file_id") if upload_file_id: @@ -174,10 +186,21 @@ def _build_from_remote_url( if upload_file is None: raise ValueError("Invalid upload file") - file_type = _standardize_file_type(extension="." + upload_file.extension, mime_type=upload_file.mime_type) - if file_type.value != mapping.get("type", "custom"): + detected_file_type = _standardize_file_type( + extension="." + upload_file.extension, mime_type=upload_file.mime_type + ) + + specified_type = mapping.get("type") + + if strict_type_validation and specified_type and detected_file_type.value != specified_type: raise ValueError("Detected file type does not match the specified type. Please verify the file.") + file_type = ( + FileType(specified_type) + if specified_type and specified_type != FileType.CUSTOM.value + else detected_file_type + ) + return File( id=mapping.get("id"), filename=upload_file.name, @@ -237,6 +260,7 @@ def _build_from_tool_file( mapping: Mapping[str, Any], tenant_id: str, transfer_method: FileTransferMethod, + strict_type_validation: bool = False, ) -> File: tool_file = ( db.session.query(ToolFile) @@ -252,7 +276,16 @@ def _build_from_tool_file( extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin" - file_type = _standardize_file_type(extension=extension, mime_type=tool_file.mimetype) + detected_file_type = _standardize_file_type(extension="." + extension, mime_type=tool_file.mimetype) + + specified_type = mapping.get("type") + + if strict_type_validation and specified_type and detected_file_type.value != specified_type: + raise ValueError("Detected file type does not match the specified type. Please verify the file.") + + file_type = ( + FileType(specified_type) if specified_type and specified_type != FileType.CUSTOM.value else detected_file_type + ) return File( id=mapping.get("id"), diff --git a/api/tests/unit_tests/factories/test_build_from_mapping.py b/api/tests/unit_tests/factories/test_build_from_mapping.py new file mode 100644 index 0000000000..48463a369e --- /dev/null +++ b/api/tests/unit_tests/factories/test_build_from_mapping.py @@ -0,0 +1,198 @@ +import uuid +from unittest.mock import MagicMock, patch + +import pytest +from httpx import Response + +from factories.file_factory import ( + File, + FileTransferMethod, + FileType, + FileUploadConfig, + build_from_mapping, +) +from models import ToolFile, UploadFile + +# Test Data +TEST_TENANT_ID = "test_tenant_id" +TEST_UPLOAD_FILE_ID = str(uuid.uuid4()) +TEST_TOOL_FILE_ID = str(uuid.uuid4()) +TEST_REMOTE_URL = "http://example.com/test.jpg" + +# Test Config +TEST_CONFIG = FileUploadConfig( + allowed_file_types=["image", "document"], + allowed_file_extensions=[".jpg", ".pdf"], + allowed_file_upload_methods=[FileTransferMethod.LOCAL_FILE, FileTransferMethod.TOOL_FILE], + number_limits=10, +) + + +# Fixtures +@pytest.fixture +def mock_upload_file(): + mock = MagicMock(spec=UploadFile) + mock.id = TEST_UPLOAD_FILE_ID + mock.tenant_id = TEST_TENANT_ID + mock.name = "test.jpg" + mock.extension = "jpg" + mock.mime_type = "image/jpeg" + mock.source_url = TEST_REMOTE_URL + mock.size = 1024 + mock.key = "test_key" + with patch("factories.file_factory.db.session.scalar", return_value=mock) as m: + yield m + + +@pytest.fixture +def mock_tool_file(): + mock = MagicMock(spec=ToolFile) + mock.id = TEST_TOOL_FILE_ID + mock.tenant_id = TEST_TENANT_ID + mock.name = "tool_file.pdf" + mock.file_key = "tool_file.pdf" + mock.mimetype = "application/pdf" + mock.original_url = "http://example.com/tool.pdf" + mock.size = 2048 + with patch("factories.file_factory.db.session.query") as mock_query: + mock_query.return_value.filter.return_value.first.return_value = mock + yield mock + + +@pytest.fixture +def mock_http_head(): + def _mock_response(filename, size, content_type): + return Response( + status_code=200, + headers={ + "Content-Disposition": f'attachment; filename="{filename}"', + "Content-Length": str(size), + "Content-Type": content_type, + }, + ) + + with patch("factories.file_factory.ssrf_proxy.head") as mock_head: + mock_head.return_value = _mock_response("remote_test.jpg", 2048, "image/jpeg") + yield mock_head + + +# Helper functions +def local_file_mapping(file_type="image"): + return { + "transfer_method": "local_file", + "upload_file_id": TEST_UPLOAD_FILE_ID, + "type": file_type, + } + + +def tool_file_mapping(file_type="document"): + return { + "transfer_method": "tool_file", + "tool_file_id": TEST_TOOL_FILE_ID, + "type": file_type, + } + + +# Tests +def test_build_from_mapping_backward_compatibility(mock_upload_file): + mapping = local_file_mapping(file_type="image") + file = build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID) + assert isinstance(file, File) + assert file.transfer_method == FileTransferMethod.LOCAL_FILE + assert file.type == FileType.IMAGE + assert file.related_id == TEST_UPLOAD_FILE_ID + + +@pytest.mark.parametrize( + ("file_type", "should_pass", "expected_error"), + [ + ("image", True, None), + ("document", False, "Detected file type does not match"), + ], +) +def test_build_from_local_file_strict_validation(mock_upload_file, file_type, should_pass, expected_error): + mapping = local_file_mapping(file_type=file_type) + if should_pass: + file = build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID, strict_type_validation=True) + assert file.type == FileType(file_type) + else: + with pytest.raises(ValueError, match=expected_error): + build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID, strict_type_validation=True) + + +@pytest.mark.parametrize( + ("file_type", "should_pass", "expected_error"), + [ + ("document", True, None), + ("image", False, "Detected file type does not match"), + ], +) +def test_build_from_tool_file_strict_validation(mock_tool_file, file_type, should_pass, expected_error): + """Strict type validation for tool_file.""" + mapping = tool_file_mapping(file_type=file_type) + if should_pass: + file = build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID, strict_type_validation=True) + assert file.type == FileType(file_type) + else: + with pytest.raises(ValueError, match=expected_error): + build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID, strict_type_validation=True) + + +def test_build_from_remote_url(mock_http_head): + mapping = { + "transfer_method": "remote_url", + "url": TEST_REMOTE_URL, + "type": "image", + } + file = build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID) + assert file.transfer_method == FileTransferMethod.REMOTE_URL + assert file.type == FileType.IMAGE + assert file.filename == "remote_test.jpg" + assert file.size == 2048 + + +def test_tool_file_not_found(): + """Test ToolFile not found in database.""" + with patch("factories.file_factory.db.session.query") as mock_query: + mock_query.return_value.filter.return_value.first.return_value = None + mapping = tool_file_mapping() + with pytest.raises(ValueError, match=f"ToolFile {TEST_TOOL_FILE_ID} not found"): + build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID) + + +def test_local_file_not_found(): + """Test UploadFile not found in database.""" + with patch("factories.file_factory.db.session.scalar", return_value=None): + mapping = local_file_mapping() + with pytest.raises(ValueError, match="Invalid upload file"): + build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID) + + +def test_build_without_type_specification(mock_upload_file): + """Test the situation where no file type is specified""" + mapping = { + "transfer_method": "local_file", + "upload_file_id": TEST_UPLOAD_FILE_ID, + # leave out the type + } + file = build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID) + # It should automatically infer the type as "image" based on the file extension + assert file.type == FileType.IMAGE + + +@pytest.mark.parametrize( + ("file_type", "should_pass", "expected_error"), + [ + ("image", True, None), + ("video", False, "File validation failed"), + ], +) +def test_file_validation_with_config(mock_upload_file, file_type, should_pass, expected_error): + """Test the validation of files and configurations""" + mapping = local_file_mapping(file_type=file_type) + if should_pass: + file = build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID, config=TEST_CONFIG) + assert file is not None + else: + with pytest.raises(ValueError, match=expected_error): + build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID, config=TEST_CONFIG)