From 23fa3dedc487c6f2c64000f6108f2c691ef30f12 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Thu, 20 Jun 2024 15:16:21 +0800 Subject: [PATCH] fix(core): Fix incorrect type hints. (#5427) --- api/core/extension/extensible.py | 6 +- api/core/helper/module_import_helper.py | 21 +++-- api/core/helper/position_helper.py | 7 +- api/core/model_manager.py | 13 ++- .../entities/provider_entities.py | 9 ++- .../model_providers/__base/ai_model.py | 52 +++++++----- .../__base/large_language_model.py | 31 +++---- .../model_providers/__base/model_provider.py | 29 ++++--- .../model_providers/model_provider_factory.py | 80 +++++++++++++------ .../model_providers/openai/_common.py | 32 +++----- .../model_providers/openai/openai.py | 3 +- 11 files changed, 166 insertions(+), 117 deletions(-) diff --git a/api/core/extension/extensible.py b/api/core/extension/extensible.py index b106542564..0296126d8b 100644 --- a/api/core/extension/extensible.py +++ b/api/core/extension/extensible.py @@ -1,5 +1,5 @@ import enum -import importlib +import importlib.util import json import logging import os @@ -74,6 +74,8 @@ class Extensible: # Dynamic loading {subdir_name}.py file and find the subclass of Extensible py_path = os.path.join(subdir_path, extension_name + '.py') spec = importlib.util.spec_from_file_location(extension_name, py_path) + if not spec or not spec.loader: + raise Exception(f"Failed to load module {extension_name} from {py_path}") mod = importlib.util.module_from_spec(spec) spec.loader.exec_module(mod) @@ -108,6 +110,6 @@ class Extensible: position=position )) - sorted_extensions = sort_to_dict_by_position_map(position_map, extensions, lambda x: x.name) + sorted_extensions = sort_to_dict_by_position_map(position_map=position_map, data=extensions, name_func=lambda x: x.name) return sorted_extensions diff --git a/api/core/helper/module_import_helper.py b/api/core/helper/module_import_helper.py index d3a4bab4a1..2000577a40 100644 --- a/api/core/helper/module_import_helper.py +++ b/api/core/helper/module_import_helper.py @@ -5,11 +5,7 @@ from types import ModuleType from typing import AnyStr -def import_module_from_source( - module_name: str, - py_file_path: AnyStr, - use_lazy_loader: bool = False -) -> ModuleType: +def import_module_from_source(*, module_name: str, py_file_path: AnyStr, use_lazy_loader: bool = False) -> ModuleType: """ Importing a module from the source file directly """ @@ -17,9 +13,13 @@ def import_module_from_source( existed_spec = importlib.util.find_spec(module_name) if existed_spec: spec = existed_spec + if not spec.loader: + raise Exception(f"Failed to load module {module_name} from {py_file_path}") else: # Refer to: https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly spec = importlib.util.spec_from_file_location(module_name, py_file_path) + if not spec or not spec.loader: + raise Exception(f"Failed to load module {module_name} from {py_file_path}") if use_lazy_loader: # Refer to: https://docs.python.org/3/library/importlib.html#implementing-lazy-imports spec.loader = importlib.util.LazyLoader(spec.loader) @@ -29,7 +29,7 @@ def import_module_from_source( spec.loader.exec_module(module) return module except Exception as e: - logging.exception(f'Failed to load module {module_name} from {py_file_path}: {str(e)}') + logging.exception(f"Failed to load module {module_name} from {py_file_path}: {str(e)}") raise e @@ -43,15 +43,14 @@ def get_subclasses_from_module(mod: ModuleType, parent_type: type) -> list[type] def load_single_subclass_from_source( - module_name: str, - script_path: AnyStr, - parent_type: type, - use_lazy_loader: bool = False, + *, module_name: str, script_path: AnyStr, parent_type: type, use_lazy_loader: bool = False ) -> type: """ Load a single subclass from the source """ - module = import_module_from_source(module_name, script_path, use_lazy_loader) + module = import_module_from_source( + module_name=module_name, py_file_path=script_path, use_lazy_loader=use_lazy_loader + ) subclasses = get_subclasses_from_module(module, parent_type) match len(subclasses): case 1: diff --git a/api/core/helper/position_helper.py b/api/core/helper/position_helper.py index 689ab194a7..e4ceeb652e 100644 --- a/api/core/helper/position_helper.py +++ b/api/core/helper/position_helper.py @@ -1,15 +1,12 @@ import os from collections import OrderedDict from collections.abc import Callable -from typing import Any, AnyStr +from typing import Any from core.tools.utils.yaml_utils import load_yaml_file -def get_position_map( - folder_path: AnyStr, - file_name: str = '_position.yaml', -) -> dict[str, int]: +def get_position_map(folder_path: str, *, file_name: str = "_position.yaml") -> dict[str, int]: """ Get the mapping from name to index from a YAML file :param folder_path: diff --git a/api/core/model_manager.py b/api/core/model_manager.py index 51dff09609..e0b6960c23 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -1,6 +1,6 @@ import logging import os -from collections.abc import Generator +from collections.abc import Callable, Generator from typing import IO, Optional, Union, cast from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle @@ -102,7 +102,7 @@ class ModelInstance: def invoke_llm(self, prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None, tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None, callbacks: list[Callback] = None) \ + stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) \ -> Union[LLMResult, Generator]: """ Invoke large language model @@ -291,7 +291,7 @@ class ModelInstance: streaming=streaming ) - def _round_robin_invoke(self, function: callable, *args, **kwargs): + def _round_robin_invoke(self, function: Callable, *args, **kwargs): """ Round-robin invoke :param function: function to invoke @@ -437,6 +437,7 @@ class LBModelManager: while True: current_index = redis_client.incr(cache_key) + current_index = cast(int, current_index) if current_index >= 10000000: current_index = 1 redis_client.set(cache_key, current_index) @@ -499,7 +500,10 @@ class LBModelManager: config.id ) - return redis_client.exists(cooldown_cache_key) + + res = redis_client.exists(cooldown_cache_key) + res = cast(bool, res) + return res @classmethod def get_config_in_cooldown_and_ttl(cls, tenant_id: str, @@ -528,4 +532,5 @@ class LBModelManager: if ttl == -2: return False, 0 + ttl = cast(int, ttl) return True, ttl diff --git a/api/core/model_runtime/entities/provider_entities.py b/api/core/model_runtime/entities/provider_entities.py index f0a3997204..f88f89d588 100644 --- a/api/core/model_runtime/entities/provider_entities.py +++ b/api/core/model_runtime/entities/provider_entities.py @@ -1,10 +1,11 @@ +from collections.abc import Sequence from enum import Enum from typing import Optional from pydantic import BaseModel, ConfigDict from core.model_runtime.entities.common_entities import I18nObject -from core.model_runtime.entities.model_entities import AIModelEntity, ModelType, ProviderModel +from core.model_runtime.entities.model_entities import ModelType, ProviderModel class ConfigurateMethod(Enum): @@ -93,8 +94,8 @@ class SimpleProviderEntity(BaseModel): label: I18nObject icon_small: Optional[I18nObject] = None icon_large: Optional[I18nObject] = None - supported_model_types: list[ModelType] - models: list[AIModelEntity] = [] + supported_model_types: Sequence[ModelType] + models: list[ProviderModel] = [] class ProviderHelpEntity(BaseModel): @@ -116,7 +117,7 @@ class ProviderEntity(BaseModel): icon_large: Optional[I18nObject] = None background: Optional[str] = None help: Optional[ProviderHelpEntity] = None - supported_model_types: list[ModelType] + supported_model_types: Sequence[ModelType] configurate_methods: list[ConfigurateMethod] models: list[ProviderModel] = [] provider_credential_schema: Optional[ProviderCredentialSchema] = None diff --git a/api/core/model_runtime/model_providers/__base/ai_model.py b/api/core/model_runtime/model_providers/__base/ai_model.py index 83cfffa611..04b539433c 100644 --- a/api/core/model_runtime/model_providers/__base/ai_model.py +++ b/api/core/model_runtime/model_providers/__base/ai_model.py @@ -1,6 +1,7 @@ import decimal import os from abc import ABC, abstractmethod +from collections.abc import Mapping from typing import Optional from pydantic import ConfigDict @@ -26,15 +27,16 @@ class AIModel(ABC): """ Base class for all models. """ + model_type: ModelType - model_schemas: list[AIModelEntity] = None + model_schemas: Optional[list[AIModelEntity]] = None started_at: float = 0 # pydantic configs model_config = ConfigDict(protected_namespaces=()) @abstractmethod - def validate_credentials(self, model: str, credentials: dict) -> None: + def validate_credentials(self, model: str, credentials: Mapping) -> None: """ Validate model credentials @@ -90,8 +92,8 @@ class AIModel(ABC): # get price info from predefined model schema price_config: Optional[PriceConfig] = None - if model_schema: - price_config: PriceConfig = model_schema.pricing + if model_schema and model_schema.pricing: + price_config = model_schema.pricing # get unit price unit_price = None @@ -103,13 +105,15 @@ class AIModel(ABC): if unit_price is None: return PriceInfo( - unit_price=decimal.Decimal('0.0'), - unit=decimal.Decimal('0.0'), - total_amount=decimal.Decimal('0.0'), + unit_price=decimal.Decimal("0.0"), + unit=decimal.Decimal("0.0"), + total_amount=decimal.Decimal("0.0"), currency="USD", ) # calculate total amount + if not price_config: + raise ValueError(f"Price config not found for model {model}") total_amount = tokens * unit_price * price_config.unit total_amount = total_amount.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP) @@ -209,7 +213,7 @@ class AIModel(ABC): return model_schemas - def get_model_schema(self, model: str, credentials: Optional[dict] = None) -> Optional[AIModelEntity]: + def get_model_schema(self, model: str, credentials: Optional[Mapping] = None) -> Optional[AIModelEntity]: """ Get model schema by model name and credentials @@ -231,7 +235,7 @@ class AIModel(ABC): return None - def get_customizable_model_schema_from_credentials(self, model: str, credentials: dict) -> Optional[AIModelEntity]: + def get_customizable_model_schema_from_credentials(self, model: str, credentials: Mapping) -> Optional[AIModelEntity]: """ Get customizable model schema from credentials @@ -240,8 +244,8 @@ class AIModel(ABC): :return: model schema """ return self._get_customizable_model_schema(model, credentials) - - def _get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: + + def _get_customizable_model_schema(self, model: str, credentials: Mapping) -> Optional[AIModelEntity]: """ Get customizable model schema and fill in the template """ @@ -249,7 +253,7 @@ class AIModel(ABC): if not schema: return None - + # fill in the template new_parameter_rules = [] for parameter_rule in schema.parameter_rules: @@ -271,10 +275,20 @@ class AIModel(ABC): parameter_rule.help = I18nObject( en_US=default_parameter_rule['help']['en_US'], ) - if not parameter_rule.help.en_US and ('help' in default_parameter_rule and 'en_US' in default_parameter_rule['help']): - parameter_rule.help.en_US = default_parameter_rule['help']['en_US'] - if not parameter_rule.help.zh_Hans and ('help' in default_parameter_rule and 'zh_Hans' in default_parameter_rule['help']): - parameter_rule.help.zh_Hans = default_parameter_rule['help'].get('zh_Hans', default_parameter_rule['help']['en_US']) + if ( + parameter_rule.help + and not parameter_rule.help.en_US + and ("help" in default_parameter_rule and "en_US" in default_parameter_rule["help"]) + ): + parameter_rule.help.en_US = default_parameter_rule["help"]["en_US"] + if ( + parameter_rule.help + and not parameter_rule.help.zh_Hans + and ("help" in default_parameter_rule and "zh_Hans" in default_parameter_rule["help"]) + ): + parameter_rule.help.zh_Hans = default_parameter_rule["help"].get( + "zh_Hans", default_parameter_rule["help"]["en_US"] + ) except ValueError: pass @@ -284,7 +298,7 @@ class AIModel(ABC): return schema - def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: + def get_customizable_model_schema(self, model: str, credentials: Mapping) -> Optional[AIModelEntity]: """ Get customizable model schema @@ -304,7 +318,7 @@ class AIModel(ABC): default_parameter_rule = PARAMETER_RULE_TEMPLATE.get(name) if not default_parameter_rule: - raise Exception(f'Invalid model parameter rule name {name}') + raise Exception(f"Invalid model parameter rule name {name}") return default_parameter_rule @@ -318,4 +332,4 @@ class AIModel(ABC): :param text: plain text of prompt. You need to convert the original message to plain text :return: number of tokens """ - return GPT2Tokenizer.get_num_tokens(text) \ No newline at end of file + return GPT2Tokenizer.get_num_tokens(text) diff --git a/api/core/model_runtime/model_providers/__base/large_language_model.py b/api/core/model_runtime/model_providers/__base/large_language_model.py index ef633c61cd..2ade452cf0 100644 --- a/api/core/model_runtime/model_providers/__base/large_language_model.py +++ b/api/core/model_runtime/model_providers/__base/large_language_model.py @@ -3,7 +3,7 @@ import os import re import time from abc import abstractmethod -from collections.abc import Generator +from collections.abc import Generator, Mapping from typing import Optional, Union from pydantic import ConfigDict @@ -43,7 +43,7 @@ class LargeLanguageModel(AIModel): def invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None, tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None, callbacks: list[Callback] = None) \ + stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) \ -> Union[LLMResult, Generator]: """ Invoke large language model @@ -129,7 +129,7 @@ class LargeLanguageModel(AIModel): user=user, callbacks=callbacks ) - else: + elif isinstance(result, LLMResult): self._trigger_after_invoke_callbacks( model=model, result=result, @@ -148,7 +148,7 @@ class LargeLanguageModel(AIModel): def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, - callbacks: list[Callback] = None) -> Union[LLMResult, Generator]: + callbacks: Optional[list[Callback]] = None) -> Union[LLMResult, Generator]: """ Code block mode wrapper, ensure the response is a code block with output markdown quote @@ -196,7 +196,7 @@ if you are not sure about the structure. # override the system message prompt_messages[0] = SystemPromptMessage( content=block_prompts - .replace("{{instructions}}", prompt_messages[0].content) + .replace("{{instructions}}", str(prompt_messages[0].content)) ) else: # insert the system message @@ -274,8 +274,9 @@ if you are not sure about the structure. else: yield piece continue - new_piece = "" + new_piece: str = "" for char in piece: + char = str(char) if state == "normal": if char == "`": state = "in_backticks" @@ -340,7 +341,7 @@ if you are not sure about the structure. if state == "done": continue - new_piece = "" + new_piece: str = "" for char in piece: if state == "search_start": if char == "`": @@ -365,7 +366,7 @@ if you are not sure about the structure. # If backticks were counted but we're still collecting content, it was a false start new_piece += "`" * backtick_count backtick_count = 0 - new_piece += char + new_piece += str(char) elif state == "done": break @@ -388,13 +389,14 @@ if you are not sure about the structure. prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, - user: Optional[str] = None, callbacks: list[Callback] = None) -> Generator: + user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) -> Generator: """ Invoke result generator :param result: result generator :return: result generator """ + callbacks = callbacks or [] prompt_message = AssistantPromptMessage( content="" ) @@ -489,6 +491,7 @@ if you are not sure about the structure. def _llm_result_to_stream(self, result: LLMResult) -> Generator: """ +from typing_extensions import deprecated Transform llm result to stream :param result: llm result @@ -531,7 +534,7 @@ if you are not sure about the structure. return [] - def get_model_mode(self, model: str, credentials: Optional[dict] = None) -> LLMMode: + def get_model_mode(self, model: str, credentials: Optional[Mapping] = None) -> LLMMode: """ Get model mode @@ -595,7 +598,7 @@ if you are not sure about the structure. prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, - user: Optional[str] = None, callbacks: list[Callback] = None) -> None: + user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) -> None: """ Trigger before invoke callbacks @@ -633,7 +636,7 @@ if you are not sure about the structure. prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, - user: Optional[str] = None, callbacks: list[Callback] = None) -> None: + user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) -> None: """ Trigger new chunk callbacks @@ -672,7 +675,7 @@ if you are not sure about the structure. prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, - user: Optional[str] = None, callbacks: list[Callback] = None) -> None: + user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) -> None: """ Trigger after invoke callbacks @@ -712,7 +715,7 @@ if you are not sure about the structure. prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, - user: Optional[str] = None, callbacks: list[Callback] = None) -> None: + user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) -> None: """ Trigger invoke error callbacks diff --git a/api/core/model_runtime/model_providers/__base/model_provider.py b/api/core/model_runtime/model_providers/__base/model_provider.py index a893d023c0..51dd3b7e28 100644 --- a/api/core/model_runtime/model_providers/__base/model_provider.py +++ b/api/core/model_runtime/model_providers/__base/model_provider.py @@ -1,5 +1,6 @@ import os from abc import ABC, abstractmethod +from typing import Optional from core.helper.module_import_helper import get_subclasses_from_module, import_module_from_source from core.model_runtime.entities.model_entities import AIModelEntity, ModelType @@ -9,7 +10,7 @@ from core.tools.utils.yaml_utils import load_yaml_file class ModelProvider(ABC): - provider_schema: ProviderEntity = None + provider_schema: Optional[ProviderEntity] = None model_instance_map: dict[str, AIModel] = {} @abstractmethod @@ -28,23 +29,23 @@ class ModelProvider(ABC): def get_provider_schema(self) -> ProviderEntity: """ Get provider schema - + :return: provider schema """ if self.provider_schema: return self.provider_schema - + # get dirname of the current path provider_name = self.__class__.__module__.split('.')[-1] # get the path of the model_provider classes base_path = os.path.abspath(__file__) current_path = os.path.join(os.path.dirname(os.path.dirname(base_path)), provider_name) - + # read provider schema from yaml file yaml_path = os.path.join(current_path, f'{provider_name}.yaml') yaml_data = load_yaml_file(yaml_path, ignore_error=True) - + try: # yaml_data to entity provider_schema = ProviderEntity(**yaml_data) @@ -53,7 +54,7 @@ class ModelProvider(ABC): # cache schema self.provider_schema = provider_schema - + return provider_schema def models(self, model_type: ModelType) -> list[AIModelEntity]: @@ -84,7 +85,7 @@ class ModelProvider(ABC): :return: """ # get dirname of the current path - provider_name = self.__class__.__module__.split('.')[-1] + provider_name = self.__class__.__module__.split(".")[-1] if f"{provider_name}.{model_type.value}" in self.model_instance_map: return self.model_instance_map[f"{provider_name}.{model_type.value}"] @@ -101,11 +102,17 @@ class ModelProvider(ABC): # Dynamic loading {model_type_name}.py file and find the subclass of AIModel parent_module = '.'.join(self.__class__.__module__.split('.')[:-1]) mod = import_module_from_source( - f'{parent_module}.{model_type_name}.{model_type_name}', model_type_py_path) - model_class = next(filter(lambda x: x.__module__ == mod.__name__ and not x.__abstractmethods__, - get_subclasses_from_module(mod, AIModel)), None) + module_name=f"{parent_module}.{model_type_name}.{model_type_name}", py_file_path=model_type_py_path + ) + model_class = next( + filter( + lambda x: x.__module__ == mod.__name__ and not x.__abstractmethods__, + get_subclasses_from_module(mod, AIModel), + ), + None, + ) if not model_class: - raise Exception(f'Missing AIModel Class for model type {model_type} in {model_type_py_path}') + raise Exception(f"Missing AIModel Class for model type {model_type} in {model_type_py_path}") model_instance_map = model_class() self.model_instance_map[f"{provider_name}.{model_type.value}"] = model_instance_map diff --git a/api/core/model_runtime/model_providers/model_provider_factory.py b/api/core/model_runtime/model_providers/model_provider_factory.py index a4dbaabfc9..b1660afafb 100644 --- a/api/core/model_runtime/model_providers/model_provider_factory.py +++ b/api/core/model_runtime/model_providers/model_provider_factory.py @@ -1,5 +1,6 @@ import logging import os +from collections.abc import Sequence from typing import Optional from pydantic import BaseModel, ConfigDict @@ -16,20 +17,21 @@ logger = logging.getLogger(__name__) class ModelProviderExtension(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + provider_instance: ModelProvider name: str position: Optional[int] = None - model_config = ConfigDict(arbitrary_types_allowed=True) class ModelProviderFactory: - model_provider_extensions: dict[str, ModelProviderExtension] = None + model_provider_extensions: Optional[dict[str, ModelProviderExtension]] = None def __init__(self) -> None: # for cache in memory self.get_providers() - def get_providers(self) -> list[ProviderEntity]: + def get_providers(self) -> Sequence[ProviderEntity]: """ Get all providers :return: list of providers @@ -39,7 +41,7 @@ class ModelProviderFactory: # traverse all model_provider_extensions providers = [] - for name, model_provider_extension in model_provider_extensions.items(): + for model_provider_extension in model_provider_extensions.values(): # get model_provider instance model_provider_instance = model_provider_extension.provider_instance @@ -57,7 +59,7 @@ class ModelProviderFactory: # return providers return providers - def provider_credentials_validate(self, provider: str, credentials: dict) -> dict: + def provider_credentials_validate(self, *, provider: str, credentials: dict) -> dict: """ Validate provider credentials @@ -74,6 +76,9 @@ class ModelProviderFactory: # get provider_credential_schema and validate credentials according to the rules provider_credential_schema = provider_schema.provider_credential_schema + if not provider_credential_schema: + raise ValueError(f"Provider {provider} does not have provider_credential_schema") + # validate provider credential schema validator = ProviderCredentialSchemaValidator(provider_credential_schema) filtered_credentials = validator.validate_and_filter(credentials) @@ -83,8 +88,9 @@ class ModelProviderFactory: return filtered_credentials - def model_credentials_validate(self, provider: str, model_type: ModelType, - model: str, credentials: dict) -> dict: + def model_credentials_validate( + self, *, provider: str, model_type: ModelType, model: str, credentials: dict + ) -> dict: """ Validate model credentials @@ -103,6 +109,9 @@ class ModelProviderFactory: # get model_credential_schema and validate credentials according to the rules model_credential_schema = provider_schema.model_credential_schema + if not model_credential_schema: + raise ValueError(f"Provider {provider} does not have model_credential_schema") + # validate model credential schema validator = ModelCredentialSchemaValidator(model_type, model_credential_schema) filtered_credentials = validator.validate_and_filter(credentials) @@ -115,11 +124,13 @@ class ModelProviderFactory: return filtered_credentials - def get_models(self, - provider: Optional[str] = None, - model_type: Optional[ModelType] = None, - provider_configs: Optional[list[ProviderConfig]] = None) \ - -> list[SimpleProviderEntity]: + def get_models( + self, + *, + provider: Optional[str] = None, + model_type: Optional[ModelType] = None, + provider_configs: Optional[list[ProviderConfig]] = None, + ) -> list[SimpleProviderEntity]: """ Get all models for given model type @@ -128,6 +139,8 @@ class ModelProviderFactory: :param provider_configs: list of provider configs :return: list of models """ + provider_configs = provider_configs or [] + # scan all providers model_provider_extensions = self._get_model_provider_map() @@ -184,7 +197,7 @@ class ModelProviderFactory: # get the provider extension model_provider_extension = model_provider_extensions.get(provider) if not model_provider_extension: - raise Exception(f'Invalid provider: {provider}') + raise Exception(f"Invalid provider: {provider}") # get the provider instance model_provider_instance = model_provider_extension.provider_instance @@ -192,10 +205,22 @@ class ModelProviderFactory: return model_provider_instance def _get_model_provider_map(self) -> dict[str, ModelProviderExtension]: + """ + Retrieves the model provider map. + + This method retrieves the model provider map, which is a dictionary containing the model provider names as keys + and instances of `ModelProviderExtension` as values. The model provider map is used to store information about + available model providers. + + Returns: + A dictionary containing the model provider map. + + Raises: + None. + """ if self.model_provider_extensions: return self.model_provider_extensions - # get the path of current classes current_path = os.path.abspath(__file__) model_providers_path = os.path.dirname(current_path) @@ -204,8 +229,8 @@ class ModelProviderFactory: model_provider_dir_paths = [ os.path.join(model_providers_path, model_provider_dir) for model_provider_dir in os.listdir(model_providers_path) - if not model_provider_dir.startswith('__') - and os.path.isdir(os.path.join(model_providers_path, model_provider_dir)) + if not model_provider_dir.startswith("__") + and os.path.isdir(os.path.join(model_providers_path, model_provider_dir)) ] # get _position.yaml file path @@ -219,30 +244,33 @@ class ModelProviderFactory: file_names = os.listdir(model_provider_dir_path) - if (model_provider_name + '.py') not in file_names: + if (model_provider_name + ".py") not in file_names: logger.warning(f"Missing {model_provider_name}.py file in {model_provider_dir_path}, Skip.") continue # Dynamic loading {model_provider_name}.py file and find the subclass of ModelProvider - py_path = os.path.join(model_provider_dir_path, model_provider_name + '.py') + py_path = os.path.join(model_provider_dir_path, model_provider_name + ".py") model_provider_class = load_single_subclass_from_source( - module_name=f'core.model_runtime.model_providers.{model_provider_name}.{model_provider_name}', + module_name=f"core.model_runtime.model_providers.{model_provider_name}.{model_provider_name}", script_path=py_path, - parent_type=ModelProvider) + parent_type=ModelProvider, + ) if not model_provider_class: logger.warning(f"Missing Model Provider Class that extends ModelProvider in {py_path}, Skip.") continue - if f'{model_provider_name}.yaml' not in file_names: + if f"{model_provider_name}.yaml" not in file_names: logger.warning(f"Missing {model_provider_name}.yaml file in {model_provider_dir_path}, Skip.") continue - model_providers.append(ModelProviderExtension( - name=model_provider_name, - provider_instance=model_provider_class(), - position=position_map.get(model_provider_name) - )) + model_providers.append( + ModelProviderExtension( + name=model_provider_name, + provider_instance=model_provider_class(), + position=position_map.get(model_provider_name), + ) + ) sorted_extensions = sort_to_dict_by_position_map(position_map, model_providers, lambda x: x.name) diff --git a/api/core/model_runtime/model_providers/openai/_common.py b/api/core/model_runtime/model_providers/openai/_common.py index 5772f325e1..467a51daf2 100644 --- a/api/core/model_runtime/model_providers/openai/_common.py +++ b/api/core/model_runtime/model_providers/openai/_common.py @@ -1,3 +1,5 @@ +from collections.abc import Mapping + import openai from httpx import Timeout @@ -12,7 +14,7 @@ from core.model_runtime.errors.invoke import ( class _CommonOpenAI: - def _to_credential_kwargs(self, credentials: dict) -> dict: + def _to_credential_kwargs(self, credentials: Mapping) -> dict: """ Transform credentials to kwargs for model instance @@ -25,9 +27,9 @@ class _CommonOpenAI: "max_retries": 1, } - if credentials.get('openai_api_base'): - credentials['openai_api_base'] = credentials['openai_api_base'].rstrip('/') - credentials_kwargs['base_url'] = credentials['openai_api_base'] + '/v1' + if credentials.get("openai_api_base"): + openai_api_base = credentials["openai_api_base"].rstrip("/") + credentials_kwargs["base_url"] = openai_api_base + "/v1" if 'openai_organization' in credentials: credentials_kwargs['organization'] = credentials['openai_organization'] @@ -45,24 +47,14 @@ class _CommonOpenAI: :return: Invoke error mapping """ return { - InvokeConnectionError: [ - openai.APIConnectionError, - openai.APITimeoutError - ], - InvokeServerUnavailableError: [ - openai.InternalServerError - ], - InvokeRateLimitError: [ - openai.RateLimitError - ], - InvokeAuthorizationError: [ - openai.AuthenticationError, - openai.PermissionDeniedError - ], + InvokeConnectionError: [openai.APIConnectionError, openai.APITimeoutError], + InvokeServerUnavailableError: [openai.InternalServerError], + InvokeRateLimitError: [openai.RateLimitError], + InvokeAuthorizationError: [openai.AuthenticationError, openai.PermissionDeniedError], InvokeBadRequestError: [ openai.BadRequestError, openai.NotFoundError, openai.UnprocessableEntityError, - openai.APIError - ] + openai.APIError, + ], } diff --git a/api/core/model_runtime/model_providers/openai/openai.py b/api/core/model_runtime/model_providers/openai/openai.py index d4a4e24c97..66efd4797f 100644 --- a/api/core/model_runtime/model_providers/openai/openai.py +++ b/api/core/model_runtime/model_providers/openai/openai.py @@ -1,4 +1,5 @@ import logging +from collections.abc import Mapping from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.validate import CredentialsValidateFailedError @@ -9,7 +10,7 @@ logger = logging.getLogger(__name__) class OpenAIProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: + def validate_provider_credentials(self, credentials: Mapping) -> None: """ Validate provider credentials if validate failed, raise exception