fix(core): Fix incorrect type hints. (#5427)

This commit is contained in:
-LAN- 2024-06-20 15:16:21 +08:00 committed by GitHub
parent e4259a8f13
commit 23fa3dedc4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 166 additions and 117 deletions

View File

@ -1,5 +1,5 @@
import enum import enum
import importlib import importlib.util
import json import json
import logging import logging
import os import os
@ -74,6 +74,8 @@ class Extensible:
# Dynamic loading {subdir_name}.py file and find the subclass of Extensible # Dynamic loading {subdir_name}.py file and find the subclass of Extensible
py_path = os.path.join(subdir_path, extension_name + '.py') py_path = os.path.join(subdir_path, extension_name + '.py')
spec = importlib.util.spec_from_file_location(extension_name, py_path) spec = importlib.util.spec_from_file_location(extension_name, py_path)
if not spec or not spec.loader:
raise Exception(f"Failed to load module {extension_name} from {py_path}")
mod = importlib.util.module_from_spec(spec) mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod) spec.loader.exec_module(mod)
@ -108,6 +110,6 @@ class Extensible:
position=position position=position
)) ))
sorted_extensions = sort_to_dict_by_position_map(position_map, extensions, lambda x: x.name) sorted_extensions = sort_to_dict_by_position_map(position_map=position_map, data=extensions, name_func=lambda x: x.name)
return sorted_extensions return sorted_extensions

View File

