improve: extract method for safe loading yaml file and avoid using PyYaml's FullLoader (#4031)

This commit is contained in:
Bowen Liang 2024-05-24 12:08:12 +08:00 committed by GitHub
parent 296887754f
commit 3fda2245a4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 190 additions and 62 deletions

View File

@ -3,8 +3,6 @@ import os
from abc import ABC, abstractmethod
from typing import Optional
import yaml
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE
from core.model_runtime.entities.model_entities import (
@ -18,6 +16,7 @@ from core.model_runtime.entities.model_entities import (
)
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
from core.tools.utils.yaml_utils import load_yaml_file
from core.utils.position_helper import get_position_map, sort_by_position_map
@ -154,8 +153,7 @@ class AIModel(ABC):
# traverse all model_schema_yaml_paths
for model_schema_yaml_path in model_schema_yaml_paths:
# read yaml data from yaml file
with open(model_schema_yaml_path, encoding='utf-8') as f:
yaml_data = yaml.safe_load(f)
yaml_data = load_yaml_file(model_schema_yaml_path, ignore_error=True)
new_parameter_rules = []
for parameter_rule in yaml_data.get('parameter_rules', []):

View File

@ -1,11 +1,10 @@
import os
from abc import ABC, abstractmethod
import yaml
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
from core.model_runtime.entities.provider_entities import ProviderEntity
from core.model_runtime.model_providers.__base.ai_model import AIModel
from core.tools.utils.yaml_utils import load_yaml_file
from core.utils.module_import_helper import get_subclasses_from_module, import_module_from_source
@ -44,10 +43,7 @@ class ModelProvider(ABC):
# read provider schema from yaml file
yaml_path = os.path.join(current_path, f'{provider_name}.yaml')
yaml_data = {}
if os.path.exists(yaml_path):
with open(yaml_path, encoding='utf-8') as f:
yaml_data = yaml.safe_load(f)
yaml_data = load_yaml_file(yaml_path, ignore_error=True)
try:
# yaml_data to entity

View File

@ -2,8 +2,6 @@ from abc import abstractmethod
from os import listdir, path
from typing import Any
from yaml import FullLoader, load
from core.tools.entities.tool_entities import ToolParameter, ToolProviderCredentials, ToolProviderType
from core.tools.entities.user_entities import UserToolProviderCredentials
from core.tools.errors import (
@ -15,6 +13,7 @@ from core.tools.errors import (
from core.tools.provider.tool_provider import ToolProviderController
from core.tools.tool.builtin_tool import BuiltinTool
from core.tools.tool.tool import Tool
from core.tools.utils.yaml_utils import load_yaml_file
from core.utils.module_import_helper import load_single_subclass_from_source
@ -28,10 +27,9 @@ class BuiltinToolProviderController(ToolProviderController):
provider = self.__class__.__module__.split('.')[-1]
yaml_path = path.join(path.dirname(path.realpath(__file__)), 'builtin', provider, f'{provider}.yaml')
try:
with open(yaml_path, 'rb') as f:
provider_yaml = load(f.read(), FullLoader)
except:
raise ToolProviderNotFoundError(f'can not load provider yaml for {provider}')
provider_yaml = load_yaml_file(yaml_path)
except Exception as e:
raise ToolProviderNotFoundError(f'can not load provider yaml for {provider}: {e}')
if 'credentials_for_provider' in provider_yaml and provider_yaml['credentials_for_provider'] is not None:
# set credentials name
@ -58,18 +56,18 @@ class BuiltinToolProviderController(ToolProviderController):
tool_files = list(filter(lambda x: x.endswith(".yaml") and not x.startswith("__"), listdir(tool_path)))
tools = []
for tool_file in tool_files:
with open(path.join(tool_path, tool_file), encoding='utf-8') as f:
# get tool name
tool_name = tool_file.split(".")[0]
tool = load(f.read(), FullLoader)
# get tool class, import the module
assistant_tool_class = load_single_subclass_from_source(
module_name=f'core.tools.provider.builtin.{provider}.tools.{tool_name}',
script_path=path.join(path.dirname(path.realpath(__file__)),
'builtin', provider, 'tools', f'{tool_name}.py'),
parent_type=BuiltinTool)
tool["identity"]["provider"] = provider
tools.append(assistant_tool_class(**tool))
# get tool name
tool_name = tool_file.split(".")[0]
tool = load_yaml_file(path.join(tool_path, tool_file))
# get tool class, import the module
assistant_tool_class = load_single_subclass_from_source(
module_name=f'core.tools.provider.builtin.{provider}.tools.{tool_name}',
script_path=path.join(path.dirname(path.realpath(__file__)),
'builtin', provider, 'tools', f'{tool_name}.py'),
parent_type=BuiltinTool)
tool["identity"]["provider"] = provider
tools.append(assistant_tool_class(**tool))
self.tools = tools
return tools

View File

@ -23,7 +23,7 @@ class ToolConfigurationManager(BaseModel):
deep copy credentials
"""
return deepcopy(credentials)
def encrypt_tool_credentials(self, credentials: dict[str, str]) -> dict[str, str]:
"""
encrypt tool credentials with tenant id
@ -39,9 +39,9 @@ class ToolConfigurationManager(BaseModel):
if field_name in credentials:
encrypted = encrypter.encrypt_token(self.tenant_id, credentials[field_name])
credentials[field_name] = encrypted
return credentials
def mask_tool_credentials(self, credentials: dict[str, Any]) -> dict[str, Any]:
"""
mask tool credentials
@ -58,7 +58,7 @@ class ToolConfigurationManager(BaseModel):
if len(credentials[field_name]) > 6:
credentials[field_name] = \
credentials[field_name][:2] + \
'*' * (len(credentials[field_name]) - 4) +\
'*' * (len(credentials[field_name]) - 4) + \
credentials[field_name][-2:]
else:
credentials[field_name] = '*' * len(credentials[field_name])
@ -72,7 +72,7 @@ class ToolConfigurationManager(BaseModel):
return a deep copy of credentials with decrypted values
"""
cache = ToolProviderCredentialsCache(
tenant_id=self.tenant_id,
tenant_id=self.tenant_id,
identity_id=f'{self.provider_controller.app_type.value}.{self.provider_controller.identity.name}',
cache_type=ToolProviderCredentialsCacheType.PROVIDER
)
@ -92,10 +92,10 @@ class ToolConfigurationManager(BaseModel):
cache.set(credentials)
return credentials
def delete_tool_credentials_cache(self):
cache = ToolProviderCredentialsCache(
tenant_id=self.tenant_id,
tenant_id=self.tenant_id,
identity_id=f'{self.provider_controller.app_type.value}.{self.provider_controller.identity.name}',
cache_type=ToolProviderCredentialsCacheType.PROVIDER
)
@ -116,7 +116,7 @@ class ToolParameterConfigurationManager(BaseModel):
deep copy parameters
"""
return deepcopy(parameters)
def _merge_parameters(self) -> list[ToolParameter]:
"""
merge parameters
@ -139,7 +139,7 @@ class ToolParameterConfigurationManager(BaseModel):
current_parameters.append(runtime_parameter)
return current_parameters
def mask_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
"""
mask tool parameters
@ -157,13 +157,13 @@ class ToolParameterConfigurationManager(BaseModel):
if len(parameters[parameter.name]) > 6:
parameters[parameter.name] = \
parameters[parameter.name][:2] + \
'*' * (len(parameters[parameter.name]) - 4) +\
'*' * (len(parameters[parameter.name]) - 4) + \
parameters[parameter.name][-2:]
else:
parameters[parameter.name] = '*' * len(parameters[parameter.name])
return parameters
def encrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
"""
encrypt tool parameters with tenant id
@ -180,9 +180,9 @@ class ToolParameterConfigurationManager(BaseModel):
if parameter.name in parameters:
encrypted = encrypter.encrypt_token(self.tenant_id, parameters[parameter.name])
parameters[parameter.name] = encrypted
return parameters
def decrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
"""
decrypt tool parameters with tenant id
@ -190,7 +190,7 @@ class ToolParameterConfigurationManager(BaseModel):
return a deep copy of parameters with decrypted values
"""
cache = ToolParameterCache(
tenant_id=self.tenant_id,
tenant_id=self.tenant_id,
provider=f'{self.provider_type}.{self.provider_name}',
tool_name=self.tool_runtime.identity.name,
cache_type=ToolParameterCacheType.PARAMETER,
@ -212,15 +212,15 @@ class ToolParameterConfigurationManager(BaseModel):
parameters[parameter.name] = encrypter.decrypt_token(self.tenant_id, parameters[parameter.name])
except:
pass
if has_secret_input:
cache.set(parameters)
return parameters
def delete_tool_parameters_cache(self):
cache = ToolParameterCache(
tenant_id=self.tenant_id,
tenant_id=self.tenant_id,
provider=f'{self.provider_type}.{self.provider_name}',
tool_name=self.tool_runtime.identity.name,
cache_type=ToolParameterCacheType.PARAMETER,

View File

@ -0,0 +1,34 @@
import logging
import os
import yaml
from yaml import YAMLError
def load_yaml_file(file_path: str, ignore_error: bool = False) -> dict:
"""
Safe loading a YAML file to a dict
:param file_path: the path of the YAML file
:param ignore_error:
if True, return empty dict if error occurs and the error will be logged in warning level
if False, raise error if error occurs
:return: a dict of the YAML content
"""
try:
if not file_path or not os.path.exists(file_path):
raise FileNotFoundError(f'Failed to load YAML file {file_path}: file not found')
with open(file_path, encoding='utf-8') as file:
try:
return yaml.safe_load(file)
except Exception as e:
raise YAMLError(f'Failed to load YAML file {file_path}: {e}')
except FileNotFoundError as e:
logging.debug(f'Failed to load YAML file {file_path}: {e}')
return {}
except Exception as e:
if ignore_error:
logging.warning(f'Failed to load YAML file {file_path}: {e}')
return {}
else:
raise e

View File

@ -1,10 +1,9 @@
import logging
import os
from collections import OrderedDict
from collections.abc import Callable
from typing import Any, AnyStr
import yaml
from core.tools.utils.yaml_utils import load_yaml_file
def get_position_map(
@ -17,21 +16,15 @@ def get_position_map(
:param file_name: the YAML file name, default to '_position.yaml'
:return: a dict with name as key and index as value
"""
try:
position_file_name = os.path.join(folder_path, file_name)
if not os.path.exists(position_file_name):
return {}
with open(position_file_name, encoding='utf-8') as f:
positions = yaml.safe_load(f)
position_map = {}
for index, name in enumerate(positions):
if name and isinstance(name, str):
position_map[name.strip()] = index
return position_map
except:
logging.warning(f'Failed to load the YAML position file {folder_path}/{file_name}.')
return {}
position_file_name = os.path.join(folder_path, file_name)
positions = load_yaml_file(position_file_name, ignore_error=True)
position_map = {}
index = 0
for _, name in enumerate(positions):
if name and isinstance(name, str):
position_map[name.strip()] = index
index += 1
return position_map
def sort_by_position_map(

View File

@ -14,6 +14,7 @@ select = [
"I", # isort rules
"UP", # pyupgrade rules
"RUF019", # unnecessary-key-check
"S506", # unsafe-yaml-load
]
ignore = [
"F403", # undefined-local-with-import-star

View File

View File

@ -0,0 +1,34 @@
from textwrap import dedent
import pytest
from core.utils.position_helper import get_position_map
@pytest.fixture
def prepare_example_positions_yaml(tmp_path, monkeypatch) -> str:
monkeypatch.chdir(tmp_path)
tmp_path.joinpath("example_positions.yaml").write_text(dedent(
"""\
- first
- second
# - commented
- third
- 9999999999999
- forth
"""))
return str(tmp_path)
def test_position_helper(prepare_example_positions_yaml):
position_map = get_position_map(
folder_path=prepare_example_positions_yaml,
file_name='example_positions.yaml')
assert len(position_map) == 4
assert position_map == {
'first': 0,
'second': 1,
'third': 2,
'forth': 3,
}

View File

@ -0,0 +1,74 @@
from textwrap import dedent
import pytest
from yaml import YAMLError
from core.tools.utils.yaml_utils import load_yaml_file
EXAMPLE_YAML_FILE = 'example_yaml.yaml'
INVALID_YAML_FILE = 'invalid_yaml.yaml'
NON_EXISTING_YAML_FILE = 'non_existing_file.yaml'
@pytest.fixture
def prepare_example_yaml_file(tmp_path, monkeypatch) -> str:
monkeypatch.chdir(tmp_path)
file_path = tmp_path.joinpath(EXAMPLE_YAML_FILE)
file_path.write_text(dedent(
"""\
address:
city: Example City
country: Example Country
age: 30
gender: male
languages:
- Python
- Java
- C++
empty_key:
"""))
return str(file_path)
@pytest.fixture
def prepare_invalid_yaml_file(tmp_path, monkeypatch) -> str:
monkeypatch.chdir(tmp_path)
file_path = tmp_path.joinpath(INVALID_YAML_FILE)
file_path.write_text(dedent(
"""\
address:
city: Example City
country: Example Country
age: 30
gender: male
languages:
- Python
- Java
- C++
"""))
return str(file_path)
def test_load_yaml_non_existing_file():
assert load_yaml_file(file_path=NON_EXISTING_YAML_FILE) == {}
assert load_yaml_file(file_path='') == {}
def test_load_valid_yaml_file(prepare_example_yaml_file):
yaml_data = load_yaml_file(file_path=prepare_example_yaml_file)
assert len(yaml_data) > 0
assert yaml_data['age'] == 30
assert yaml_data['gender'] == 'male'
assert yaml_data['address']['city'] == 'Example City'
assert set(yaml_data['languages']) == {'Python', 'Java', 'C++'}
assert yaml_data.get('empty_key') is None
assert yaml_data.get('non_existed_key') is None
def test_load_invalid_yaml_file(prepare_invalid_yaml_file):
# yaml syntax error
with pytest.raises(YAMLError):
load_yaml_file(file_path=prepare_invalid_yaml_file)
# ignore error
assert load_yaml_file(file_path=prepare_invalid_yaml_file, ignore_error=True) == {}