mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-11 22:19:00 +08:00
improve: extract method for safe loading yaml file and avoid using PyYaml's FullLoader (#4031)
This commit is contained in:
parent
296887754f
commit
3fda2245a4
@ -3,8 +3,6 @@ import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
import yaml
|
||||
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE
|
||||
from core.model_runtime.entities.model_entities import (
|
||||
@ -18,6 +16,7 @@ from core.model_runtime.entities.model_entities import (
|
||||
)
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
|
||||
from core.tools.utils.yaml_utils import load_yaml_file
|
||||
from core.utils.position_helper import get_position_map, sort_by_position_map
|
||||
|
||||
|
||||
@ -154,8 +153,7 @@ class AIModel(ABC):
|
||||
# traverse all model_schema_yaml_paths
|
||||
for model_schema_yaml_path in model_schema_yaml_paths:
|
||||
# read yaml data from yaml file
|
||||
with open(model_schema_yaml_path, encoding='utf-8') as f:
|
||||
yaml_data = yaml.safe_load(f)
|
||||
yaml_data = load_yaml_file(model_schema_yaml_path, ignore_error=True)
|
||||
|
||||
new_parameter_rules = []
|
||||
for parameter_rule in yaml_data.get('parameter_rules', []):
|
||||
|
@ -1,11 +1,10 @@
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import yaml
|
||||
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||
from core.model_runtime.entities.provider_entities import ProviderEntity
|
||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
from core.tools.utils.yaml_utils import load_yaml_file
|
||||
from core.utils.module_import_helper import get_subclasses_from_module, import_module_from_source
|
||||
|
||||
|
||||
@ -44,10 +43,7 @@ class ModelProvider(ABC):
|
||||
|
||||
# read provider schema from yaml file
|
||||
yaml_path = os.path.join(current_path, f'{provider_name}.yaml')
|
||||
yaml_data = {}
|
||||
if os.path.exists(yaml_path):
|
||||
with open(yaml_path, encoding='utf-8') as f:
|
||||
yaml_data = yaml.safe_load(f)
|
||||
yaml_data = load_yaml_file(yaml_path, ignore_error=True)
|
||||
|
||||
try:
|
||||
# yaml_data to entity
|
||||
|
@ -2,8 +2,6 @@ from abc import abstractmethod
|
||||
from os import listdir, path
|
||||
from typing import Any
|
||||
|
||||
from yaml import FullLoader, load
|
||||
|
||||
from core.tools.entities.tool_entities import ToolParameter, ToolProviderCredentials, ToolProviderType
|
||||
from core.tools.entities.user_entities import UserToolProviderCredentials
|
||||
from core.tools.errors import (
|
||||
@ -15,6 +13,7 @@ from core.tools.errors import (
|
||||
from core.tools.provider.tool_provider import ToolProviderController
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
from core.tools.tool.tool import Tool
|
||||
from core.tools.utils.yaml_utils import load_yaml_file
|
||||
from core.utils.module_import_helper import load_single_subclass_from_source
|
||||
|
||||
|
||||
@ -28,10 +27,9 @@ class BuiltinToolProviderController(ToolProviderController):
|
||||
provider = self.__class__.__module__.split('.')[-1]
|
||||
yaml_path = path.join(path.dirname(path.realpath(__file__)), 'builtin', provider, f'{provider}.yaml')
|
||||
try:
|
||||
with open(yaml_path, 'rb') as f:
|
||||
provider_yaml = load(f.read(), FullLoader)
|
||||
except:
|
||||
raise ToolProviderNotFoundError(f'can not load provider yaml for {provider}')
|
||||
provider_yaml = load_yaml_file(yaml_path)
|
||||
except Exception as e:
|
||||
raise ToolProviderNotFoundError(f'can not load provider yaml for {provider}: {e}')
|
||||
|
||||
if 'credentials_for_provider' in provider_yaml and provider_yaml['credentials_for_provider'] is not None:
|
||||
# set credentials name
|
||||
@ -58,18 +56,18 @@ class BuiltinToolProviderController(ToolProviderController):
|
||||
tool_files = list(filter(lambda x: x.endswith(".yaml") and not x.startswith("__"), listdir(tool_path)))
|
||||
tools = []
|
||||
for tool_file in tool_files:
|
||||
with open(path.join(tool_path, tool_file), encoding='utf-8') as f:
|
||||
# get tool name
|
||||
tool_name = tool_file.split(".")[0]
|
||||
tool = load(f.read(), FullLoader)
|
||||
# get tool class, import the module
|
||||
assistant_tool_class = load_single_subclass_from_source(
|
||||
module_name=f'core.tools.provider.builtin.{provider}.tools.{tool_name}',
|
||||
script_path=path.join(path.dirname(path.realpath(__file__)),
|
||||
'builtin', provider, 'tools', f'{tool_name}.py'),
|
||||
parent_type=BuiltinTool)
|
||||
tool["identity"]["provider"] = provider
|
||||
tools.append(assistant_tool_class(**tool))
|
||||
# get tool name
|
||||
tool_name = tool_file.split(".")[0]
|
||||
tool = load_yaml_file(path.join(tool_path, tool_file))
|
||||
|
||||
# get tool class, import the module
|
||||
assistant_tool_class = load_single_subclass_from_source(
|
||||
module_name=f'core.tools.provider.builtin.{provider}.tools.{tool_name}',
|
||||
script_path=path.join(path.dirname(path.realpath(__file__)),
|
||||
'builtin', provider, 'tools', f'{tool_name}.py'),
|
||||
parent_type=BuiltinTool)
|
||||
tool["identity"]["provider"] = provider
|
||||
tools.append(assistant_tool_class(**tool))
|
||||
|
||||
self.tools = tools
|
||||
return tools
|
||||
|
@ -23,7 +23,7 @@ class ToolConfigurationManager(BaseModel):
|
||||
deep copy credentials
|
||||
"""
|
||||
return deepcopy(credentials)
|
||||
|
||||
|
||||
def encrypt_tool_credentials(self, credentials: dict[str, str]) -> dict[str, str]:
|
||||
"""
|
||||
encrypt tool credentials with tenant id
|
||||
@ -39,9 +39,9 @@ class ToolConfigurationManager(BaseModel):
|
||||
if field_name in credentials:
|
||||
encrypted = encrypter.encrypt_token(self.tenant_id, credentials[field_name])
|
||||
credentials[field_name] = encrypted
|
||||
|
||||
|
||||
return credentials
|
||||
|
||||
|
||||
def mask_tool_credentials(self, credentials: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
mask tool credentials
|
||||
@ -58,7 +58,7 @@ class ToolConfigurationManager(BaseModel):
|
||||
if len(credentials[field_name]) > 6:
|
||||
credentials[field_name] = \
|
||||
credentials[field_name][:2] + \
|
||||
'*' * (len(credentials[field_name]) - 4) +\
|
||||
'*' * (len(credentials[field_name]) - 4) + \
|
||||
credentials[field_name][-2:]
|
||||
else:
|
||||
credentials[field_name] = '*' * len(credentials[field_name])
|
||||
@ -72,7 +72,7 @@ class ToolConfigurationManager(BaseModel):
|
||||
return a deep copy of credentials with decrypted values
|
||||
"""
|
||||
cache = ToolProviderCredentialsCache(
|
||||
tenant_id=self.tenant_id,
|
||||
tenant_id=self.tenant_id,
|
||||
identity_id=f'{self.provider_controller.app_type.value}.{self.provider_controller.identity.name}',
|
||||
cache_type=ToolProviderCredentialsCacheType.PROVIDER
|
||||
)
|
||||
@ -92,10 +92,10 @@ class ToolConfigurationManager(BaseModel):
|
||||
|
||||
cache.set(credentials)
|
||||
return credentials
|
||||
|
||||
|
||||
def delete_tool_credentials_cache(self):
|
||||
cache = ToolProviderCredentialsCache(
|
||||
tenant_id=self.tenant_id,
|
||||
tenant_id=self.tenant_id,
|
||||
identity_id=f'{self.provider_controller.app_type.value}.{self.provider_controller.identity.name}',
|
||||
cache_type=ToolProviderCredentialsCacheType.PROVIDER
|
||||
)
|
||||
@ -116,7 +116,7 @@ class ToolParameterConfigurationManager(BaseModel):
|
||||
deep copy parameters
|
||||
"""
|
||||
return deepcopy(parameters)
|
||||
|
||||
|
||||
def _merge_parameters(self) -> list[ToolParameter]:
|
||||
"""
|
||||
merge parameters
|
||||
@ -139,7 +139,7 @@ class ToolParameterConfigurationManager(BaseModel):
|
||||
current_parameters.append(runtime_parameter)
|
||||
|
||||
return current_parameters
|
||||
|
||||
|
||||
def mask_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
mask tool parameters
|
||||
@ -157,13 +157,13 @@ class ToolParameterConfigurationManager(BaseModel):
|
||||
if len(parameters[parameter.name]) > 6:
|
||||
parameters[parameter.name] = \
|
||||
parameters[parameter.name][:2] + \
|
||||
'*' * (len(parameters[parameter.name]) - 4) +\
|
||||
'*' * (len(parameters[parameter.name]) - 4) + \
|
||||
parameters[parameter.name][-2:]
|
||||
else:
|
||||
parameters[parameter.name] = '*' * len(parameters[parameter.name])
|
||||
|
||||
return parameters
|
||||
|
||||
|
||||
def encrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
encrypt tool parameters with tenant id
|
||||
@ -180,9 +180,9 @@ class ToolParameterConfigurationManager(BaseModel):
|
||||
if parameter.name in parameters:
|
||||
encrypted = encrypter.encrypt_token(self.tenant_id, parameters[parameter.name])
|
||||
parameters[parameter.name] = encrypted
|
||||
|
||||
|
||||
return parameters
|
||||
|
||||
|
||||
def decrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
decrypt tool parameters with tenant id
|
||||
@ -190,7 +190,7 @@ class ToolParameterConfigurationManager(BaseModel):
|
||||
return a deep copy of parameters with decrypted values
|
||||
"""
|
||||
cache = ToolParameterCache(
|
||||
tenant_id=self.tenant_id,
|
||||
tenant_id=self.tenant_id,
|
||||
provider=f'{self.provider_type}.{self.provider_name}',
|
||||
tool_name=self.tool_runtime.identity.name,
|
||||
cache_type=ToolParameterCacheType.PARAMETER,
|
||||
@ -212,15 +212,15 @@ class ToolParameterConfigurationManager(BaseModel):
|
||||
parameters[parameter.name] = encrypter.decrypt_token(self.tenant_id, parameters[parameter.name])
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
if has_secret_input:
|
||||
cache.set(parameters)
|
||||
|
||||
return parameters
|
||||
|
||||
|
||||
def delete_tool_parameters_cache(self):
|
||||
cache = ToolParameterCache(
|
||||
tenant_id=self.tenant_id,
|
||||
tenant_id=self.tenant_id,
|
||||
provider=f'{self.provider_type}.{self.provider_name}',
|
||||
tool_name=self.tool_runtime.identity.name,
|
||||
cache_type=ToolParameterCacheType.PARAMETER,
|
||||
|
34
api/core/tools/utils/yaml_utils.py
Normal file
34
api/core/tools/utils/yaml_utils.py
Normal file
@ -0,0 +1,34 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
import yaml
|
||||
from yaml import YAMLError
|
||||
|
||||
|
||||
def load_yaml_file(file_path: str, ignore_error: bool = False) -> dict:
|
||||
"""
|
||||
Safe loading a YAML file to a dict
|
||||
:param file_path: the path of the YAML file
|
||||
:param ignore_error:
|
||||
if True, return empty dict if error occurs and the error will be logged in warning level
|
||||
if False, raise error if error occurs
|
||||
:return: a dict of the YAML content
|
||||
"""
|
||||
try:
|
||||
if not file_path or not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f'Failed to load YAML file {file_path}: file not found')
|
||||
|
||||
with open(file_path, encoding='utf-8') as file:
|
||||
try:
|
||||
return yaml.safe_load(file)
|
||||
except Exception as e:
|
||||
raise YAMLError(f'Failed to load YAML file {file_path}: {e}')
|
||||
except FileNotFoundError as e:
|
||||
logging.debug(f'Failed to load YAML file {file_path}: {e}')
|
||||
return {}
|
||||
except Exception as e:
|
||||
if ignore_error:
|
||||
logging.warning(f'Failed to load YAML file {file_path}: {e}')
|
||||
return {}
|
||||
else:
|
||||
raise e
|
@ -1,10 +1,9 @@
|
||||
import logging
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Callable
|
||||
from typing import Any, AnyStr
|
||||
|
||||
import yaml
|
||||
from core.tools.utils.yaml_utils import load_yaml_file
|
||||
|
||||
|
||||
def get_position_map(
|
||||
@ -17,21 +16,15 @@ def get_position_map(
|
||||
:param file_name: the YAML file name, default to '_position.yaml'
|
||||
:return: a dict with name as key and index as value
|
||||
"""
|
||||
try:
|
||||
position_file_name = os.path.join(folder_path, file_name)
|
||||
if not os.path.exists(position_file_name):
|
||||
return {}
|
||||
|
||||
with open(position_file_name, encoding='utf-8') as f:
|
||||
positions = yaml.safe_load(f)
|
||||
position_map = {}
|
||||
for index, name in enumerate(positions):
|
||||
if name and isinstance(name, str):
|
||||
position_map[name.strip()] = index
|
||||
return position_map
|
||||
except:
|
||||
logging.warning(f'Failed to load the YAML position file {folder_path}/{file_name}.')
|
||||
return {}
|
||||
position_file_name = os.path.join(folder_path, file_name)
|
||||
positions = load_yaml_file(position_file_name, ignore_error=True)
|
||||
position_map = {}
|
||||
index = 0
|
||||
for _, name in enumerate(positions):
|
||||
if name and isinstance(name, str):
|
||||
position_map[name.strip()] = index
|
||||
index += 1
|
||||
return position_map
|
||||
|
||||
|
||||
def sort_by_position_map(
|
||||
|
@ -14,6 +14,7 @@ select = [
|
||||
"I", # isort rules
|
||||
"UP", # pyupgrade rules
|
||||
"RUF019", # unnecessary-key-check
|
||||
"S506", # unsafe-yaml-load
|
||||
]
|
||||
ignore = [
|
||||
"F403", # undefined-local-with-import-star
|
||||
|
0
api/tests/unit_tests/utils/__init__.py
Normal file
0
api/tests/unit_tests/utils/__init__.py
Normal file
@ -0,0 +1,34 @@
|
||||
from textwrap import dedent
|
||||
|
||||
import pytest
|
||||
|
||||
from core.utils.position_helper import get_position_map
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def prepare_example_positions_yaml(tmp_path, monkeypatch) -> str:
|
||||
monkeypatch.chdir(tmp_path)
|
||||
tmp_path.joinpath("example_positions.yaml").write_text(dedent(
|
||||
"""\
|
||||
- first
|
||||
- second
|
||||
# - commented
|
||||
- third
|
||||
|
||||
- 9999999999999
|
||||
- forth
|
||||
"""))
|
||||
return str(tmp_path)
|
||||
|
||||
|
||||
def test_position_helper(prepare_example_positions_yaml):
|
||||
position_map = get_position_map(
|
||||
folder_path=prepare_example_positions_yaml,
|
||||
file_name='example_positions.yaml')
|
||||
assert len(position_map) == 4
|
||||
assert position_map == {
|
||||
'first': 0,
|
||||
'second': 1,
|
||||
'third': 2,
|
||||
'forth': 3,
|
||||
}
|
0
api/tests/unit_tests/utils/yaml/__init__.py
Normal file
0
api/tests/unit_tests/utils/yaml/__init__.py
Normal file
74
api/tests/unit_tests/utils/yaml/test_yaml_utils.py
Normal file
74
api/tests/unit_tests/utils/yaml/test_yaml_utils.py
Normal file
@ -0,0 +1,74 @@
|
||||
from textwrap import dedent
|
||||
|
||||
import pytest
|
||||
from yaml import YAMLError
|
||||
|
||||
from core.tools.utils.yaml_utils import load_yaml_file
|
||||
|
||||
EXAMPLE_YAML_FILE = 'example_yaml.yaml'
|
||||
INVALID_YAML_FILE = 'invalid_yaml.yaml'
|
||||
NON_EXISTING_YAML_FILE = 'non_existing_file.yaml'
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def prepare_example_yaml_file(tmp_path, monkeypatch) -> str:
|
||||
monkeypatch.chdir(tmp_path)
|
||||
file_path = tmp_path.joinpath(EXAMPLE_YAML_FILE)
|
||||
file_path.write_text(dedent(
|
||||
"""\
|
||||
address:
|
||||
city: Example City
|
||||
country: Example Country
|
||||
age: 30
|
||||
gender: male
|
||||
languages:
|
||||
- Python
|
||||
- Java
|
||||
- C++
|
||||
empty_key:
|
||||
"""))
|
||||
return str(file_path)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def prepare_invalid_yaml_file(tmp_path, monkeypatch) -> str:
|
||||
monkeypatch.chdir(tmp_path)
|
||||
file_path = tmp_path.joinpath(INVALID_YAML_FILE)
|
||||
file_path.write_text(dedent(
|
||||
"""\
|
||||
address:
|
||||
city: Example City
|
||||
country: Example Country
|
||||
age: 30
|
||||
gender: male
|
||||
languages:
|
||||
- Python
|
||||
- Java
|
||||
- C++
|
||||
"""))
|
||||
return str(file_path)
|
||||
|
||||
|
||||
def test_load_yaml_non_existing_file():
|
||||
assert load_yaml_file(file_path=NON_EXISTING_YAML_FILE) == {}
|
||||
assert load_yaml_file(file_path='') == {}
|
||||
|
||||
|
||||
def test_load_valid_yaml_file(prepare_example_yaml_file):
|
||||
yaml_data = load_yaml_file(file_path=prepare_example_yaml_file)
|
||||
assert len(yaml_data) > 0
|
||||
assert yaml_data['age'] == 30
|
||||
assert yaml_data['gender'] == 'male'
|
||||
assert yaml_data['address']['city'] == 'Example City'
|
||||
assert set(yaml_data['languages']) == {'Python', 'Java', 'C++'}
|
||||
assert yaml_data.get('empty_key') is None
|
||||
assert yaml_data.get('non_existed_key') is None
|
||||
|
||||
|
||||
def test_load_invalid_yaml_file(prepare_invalid_yaml_file):
|
||||
# yaml syntax error
|
||||
with pytest.raises(YAMLError):
|
||||
load_yaml_file(file_path=prepare_invalid_yaml_file)
|
||||
|
||||
# ignore error
|
||||
assert load_yaml_file(file_path=prepare_invalid_yaml_file, ignore_error=True) == {}
|
Loading…
x
Reference in New Issue
Block a user