@ -5,11 +5,7 @@ from types import ModuleType
from typing import AnyStr from typing import AnyStr
def import_module_from_source( def import_module_from_source(*, module_name: str, py_file_path: AnyStr, use_lazy_loader: bool = False) -> ModuleType:
module_name: str,
py_file_path: AnyStr,
use_lazy_loader: bool = False
) -> ModuleType:
""" """
Importing a module from the source file directly Importing a module from the source file directly
""" """
@ -17,9 +13,13 @@ def import_module_from_source(
existed_spec = importlib.util.find_spec(module_name) existed_spec = importlib.util.find_spec(module_name)
if existed_spec: if existed_spec:
spec = existed_spec spec = existed_spec
if not spec.loader:
raise Exception(f"Failed to load module {module_name} from {py_file_path}")
else: else:
# Refer to: https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly # Refer to: https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly
spec = importlib.util.spec_from_file_location(module_name, py_file_path) spec = importlib.util.spec_from_file_location(module_name, py_file_path)
if not spec or not spec.loader:
raise Exception(f"Failed to load module {module_name} from {py_file_path}")
if use_lazy_loader: if use_lazy_loader:
# Refer to: https://docs.python.org/3/library/importlib.html#implementing-lazy-imports # Refer to: https://docs.python.org/3/library/importlib.html#implementing-lazy-imports
spec.loader = importlib.util.LazyLoader(spec.loader) spec.loader = importlib.util.LazyLoader(spec.loader)
@ -29,7 +29,7 @@ def import_module_from_source(
spec.loader.exec_module(module) spec.loader.exec_module(module)
return module return module
except Exception as e: except Exception as e:
logging.exception(f'Failed to load module {module_name} from {py_file_path}: {str(e)}') logging.exception(f"Failed to load module {module_name} from {py_file_path}: {str(e)}")
raise e raise e
@ -43,15 +43,14 @@ def get_subclasses_from_module(mod: ModuleType, parent_type: type) -> list[type]
def load_single_subclass_from_source( def load_single_subclass_from_source(
module_name: str, *, module_name: str, script_path: AnyStr, parent_type: type, use_lazy_loader: bool = False
script_path: AnyStr,
parent_type: type,
use_lazy_loader: bool = False,
) -> type: ) -> type:
""" """
Load a single subclass from the source Load a single subclass from the source
""" """
module = import_module_from_source(module_name, script_path, use_lazy_loader) module = import_module_from_source(
module_name=module_name, py_file_path=script_path, use_lazy_loader=use_lazy_loader
)
subclasses = get_subclasses_from_module(module, parent_type) subclasses = get_subclasses_from_module(module, parent_type)
match len(subclasses): match len(subclasses):
case 1: case 1:

View File

@ -1,15 +1,12 @@
import os import os
from collections import OrderedDict from collections import OrderedDict
from collections.abc import Callable from collections.abc import Callable
from typing import Any, AnyStr from typing import Any
from core.tools.utils.yaml_utils import load_yaml_file from core.tools.utils.yaml_utils import load_yaml_file
def get_position_map( def get_position_map(folder_path: str, *, file_name: str = "_position.yaml") -> dict[str, int]:
folder_path: AnyStr,
file_name: str = '_position.yaml',
) -> dict[str, int]:
""" """
Get the mapping from name to index from a YAML file Get the mapping from name to index from a YAML file
:param folder_path: :param folder_path:

View File

@ -1,6 +1,6 @@
import logging import logging
import os import os
from collections.abc import Generator from collections.abc import Callable, Generator
from typing import IO, Optional, Union, cast from typing import IO, Optional, Union, cast
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
@ -102,7 +102,7 @@ class ModelInstance:
def invoke_llm(self, prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None, def invoke_llm(self, prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None, callbacks: list[Callback] = None) \ stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) \
-> Union[LLMResult, Generator]: -> Union[LLMResult, Generator]:
""" """
Invoke large language model Invoke large language model
@ -291,7 +291,7 @@ class ModelInstance:
streaming=streaming streaming=streaming
) )
def _round_robin_invoke(self, function: callable, *args, **kwargs): def _round_robin_invoke(self, function: Callable, *args, **kwargs):
""" """
Round-robin invoke Round-robin invoke
:param function: function to invoke :param function: function to invoke
@ -437,6 +437,7 @@ class LBModelManager:
while True: while True:
current_index = redis_client.incr(cache_key) current_index = redis_client.incr(cache_key)
current_index = cast(int, current_index)
if current_index >= 10000000: if current_index >= 10000000:
current_index = 1 current_index = 1
redis_client.set(cache_key, current_index) redis_client.set(cache_key, current_index)
@ -499,7 +500,10 @@ class LBModelManager:
config.id config.id
) )
return redis_client.exists(cooldown_cache_key)
res = redis_client.exists(cooldown_cache_key)
res = cast(bool, res)
return res
@classmethod @classmethod
def get_config_in_cooldown_and_ttl(cls, tenant_id: str, def get_config_in_cooldown_and_ttl(cls, tenant_id: str,
@ -528,4 +532,5 @@ class LBModelManager:
if ttl == -2: if ttl == -2:
return False, 0 return False, 0
ttl = cast(int, ttl)
return True, ttl return True, ttl

View File

@ -1,10 +1,11 @@
from collections.abc import Sequence
from enum import Enum from enum import Enum
from typing import Optional from typing import Optional
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType, ProviderModel from core.model_runtime.entities.model_entities import ModelType, ProviderModel
class ConfigurateMethod(Enum): class ConfigurateMethod(Enum):
@ -93,8 +94,8 @@ class SimpleProviderEntity(BaseModel):
label: I18nObject label: I18nObject
icon_small: Optional[I18nObject] = None icon_small: Optional[I18nObject] = None
icon_large: Optional[I18nObject] = None icon_large: Optional[I18nObject] = None
supported_model_types: list[ModelType] supported_model_types: Sequence[ModelType]
models: list[AIModelEntity] = [] models: list[ProviderModel] = []
class ProviderHelpEntity(BaseModel): class ProviderHelpEntity(BaseModel):
@ -116,7 +117,7 @@ class ProviderEntity(BaseModel):
icon_large: Optional[I18nObject] = None icon_large: Optional[I18nObject] = None
background: Optional[str] = None background: Optional[str] = None
help: Optional[ProviderHelpEntity] = None help: Optional[ProviderHelpEntity] = None
supported_model_types: list[ModelType] supported_model_types: Sequence[ModelType]
configurate_methods: list[ConfigurateMethod] configurate_methods: list[ConfigurateMethod]
models: list[ProviderModel] = [] models: list[ProviderModel] = []
provider_credential_schema: Optional[ProviderCredentialSchema] = None provider_credential_schema: Optional[ProviderCredentialSchema] = None

View File

@ -1,6 +1,7 @@
import decimal import decimal
import os import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Mapping
from typing import Optional from typing import Optional
from pydantic import ConfigDict from pydantic import ConfigDict
@ -26,15 +27,16 @@ class AIModel(ABC):
""" """
Base class for all models. Base class for all models.
""" """
model_type: ModelType model_type: ModelType
model_schemas: list[AIModelEntity] = None model_schemas: Optional[list[AIModelEntity]] = None
started_at: float = 0 started_at: float = 0
# pydantic configs # pydantic configs
model_config = ConfigDict(protected_namespaces=()) model_config = ConfigDict(protected_namespaces=())
@abstractmethod @abstractmethod
def validate_credentials(self, model: str, credentials: dict) -> None: def validate_credentials(self, model: str, credentials: Mapping) -> None:
""" """
Validate model credentials Validate model credentials
@ -90,8 +92,8 @@ class AIModel(ABC):
# get price info from predefined model schema # get price info from predefined model schema
price_config: Optional[PriceConfig] = None price_config: Optional[PriceConfig] = None
if model_schema: if model_schema and model_schema.pricing:
price_config: PriceConfig = model_schema.pricing price_config = model_schema.pricing
# get unit price # get unit price
unit_price = None unit_price = None
@ -103,13 +105,15 @@ class AIModel(ABC):
if unit_price is None: if unit_price is None:
return PriceInfo( return PriceInfo(
unit_price=decimal.Decimal('0.0'), unit_price=decimal.Decimal("0.0"),
unit=decimal.Decimal('0.0'), unit=decimal.Decimal("0.0"),
total_amount=decimal.Decimal('0.0'), total_amount=decimal.Decimal("0.0"),
currency="USD", currency="USD",
) )
# calculate total amount # calculate total amount
if not price_config:
raise ValueError(f"Price config not found for model {model}")
total_amount = tokens * unit_price * price_config.unit total_amount = tokens * unit_price * price_config.unit
total_amount = total_amount.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP) total_amount = total_amount.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
@ -209,7 +213,7 @@ class AIModel(ABC):
return model_schemas return model_schemas
def get_model_schema(self, model: str, credentials: Optional[dict] = None) -> Optional[AIModelEntity]: def get_model_schema(self, model: str, credentials: Optional[Mapping] = None) -> Optional[AIModelEntity]:
""" """
Get model schema by model name and credentials Get model schema by model name and credentials
@ -231,7 +235,7 @@ class AIModel(ABC):
return None return None
def get_customizable_model_schema_from_credentials(self, model: str, credentials: dict) -> Optional[AIModelEntity]: def get_customizable_model_schema_from_credentials(self, model: str, credentials: Mapping) -> Optional[AIModelEntity]:
""" """
Get customizable model schema from credentials Get customizable model schema from credentials
@ -241,7 +245,7 @@ class AIModel(ABC):
""" """
return self._get_customizable_model_schema(model, credentials) return self._get_customizable_model_schema(model, credentials)
def _get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: def _get_customizable_model_schema(self, model: str, credentials: Mapping) -> Optional[AIModelEntity]:
""" """
Get customizable model schema and fill in the template Get customizable model schema and fill in the template
""" """
@ -271,10 +275,20 @@ class AIModel(ABC):
parameter_rule.help = I18nObject( parameter_rule.help = I18nObject(
en_US=default_parameter_rule['help']['en_US'], en_US=default_parameter_rule['help']['en_US'],
) )
if not parameter_rule.help.en_US and ('help' in default_parameter_rule and 'en_US' in default_parameter_rule['help']): if (
parameter_rule.help.en_US = default_parameter_rule['help']['en_US'] parameter_rule.help
if not parameter_rule.help.zh_Hans and ('help' in default_parameter_rule and 'zh_Hans' in default_parameter_rule['help']): and not parameter_rule.help.en_US
parameter_rule.help.zh_Hans = default_parameter_rule['help'].get('zh_Hans', default_parameter_rule['help']['en_US']) and ("help" in default_parameter_rule and "en_US" in default_parameter_rule["help"])
):
parameter_rule.help.en_US = default_parameter_rule["help"]["en_US"]
if (
parameter_rule.help
and not parameter_rule.help.zh_Hans
and ("help" in default_parameter_rule and "zh_Hans" in default_parameter_rule["help"])
):
parameter_rule.help.zh_Hans = default_parameter_rule["help"].get(
"zh_Hans", default_parameter_rule["help"]["en_US"]
)
except ValueError: except ValueError:
pass pass
@ -284,7 +298,7 @@ class AIModel(ABC):
return schema return schema
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: def get_customizable_model_schema(self, model: str, credentials: Mapping) -> Optional[AIModelEntity]:
""" """
Get customizable model schema Get customizable model schema
@ -304,7 +318,7 @@ class AIModel(ABC):
default_parameter_rule = PARAMETER_RULE_TEMPLATE.get(name) default_parameter_rule = PARAMETER_RULE_TEMPLATE.get(name)
if not default_parameter_rule: if not default_parameter_rule:
raise Exception(f'Invalid model parameter rule name {name}') raise Exception(f"Invalid model parameter rule name {name}")
return default_parameter_rule return default_parameter_rule

View File

@ -3,7 +3,7 @@ import os
import re import re
import time import time
from abc import abstractmethod from abc import abstractmethod
from collections.abc import Generator from collections.abc import Generator, Mapping
from typing import Optional, Union from typing import Optional, Union
from pydantic import ConfigDict from pydantic import ConfigDict
@ -43,7 +43,7 @@ class LargeLanguageModel(AIModel):
def invoke(self, model: str, credentials: dict, def invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None, prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None, callbacks: list[Callback] = None) \ stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) \
-> Union[LLMResult, Generator]: -> Union[LLMResult, Generator]:
""" """
Invoke large language model Invoke large language model
@ -129,7 +129,7 @@ class LargeLanguageModel(AIModel):
user=user, user=user,
callbacks=callbacks callbacks=callbacks
) )
else: elif isinstance(result, LLMResult):
self._trigger_after_invoke_callbacks( self._trigger_after_invoke_callbacks(
model=model, model=model,
result=result, result=result,
@ -148,7 +148,7 @@ class LargeLanguageModel(AIModel):
def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None,
callbacks: list[Callback] = None) -> Union[LLMResult, Generator]: callbacks: Optional[list[Callback]] = None) -> Union[LLMResult, Generator]:
""" """
Code block mode wrapper, ensure the response is a code block with output markdown quote Code block mode wrapper, ensure the response is a code block with output markdown quote
@ -196,7 +196,7 @@ if you are not sure about the structure.
# override the system message # override the system message
prompt_messages[0] = SystemPromptMessage( prompt_messages[0] = SystemPromptMessage(
content=block_prompts content=block_prompts
.replace("{{instructions}}", prompt_messages[0].content) .replace("{{instructions}}", str(prompt_messages[0].content))
) )
else: else:
# insert the system message # insert the system message
@ -274,8 +274,9 @@ if you are not sure about the structure.
else: else:
yield piece yield piece
continue continue
new_piece = "" new_piece: str = ""
for char in piece: for char in piece:
char = str(char)
if state == "normal": if state == "normal":
if char == "`": if char == "`":
state = "in_backticks" state = "in_backticks"
@ -340,7 +341,7 @@ if you are not sure about the structure.
if state == "done": if state == "done":
continue continue
new_piece = "" new_piece: str = ""
for char in piece: for char in piece:
if state == "search_start": if state == "search_start":
if char == "`": if char == "`":
@ -365,7 +366,7 @@ if you are not sure about the structure.
# If backticks were counted but we're still collecting content, it was a false start # If backticks were counted but we're still collecting content, it was a false start
new_piece += "`" * backtick_count new_piece += "`" * backtick_count
backtick_count = 0 backtick_count = 0
new_piece += char new_piece += str(char)
elif state == "done": elif state == "done":
break break
@ -388,13 +389,14 @@ if you are not sure about the structure.
prompt_messages: list[PromptMessage], model_parameters: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stream: bool = True, stop: Optional[list[str]] = None, stream: bool = True,
user: Optional[str] = None, callbacks: list[Callback] = None) -> Generator: user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) -> Generator:
""" """
Invoke result generator Invoke result generator
:param result: result generator :param result: result generator
:return: result generator :return: result generator
""" """
callbacks = callbacks or []
prompt_message = AssistantPromptMessage( prompt_message = AssistantPromptMessage(
content="" content=""
) )
@ -489,6 +491,7 @@ if you are not sure about the structure.
def _llm_result_to_stream(self, result: LLMResult) -> Generator: def _llm_result_to_stream(self, result: LLMResult) -> Generator:
""" """
from typing_extensions import deprecated
Transform llm result to stream Transform llm result to stream
:param result: llm result :param result: llm result
@ -531,7 +534,7 @@ if you are not sure about the structure.
return [] return []
def get_model_mode(self, model: str, credentials: Optional[dict] = None) -> LLMMode: def get_model_mode(self, model: str, credentials: Optional[Mapping] = None) -> LLMMode:
""" """
Get model mode Get model mode
@ -595,7 +598,7 @@ if you are not sure about the structure.
prompt_messages: list[PromptMessage], model_parameters: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stream: bool = True, stop: Optional[list[str]] = None, stream: bool = True,
user: Optional[str] = None, callbacks: list[Callback] = None) -> None: user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) -> None:
""" """
Trigger before invoke callbacks Trigger before invoke callbacks
@ -633,7 +636,7 @@ if you are not sure about the structure.
prompt_messages: list[PromptMessage], model_parameters: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stream: bool = True, stop: Optional[list[str]] = None, stream: bool = True,
user: Optional[str] = None, callbacks: list[Callback] = None) -> None: user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) -> None:
""" """
Trigger new chunk callbacks Trigger new chunk callbacks
@ -672,7 +675,7 @@ if you are not sure about the structure.
prompt_messages: list[PromptMessage], model_parameters: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stream: bool = True, stop: Optional[list[str]] = None, stream: bool = True,
user: Optional[str] = None, callbacks: list[Callback] = None) -> None: user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) -> None:
""" """
Trigger after invoke callbacks Trigger after invoke callbacks
@ -712,7 +715,7 @@ if you are not sure about the structure.
prompt_messages: list[PromptMessage], model_parameters: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stream: bool = True, stop: Optional[list[str]] = None, stream: bool = True,
user: Optional[str] = None, callbacks: list[Callback] = None) -> None: user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) -> None:
""" """
Trigger invoke error callbacks Trigger invoke error callbacks

