improve: extract method for safe loading yaml file and avoid using PyYaml's FullLoader (#4031)

This commit is contained in:
Bowen Liang 2024-05-24 12:08:12 +08:00 committed by GitHub
parent 296887754f
commit 3fda2245a4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 190 additions and 62 deletions

View File

@ -3,8 +3,6 @@ import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Optional from typing import Optional
import yaml
from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE
from core.model_runtime.entities.model_entities import ( from core.model_runtime.entities.model_entities import (
@ -18,6 +16,7 @@ from core.model_runtime.entities.model_entities import (
) )
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
from core.tools.utils.yaml_utils import load_yaml_file
from core.utils.position_helper import get_position_map, sort_by_position_map from core.utils.position_helper import get_position_map, sort_by_position_map
@ -154,8 +153,7 @@ class AIModel(ABC):
# traverse all model_schema_yaml_paths # traverse all model_schema_yaml_paths
for model_schema_yaml_path in model_schema_yaml_paths: for model_schema_yaml_path in model_schema_yaml_paths:
# read yaml data from yaml file # read yaml data from yaml file
with open(model_schema_yaml_path, encoding='utf-8') as f: yaml_data = load_yaml_file(model_schema_yaml_path, ignore_error=True)
yaml_data = yaml.safe_load(f)
new_parameter_rules = [] new_parameter_rules = []
for parameter_rule in yaml_data.get('parameter_rules', []): for parameter_rule in yaml_data.get('parameter_rules', []):

View File

@ -1,11 +1,10 @@
import os import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import yaml
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
from core.model_runtime.entities.provider_entities import ProviderEntity from core.model_runtime.entities.provider_entities import ProviderEntity
from core.model_runtime.model_providers.__base.ai_model import AIModel from core.model_runtime.model_providers.__base.ai_model import AIModel
from core.tools.utils.yaml_utils import load_yaml_file
from core.utils.module_import_helper import get_subclasses_from_module, import_module_from_source from core.utils.module_import_helper import get_subclasses_from_module, import_module_from_source
@ -44,10 +43,7 @@ class ModelProvider(ABC):
# read provider schema from yaml file # read provider schema from yaml file
yaml_path = os.path.join(current_path, f'{provider_name}.yaml') yaml_path = os.path.join(current_path, f'{provider_name}.yaml')
yaml_data = {} yaml_data = load_yaml_file(yaml_path, ignore_error=True)
if os.path.exists(yaml_path):
with open(yaml_path, encoding='utf-8') as f:
yaml_data = yaml.safe_load(f)
try: try:
# yaml_data to entity # yaml_data to entity

View File

@ -2,8 +2,6 @@ from abc import abstractmethod
from os import listdir, path from os import listdir, path
from typing import Any from typing import Any
from yaml import FullLoader, load
from core.tools.entities.tool_entities import ToolParameter, ToolProviderCredentials, ToolProviderType from core.tools.entities.tool_entities import ToolParameter, ToolProviderCredentials, ToolProviderType
from core.tools.entities.user_entities import UserToolProviderCredentials from core.tools.entities.user_entities import UserToolProviderCredentials
from core.tools.errors import ( from core.tools.errors import (
@ -15,6 +13,7 @@ from core.tools.errors import (
from core.tools.provider.tool_provider import ToolProviderController from core.tools.provider.tool_provider import ToolProviderController
from core.tools.tool.builtin_tool import BuiltinTool from core.tools.tool.builtin_tool import BuiltinTool
from core.tools.tool.tool import Tool from core.tools.tool.tool import Tool
from core.tools.utils.yaml_utils import load_yaml_file
from core.utils.module_import_helper import load_single_subclass_from_source from core.utils.module_import_helper import load_single_subclass_from_source
@ -28,10 +27,9 @@ class BuiltinToolProviderController(ToolProviderController):
provider = self.__class__.__module__.split('.')[-1] provider = self.__class__.__module__.split('.')[-1]
yaml_path = path.join(path.dirname(path.realpath(__file__)), 'builtin', provider, f'{provider}.yaml') yaml_path = path.join(path.dirname(path.realpath(__file__)), 'builtin', provider, f'{provider}.yaml')
try: try:
with open(yaml_path, 'rb') as f: provider_yaml = load_yaml_file(yaml_path)
provider_yaml = load(f.read(), FullLoader) except Exception as e:
except: raise ToolProviderNotFoundError(f'can not load provider yaml for {provider}: {e}')
raise ToolProviderNotFoundError(f'can not load provider yaml for {provider}')
if 'credentials_for_provider' in provider_yaml and provider_yaml['credentials_for_provider'] is not None: if 'credentials_for_provider' in provider_yaml and provider_yaml['credentials_for_provider'] is not None:
# set credentials name # set credentials name
@ -58,18 +56,18 @@ class BuiltinToolProviderController(ToolProviderController):
tool_files = list(filter(lambda x: x.endswith(".yaml") and not x.startswith("__"), listdir(tool_path))) tool_files = list(filter(lambda x: x.endswith(".yaml") and not x.startswith("__"), listdir(tool_path)))
tools = [] tools = []
for tool_file in tool_files: for tool_file in tool_files:
with open(path.join(tool_path, tool_file), encoding='utf-8') as f: # get tool name
# get tool name tool_name = tool_file.split(".")[0]
tool_name = tool_file.split(".")[0] tool = load_yaml_file(path.join(tool_path, tool_file))
tool = load(f.read(), FullLoader)
# get tool class, import the module # get tool class, import the module
assistant_tool_class = load_single_subclass_from_source( assistant_tool_class = load_single_subclass_from_source(
module_name=f'core.tools.provider.builtin.{provider}.tools.{tool_name}', module_name=f'core.tools.provider.builtin.{provider}.tools.{tool_name}',
script_path=path.join(path.dirname(path.realpath(__file__)), script_path=path.join(path.dirname(path.realpath(__file__)),
'builtin', provider, 'tools', f'{tool_name}.py'), 'builtin', provider, 'tools', f'{tool_name}.py'),
parent_type=BuiltinTool) parent_type=BuiltinTool)
tool["identity"]["provider"] = provider tool["identity"]["provider"] = provider
tools.append(assistant_tool_class(**tool)) tools.append(assistant_tool_class(**tool))
self.tools = tools self.tools = tools
return tools return tools

View File

@ -58,7 +58,7 @@ class ToolConfigurationManager(BaseModel):
if len(credentials[field_name]) > 6: if len(credentials[field_name]) > 6:
credentials[field_name] = \ credentials[field_name] = \
credentials[field_name][:2] + \ credentials[field_name][:2] + \
'*' * (len(credentials[field_name]) - 4) +\ '*' * (len(credentials[field_name]) - 4) + \
credentials[field_name][-2:] credentials[field_name][-2:]
else: else:
credentials[field_name] = '*' * len(credentials[field_name]) credentials[field_name] = '*' * len(credentials[field_name])
@ -157,7 +157,7 @@ class ToolParameterConfigurationManager(BaseModel):
if len(parameters[parameter.name]) > 6: if len(parameters[parameter.name]) > 6:
parameters[parameter.name] = \ parameters[parameter.name] = \
parameters[parameter.name][:2] + \ parameters[parameter.name][:2] + \
'*' * (len(parameters[parameter.name]) - 4) +\ '*' * (len(parameters[parameter.name]) - 4) + \
parameters[parameter.name][-2:] parameters[parameter.name][-2:]
else: else:
parameters[parameter.name] = '*' * len(parameters[parameter.name]) parameters[parameter.name] = '*' * len(parameters[parameter.name])

View File

@ -0,0 +1,34 @@
import logging
import os
import yaml
from yaml import YAMLError
def load_yaml_file(file_path: str, ignore_error: bool = False) -> dict:
"""
Safe loading a YAML file to a dict
:param file_path: the path of the YAML file
:param ignore_error:
if True, return empty dict if error occurs and the error will be logged in warning level
if False, raise error if error occurs
:return: a dict of the YAML content
"""
try:
if not file_path or not os.path.exists(file_path):
raise FileNotFoundError(f'Failed to load YAML file {file_path}: file not found')
with open(file_path, encoding='utf-8') as file:
try:
return yaml.safe_load(file)
except Exception as e:
raise YAMLError(f'Failed to load YAML file {file_path}: {e}')
except FileNotFoundError as e:
logging.debug(f'Failed to load YAML file {file_path}: {e}')
return {}
except Exception as e:
if ignore_error:
logging.warning(f'Failed to load YAML file {file_path}: {e}')
return {}
else:
raise e

View File

@ -1,10 +1,9 @@
import logging
import os import os
from collections import OrderedDict from collections import OrderedDict
from collections.abc import Callable from collections.abc import Callable
from typing import Any, AnyStr from typing import Any, AnyStr
import yaml from core.tools.utils.yaml_utils import load_yaml_file
def get_position_map( def get_position_map(
@ -17,21 +16,15 @@ def get_position_map(
:param file_name: the YAML file name, default to '_position.yaml' :param file_name: the YAML file name, default to '_position.yaml'
:return: a dict with name as key and index as value :return: a dict with name as key and index as value
""" """
try: position_file_name = os.path.join(folder_path, file_name)
position_file_name = os.path.join(folder_path, file_name) positions = load_yaml_file(position_file_name, ignore_error=True)
if not os.path.exists(position_file_name): position_map = {}
return {} index = 0
for _, name in enumerate(positions):
with open(position_file_name, encoding='utf-8') as f: if name and isinstance(name, str):
positions = yaml.safe_load(f) position_map[name.strip()] = index
position_map = {} index += 1
for index, name in enumerate(positions): return position_map
if name and isinstance(name, str):
position_map[name.strip()] = index
return position_map
except:
logging.warning(f'Failed to load the YAML position file {folder_path}/{file_name}.')
return {}
def sort_by_position_map( def sort_by_position_map(

View File

@ -14,6 +14,7 @@ select = [
"I", # isort rules "I", # isort rules
"UP", # pyupgrade rules "UP", # pyupgrade rules
"RUF019", # unnecessary-key-check "RUF019", # unnecessary-key-check
"S506", # unsafe-yaml-load
] ]
ignore = [ ignore = [
"F403", # undefined-local-with-import-star "F403", # undefined-local-with-import-star

View File

View File

@ -0,0 +1,34 @@
from textwrap import dedent
import pytest
from core.utils.position_helper import get_position_map
@pytest.fixture
def prepare_example_positions_yaml(tmp_path, monkeypatch) -> str:
monkeypatch.chdir(tmp_path)
tmp_path.joinpath("example_positions.yaml").write_text(dedent(
"""\
- first
- second
# - commented
- third
- 9999999999999
- forth
"""))
return str(tmp_path)
def test_position_helper(prepare_example_positions_yaml):
position_map = get_position_map(
folder_path=prepare_example_positions_yaml,
file_name='example_positions.yaml')
assert len(position_map) == 4
assert position_map == {
'first': 0,
'second': 1,
'third': 2,
'forth': 3,
}

View File

@ -0,0 +1,74 @@
from textwrap import dedent
import pytest
from yaml import YAMLError
from core.tools.utils.yaml_utils import load_yaml_file
EXAMPLE_YAML_FILE = 'example_yaml.yaml'
INVALID_YAML_FILE = 'invalid_yaml.yaml'
NON_EXISTING_YAML_FILE = 'non_existing_file.yaml'
@pytest.fixture
def prepare_example_yaml_file(tmp_path, monkeypatch) -> str:
monkeypatch.chdir(tmp_path)
file_path = tmp_path.joinpath(EXAMPLE_YAML_FILE)
file_path.write_text(dedent(
"""\
address:
city: Example City
country: Example Country
age: 30
gender: male
languages:
- Python
- Java
- C++
empty_key:
"""))
return str(file_path)
@pytest.fixture
def prepare_invalid_yaml_file(tmp_path, monkeypatch) -> str:
monkeypatch.chdir(tmp_path)
file_path = tmp_path.joinpath(INVALID_YAML_FILE)
file_path.write_text(dedent(
"""\
address:
city: Example City
country: Example Country
age: 30
gender: male
languages:
- Python
- Java
- C++
"""))
return str(file_path)
def test_load_yaml_non_existing_file():
assert load_yaml_file(file_path=NON_EXISTING_YAML_FILE) == {}
assert load_yaml_file(file_path='') == {}
def test_load_valid_yaml_file(prepare_example_yaml_file):
yaml_data = load_yaml_file(file_path=prepare_example_yaml_file)
assert len(yaml_data) > 0
assert yaml_data['age'] == 30
assert yaml_data['gender'] == 'male'
assert yaml_data['address']['city'] == 'Example City'
assert set(yaml_data['languages']) == {'Python', 'Java', 'C++'}
assert yaml_data.get('empty_key') is None
assert yaml_data.get('non_existed_key') is None
def test_load_invalid_yaml_file(prepare_invalid_yaml_file):
# yaml syntax error
with pytest.raises(YAMLError):
load_yaml_file(file_path=prepare_invalid_yaml_file)
# ignore error
assert load_yaml_file(file_path=prepare_invalid_yaml_file, ignore_error=True) == {}