diff --git a/api/core/model_runtime/model_providers/__base/ai_model.py b/api/core/model_runtime/model_providers/__base/ai_model.py index 34a7375493..cd243ca223 100644 --- a/api/core/model_runtime/model_providers/__base/ai_model.py +++ b/api/core/model_runtime/model_providers/__base/ai_model.py @@ -3,8 +3,6 @@ import os from abc import ABC, abstractmethod from typing import Optional -import yaml - from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE from core.model_runtime.entities.model_entities import ( @@ -18,6 +16,7 @@ from core.model_runtime.entities.model_entities import ( ) from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer +from core.tools.utils.yaml_utils import load_yaml_file from core.utils.position_helper import get_position_map, sort_by_position_map @@ -154,8 +153,7 @@ class AIModel(ABC): # traverse all model_schema_yaml_paths for model_schema_yaml_path in model_schema_yaml_paths: # read yaml data from yaml file - with open(model_schema_yaml_path, encoding='utf-8') as f: - yaml_data = yaml.safe_load(f) + yaml_data = load_yaml_file(model_schema_yaml_path, ignore_error=True) new_parameter_rules = [] for parameter_rule in yaml_data.get('parameter_rules', []): diff --git a/api/core/model_runtime/model_providers/__base/model_provider.py b/api/core/model_runtime/model_providers/__base/model_provider.py index 7c839a9672..9ab78b7610 100644 --- a/api/core/model_runtime/model_providers/__base/model_provider.py +++ b/api/core/model_runtime/model_providers/__base/model_provider.py @@ -1,11 +1,10 @@ import os from abc import ABC, abstractmethod -import yaml - from core.model_runtime.entities.model_entities import AIModelEntity, ModelType from core.model_runtime.entities.provider_entities import ProviderEntity from core.model_runtime.model_providers.__base.ai_model import AIModel +from core.tools.utils.yaml_utils import load_yaml_file from core.utils.module_import_helper import get_subclasses_from_module, import_module_from_source @@ -44,10 +43,7 @@ class ModelProvider(ABC): # read provider schema from yaml file yaml_path = os.path.join(current_path, f'{provider_name}.yaml') - yaml_data = {} - if os.path.exists(yaml_path): - with open(yaml_path, encoding='utf-8') as f: - yaml_data = yaml.safe_load(f) + yaml_data = load_yaml_file(yaml_path, ignore_error=True) try: # yaml_data to entity diff --git a/api/core/tools/provider/builtin_tool_provider.py b/api/core/tools/provider/builtin_tool_provider.py index c2178cdd40..76ee473beb 100644 --- a/api/core/tools/provider/builtin_tool_provider.py +++ b/api/core/tools/provider/builtin_tool_provider.py @@ -2,8 +2,6 @@ from abc import abstractmethod from os import listdir, path from typing import Any -from yaml import FullLoader, load - from core.tools.entities.tool_entities import ToolParameter, ToolProviderCredentials, ToolProviderType from core.tools.entities.user_entities import UserToolProviderCredentials from core.tools.errors import ( @@ -15,6 +13,7 @@ from core.tools.errors import ( from core.tools.provider.tool_provider import ToolProviderController from core.tools.tool.builtin_tool import BuiltinTool from core.tools.tool.tool import Tool +from core.tools.utils.yaml_utils import load_yaml_file from core.utils.module_import_helper import load_single_subclass_from_source @@ -28,10 +27,9 @@ class BuiltinToolProviderController(ToolProviderController): provider = self.__class__.__module__.split('.')[-1] yaml_path = path.join(path.dirname(path.realpath(__file__)), 'builtin', provider, f'{provider}.yaml') try: - with open(yaml_path, 'rb') as f: - provider_yaml = load(f.read(), FullLoader) - except: - raise ToolProviderNotFoundError(f'can not load provider yaml for {provider}') + provider_yaml = load_yaml_file(yaml_path) + except Exception as e: + raise ToolProviderNotFoundError(f'can not load provider yaml for {provider}: {e}') if 'credentials_for_provider' in provider_yaml and provider_yaml['credentials_for_provider'] is not None: # set credentials name @@ -58,18 +56,18 @@ class BuiltinToolProviderController(ToolProviderController): tool_files = list(filter(lambda x: x.endswith(".yaml") and not x.startswith("__"), listdir(tool_path))) tools = [] for tool_file in tool_files: - with open(path.join(tool_path, tool_file), encoding='utf-8') as f: - # get tool name - tool_name = tool_file.split(".")[0] - tool = load(f.read(), FullLoader) - # get tool class, import the module - assistant_tool_class = load_single_subclass_from_source( - module_name=f'core.tools.provider.builtin.{provider}.tools.{tool_name}', - script_path=path.join(path.dirname(path.realpath(__file__)), - 'builtin', provider, 'tools', f'{tool_name}.py'), - parent_type=BuiltinTool) - tool["identity"]["provider"] = provider - tools.append(assistant_tool_class(**tool)) + # get tool name + tool_name = tool_file.split(".")[0] + tool = load_yaml_file(path.join(tool_path, tool_file)) + + # get tool class, import the module + assistant_tool_class = load_single_subclass_from_source( + module_name=f'core.tools.provider.builtin.{provider}.tools.{tool_name}', + script_path=path.join(path.dirname(path.realpath(__file__)), + 'builtin', provider, 'tools', f'{tool_name}.py'), + parent_type=BuiltinTool) + tool["identity"]["provider"] = provider + tools.append(assistant_tool_class(**tool)) self.tools = tools return tools diff --git a/api/core/tools/utils/configuration.py b/api/core/tools/utils/configuration.py index b68efad124..90b39a7fc9 100644 --- a/api/core/tools/utils/configuration.py +++ b/api/core/tools/utils/configuration.py @@ -23,7 +23,7 @@ class ToolConfigurationManager(BaseModel): deep copy credentials """ return deepcopy(credentials) - + def encrypt_tool_credentials(self, credentials: dict[str, str]) -> dict[str, str]: """ encrypt tool credentials with tenant id @@ -39,9 +39,9 @@ class ToolConfigurationManager(BaseModel): if field_name in credentials: encrypted = encrypter.encrypt_token(self.tenant_id, credentials[field_name]) credentials[field_name] = encrypted - + return credentials - + def mask_tool_credentials(self, credentials: dict[str, Any]) -> dict[str, Any]: """ mask tool credentials @@ -58,7 +58,7 @@ class ToolConfigurationManager(BaseModel): if len(credentials[field_name]) > 6: credentials[field_name] = \ credentials[field_name][:2] + \ - '*' * (len(credentials[field_name]) - 4) +\ + '*' * (len(credentials[field_name]) - 4) + \ credentials[field_name][-2:] else: credentials[field_name] = '*' * len(credentials[field_name]) @@ -72,7 +72,7 @@ class ToolConfigurationManager(BaseModel): return a deep copy of credentials with decrypted values """ cache = ToolProviderCredentialsCache( - tenant_id=self.tenant_id, + tenant_id=self.tenant_id, identity_id=f'{self.provider_controller.app_type.value}.{self.provider_controller.identity.name}', cache_type=ToolProviderCredentialsCacheType.PROVIDER ) @@ -92,10 +92,10 @@ class ToolConfigurationManager(BaseModel): cache.set(credentials) return credentials - + def delete_tool_credentials_cache(self): cache = ToolProviderCredentialsCache( - tenant_id=self.tenant_id, + tenant_id=self.tenant_id, identity_id=f'{self.provider_controller.app_type.value}.{self.provider_controller.identity.name}', cache_type=ToolProviderCredentialsCacheType.PROVIDER ) @@ -116,7 +116,7 @@ class ToolParameterConfigurationManager(BaseModel): deep copy parameters """ return deepcopy(parameters) - + def _merge_parameters(self) -> list[ToolParameter]: """ merge parameters @@ -139,7 +139,7 @@ class ToolParameterConfigurationManager(BaseModel): current_parameters.append(runtime_parameter) return current_parameters - + def mask_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]: """ mask tool parameters @@ -157,13 +157,13 @@ class ToolParameterConfigurationManager(BaseModel): if len(parameters[parameter.name]) > 6: parameters[parameter.name] = \ parameters[parameter.name][:2] + \ - '*' * (len(parameters[parameter.name]) - 4) +\ + '*' * (len(parameters[parameter.name]) - 4) + \ parameters[parameter.name][-2:] else: parameters[parameter.name] = '*' * len(parameters[parameter.name]) return parameters - + def encrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]: """ encrypt tool parameters with tenant id @@ -180,9 +180,9 @@ class ToolParameterConfigurationManager(BaseModel): if parameter.name in parameters: encrypted = encrypter.encrypt_token(self.tenant_id, parameters[parameter.name]) parameters[parameter.name] = encrypted - + return parameters - + def decrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]: """ decrypt tool parameters with tenant id @@ -190,7 +190,7 @@ class ToolParameterConfigurationManager(BaseModel): return a deep copy of parameters with decrypted values """ cache = ToolParameterCache( - tenant_id=self.tenant_id, + tenant_id=self.tenant_id, provider=f'{self.provider_type}.{self.provider_name}', tool_name=self.tool_runtime.identity.name, cache_type=ToolParameterCacheType.PARAMETER, @@ -212,15 +212,15 @@ class ToolParameterConfigurationManager(BaseModel): parameters[parameter.name] = encrypter.decrypt_token(self.tenant_id, parameters[parameter.name]) except: pass - + if has_secret_input: cache.set(parameters) return parameters - + def delete_tool_parameters_cache(self): cache = ToolParameterCache( - tenant_id=self.tenant_id, + tenant_id=self.tenant_id, provider=f'{self.provider_type}.{self.provider_name}', tool_name=self.tool_runtime.identity.name, cache_type=ToolParameterCacheType.PARAMETER, diff --git a/api/core/tools/utils/yaml_utils.py b/api/core/tools/utils/yaml_utils.py new file mode 100644 index 0000000000..22e4d3d128 --- /dev/null +++ b/api/core/tools/utils/yaml_utils.py @@ -0,0 +1,34 @@ +import logging +import os + +import yaml +from yaml import YAMLError + + +def load_yaml_file(file_path: str, ignore_error: bool = False) -> dict: + """ + Safe loading a YAML file to a dict + :param file_path: the path of the YAML file + :param ignore_error: + if True, return empty dict if error occurs and the error will be logged in warning level + if False, raise error if error occurs + :return: a dict of the YAML content + """ + try: + if not file_path or not os.path.exists(file_path): + raise FileNotFoundError(f'Failed to load YAML file {file_path}: file not found') + + with open(file_path, encoding='utf-8') as file: + try: + return yaml.safe_load(file) + except Exception as e: + raise YAMLError(f'Failed to load YAML file {file_path}: {e}') + except FileNotFoundError as e: + logging.debug(f'Failed to load YAML file {file_path}: {e}') + return {} + except Exception as e: + if ignore_error: + logging.warning(f'Failed to load YAML file {file_path}: {e}') + return {} + else: + raise e diff --git a/api/core/utils/position_helper.py b/api/core/utils/position_helper.py index e038390e09..689ab194a7 100644 --- a/api/core/utils/position_helper.py +++ b/api/core/utils/position_helper.py @@ -1,10 +1,9 @@ -import logging import os from collections import OrderedDict from collections.abc import Callable from typing import Any, AnyStr -import yaml +from core.tools.utils.yaml_utils import load_yaml_file def get_position_map( @@ -17,21 +16,15 @@ def get_position_map( :param file_name: the YAML file name, default to '_position.yaml' :return: a dict with name as key and index as value """ - try: - position_file_name = os.path.join(folder_path, file_name) - if not os.path.exists(position_file_name): - return {} - - with open(position_file_name, encoding='utf-8') as f: - positions = yaml.safe_load(f) - position_map = {} - for index, name in enumerate(positions): - if name and isinstance(name, str): - position_map[name.strip()] = index - return position_map - except: - logging.warning(f'Failed to load the YAML position file {folder_path}/{file_name}.') - return {} + position_file_name = os.path.join(folder_path, file_name) + positions = load_yaml_file(position_file_name, ignore_error=True) + position_map = {} + index = 0 + for _, name in enumerate(positions): + if name and isinstance(name, str): + position_map[name.strip()] = index + index += 1 + return position_map def sort_by_position_map( diff --git a/api/pyproject.toml b/api/pyproject.toml index 070195b48a..ac5b4b1561 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -14,6 +14,7 @@ select = [ "I", # isort rules "UP", # pyupgrade rules "RUF019", # unnecessary-key-check + "S506", # unsafe-yaml-load ] ignore = [ "F403", # undefined-local-with-import-star diff --git a/api/tests/unit_tests/utils/__init__.py b/api/tests/unit_tests/utils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/utils/position_helper/test_position_helper.py b/api/tests/unit_tests/utils/position_helper/test_position_helper.py new file mode 100644 index 0000000000..b7442d0d93 --- /dev/null +++ b/api/tests/unit_tests/utils/position_helper/test_position_helper.py @@ -0,0 +1,34 @@ +from textwrap import dedent + +import pytest + +from core.utils.position_helper import get_position_map + + +@pytest.fixture +def prepare_example_positions_yaml(tmp_path, monkeypatch) -> str: + monkeypatch.chdir(tmp_path) + tmp_path.joinpath("example_positions.yaml").write_text(dedent( + """\ + - first + - second + # - commented + - third + + - 9999999999999 + - forth + """)) + return str(tmp_path) + + +def test_position_helper(prepare_example_positions_yaml): + position_map = get_position_map( + folder_path=prepare_example_positions_yaml, + file_name='example_positions.yaml') + assert len(position_map) == 4 + assert position_map == { + 'first': 0, + 'second': 1, + 'third': 2, + 'forth': 3, + } diff --git a/api/tests/unit_tests/utils/yaml/__init__.py b/api/tests/unit_tests/utils/yaml/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/utils/yaml/test_yaml_utils.py b/api/tests/unit_tests/utils/yaml/test_yaml_utils.py new file mode 100644 index 0000000000..446588cde1 --- /dev/null +++ b/api/tests/unit_tests/utils/yaml/test_yaml_utils.py @@ -0,0 +1,74 @@ +from textwrap import dedent + +import pytest +from yaml import YAMLError + +from core.tools.utils.yaml_utils import load_yaml_file + +EXAMPLE_YAML_FILE = 'example_yaml.yaml' +INVALID_YAML_FILE = 'invalid_yaml.yaml' +NON_EXISTING_YAML_FILE = 'non_existing_file.yaml' + + +@pytest.fixture +def prepare_example_yaml_file(tmp_path, monkeypatch) -> str: + monkeypatch.chdir(tmp_path) + file_path = tmp_path.joinpath(EXAMPLE_YAML_FILE) + file_path.write_text(dedent( + """\ + address: + city: Example City + country: Example Country + age: 30 + gender: male + languages: + - Python + - Java + - C++ + empty_key: + """)) + return str(file_path) + + +@pytest.fixture +def prepare_invalid_yaml_file(tmp_path, monkeypatch) -> str: + monkeypatch.chdir(tmp_path) + file_path = tmp_path.joinpath(INVALID_YAML_FILE) + file_path.write_text(dedent( + """\ + address: + city: Example City + country: Example Country + age: 30 + gender: male + languages: + - Python + - Java + - C++ + """)) + return str(file_path) + + +def test_load_yaml_non_existing_file(): + assert load_yaml_file(file_path=NON_EXISTING_YAML_FILE) == {} + assert load_yaml_file(file_path='') == {} + + +def test_load_valid_yaml_file(prepare_example_yaml_file): + yaml_data = load_yaml_file(file_path=prepare_example_yaml_file) + assert len(yaml_data) > 0 + assert yaml_data['age'] == 30 + assert yaml_data['gender'] == 'male' + assert yaml_data['address']['city'] == 'Example City' + assert set(yaml_data['languages']) == {'Python', 'Java', 'C++'} + assert yaml_data.get('empty_key') is None + assert yaml_data.get('non_existed_key') is None + + +def test_load_invalid_yaml_file(prepare_invalid_yaml_file): + # yaml syntax error + with pytest.raises(YAMLError): + load_yaml_file(file_path=prepare_invalid_yaml_file) + + # ignore error + assert load_yaml_file(file_path=prepare_invalid_yaml_file, ignore_error=True) == {}