View File

@ -1,5 +1,6 @@
import os import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Optional
from core.helper.module_import_helper import get_subclasses_from_module, import_module_from_source from core.helper.module_import_helper import get_subclasses_from_module, import_module_from_source
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
@ -9,7 +10,7 @@ from core.tools.utils.yaml_utils import load_yaml_file
class ModelProvider(ABC): class ModelProvider(ABC):
provider_schema: ProviderEntity = None provider_schema: Optional[ProviderEntity] = None
model_instance_map: dict[str, AIModel] = {} model_instance_map: dict[str, AIModel] = {}
@abstractmethod @abstractmethod
@ -84,7 +85,7 @@ class ModelProvider(ABC):
:return: :return:
""" """
# get dirname of the current path # get dirname of the current path
provider_name = self.__class__.__module__.split('.')[-1] provider_name = self.__class__.__module__.split(".")[-1]
if f"{provider_name}.{model_type.value}" in self.model_instance_map: if f"{provider_name}.{model_type.value}" in self.model_instance_map:
return self.model_instance_map[f"{provider_name}.{model_type.value}"] return self.model_instance_map[f"{provider_name}.{model_type.value}"]
@ -101,11 +102,17 @@ class ModelProvider(ABC):
# Dynamic loading {model_type_name}.py file and find the subclass of AIModel # Dynamic loading {model_type_name}.py file and find the subclass of AIModel
parent_module = '.'.join(self.__class__.__module__.split('.')[:-1]) parent_module = '.'.join(self.__class__.__module__.split('.')[:-1])
mod = import_module_from_source( mod = import_module_from_source(
f'{parent_module}.{model_type_name}.{model_type_name}', model_type_py_path) module_name=f"{parent_module}.{model_type_name}.{model_type_name}", py_file_path=model_type_py_path
model_class = next(filter(lambda x: x.__module__ == mod.__name__ and not x.__abstractmethods__, )
get_subclasses_from_module(mod, AIModel)), None) model_class = next(
filter(
lambda x: x.__module__ == mod.__name__ and not x.__abstractmethods__,
get_subclasses_from_module(mod, AIModel),
),
None,
)
if not model_class: if not model_class:
raise Exception(f'Missing AIModel Class for model type {model_type} in {model_type_py_path}') raise Exception(f"Missing AIModel Class for model type {model_type} in {model_type_py_path}")
model_instance_map = model_class() model_instance_map = model_class()
self.model_instance_map[f"{provider_name}.{model_type.value}"] = model_instance_map self.model_instance_map[f"{provider_name}.{model_type.value}"] = model_instance_map

