mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-11 08:18:58 +08:00
generalize helper for loading module from source (#2862)
This commit is contained in:
parent
c8b82b9d08
commit
08b727833e
@ -1,5 +1,4 @@
|
||||
import enum
|
||||
import importlib.util
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
@ -7,6 +6,7 @@ from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.utils.module_import_helper import load_single_subclass_from_source
|
||||
from core.utils.position_helper import sort_to_dict_by_position_map
|
||||
|
||||
|
||||
@ -73,17 +73,9 @@ class Extensible:
|
||||
|
||||
# Dynamic loading {subdir_name}.py file and find the subclass of Extensible
|
||||
py_path = os.path.join(subdir_path, extension_name + '.py')
|
||||
spec = importlib.util.spec_from_file_location(extension_name, py_path)
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(mod)
|
||||
|
||||
extension_class = None
|
||||
for name, obj in vars(mod).items():
|
||||
if isinstance(obj, type) and issubclass(obj, cls) and obj != cls:
|
||||
extension_class = obj
|
||||
break
|
||||
|
||||
if not extension_class:
|
||||
try:
|
||||
extension_class = load_single_subclass_from_source(extension_name, py_path, cls)
|
||||
except Exception:
|
||||
logging.warning(f"Missing subclass of {cls.__name__} in {py_path}, Skip.")
|
||||
continue
|
||||
|
||||
|
@ -1,4 +1,3 @@
|
||||
import importlib
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
@ -7,6 +6,7 @@ import yaml
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||
from core.model_runtime.entities.provider_entities import ProviderEntity
|
||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
from core.utils.module_import_helper import get_subclasses_from_module, import_module_from_source
|
||||
|
||||
|
||||
class ModelProvider(ABC):
|
||||
@ -104,17 +104,10 @@ class ModelProvider(ABC):
|
||||
|
||||
# Dynamic loading {model_type_name}.py file and find the subclass of AIModel
|
||||
parent_module = '.'.join(self.__class__.__module__.split('.')[:-1])
|
||||
spec = importlib.util.spec_from_file_location(f"{parent_module}.{model_type_name}.{model_type_name}", model_type_py_path)
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(mod)
|
||||
|
||||
model_class = None
|
||||
for name, obj in vars(mod).items():
|
||||
if (isinstance(obj, type) and issubclass(obj, AIModel) and not obj.__abstractmethods__
|
||||
and obj != AIModel and obj.__module__ == mod.__name__):
|
||||
model_class = obj
|
||||
break
|
||||
|
||||
mod = import_module_from_source(
|
||||
f'{parent_module}.{model_type_name}.{model_type_name}', model_type_py_path)
|
||||
model_class = next(filter(lambda x: x.__module__ == mod.__name__ and not x.__abstractmethods__,
|
||||
get_subclasses_from_module(mod, AIModel)), None)
|
||||
if not model_class:
|
||||
raise Exception(f'Missing AIModel Class for model type {model_type} in {model_type_py_path}')
|
||||
|
||||
|
@ -1,4 +1,3 @@
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional
|
||||
@ -10,6 +9,7 @@ from core.model_runtime.entities.provider_entities import ProviderConfig, Provid
|
||||
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
||||
from core.model_runtime.schema_validators.model_credential_schema_validator import ModelCredentialSchemaValidator
|
||||
from core.model_runtime.schema_validators.provider_credential_schema_validator import ProviderCredentialSchemaValidator
|
||||
from core.utils.module_import_helper import load_single_subclass_from_source
|
||||
from core.utils.position_helper import get_position_map, sort_to_dict_by_position_map
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -229,15 +229,10 @@ class ModelProviderFactory:
|
||||
|
||||
# Dynamic loading {model_provider_name}.py file and find the subclass of ModelProvider
|
||||
py_path = os.path.join(model_provider_dir_path, model_provider_name + '.py')
|
||||
spec = importlib.util.spec_from_file_location(f'core.model_runtime.model_providers.{model_provider_name}.{model_provider_name}', py_path)
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(mod)
|
||||
|
||||
model_provider_class = None
|
||||
for name, obj in vars(mod).items():
|
||||
if isinstance(obj, type) and issubclass(obj, ModelProvider) and obj != ModelProvider:
|
||||
model_provider_class = obj
|
||||
break
|
||||
model_provider_class = load_single_subclass_from_source(
|
||||
module_name=f'core.model_runtime.model_providers.{model_provider_name}.{model_provider_name}',
|
||||
script_path=py_path,
|
||||
parent_type=ModelProvider)
|
||||
|
||||
if not model_provider_class:
|
||||
logger.warning(f"Missing Model Provider Class that extends ModelProvider in {py_path}, Skip.")
|
||||
|
@ -1,4 +1,3 @@
|
||||
import importlib
|
||||
from abc import abstractmethod
|
||||
from os import listdir, path
|
||||
from typing import Any
|
||||
@ -16,6 +15,7 @@ from core.tools.errors import (
|
||||
from core.tools.provider.tool_provider import ToolProviderController
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
from core.tools.tool.tool import Tool
|
||||
from core.utils.module_import_helper import load_single_subclass_from_source
|
||||
|
||||
|
||||
class BuiltinToolProviderController(ToolProviderController):
|
||||
@ -63,16 +63,11 @@ class BuiltinToolProviderController(ToolProviderController):
|
||||
tool_name = tool_file.split(".")[0]
|
||||
tool = load(f.read(), FullLoader)
|
||||
# get tool class, import the module
|
||||
py_path = path.join(path.dirname(path.realpath(__file__)), 'builtin', provider, 'tools', f'{tool_name}.py')
|
||||
spec = importlib.util.spec_from_file_location(f'core.tools.provider.builtin.{provider}.tools.{tool_name}', py_path)
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(mod)
|
||||
|
||||
# get all the classes in the module
|
||||
classes = [x for _, x in vars(mod).items()
|
||||
if isinstance(x, type) and x not in [BuiltinTool, Tool] and issubclass(x, BuiltinTool)
|
||||
]
|
||||
assistant_tool_class = classes[0]
|
||||
assistant_tool_class = load_single_subclass_from_source(
|
||||
module_name=f'core.tools.provider.builtin.{provider}.tools.{tool_name}',
|
||||
script_path=path.join(path.dirname(path.realpath(__file__)),
|
||||
'builtin', provider, 'tools', f'{tool_name}.py'),
|
||||
parent_type=BuiltinTool)
|
||||
tools.append(assistant_tool_class(**tool))
|
||||
|
||||
self.tools = tools
|
||||
|
@ -1,4 +1,3 @@
|
||||
import importlib
|
||||
import json
|
||||
import logging
|
||||
import mimetypes
|
||||
@ -34,6 +33,7 @@ from core.tools.utils.configuration import (
|
||||
ToolParameterConfigurationManager,
|
||||
)
|
||||
from core.tools.utils.encoder import serialize_base_model_dict
|
||||
from core.utils.module_import_helper import load_single_subclass_from_source
|
||||
from extensions.ext_database import db
|
||||
from models.tools import ApiToolProvider, BuiltinToolProvider
|
||||
|
||||
@ -72,21 +72,11 @@ class ToolManager:
|
||||
|
||||
if provider_entity is None:
|
||||
# fetch the provider from .provider.builtin
|
||||
py_path = path.join(path.dirname(path.realpath(__file__)), 'builtin', provider, f'{provider}.py')
|
||||
spec = importlib.util.spec_from_file_location(f'core.tools.provider.builtin.{provider}.{provider}', py_path)
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(mod)
|
||||
|
||||
# get all the classes in the module
|
||||
classes = [ x for _, x in vars(mod).items()
|
||||
if isinstance(x, type) and x != ToolProviderController and issubclass(x, ToolProviderController)
|
||||
]
|
||||
if len(classes) == 0:
|
||||
raise ToolProviderNotFoundError(f'provider {provider} not found')
|
||||
if len(classes) > 1:
|
||||
raise ToolProviderNotFoundError(f'multiple providers found for {provider}')
|
||||
|
||||
provider_entity = classes[0]()
|
||||
provider_class = load_single_subclass_from_source(
|
||||
module_name=f'core.tools.provider.builtin.{provider}.{provider}',
|
||||
script_path=path.join(path.dirname(path.realpath(__file__)), 'builtin', provider, f'{provider}.py'),
|
||||
parent_type=ToolProviderController)
|
||||
provider_entity = provider_class()
|
||||
|
||||
return provider_entity.invoke(tool_id, tool_name, tool_parameters, credentials, prompt_messages)
|
||||
|
||||
@ -330,23 +320,12 @@ class ToolManager:
|
||||
if provider.startswith('__'):
|
||||
continue
|
||||
|
||||
py_path = path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin', provider, f'{provider}.py')
|
||||
spec = importlib.util.spec_from_file_location(f'core.tools.provider.builtin.{provider}.{provider}', py_path)
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(mod)
|
||||
|
||||
# load all classes
|
||||
classes = [
|
||||
obj for name, obj in vars(mod).items()
|
||||
if isinstance(obj, type) and obj != BuiltinToolProviderController and issubclass(obj, BuiltinToolProviderController)
|
||||
]
|
||||
if len(classes) == 0:
|
||||
raise ToolProviderNotFoundError(f'provider {provider} not found')
|
||||
if len(classes) > 1:
|
||||
raise ToolProviderNotFoundError(f'multiple providers found for {provider}')
|
||||
|
||||
# init provider
|
||||
provider_class = classes[0]
|
||||
provider_class = load_single_subclass_from_source(
|
||||
module_name=f'core.tools.provider.builtin.{provider}.{provider}',
|
||||
script_path=path.join(path.dirname(path.realpath(__file__)),
|
||||
'provider', 'builtin', provider, f'{provider}.py'),
|
||||
parent_type=BuiltinToolProviderController)
|
||||
builtin_providers.append(provider_class())
|
||||
|
||||
# cache the builtin providers
|
||||
|
62
api/core/utils/module_import_helper.py
Normal file
62
api/core/utils/module_import_helper.py
Normal file
@ -0,0 +1,62 @@
|
||||
import importlib.util
|
||||
import logging
|
||||
import sys
|
||||
from types import ModuleType
|
||||
from typing import AnyStr
|
||||
|
||||
|
||||
def import_module_from_source(
|
||||
module_name: str,
|
||||
py_file_path: AnyStr,
|
||||
use_lazy_loader: bool = False
|
||||
) -> ModuleType:
|
||||
"""
|
||||
Importing a module from the source file directly
|
||||
"""
|
||||
try:
|
||||
existed_spec = importlib.util.find_spec(module_name)
|
||||
if existed_spec:
|
||||
spec = existed_spec
|
||||
else:
|
||||
# Refer to: https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly
|
||||
spec = importlib.util.spec_from_file_location(module_name, py_file_path)
|
||||
if use_lazy_loader:
|
||||
# Refer to: https://docs.python.org/3/library/importlib.html#implementing-lazy-imports
|
||||
spec.loader = importlib.util.LazyLoader(spec.loader)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
if not existed_spec:
|
||||
sys.modules[module_name] = module
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
except Exception as e:
|
||||
logging.exception(f'Failed to load module {module_name} from {py_file_path}: {str(e)}')
|
||||
raise e
|
||||
|
||||
|
||||
def get_subclasses_from_module(mod: ModuleType, parent_type: type) -> list[type]:
|
||||
"""
|
||||
Get all the subclasses of the parent type from the module
|
||||
"""
|
||||
classes = [x for _, x in vars(mod).items()
|
||||
if isinstance(x, type) and x != parent_type and issubclass(x, parent_type)]
|
||||
return classes
|
||||
|
||||
|
||||
def load_single_subclass_from_source(
|
||||
module_name: str,
|
||||
script_path: AnyStr,
|
||||
parent_type: type,
|
||||
use_lazy_loader: bool = False,
|
||||
) -> type:
|
||||
"""
|
||||
Load a single subclass from the source
|
||||
"""
|
||||
module = import_module_from_source(module_name, script_path, use_lazy_loader)
|
||||
subclasses = get_subclasses_from_module(module, parent_type)
|
||||
match len(subclasses):
|
||||
case 1:
|
||||
return subclasses[0]
|
||||
case 0:
|
||||
raise Exception(f'Missing subclass of {parent_type.__name__} in {script_path}')
|
||||
case _:
|
||||
raise Exception(f'Multiple subclasses of {parent_type.__name__} in {script_path}')
|
7
api/tests/integration_tests/utils/child_class.py
Normal file
7
api/tests/integration_tests/utils/child_class.py
Normal file
@ -0,0 +1,7 @@
|
||||
from tests.integration_tests.utils.parent_class import ParentClass
|
||||
|
||||
|
||||
class ChildClass(ParentClass):
|
||||
def __init__(self, name: str):
|
||||
super().__init__(name)
|
||||
self.name = name
|
7
api/tests/integration_tests/utils/lazy_load_class.py
Normal file
7
api/tests/integration_tests/utils/lazy_load_class.py
Normal file
@ -0,0 +1,7 @@
|
||||
from tests.integration_tests.utils.parent_class import ParentClass
|
||||
|
||||
|
||||
class LazyLoadChildClass(ParentClass):
|
||||
def __init__(self, name: str):
|
||||
super().__init__(name)
|
||||
self.name = name
|
6
api/tests/integration_tests/utils/parent_class.py
Normal file
6
api/tests/integration_tests/utils/parent_class.py
Normal file
@ -0,0 +1,6 @@
|
||||
class ParentClass:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
def get_name(self):
|
||||
return self.name
|
@ -0,0 +1,32 @@
|
||||
import os
|
||||
|
||||
from core.utils.module_import_helper import load_single_subclass_from_source, import_module_from_source
|
||||
from tests.integration_tests.utils.parent_class import ParentClass
|
||||
|
||||
|
||||
def test_loading_subclass_from_source():
|
||||
current_path = os.getcwd()
|
||||
module = load_single_subclass_from_source(
|
||||
module_name='ChildClass',
|
||||
script_path=os.path.join(current_path, 'child_class.py'),
|
||||
parent_type=ParentClass)
|
||||
assert module and module.__name__ == 'ChildClass'
|
||||
|
||||
|
||||
def test_load_import_module_from_source():
|
||||
current_path = os.getcwd()
|
||||
module = import_module_from_source(
|
||||
module_name='ChildClass',
|
||||
py_file_path=os.path.join(current_path, 'child_class.py'))
|
||||
assert module and module.__name__ == 'ChildClass'
|
||||
|
||||
|
||||
def test_lazy_loading_subclass_from_source():
|
||||
current_path = os.getcwd()
|
||||
clz = load_single_subclass_from_source(
|
||||
module_name='LazyLoadChildClass',
|
||||
script_path=os.path.join(current_path, 'lazy_load_class.py'),
|
||||
parent_type=ParentClass,
|
||||
use_lazy_loader=True)
|
||||
instance = clz('dify')
|
||||
assert instance.get_name() == 'dify'
|
Loading…
x
Reference in New Issue
Block a user