From e5397c5ec2e3ec3d867cc7194f9a771266afebaf Mon Sep 17 00:00:00 2001 From: -LAN- Date: Thu, 31 Oct 2024 15:16:34 +0800 Subject: [PATCH] feat(app_dsl_service): enhance error handling and DSL version management (#10108) --- api/models/model.py | 2 +- api/services/app_dsl_service/__init__.py | 3 + api/services/app_dsl_service/exc.py | 34 ++++ .../service.py} | 161 +++++++++++------- .../app_dsl_service/test_app_dsl_service.py | 41 +++++ 5 files changed, 178 insertions(+), 63 deletions(-) create mode 100644 api/services/app_dsl_service/__init__.py create mode 100644 api/services/app_dsl_service/exc.py rename api/services/{app_dsl_service.py => app_dsl_service/service.py} (75%) create mode 100644 api/tests/unit_tests/services/app_dsl_service/test_app_dsl_service.py diff --git a/api/models/model.py b/api/models/model.py index 3bd5886d75..20fbee29aa 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -396,7 +396,7 @@ class AppModelConfig(db.Model): "file_upload": self.file_upload_dict, } - def from_model_config_dict(self, model_config: dict): + def from_model_config_dict(self, model_config: Mapping[str, Any]): self.opening_statement = model_config.get("opening_statement") self.suggested_questions = ( json.dumps(model_config["suggested_questions"]) if model_config.get("suggested_questions") else None diff --git a/api/services/app_dsl_service/__init__.py b/api/services/app_dsl_service/__init__.py new file mode 100644 index 0000000000..9fc988ffb3 --- /dev/null +++ b/api/services/app_dsl_service/__init__.py @@ -0,0 +1,3 @@ +from .service import AppDslService + +__all__ = ["AppDslService"] diff --git a/api/services/app_dsl_service/exc.py b/api/services/app_dsl_service/exc.py new file mode 100644 index 0000000000..6da4b1938f --- /dev/null +++ b/api/services/app_dsl_service/exc.py @@ -0,0 +1,34 @@ +class DSLVersionNotSupportedError(ValueError): + """Raised when the imported DSL version is not supported by the current Dify version.""" + + +class InvalidYAMLFormatError(ValueError): + """Raised when the provided YAML format is invalid.""" + + +class MissingAppDataError(ValueError): + """Raised when the app data is missing in the provided DSL.""" + + +class InvalidAppModeError(ValueError): + """Raised when the app mode is invalid.""" + + +class MissingWorkflowDataError(ValueError): + """Raised when the workflow data is missing in the provided DSL.""" + + +class MissingModelConfigError(ValueError): + """Raised when the model config data is missing in the provided DSL.""" + + +class FileSizeLimitExceededError(ValueError): + """Raised when the file size exceeds the allowed limit.""" + + +class EmptyContentError(ValueError): + """Raised when the content fetched from the URL is empty.""" + + +class ContentDecodingError(ValueError): + """Raised when there is an error decoding the content.""" diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service/service.py similarity index 75% rename from api/services/app_dsl_service.py rename to api/services/app_dsl_service/service.py index 750d0a8cd2..2ff774db5f 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service/service.py @@ -1,8 +1,11 @@ import logging +from collections.abc import Mapping +from typing import Any -import httpx -import yaml # type: ignore +import yaml +from packaging import version +from core.helper import ssrf_proxy from events.app_event import app_model_config_was_updated, app_was_created from extensions.ext_database import db from factories import variable_factory @@ -11,6 +14,18 @@ from models.model import App, AppMode, AppModelConfig from models.workflow import Workflow from services.workflow_service import WorkflowService +from .exc import ( + ContentDecodingError, + DSLVersionNotSupportedError, + EmptyContentError, + FileSizeLimitExceededError, + InvalidAppModeError, + InvalidYAMLFormatError, + MissingAppDataError, + MissingModelConfigError, + MissingWorkflowDataError, +) + logger = logging.getLogger(__name__) current_dsl_version = "0.1.2" @@ -30,32 +45,21 @@ class AppDslService: :param args: request args :param account: Account instance """ - try: - max_size = 10 * 1024 * 1024 # 10MB - timeout = httpx.Timeout(10.0) - with httpx.stream("GET", url.strip(), follow_redirects=True, timeout=timeout) as response: - response.raise_for_status() - total_size = 0 - content = b"" - for chunk in response.iter_bytes(): - total_size += len(chunk) - if total_size > max_size: - raise ValueError("File size exceeds the limit of 10MB") - content += chunk - except httpx.HTTPStatusError as http_err: - raise ValueError(f"HTTP error occurred: {http_err}") - except httpx.RequestError as req_err: - raise ValueError(f"Request error occurred: {req_err}") - except Exception as e: - raise ValueError(f"Failed to fetch DSL from URL: {e}") + max_size = 10 * 1024 * 1024 # 10MB + response = ssrf_proxy.get(url.strip(), follow_redirects=True, timeout=(10, 10)) + response.raise_for_status() + content = response.content + + if len(content) > max_size: + raise FileSizeLimitExceededError("File size exceeds the limit of 10MB") if not content: - raise ValueError("Empty content from url") + raise EmptyContentError("Empty content from url") try: data = content.decode("utf-8") except UnicodeDecodeError as e: - raise ValueError(f"Error decoding content: {e}") + raise ContentDecodingError(f"Error decoding content: {e}") return cls.import_and_create_new_app(tenant_id, data, args, account) @@ -71,14 +75,14 @@ class AppDslService: try: import_data = yaml.safe_load(data) except yaml.YAMLError: - raise ValueError("Invalid YAML format in data argument.") + raise InvalidYAMLFormatError("Invalid YAML format in data argument.") # check or repair dsl version - import_data = cls._check_or_fix_dsl(import_data) + import_data = _check_or_fix_dsl(import_data) app_data = import_data.get("app") if not app_data: - raise ValueError("Missing app in data argument") + raise MissingAppDataError("Missing app in data argument") # get app basic info name = args.get("name") or app_data.get("name") @@ -90,11 +94,18 @@ class AppDslService: # import dsl and create app app_mode = AppMode.value_of(app_data.get("mode")) + if app_mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: + workflow_data = import_data.get("workflow") + if not workflow_data or not isinstance(workflow_data, dict): + raise MissingWorkflowDataError( + "Missing workflow in data argument when app mode is advanced-chat or workflow" + ) + app = cls._import_and_create_new_workflow_based_app( tenant_id=tenant_id, app_mode=app_mode, - workflow_data=import_data.get("workflow"), + workflow_data=workflow_data, account=account, name=name, description=description, @@ -104,10 +115,16 @@ class AppDslService: use_icon_as_answer_icon=use_icon_as_answer_icon, ) elif app_mode in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.COMPLETION}: + model_config = import_data.get("model_config") + if not model_config or not isinstance(model_config, dict): + raise MissingModelConfigError( + "Missing model_config in data argument when app mode is chat, agent-chat or completion" + ) + app = cls._import_and_create_new_model_config_based_app( tenant_id=tenant_id, app_mode=app_mode, - model_config_data=import_data.get("model_config"), + model_config_data=model_config, account=account, name=name, description=description, @@ -117,7 +134,7 @@ class AppDslService: use_icon_as_answer_icon=use_icon_as_answer_icon, ) else: - raise ValueError("Invalid app mode") + raise InvalidAppModeError("Invalid app mode") return app @@ -132,26 +149,32 @@ class AppDslService: try: import_data = yaml.safe_load(data) except yaml.YAMLError: - raise ValueError("Invalid YAML format in data argument.") + raise InvalidYAMLFormatError("Invalid YAML format in data argument.") # check or repair dsl version - import_data = cls._check_or_fix_dsl(import_data) + import_data = _check_or_fix_dsl(import_data) app_data = import_data.get("app") if not app_data: - raise ValueError("Missing app in data argument") + raise MissingAppDataError("Missing app in data argument") # import dsl and overwrite app app_mode = AppMode.value_of(app_data.get("mode")) if app_mode not in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: - raise ValueError("Only support import workflow in advanced-chat or workflow app.") + raise InvalidAppModeError("Only support import workflow in advanced-chat or workflow app.") if app_data.get("mode") != app_model.mode: raise ValueError(f"App mode {app_data.get('mode')} is not matched with current app mode {app_mode.value}") + workflow_data = import_data.get("workflow") + if not workflow_data or not isinstance(workflow_data, dict): + raise MissingWorkflowDataError( + "Missing workflow in data argument when app mode is advanced-chat or workflow" + ) + return cls._import_and_overwrite_workflow_based_app( app_model=app_model, - workflow_data=import_data.get("workflow"), + workflow_data=workflow_data, account=account, ) @@ -186,35 +209,12 @@ class AppDslService: return yaml.dump(export_data, allow_unicode=True) - @classmethod - def _check_or_fix_dsl(cls, import_data: dict) -> dict: - """ - Check or fix dsl - - :param import_data: import data - """ - if not import_data.get("version"): - import_data["version"] = "0.1.0" - - if not import_data.get("kind") or import_data.get("kind") != "app": - import_data["kind"] = "app" - - if import_data.get("version") != current_dsl_version: - # Currently only one DSL version, so no difference checks or compatibility fixes will be performed. - logger.warning( - f"DSL version {import_data.get('version')} is not compatible " - f"with current version {current_dsl_version}, related to " - f"Dify version {dsl_to_dify_version_mapping.get(current_dsl_version)}." - ) - - return import_data - @classmethod def _import_and_create_new_workflow_based_app( cls, tenant_id: str, app_mode: AppMode, - workflow_data: dict, + workflow_data: Mapping[str, Any], account: Account, name: str, description: str, @@ -238,7 +238,9 @@ class AppDslService: :param use_icon_as_answer_icon: use app icon as answer icon """ if not workflow_data: - raise ValueError("Missing workflow in data argument when app mode is advanced-chat or workflow") + raise MissingWorkflowDataError( + "Missing workflow in data argument when app mode is advanced-chat or workflow" + ) app = cls._create_app( tenant_id=tenant_id, @@ -277,7 +279,7 @@ class AppDslService: @classmethod def _import_and_overwrite_workflow_based_app( - cls, app_model: App, workflow_data: dict, account: Account + cls, app_model: App, workflow_data: Mapping[str, Any], account: Account ) -> Workflow: """ Import app dsl and overwrite workflow based app @@ -287,7 +289,9 @@ class AppDslService: :param account: Account instance """ if not workflow_data: - raise ValueError("Missing workflow in data argument when app mode is advanced-chat or workflow") + raise MissingWorkflowDataError( + "Missing workflow in data argument when app mode is advanced-chat or workflow" + ) # fetch draft workflow by app_model workflow_service = WorkflowService() @@ -323,7 +327,7 @@ class AppDslService: cls, tenant_id: str, app_mode: AppMode, - model_config_data: dict, + model_config_data: Mapping[str, Any], account: Account, name: str, description: str, @@ -345,7 +349,9 @@ class AppDslService: :param icon_background: app icon background """ if not model_config_data: - raise ValueError("Missing model_config in data argument when app mode is chat, agent-chat or completion") + raise MissingModelConfigError( + "Missing model_config in data argument when app mode is chat, agent-chat or completion" + ) app = cls._create_app( tenant_id=tenant_id, @@ -448,3 +454,34 @@ class AppDslService: raise ValueError("Missing app configuration, please check.") export_data["model_config"] = app_model_config.to_dict() + + +def _check_or_fix_dsl(import_data: dict[str, Any]) -> Mapping[str, Any]: + """ + Check or fix dsl + + :param import_data: import data + :raises DSLVersionNotSupportedError: if the imported DSL version is newer than the current version + """ + if not import_data.get("version"): + import_data["version"] = "0.1.0" + + if not import_data.get("kind") or import_data.get("kind") != "app": + import_data["kind"] = "app" + + imported_version = import_data.get("version") + if imported_version != current_dsl_version: + if imported_version and version.parse(imported_version) > version.parse(current_dsl_version): + raise DSLVersionNotSupportedError( + f"The imported DSL version {imported_version} is newer than " + f"the current supported version {current_dsl_version}. " + f"Please upgrade your Dify instance to import this configuration." + ) + else: + logger.warning( + f"DSL version {imported_version} is older than " + f"the current version {current_dsl_version}. " + f"This may cause compatibility issues." + ) + + return import_data diff --git a/api/tests/unit_tests/services/app_dsl_service/test_app_dsl_service.py b/api/tests/unit_tests/services/app_dsl_service/test_app_dsl_service.py new file mode 100644 index 0000000000..7982e7eed1 --- /dev/null +++ b/api/tests/unit_tests/services/app_dsl_service/test_app_dsl_service.py @@ -0,0 +1,41 @@ +import pytest +from packaging import version + +from services.app_dsl_service import AppDslService +from services.app_dsl_service.exc import DSLVersionNotSupportedError +from services.app_dsl_service.service import _check_or_fix_dsl, current_dsl_version + + +class TestAppDSLService: + def test_check_or_fix_dsl_missing_version(self): + import_data = {} + result = _check_or_fix_dsl(import_data) + assert result["version"] == "0.1.0" + assert result["kind"] == "app" + + def test_check_or_fix_dsl_missing_kind(self): + import_data = {"version": "0.1.0"} + result = _check_or_fix_dsl(import_data) + assert result["kind"] == "app" + + def test_check_or_fix_dsl_older_version(self): + import_data = {"version": "0.0.9", "kind": "app"} + result = _check_or_fix_dsl(import_data) + assert result["version"] == "0.0.9" + + def test_check_or_fix_dsl_current_version(self): + import_data = {"version": current_dsl_version, "kind": "app"} + result = _check_or_fix_dsl(import_data) + assert result["version"] == current_dsl_version + + def test_check_or_fix_dsl_newer_version(self): + current_version = version.parse(current_dsl_version) + newer_version = f"{current_version.major}.{current_version.minor + 1}.0" + import_data = {"version": newer_version, "kind": "app"} + with pytest.raises(DSLVersionNotSupportedError): + _check_or_fix_dsl(import_data) + + def test_check_or_fix_dsl_invalid_kind(self): + import_data = {"version": current_dsl_version, "kind": "invalid"} + result = _check_or_fix_dsl(import_data) + assert result["kind"] == "app"