View File

@ -1,5 +1,6 @@
import logging import logging
import os import os
from collections.abc import Sequence
from typing import Optional from typing import Optional
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
@ -16,20 +17,21 @@ logger = logging.getLogger(__name__)
class ModelProviderExtension(BaseModel): class ModelProviderExtension(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
provider_instance: ModelProvider provider_instance: ModelProvider
name: str name: str
position: Optional[int] = None position: Optional[int] = None
model_config = ConfigDict(arbitrary_types_allowed=True)
class ModelProviderFactory: class ModelProviderFactory:
model_provider_extensions: dict[str, ModelProviderExtension] = None model_provider_extensions: Optional[dict[str, ModelProviderExtension]] = None
def __init__(self) -> None: def __init__(self) -> None:
# for cache in memory # for cache in memory
self.get_providers() self.get_providers()
def get_providers(self) -> list[ProviderEntity]: def get_providers(self) -> Sequence[ProviderEntity]:
""" """
Get all providers Get all providers
:return: list of providers :return: list of providers
@ -39,7 +41,7 @@ class ModelProviderFactory:
# traverse all model_provider_extensions # traverse all model_provider_extensions
providers = [] providers = []
for name, model_provider_extension in model_provider_extensions.items(): for model_provider_extension in model_provider_extensions.values():
# get model_provider instance # get model_provider instance
model_provider_instance = model_provider_extension.provider_instance model_provider_instance = model_provider_extension.provider_instance
@ -57,7 +59,7 @@ class ModelProviderFactory:
# return providers # return providers
return providers return providers
def provider_credentials_validate(self, provider: str, credentials: dict) -> dict: def provider_credentials_validate(self, *, provider: str, credentials: dict) -> dict:
""" """
Validate provider credentials Validate provider credentials
@ -74,6 +76,9 @@ class ModelProviderFactory:
# get provider_credential_schema and validate credentials according to the rules # get provider_credential_schema and validate credentials according to the rules
provider_credential_schema = provider_schema.provider_credential_schema provider_credential_schema = provider_schema.provider_credential_schema
if not provider_credential_schema:
raise ValueError(f"Provider {provider} does not have provider_credential_schema")
# validate provider credential schema # validate provider credential schema
validator = ProviderCredentialSchemaValidator(provider_credential_schema) validator = ProviderCredentialSchemaValidator(provider_credential_schema)
filtered_credentials = validator.validate_and_filter(credentials) filtered_credentials = validator.validate_and_filter(credentials)
@ -83,8 +88,9 @@ class ModelProviderFactory:
return filtered_credentials return filtered_credentials
def model_credentials_validate(self, provider: str, model_type: ModelType, def model_credentials_validate(
model: str, credentials: dict) -> dict: self, *, provider: str, model_type: ModelType, model: str, credentials: dict
) -> dict:
""" """
Validate model credentials Validate model credentials
@ -103,6 +109,9 @@ class ModelProviderFactory:
# get model_credential_schema and validate credentials according to the rules # get model_credential_schema and validate credentials according to the rules
model_credential_schema = provider_schema.model_credential_schema model_credential_schema = provider_schema.model_credential_schema
if not model_credential_schema:
raise ValueError(f"Provider {provider} does not have model_credential_schema")
# validate model credential schema # validate model credential schema
validator = ModelCredentialSchemaValidator(model_type, model_credential_schema) validator = ModelCredentialSchemaValidator(model_type, model_credential_schema)
filtered_credentials = validator.validate_and_filter(credentials) filtered_credentials = validator.validate_and_filter(credentials)
@ -115,11 +124,13 @@ class ModelProviderFactory:
return filtered_credentials return filtered_credentials
def get_models(self, def get_models(
self,
*,
provider: Optional[str] = None, provider: Optional[str] = None,
model_type: Optional[ModelType] = None, model_type: Optional[ModelType] = None,
provider_configs: Optional[list[ProviderConfig]] = None) \ provider_configs: Optional[list[ProviderConfig]] = None,
-> list[SimpleProviderEntity]: ) -> list[SimpleProviderEntity]:
""" """
Get all models for given model type Get all models for given model type
@ -128,6 +139,8 @@ class ModelProviderFactory:
:param provider_configs: list of provider configs :param provider_configs: list of provider configs
:return: list of models :return: list of models
""" """
provider_configs = provider_configs or []
# scan all providers # scan all providers
model_provider_extensions = self._get_model_provider_map() model_provider_extensions = self._get_model_provider_map()
@ -184,7 +197,7 @@ class ModelProviderFactory:
# get the provider extension # get the provider extension
model_provider_extension = model_provider_extensions.get(provider) model_provider_extension = model_provider_extensions.get(provider)
if not model_provider_extension: if not model_provider_extension:
raise Exception(f'Invalid provider: {provider}') raise Exception(f"Invalid provider: {provider}")
# get the provider instance # get the provider instance
model_provider_instance = model_provider_extension.provider_instance model_provider_instance = model_provider_extension.provider_instance
@ -192,10 +205,22 @@ class ModelProviderFactory:
return model_provider_instance return model_provider_instance
def _get_model_provider_map(self) -> dict[str, ModelProviderExtension]: def _get_model_provider_map(self) -> dict[str, ModelProviderExtension]:
"""
Retrieves the model provider map.
This method retrieves the model provider map, which is a dictionary containing the model provider names as keys
and instances of `ModelProviderExtension` as values. The model provider map is used to store information about
available model providers.
Returns:
A dictionary containing the model provider map.
Raises:
None.
"""
if self.model_provider_extensions: if self.model_provider_extensions:
return self.model_provider_extensions return self.model_provider_extensions
# get the path of current classes # get the path of current classes
current_path = os.path.abspath(__file__) current_path = os.path.abspath(__file__)
model_providers_path = os.path.dirname(current_path) model_providers_path = os.path.dirname(current_path)
@ -204,7 +229,7 @@ class ModelProviderFactory:
model_provider_dir_paths = [ model_provider_dir_paths = [
os.path.join(model_providers_path, model_provider_dir) os.path.join(model_providers_path, model_provider_dir)
for model_provider_dir in os.listdir(model_providers_path) for model_provider_dir in os.listdir(model_providers_path)
if not model_provider_dir.startswith('__') if not model_provider_dir.startswith("__")
and os.path.isdir(os.path.join(model_providers_path, model_provider_dir)) and os.path.isdir(os.path.join(model_providers_path, model_provider_dir))
] ]
@ -219,30 +244,33 @@ class ModelProviderFactory:
file_names = os.listdir(model_provider_dir_path) file_names = os.listdir(model_provider_dir_path)
if (model_provider_name + '.py') not in file_names: if (model_provider_name + ".py") not in file_names:
logger.warning(f"Missing {model_provider_name}.py file in {model_provider_dir_path}, Skip.") logger.warning(f"Missing {model_provider_name}.py file in {model_provider_dir_path}, Skip.")
continue continue
# Dynamic loading {model_provider_name}.py file and find the subclass of ModelProvider # Dynamic loading {model_provider_name}.py file and find the subclass of ModelProvider
py_path = os.path.join(model_provider_dir_path, model_provider_name + '.py') py_path = os.path.join(model_provider_dir_path, model_provider_name + ".py")
model_provider_class = load_single_subclass_from_source( model_provider_class = load_single_subclass_from_source(
module_name=f'core.model_runtime.model_providers.{model_provider_name}.{model_provider_name}', module_name=f"core.model_runtime.model_providers.{model_provider_name}.{model_provider_name}",
script_path=py_path, script_path=py_path,
parent_type=ModelProvider) parent_type=ModelProvider,
)
if not model_provider_class: if not model_provider_class:
logger.warning(f"Missing Model Provider Class that extends ModelProvider in {py_path}, Skip.") logger.warning(f"Missing Model Provider Class that extends ModelProvider in {py_path}, Skip.")
continue continue
if f'{model_provider_name}.yaml' not in file_names: if f"{model_provider_name}.yaml" not in file_names:
logger.warning(f"Missing {model_provider_name}.yaml file in {model_provider_dir_path}, Skip.") logger.warning(f"Missing {model_provider_name}.yaml file in {model_provider_dir_path}, Skip.")
continue continue
model_providers.append(ModelProviderExtension( model_providers.append(
ModelProviderExtension(
name=model_provider_name, name=model_provider_name,
provider_instance=model_provider_class(), provider_instance=model_provider_class(),
position=position_map.get(model_provider_name) position=position_map.get(model_provider_name),
)) )
)
sorted_extensions = sort_to_dict_by_position_map(position_map, model_providers, lambda x: x.name) sorted_extensions = sort_to_dict_by_position_map(position_map, model_providers, lambda x: x.name)

View File

@ -1,3 +1,5 @@
from collections.abc import Mapping
import openai import openai
from httpx import Timeout from httpx import Timeout
@ -12,7 +14,7 @@ from core.model_runtime.errors.invoke import (
class _CommonOpenAI: class _CommonOpenAI:
def _to_credential_kwargs(self, credentials: dict) -> dict: def _to_credential_kwargs(self, credentials: Mapping) -> dict:
""" """
Transform credentials to kwargs for model instance Transform credentials to kwargs for model instance
@ -25,9 +27,9 @@ class _CommonOpenAI:
"max_retries": 1, "max_retries": 1,
} }
if credentials.get('openai_api_base'): if credentials.get("openai_api_base"):
credentials['openai_api_base'] = credentials['openai_api_base'].rstrip('/') openai_api_base = credentials["openai_api_base"].rstrip("/")
credentials_kwargs['base_url'] = credentials['openai_api_base'] + '/v1' credentials_kwargs["base_url"] = openai_api_base + "/v1"
if 'openai_organization' in credentials: if 'openai_organization' in credentials:
credentials_kwargs['organization'] = credentials['openai_organization'] credentials_kwargs['organization'] = credentials['openai_organization']
@ -45,24 +47,14 @@ class _CommonOpenAI:
:return: Invoke error mapping :return: Invoke error mapping
""" """
return { return {
InvokeConnectionError: [ InvokeConnectionError: [openai.APIConnectionError, openai.APITimeoutError],
openai.APIConnectionError, InvokeServerUnavailableError: [openai.InternalServerError],
openai.APITimeoutError InvokeRateLimitError: [openai.RateLimitError],
], InvokeAuthorizationError: [openai.AuthenticationError, openai.PermissionDeniedError],
InvokeServerUnavailableError: [
openai.InternalServerError
],
InvokeRateLimitError: [
openai.RateLimitError
],
InvokeAuthorizationError: [
openai.AuthenticationError,
openai.PermissionDeniedError
],
InvokeBadRequestError: [ InvokeBadRequestError: [
openai.BadRequestError, openai.BadRequestError,
openai.NotFoundError, openai.NotFoundError,
openai.UnprocessableEntityError, openai.UnprocessableEntityError,
openai.APIError openai.APIError,
] ],
} }

View File

@ -1,4 +1,5 @@
import logging import logging
from collections.abc import Mapping
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.errors.validate import CredentialsValidateFailedError
@ -9,7 +10,7 @@ logger = logging.getLogger(__name__)
class OpenAIProvider(ModelProvider): class OpenAIProvider(ModelProvider):
def validate_provider_credentials(self, credentials: dict) -> None: def validate_provider_credentials(self, credentials: Mapping) -> None:
""" """
Validate provider credentials Validate provider credentials
if validate failed, raise exception if validate failed, raise exception