diff --git a/api/core/extension/extensible.py b/api/core/extension/extensible.py index 1809dcd8df..c9c57b60b3 100644 --- a/api/core/extension/extensible.py +++ b/api/core/extension/extensible.py @@ -1,5 +1,4 @@ import enum -import importlib.util import json import logging import os @@ -7,6 +6,7 @@ from typing import Any, Optional from pydantic import BaseModel +from core.utils.module_import_helper import load_single_subclass_from_source from core.utils.position_helper import sort_to_dict_by_position_map @@ -73,17 +73,9 @@ class Extensible: # Dynamic loading {subdir_name}.py file and find the subclass of Extensible py_path = os.path.join(subdir_path, extension_name + '.py') - spec = importlib.util.spec_from_file_location(extension_name, py_path) - mod = importlib.util.module_from_spec(spec) - spec.loader.exec_module(mod) - - extension_class = None - for name, obj in vars(mod).items(): - if isinstance(obj, type) and issubclass(obj, cls) and obj != cls: - extension_class = obj - break - - if not extension_class: + try: + extension_class = load_single_subclass_from_source(extension_name, py_path, cls) + except Exception: logging.warning(f"Missing subclass of {cls.__name__} in {py_path}, Skip.") continue diff --git a/api/core/model_runtime/model_providers/__base/model_provider.py b/api/core/model_runtime/model_providers/__base/model_provider.py index 97ce07d35f..7c839a9672 100644 --- a/api/core/model_runtime/model_providers/__base/model_provider.py +++ b/api/core/model_runtime/model_providers/__base/model_provider.py @@ -1,4 +1,3 @@ -import importlib import os from abc import ABC, abstractmethod @@ -7,6 +6,7 @@ import yaml from core.model_runtime.entities.model_entities import AIModelEntity, ModelType from core.model_runtime.entities.provider_entities import ProviderEntity from core.model_runtime.model_providers.__base.ai_model import AIModel +from core.utils.module_import_helper import get_subclasses_from_module, import_module_from_source class ModelProvider(ABC): @@ -104,17 +104,10 @@ class ModelProvider(ABC): # Dynamic loading {model_type_name}.py file and find the subclass of AIModel parent_module = '.'.join(self.__class__.__module__.split('.')[:-1]) - spec = importlib.util.spec_from_file_location(f"{parent_module}.{model_type_name}.{model_type_name}", model_type_py_path) - mod = importlib.util.module_from_spec(spec) - spec.loader.exec_module(mod) - - model_class = None - for name, obj in vars(mod).items(): - if (isinstance(obj, type) and issubclass(obj, AIModel) and not obj.__abstractmethods__ - and obj != AIModel and obj.__module__ == mod.__name__): - model_class = obj - break - + mod = import_module_from_source( + f'{parent_module}.{model_type_name}.{model_type_name}', model_type_py_path) + model_class = next(filter(lambda x: x.__module__ == mod.__name__ and not x.__abstractmethods__, + get_subclasses_from_module(mod, AIModel)), None) if not model_class: raise Exception(f'Missing AIModel Class for model type {model_type} in {model_type_py_path}') diff --git a/api/core/model_runtime/model_providers/model_provider_factory.py b/api/core/model_runtime/model_providers/model_provider_factory.py index ee0385c6d0..44a1cf2e84 100644 --- a/api/core/model_runtime/model_providers/model_provider_factory.py +++ b/api/core/model_runtime/model_providers/model_provider_factory.py @@ -1,4 +1,3 @@ -import importlib import logging import os from typing import Optional @@ -10,6 +9,7 @@ from core.model_runtime.entities.provider_entities import ProviderConfig, Provid from core.model_runtime.model_providers.__base.model_provider import ModelProvider from core.model_runtime.schema_validators.model_credential_schema_validator import ModelCredentialSchemaValidator from core.model_runtime.schema_validators.provider_credential_schema_validator import ProviderCredentialSchemaValidator +from core.utils.module_import_helper import load_single_subclass_from_source from core.utils.position_helper import get_position_map, sort_to_dict_by_position_map logger = logging.getLogger(__name__) @@ -229,15 +229,10 @@ class ModelProviderFactory: # Dynamic loading {model_provider_name}.py file and find the subclass of ModelProvider py_path = os.path.join(model_provider_dir_path, model_provider_name + '.py') - spec = importlib.util.spec_from_file_location(f'core.model_runtime.model_providers.{model_provider_name}.{model_provider_name}', py_path) - mod = importlib.util.module_from_spec(spec) - spec.loader.exec_module(mod) - - model_provider_class = None - for name, obj in vars(mod).items(): - if isinstance(obj, type) and issubclass(obj, ModelProvider) and obj != ModelProvider: - model_provider_class = obj - break + model_provider_class = load_single_subclass_from_source( + module_name=f'core.model_runtime.model_providers.{model_provider_name}.{model_provider_name}', + script_path=py_path, + parent_type=ModelProvider) if not model_provider_class: logger.warning(f"Missing Model Provider Class that extends ModelProvider in {py_path}, Skip.") diff --git a/api/core/tools/provider/builtin_tool_provider.py b/api/core/tools/provider/builtin_tool_provider.py index 824f91c822..62e664a8f8 100644 --- a/api/core/tools/provider/builtin_tool_provider.py +++ b/api/core/tools/provider/builtin_tool_provider.py @@ -1,4 +1,3 @@ -import importlib from abc import abstractmethod from os import listdir, path from typing import Any @@ -16,6 +15,7 @@ from core.tools.errors import ( from core.tools.provider.tool_provider import ToolProviderController from core.tools.tool.builtin_tool import BuiltinTool from core.tools.tool.tool import Tool +from core.utils.module_import_helper import load_single_subclass_from_source class BuiltinToolProviderController(ToolProviderController): @@ -63,16 +63,11 @@ class BuiltinToolProviderController(ToolProviderController): tool_name = tool_file.split(".")[0] tool = load(f.read(), FullLoader) # get tool class, import the module - py_path = path.join(path.dirname(path.realpath(__file__)), 'builtin', provider, 'tools', f'{tool_name}.py') - spec = importlib.util.spec_from_file_location(f'core.tools.provider.builtin.{provider}.tools.{tool_name}', py_path) - mod = importlib.util.module_from_spec(spec) - spec.loader.exec_module(mod) - - # get all the classes in the module - classes = [x for _, x in vars(mod).items() - if isinstance(x, type) and x not in [BuiltinTool, Tool] and issubclass(x, BuiltinTool) - ] - assistant_tool_class = classes[0] + assistant_tool_class = load_single_subclass_from_source( + module_name=f'core.tools.provider.builtin.{provider}.tools.{tool_name}', + script_path=path.join(path.dirname(path.realpath(__file__)), + 'builtin', provider, 'tools', f'{tool_name}.py'), + parent_type=BuiltinTool) tools.append(assistant_tool_class(**tool)) self.tools = tools diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 2ac8f27bab..a09821279a 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -1,4 +1,3 @@ -import importlib import json import logging import mimetypes @@ -34,6 +33,7 @@ from core.tools.utils.configuration import ( ToolParameterConfigurationManager, ) from core.tools.utils.encoder import serialize_base_model_dict +from core.utils.module_import_helper import load_single_subclass_from_source from extensions.ext_database import db from models.tools import ApiToolProvider, BuiltinToolProvider @@ -72,21 +72,11 @@ class ToolManager: if provider_entity is None: # fetch the provider from .provider.builtin - py_path = path.join(path.dirname(path.realpath(__file__)), 'builtin', provider, f'{provider}.py') - spec = importlib.util.spec_from_file_location(f'core.tools.provider.builtin.{provider}.{provider}', py_path) - mod = importlib.util.module_from_spec(spec) - spec.loader.exec_module(mod) - - # get all the classes in the module - classes = [ x for _, x in vars(mod).items() - if isinstance(x, type) and x != ToolProviderController and issubclass(x, ToolProviderController) - ] - if len(classes) == 0: - raise ToolProviderNotFoundError(f'provider {provider} not found') - if len(classes) > 1: - raise ToolProviderNotFoundError(f'multiple providers found for {provider}') - - provider_entity = classes[0]() + provider_class = load_single_subclass_from_source( + module_name=f'core.tools.provider.builtin.{provider}.{provider}', + script_path=path.join(path.dirname(path.realpath(__file__)), 'builtin', provider, f'{provider}.py'), + parent_type=ToolProviderController) + provider_entity = provider_class() return provider_entity.invoke(tool_id, tool_name, tool_parameters, credentials, prompt_messages) @@ -330,23 +320,12 @@ class ToolManager: if provider.startswith('__'): continue - py_path = path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin', provider, f'{provider}.py') - spec = importlib.util.spec_from_file_location(f'core.tools.provider.builtin.{provider}.{provider}', py_path) - mod = importlib.util.module_from_spec(spec) - spec.loader.exec_module(mod) - - # load all classes - classes = [ - obj for name, obj in vars(mod).items() - if isinstance(obj, type) and obj != BuiltinToolProviderController and issubclass(obj, BuiltinToolProviderController) - ] - if len(classes) == 0: - raise ToolProviderNotFoundError(f'provider {provider} not found') - if len(classes) > 1: - raise ToolProviderNotFoundError(f'multiple providers found for {provider}') - # init provider - provider_class = classes[0] + provider_class = load_single_subclass_from_source( + module_name=f'core.tools.provider.builtin.{provider}.{provider}', + script_path=path.join(path.dirname(path.realpath(__file__)), + 'provider', 'builtin', provider, f'{provider}.py'), + parent_type=BuiltinToolProviderController) builtin_providers.append(provider_class()) # cache the builtin providers diff --git a/api/core/utils/module_import_helper.py b/api/core/utils/module_import_helper.py new file mode 100644 index 0000000000..9e6e02f29f --- /dev/null +++ b/api/core/utils/module_import_helper.py @@ -0,0 +1,62 @@ +import importlib.util +import logging +import sys +from types import ModuleType +from typing import AnyStr + + +def import_module_from_source( + module_name: str, + py_file_path: AnyStr, + use_lazy_loader: bool = False +) -> ModuleType: + """ + Importing a module from the source file directly + """ + try: + existed_spec = importlib.util.find_spec(module_name) + if existed_spec: + spec = existed_spec + else: + # Refer to: https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly + spec = importlib.util.spec_from_file_location(module_name, py_file_path) + if use_lazy_loader: + # Refer to: https://docs.python.org/3/library/importlib.html#implementing-lazy-imports + spec.loader = importlib.util.LazyLoader(spec.loader) + module = importlib.util.module_from_spec(spec) + if not existed_spec: + sys.modules[module_name] = module + spec.loader.exec_module(module) + return module + except Exception as e: + logging.exception(f'Failed to load module {module_name} from {py_file_path}: {str(e)}') + raise e + + +def get_subclasses_from_module(mod: ModuleType, parent_type: type) -> list[type]: + """ + Get all the subclasses of the parent type from the module + """ + classes = [x for _, x in vars(mod).items() + if isinstance(x, type) and x != parent_type and issubclass(x, parent_type)] + return classes + + +def load_single_subclass_from_source( + module_name: str, + script_path: AnyStr, + parent_type: type, + use_lazy_loader: bool = False, +) -> type: + """ + Load a single subclass from the source + """ + module = import_module_from_source(module_name, script_path, use_lazy_loader) + subclasses = get_subclasses_from_module(module, parent_type) + match len(subclasses): + case 1: + return subclasses[0] + case 0: + raise Exception(f'Missing subclass of {parent_type.__name__} in {script_path}') + case _: + raise Exception(f'Multiple subclasses of {parent_type.__name__} in {script_path}') diff --git a/api/tests/integration_tests/utils/child_class.py b/api/tests/integration_tests/utils/child_class.py new file mode 100644 index 0000000000..f9e5f341ff --- /dev/null +++ b/api/tests/integration_tests/utils/child_class.py @@ -0,0 +1,7 @@ +from tests.integration_tests.utils.parent_class import ParentClass + + +class ChildClass(ParentClass): + def __init__(self, name: str): + super().__init__(name) + self.name = name diff --git a/api/tests/integration_tests/utils/lazy_load_class.py b/api/tests/integration_tests/utils/lazy_load_class.py new file mode 100644 index 0000000000..ec881a470a --- /dev/null +++ b/api/tests/integration_tests/utils/lazy_load_class.py @@ -0,0 +1,7 @@ +from tests.integration_tests.utils.parent_class import ParentClass + + +class LazyLoadChildClass(ParentClass): + def __init__(self, name: str): + super().__init__(name) + self.name = name diff --git a/api/tests/integration_tests/utils/parent_class.py b/api/tests/integration_tests/utils/parent_class.py new file mode 100644 index 0000000000..39fc95256e --- /dev/null +++ b/api/tests/integration_tests/utils/parent_class.py @@ -0,0 +1,6 @@ +class ParentClass: + def __init__(self, name): + self.name = name + + def get_name(self): + return self.name \ No newline at end of file diff --git a/api/tests/integration_tests/utils/test_module_import_helper.py b/api/tests/integration_tests/utils/test_module_import_helper.py new file mode 100644 index 0000000000..e7da226434 --- /dev/null +++ b/api/tests/integration_tests/utils/test_module_import_helper.py @@ -0,0 +1,32 @@ +import os + +from core.utils.module_import_helper import load_single_subclass_from_source, import_module_from_source +from tests.integration_tests.utils.parent_class import ParentClass + + +def test_loading_subclass_from_source(): + current_path = os.getcwd() + module = load_single_subclass_from_source( + module_name='ChildClass', + script_path=os.path.join(current_path, 'child_class.py'), + parent_type=ParentClass) + assert module and module.__name__ == 'ChildClass' + + +def test_load_import_module_from_source(): + current_path = os.getcwd() + module = import_module_from_source( + module_name='ChildClass', + py_file_path=os.path.join(current_path, 'child_class.py')) + assert module and module.__name__ == 'ChildClass' + + +def test_lazy_loading_subclass_from_source(): + current_path = os.getcwd() + clz = load_single_subclass_from_source( + module_name='LazyLoadChildClass', + script_path=os.path.join(current_path, 'lazy_load_class.py'), + parent_type=ParentClass, + use_lazy_loader=True) + instance = clz('dify') + assert instance.get_name() == 'dify'