mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-12 21:48:58 +08:00
improve: extract method for safe loading yaml file and avoid using PyYaml's FullLoader (#4031)
This commit is contained in:
parent
296887754f
commit
3fda2245a4
@ -3,8 +3,6 @@ import os
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import yaml
|
|
||||||
|
|
||||||
from core.model_runtime.entities.common_entities import I18nObject
|
from core.model_runtime.entities.common_entities import I18nObject
|
||||||
from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE
|
from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE
|
||||||
from core.model_runtime.entities.model_entities import (
|
from core.model_runtime.entities.model_entities import (
|
||||||
@ -18,6 +16,7 @@ from core.model_runtime.entities.model_entities import (
|
|||||||
)
|
)
|
||||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||||
from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
|
from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
|
||||||
|
from core.tools.utils.yaml_utils import load_yaml_file
|
||||||
from core.utils.position_helper import get_position_map, sort_by_position_map
|
from core.utils.position_helper import get_position_map, sort_by_position_map
|
||||||
|
|
||||||
|
|
||||||
@ -154,8 +153,7 @@ class AIModel(ABC):
|
|||||||
# traverse all model_schema_yaml_paths
|
# traverse all model_schema_yaml_paths
|
||||||
for model_schema_yaml_path in model_schema_yaml_paths:
|
for model_schema_yaml_path in model_schema_yaml_paths:
|
||||||
# read yaml data from yaml file
|
# read yaml data from yaml file
|
||||||
with open(model_schema_yaml_path, encoding='utf-8') as f:
|
yaml_data = load_yaml_file(model_schema_yaml_path, ignore_error=True)
|
||||||
yaml_data = yaml.safe_load(f)
|
|
||||||
|
|
||||||
new_parameter_rules = []
|
new_parameter_rules = []
|
||||||
for parameter_rule in yaml_data.get('parameter_rules', []):
|
for parameter_rule in yaml_data.get('parameter_rules', []):
|
||||||
|
@ -1,11 +1,10 @@
|
|||||||
import os
|
import os
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
import yaml
|
|
||||||
|
|
||||||
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||||
from core.model_runtime.entities.provider_entities import ProviderEntity
|
from core.model_runtime.entities.provider_entities import ProviderEntity
|
||||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||||
|
from core.tools.utils.yaml_utils import load_yaml_file
|
||||||
from core.utils.module_import_helper import get_subclasses_from_module, import_module_from_source
|
from core.utils.module_import_helper import get_subclasses_from_module, import_module_from_source
|
||||||
|
|
||||||
|
|
||||||
@ -44,10 +43,7 @@ class ModelProvider(ABC):
|
|||||||
|
|
||||||
# read provider schema from yaml file
|
# read provider schema from yaml file
|
||||||
yaml_path = os.path.join(current_path, f'{provider_name}.yaml')
|
yaml_path = os.path.join(current_path, f'{provider_name}.yaml')
|
||||||
yaml_data = {}
|
yaml_data = load_yaml_file(yaml_path, ignore_error=True)
|
||||||
if os.path.exists(yaml_path):
|
|
||||||
with open(yaml_path, encoding='utf-8') as f:
|
|
||||||
yaml_data = yaml.safe_load(f)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# yaml_data to entity
|
# yaml_data to entity
|
||||||
|
@ -2,8 +2,6 @@ from abc import abstractmethod
|
|||||||
from os import listdir, path
|
from os import listdir, path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from yaml import FullLoader, load
|
|
||||||
|
|
||||||
from core.tools.entities.tool_entities import ToolParameter, ToolProviderCredentials, ToolProviderType
|
from core.tools.entities.tool_entities import ToolParameter, ToolProviderCredentials, ToolProviderType
|
||||||
from core.tools.entities.user_entities import UserToolProviderCredentials
|
from core.tools.entities.user_entities import UserToolProviderCredentials
|
||||||
from core.tools.errors import (
|
from core.tools.errors import (
|
||||||
@ -15,6 +13,7 @@ from core.tools.errors import (
|
|||||||
from core.tools.provider.tool_provider import ToolProviderController
|
from core.tools.provider.tool_provider import ToolProviderController
|
||||||
from core.tools.tool.builtin_tool import BuiltinTool
|
from core.tools.tool.builtin_tool import BuiltinTool
|
||||||
from core.tools.tool.tool import Tool
|
from core.tools.tool.tool import Tool
|
||||||
|
from core.tools.utils.yaml_utils import load_yaml_file
|
||||||
from core.utils.module_import_helper import load_single_subclass_from_source
|
from core.utils.module_import_helper import load_single_subclass_from_source
|
||||||
|
|
||||||
|
|
||||||
@ -28,10 +27,9 @@ class BuiltinToolProviderController(ToolProviderController):
|
|||||||
provider = self.__class__.__module__.split('.')[-1]
|
provider = self.__class__.__module__.split('.')[-1]
|
||||||
yaml_path = path.join(path.dirname(path.realpath(__file__)), 'builtin', provider, f'{provider}.yaml')
|
yaml_path = path.join(path.dirname(path.realpath(__file__)), 'builtin', provider, f'{provider}.yaml')
|
||||||
try:
|
try:
|
||||||
with open(yaml_path, 'rb') as f:
|
provider_yaml = load_yaml_file(yaml_path)
|
||||||
provider_yaml = load(f.read(), FullLoader)
|
except Exception as e:
|
||||||
except:
|
raise ToolProviderNotFoundError(f'can not load provider yaml for {provider}: {e}')
|
||||||
raise ToolProviderNotFoundError(f'can not load provider yaml for {provider}')
|
|
||||||
|
|
||||||
if 'credentials_for_provider' in provider_yaml and provider_yaml['credentials_for_provider'] is not None:
|
if 'credentials_for_provider' in provider_yaml and provider_yaml['credentials_for_provider'] is not None:
|
||||||
# set credentials name
|
# set credentials name
|
||||||
@ -58,18 +56,18 @@ class BuiltinToolProviderController(ToolProviderController):
|
|||||||
tool_files = list(filter(lambda x: x.endswith(".yaml") and not x.startswith("__"), listdir(tool_path)))
|
tool_files = list(filter(lambda x: x.endswith(".yaml") and not x.startswith("__"), listdir(tool_path)))
|
||||||
tools = []
|
tools = []
|
||||||
for tool_file in tool_files:
|
for tool_file in tool_files:
|
||||||
with open(path.join(tool_path, tool_file), encoding='utf-8') as f:
|
# get tool name
|
||||||
# get tool name
|
tool_name = tool_file.split(".")[0]
|
||||||
tool_name = tool_file.split(".")[0]
|
tool = load_yaml_file(path.join(tool_path, tool_file))
|
||||||
tool = load(f.read(), FullLoader)
|
|
||||||
# get tool class, import the module
|
# get tool class, import the module
|
||||||
assistant_tool_class = load_single_subclass_from_source(
|
assistant_tool_class = load_single_subclass_from_source(
|
||||||
module_name=f'core.tools.provider.builtin.{provider}.tools.{tool_name}',
|
module_name=f'core.tools.provider.builtin.{provider}.tools.{tool_name}',
|
||||||
script_path=path.join(path.dirname(path.realpath(__file__)),
|
script_path=path.join(path.dirname(path.realpath(__file__)),
|
||||||
'builtin', provider, 'tools', f'{tool_name}.py'),
|
'builtin', provider, 'tools', f'{tool_name}.py'),
|
||||||
parent_type=BuiltinTool)
|
parent_type=BuiltinTool)
|
||||||
tool["identity"]["provider"] = provider
|
tool["identity"]["provider"] = provider
|
||||||
tools.append(assistant_tool_class(**tool))
|
tools.append(assistant_tool_class(**tool))
|
||||||
|
|
||||||
self.tools = tools
|
self.tools = tools
|
||||||
return tools
|
return tools
|
||||||
|
@ -58,7 +58,7 @@ class ToolConfigurationManager(BaseModel):
|
|||||||
if len(credentials[field_name]) > 6:
|
if len(credentials[field_name]) > 6:
|
||||||
credentials[field_name] = \
|
credentials[field_name] = \
|
||||||
credentials[field_name][:2] + \
|
credentials[field_name][:2] + \
|
||||||
'*' * (len(credentials[field_name]) - 4) +\
|
'*' * (len(credentials[field_name]) - 4) + \
|
||||||
credentials[field_name][-2:]
|
credentials[field_name][-2:]
|
||||||
else:
|
else:
|
||||||
credentials[field_name] = '*' * len(credentials[field_name])
|
credentials[field_name] = '*' * len(credentials[field_name])
|
||||||
@ -157,7 +157,7 @@ class ToolParameterConfigurationManager(BaseModel):
|
|||||||
if len(parameters[parameter.name]) > 6:
|
if len(parameters[parameter.name]) > 6:
|
||||||
parameters[parameter.name] = \
|
parameters[parameter.name] = \
|
||||||
parameters[parameter.name][:2] + \
|
parameters[parameter.name][:2] + \
|
||||||
'*' * (len(parameters[parameter.name]) - 4) +\
|
'*' * (len(parameters[parameter.name]) - 4) + \
|
||||||
parameters[parameter.name][-2:]
|
parameters[parameter.name][-2:]
|
||||||
else:
|
else:
|
||||||
parameters[parameter.name] = '*' * len(parameters[parameter.name])
|
parameters[parameter.name] = '*' * len(parameters[parameter.name])
|
||||||
|
34
api/core/tools/utils/yaml_utils.py
Normal file
34
api/core/tools/utils/yaml_utils.py
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
from yaml import YAMLError
|
||||||
|
|
||||||
|
|
||||||
|
def load_yaml_file(file_path: str, ignore_error: bool = False) -> dict:
|
||||||
|
"""
|
||||||
|
Safe loading a YAML file to a dict
|
||||||
|
:param file_path: the path of the YAML file
|
||||||
|
:param ignore_error:
|
||||||
|
if True, return empty dict if error occurs and the error will be logged in warning level
|
||||||
|
if False, raise error if error occurs
|
||||||
|
:return: a dict of the YAML content
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if not file_path or not os.path.exists(file_path):
|
||||||
|
raise FileNotFoundError(f'Failed to load YAML file {file_path}: file not found')
|
||||||
|
|
||||||
|
with open(file_path, encoding='utf-8') as file:
|
||||||
|
try:
|
||||||
|
return yaml.safe_load(file)
|
||||||
|
except Exception as e:
|
||||||
|
raise YAMLError(f'Failed to load YAML file {file_path}: {e}')
|
||||||
|
except FileNotFoundError as e:
|
||||||
|
logging.debug(f'Failed to load YAML file {file_path}: {e}')
|
||||||
|
return {}
|
||||||
|
except Exception as e:
|
||||||
|
if ignore_error:
|
||||||
|
logging.warning(f'Failed to load YAML file {file_path}: {e}')
|
||||||
|
return {}
|
||||||
|
else:
|
||||||
|
raise e
|
@ -1,10 +1,9 @@
|
|||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from typing import Any, AnyStr
|
from typing import Any, AnyStr
|
||||||
|
|
||||||
import yaml
|
from core.tools.utils.yaml_utils import load_yaml_file
|
||||||
|
|
||||||
|
|
||||||
def get_position_map(
|
def get_position_map(
|
||||||
@ -17,21 +16,15 @@ def get_position_map(
|
|||||||
:param file_name: the YAML file name, default to '_position.yaml'
|
:param file_name: the YAML file name, default to '_position.yaml'
|
||||||
:return: a dict with name as key and index as value
|
:return: a dict with name as key and index as value
|
||||||
"""
|
"""
|
||||||
try:
|
position_file_name = os.path.join(folder_path, file_name)
|
||||||
position_file_name = os.path.join(folder_path, file_name)
|
positions = load_yaml_file(position_file_name, ignore_error=True)
|
||||||
if not os.path.exists(position_file_name):
|
position_map = {}
|
||||||
return {}
|
index = 0
|
||||||
|
for _, name in enumerate(positions):
|
||||||
with open(position_file_name, encoding='utf-8') as f:
|
if name and isinstance(name, str):
|
||||||
positions = yaml.safe_load(f)
|
position_map[name.strip()] = index
|
||||||
position_map = {}
|
index += 1
|
||||||
for index, name in enumerate(positions):
|
return position_map
|
||||||
if name and isinstance(name, str):
|
|
||||||
position_map[name.strip()] = index
|
|
||||||
return position_map
|
|
||||||
except:
|
|
||||||
logging.warning(f'Failed to load the YAML position file {folder_path}/{file_name}.')
|
|
||||||
return {}
|
|
||||||
|
|
||||||
|
|
||||||
def sort_by_position_map(
|
def sort_by_position_map(
|
||||||
|
@ -14,6 +14,7 @@ select = [
|
|||||||
"I", # isort rules
|
"I", # isort rules
|
||||||
"UP", # pyupgrade rules
|
"UP", # pyupgrade rules
|
||||||
"RUF019", # unnecessary-key-check
|
"RUF019", # unnecessary-key-check
|
||||||
|
"S506", # unsafe-yaml-load
|
||||||
]
|
]
|
||||||
ignore = [
|
ignore = [
|
||||||
"F403", # undefined-local-with-import-star
|
"F403", # undefined-local-with-import-star
|
||||||
|
0
api/tests/unit_tests/utils/__init__.py
Normal file
0
api/tests/unit_tests/utils/__init__.py
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
from textwrap import dedent
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from core.utils.position_helper import get_position_map
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def prepare_example_positions_yaml(tmp_path, monkeypatch) -> str:
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
tmp_path.joinpath("example_positions.yaml").write_text(dedent(
|
||||||
|
"""\
|
||||||
|
- first
|
||||||
|
- second
|
||||||
|
# - commented
|
||||||
|
- third
|
||||||
|
|
||||||
|
- 9999999999999
|
||||||
|
- forth
|
||||||
|
"""))
|
||||||
|
return str(tmp_path)
|
||||||
|
|
||||||
|
|
||||||
|
def test_position_helper(prepare_example_positions_yaml):
|
||||||
|
position_map = get_position_map(
|
||||||
|
folder_path=prepare_example_positions_yaml,
|
||||||
|
file_name='example_positions.yaml')
|
||||||
|
assert len(position_map) == 4
|
||||||
|
assert position_map == {
|
||||||
|
'first': 0,
|
||||||
|
'second': 1,
|
||||||
|
'third': 2,
|
||||||
|
'forth': 3,
|
||||||
|
}
|
0
api/tests/unit_tests/utils/yaml/__init__.py
Normal file
0
api/tests/unit_tests/utils/yaml/__init__.py
Normal file
74
api/tests/unit_tests/utils/yaml/test_yaml_utils.py
Normal file
74
api/tests/unit_tests/utils/yaml/test_yaml_utils.py
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
from textwrap import dedent
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from yaml import YAMLError
|
||||||
|
|
||||||
|
from core.tools.utils.yaml_utils import load_yaml_file
|
||||||
|
|
||||||
|
EXAMPLE_YAML_FILE = 'example_yaml.yaml'
|
||||||
|
INVALID_YAML_FILE = 'invalid_yaml.yaml'
|
||||||
|
NON_EXISTING_YAML_FILE = 'non_existing_file.yaml'
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def prepare_example_yaml_file(tmp_path, monkeypatch) -> str:
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
file_path = tmp_path.joinpath(EXAMPLE_YAML_FILE)
|
||||||
|
file_path.write_text(dedent(
|
||||||
|
"""\
|
||||||
|
address:
|
||||||
|
city: Example City
|
||||||
|
country: Example Country
|
||||||
|
age: 30
|
||||||
|
gender: male
|
||||||
|
languages:
|
||||||
|
- Python
|
||||||
|
- Java
|
||||||
|
- C++
|
||||||
|
empty_key:
|
||||||
|
"""))
|
||||||
|
return str(file_path)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def prepare_invalid_yaml_file(tmp_path, monkeypatch) -> str:
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
file_path = tmp_path.joinpath(INVALID_YAML_FILE)
|
||||||
|
file_path.write_text(dedent(
|
||||||
|
"""\
|
||||||
|
address:
|
||||||
|
city: Example City
|
||||||
|
country: Example Country
|
||||||
|
age: 30
|
||||||
|
gender: male
|
||||||
|
languages:
|
||||||
|
- Python
|
||||||
|
- Java
|
||||||
|
- C++
|
||||||
|
"""))
|
||||||
|
return str(file_path)
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_yaml_non_existing_file():
|
||||||
|
assert load_yaml_file(file_path=NON_EXISTING_YAML_FILE) == {}
|
||||||
|
assert load_yaml_file(file_path='') == {}
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_valid_yaml_file(prepare_example_yaml_file):
|
||||||
|
yaml_data = load_yaml_file(file_path=prepare_example_yaml_file)
|
||||||
|
assert len(yaml_data) > 0
|
||||||
|
assert yaml_data['age'] == 30
|
||||||
|
assert yaml_data['gender'] == 'male'
|
||||||
|
assert yaml_data['address']['city'] == 'Example City'
|
||||||
|
assert set(yaml_data['languages']) == {'Python', 'Java', 'C++'}
|
||||||
|
assert yaml_data.get('empty_key') is None
|
||||||
|
assert yaml_data.get('non_existed_key') is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_invalid_yaml_file(prepare_invalid_yaml_file):
|
||||||
|
# yaml syntax error
|
||||||
|
with pytest.raises(YAMLError):
|
||||||
|
load_yaml_file(file_path=prepare_invalid_yaml_file)
|
||||||
|
|
||||||
|
# ignore error
|
||||||
|
assert load_yaml_file(file_path=prepare_invalid_yaml_file, ignore_error=True) == {}
|
Loading…
x
Reference in New Issue
Block a user