diff --git a/api/core/helper/position_helper.py b/api/core/helper/position_helper.py index 04675d85bb..dd1534c791 100644 --- a/api/core/helper/position_helper.py +++ b/api/core/helper/position_helper.py @@ -13,18 +13,10 @@ def get_position_map(folder_path: str, *, file_name: str = "_position.yaml") -> :param file_name: the YAML file name, default to '_position.yaml' :return: a dict with name as key and index as value """ - position_file_name = os.path.join(folder_path, file_name) - if not position_file_name or not os.path.exists(position_file_name): - return {} - - positions = load_yaml_file(position_file_name, ignore_error=True) - position_map = {} - index = 0 - for _, name in enumerate(positions): - if name and isinstance(name, str): - position_map[name.strip()] = index - index += 1 - return position_map + position_file_path = os.path.join(folder_path, file_name) + yaml_content = load_yaml_file(file_path=position_file_path, default_value=[]) + positions = [item.strip() for item in yaml_content if item and isinstance(item, str) and item.strip()] + return {name: index for index, name in enumerate(positions)} def sort_by_position_map( diff --git a/api/core/model_runtime/model_providers/__base/ai_model.py b/api/core/model_runtime/model_providers/__base/ai_model.py index 04b539433c..0de216bf89 100644 --- a/api/core/model_runtime/model_providers/__base/ai_model.py +++ b/api/core/model_runtime/model_providers/__base/ai_model.py @@ -162,7 +162,7 @@ class AIModel(ABC): # traverse all model_schema_yaml_paths for model_schema_yaml_path in model_schema_yaml_paths: # read yaml data from yaml file - yaml_data = load_yaml_file(model_schema_yaml_path, ignore_error=True) + yaml_data = load_yaml_file(model_schema_yaml_path) new_parameter_rules = [] for parameter_rule in yaml_data.get('parameter_rules', []): diff --git a/api/core/model_runtime/model_providers/__base/model_provider.py b/api/core/model_runtime/model_providers/__base/model_provider.py index 51dd3b7e28..780460a3f7 100644 --- a/api/core/model_runtime/model_providers/__base/model_provider.py +++ b/api/core/model_runtime/model_providers/__base/model_provider.py @@ -44,7 +44,7 @@ class ModelProvider(ABC): # read provider schema from yaml file yaml_path = os.path.join(current_path, f'{provider_name}.yaml') - yaml_data = load_yaml_file(yaml_path, ignore_error=True) + yaml_data = load_yaml_file(yaml_path) try: # yaml_data to entity diff --git a/api/core/tools/provider/builtin_tool_provider.py b/api/core/tools/provider/builtin_tool_provider.py index 47e33b70c9..bcf41c90ed 100644 --- a/api/core/tools/provider/builtin_tool_provider.py +++ b/api/core/tools/provider/builtin_tool_provider.py @@ -27,7 +27,7 @@ class BuiltinToolProviderController(ToolProviderController): provider = self.__class__.__module__.split('.')[-1] yaml_path = path.join(path.dirname(path.realpath(__file__)), 'builtin', provider, f'{provider}.yaml') try: - provider_yaml = load_yaml_file(yaml_path) + provider_yaml = load_yaml_file(yaml_path, ignore_error=False) except Exception as e: raise ToolProviderNotFoundError(f'can not load provider yaml for {provider}: {e}') @@ -58,7 +58,7 @@ class BuiltinToolProviderController(ToolProviderController): for tool_file in tool_files: # get tool name tool_name = tool_file.split(".")[0] - tool = load_yaml_file(path.join(tool_path, tool_file)) + tool = load_yaml_file(path.join(tool_path, tool_file), ignore_error=False) # get tool class, import the module assistant_tool_class = load_single_subclass_from_source( diff --git a/api/core/tools/utils/yaml_utils.py b/api/core/tools/utils/yaml_utils.py index 3526647b4f..11486da7da 100644 --- a/api/core/tools/utils/yaml_utils.py +++ b/api/core/tools/utils/yaml_utils.py @@ -1,35 +1,31 @@ import logging -import os +from typing import Any import yaml from yaml import YAMLError logger = logging.getLogger(__name__) -def load_yaml_file(file_path: str, ignore_error: bool = False) -> dict: + +def load_yaml_file(file_path: str, ignore_error: bool = True, default_value: Any = {}) -> Any: """ - Safe loading a YAML file to a dict + Safe loading a YAML file :param file_path: the path of the YAML file :param ignore_error: - if True, return empty dict if error occurs and the error will be logged in warning level + if True, return default_value if error occurs and the error will be logged in warning level if False, raise error if error occurs - :return: a dict of the YAML content + :param default_value: the value returned when errors ignored + :return: an object of the YAML content """ try: - if not file_path or not os.path.exists(file_path): - raise FileNotFoundError(f'Failed to load YAML file {file_path}: file not found') - - with open(file_path, encoding='utf-8') as file: + with open(file_path, encoding='utf-8') as yaml_file: try: - return yaml.safe_load(file) + return yaml.safe_load(yaml_file) except Exception as e: raise YAMLError(f'Failed to load YAML file {file_path}: {e}') - except FileNotFoundError as e: - logger.debug(f'Failed to load YAML file {file_path}: {e}') - return {} except Exception as e: if ignore_error: logger.warning(f'Failed to load YAML file {file_path}: {e}') - return {} + return default_value else: raise e diff --git a/api/tests/unit_tests/utils/position_helper/test_position_helper.py b/api/tests/unit_tests/utils/position_helper/test_position_helper.py index c389461454..2237319904 100644 --- a/api/tests/unit_tests/utils/position_helper/test_position_helper.py +++ b/api/tests/unit_tests/utils/position_helper/test_position_helper.py @@ -21,6 +21,20 @@ def prepare_example_positions_yaml(tmp_path, monkeypatch) -> str: return str(tmp_path) +@pytest.fixture +def prepare_empty_commented_positions_yaml(tmp_path, monkeypatch) -> str: + monkeypatch.chdir(tmp_path) + tmp_path.joinpath("example_positions_all_commented.yaml").write_text(dedent( + """\ + # - commented1 + # - commented2 + - + - + + """)) + return str(tmp_path) + + def test_position_helper(prepare_example_positions_yaml): position_map = get_position_map( folder_path=prepare_example_positions_yaml, @@ -32,3 +46,10 @@ def test_position_helper(prepare_example_positions_yaml): 'third': 2, 'forth': 3, } + + +def test_position_helper_with_all_commented(prepare_empty_commented_positions_yaml): + position_map = get_position_map( + folder_path=prepare_empty_commented_positions_yaml, + file_name='example_positions_all_commented.yaml') + assert position_map == {} diff --git a/api/tests/unit_tests/utils/yaml/test_yaml_utils.py b/api/tests/unit_tests/utils/yaml/test_yaml_utils.py index 446588cde1..c0452b4e4d 100644 --- a/api/tests/unit_tests/utils/yaml/test_yaml_utils.py +++ b/api/tests/unit_tests/utils/yaml/test_yaml_utils.py @@ -53,6 +53,9 @@ def test_load_yaml_non_existing_file(): assert load_yaml_file(file_path=NON_EXISTING_YAML_FILE) == {} assert load_yaml_file(file_path='') == {} + with pytest.raises(FileNotFoundError): + load_yaml_file(file_path=NON_EXISTING_YAML_FILE, ignore_error=False) + def test_load_valid_yaml_file(prepare_example_yaml_file): yaml_data = load_yaml_file(file_path=prepare_example_yaml_file) @@ -68,7 +71,7 @@ def test_load_valid_yaml_file(prepare_example_yaml_file): def test_load_invalid_yaml_file(prepare_invalid_yaml_file): # yaml syntax error with pytest.raises(YAMLError): - load_yaml_file(file_path=prepare_invalid_yaml_file) + load_yaml_file(file_path=prepare_invalid_yaml_file, ignore_error=False) # ignore error - assert load_yaml_file(file_path=prepare_invalid_yaml_file, ignore_error=True) == {} + assert load_yaml_file(file_path=prepare_invalid_yaml_file) == {}