fix(core): Fix incorrect type hints. (#5427)

This commit is contained in:
-LAN- 2024-06-20 15:16:21 +08:00 committed by GitHub
parent e4259a8f13
commit 23fa3dedc4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 166 additions and 117 deletions

View File

@ -1,5 +1,5 @@
import enum
import importlib
import importlib.util
import json
import logging
import os
@ -74,6 +74,8 @@ class Extensible:
# Dynamic loading {subdir_name}.py file and find the subclass of Extensible
py_path = os.path.join(subdir_path, extension_name + '.py')
spec = importlib.util.spec_from_file_location(extension_name, py_path)
if not spec or not spec.loader:
raise Exception(f"Failed to load module {extension_name} from {py_path}")
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
@ -108,6 +110,6 @@ class Extensible:
position=position
))
sorted_extensions = sort_to_dict_by_position_map(position_map, extensions, lambda x: x.name)
sorted_extensions = sort_to_dict_by_position_map(position_map=position_map, data=extensions, name_func=lambda x: x.name)
return sorted_extensions

View File

@ -5,11 +5,7 @@ from types import ModuleType
from typing import AnyStr
def import_module_from_source(
module_name: str,
py_file_path: AnyStr,
use_lazy_loader: bool = False
) -> ModuleType:
def import_module_from_source(*, module_name: str, py_file_path: AnyStr, use_lazy_loader: bool = False) -> ModuleType:
"""
Importing a module from the source file directly
"""
@ -17,9 +13,13 @@ def import_module_from_source(
existed_spec = importlib.util.find_spec(module_name)
if existed_spec:
spec = existed_spec
if not spec.loader:
raise Exception(f"Failed to load module {module_name} from {py_file_path}")
else:
# Refer to: https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly
spec = importlib.util.spec_from_file_location(module_name, py_file_path)
if not spec or not spec.loader:
raise Exception(f"Failed to load module {module_name} from {py_file_path}")
if use_lazy_loader:
# Refer to: https://docs.python.org/3/library/importlib.html#implementing-lazy-imports
spec.loader = importlib.util.LazyLoader(spec.loader)
@ -29,7 +29,7 @@ def import_module_from_source(
spec.loader.exec_module(module)
return module
except Exception as e:
logging.exception(f'Failed to load module {module_name} from {py_file_path}: {str(e)}')
logging.exception(f"Failed to load module {module_name} from {py_file_path}: {str(e)}")
raise e
@ -43,15 +43,14 @@ def get_subclasses_from_module(mod: ModuleType, parent_type: type) -> list[type]
def load_single_subclass_from_source(
module_name: str,
script_path: AnyStr,
parent_type: type,
use_lazy_loader: bool = False,
*, module_name: str, script_path: AnyStr, parent_type: type, use_lazy_loader: bool = False
) -> type:
"""
Load a single subclass from the source
"""
module = import_module_from_source(module_name, script_path, use_lazy_loader)
module = import_module_from_source(
module_name=module_name, py_file_path=script_path, use_lazy_loader=use_lazy_loader
)
subclasses = get_subclasses_from_module(module, parent_type)
match len(subclasses):
case 1:

View File

@ -1,15 +1,12 @@
import os
from collections import OrderedDict
from collections.abc import Callable
from typing import Any, AnyStr
from typing import Any
from core.tools.utils.yaml_utils import load_yaml_file
def get_position_map(
folder_path: AnyStr,
file_name: str = '_position.yaml',
) -> dict[str, int]:
def get_position_map(folder_path: str, *, file_name: str = "_position.yaml") -> dict[str, int]:
"""
Get the mapping from name to index from a YAML file
:param folder_path:

View File

@ -1,6 +1,6 @@
import logging
import os
from collections.abc import Generator
from collections.abc import Callable, Generator
from typing import IO, Optional, Union, cast
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
@ -102,7 +102,7 @@ class ModelInstance:
def invoke_llm(self, prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None, callbacks: list[Callback] = None) \
stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) \
-> Union[LLMResult, Generator]:
"""
Invoke large language model
@ -291,7 +291,7 @@ class ModelInstance:
streaming=streaming
)
def _round_robin_invoke(self, function: callable, *args, **kwargs):
def _round_robin_invoke(self, function: Callable, *args, **kwargs):
"""
Round-robin invoke
:param function: function to invoke
@ -437,6 +437,7 @@ class LBModelManager:
while True:
current_index = redis_client.incr(cache_key)
current_index = cast(int, current_index)
if current_index >= 10000000:
current_index = 1
redis_client.set(cache_key, current_index)
@ -499,7 +500,10 @@ class LBModelManager:
config.id
)
return redis_client.exists(cooldown_cache_key)
res = redis_client.exists(cooldown_cache_key)
res = cast(bool, res)
return res
@classmethod
def get_config_in_cooldown_and_ttl(cls, tenant_id: str,
@ -528,4 +532,5 @@ class LBModelManager:
if ttl == -2:
return False, 0
ttl = cast(int, ttl)
return True, ttl

View File

@ -1,10 +1,11 @@
from collections.abc import Sequence
from enum import Enum
from typing import Optional
from pydantic import BaseModel, ConfigDict
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType, ProviderModel
from core.model_runtime.entities.model_entities import ModelType, ProviderModel
class ConfigurateMethod(Enum):
@ -93,8 +94,8 @@ class SimpleProviderEntity(BaseModel):
label: I18nObject
icon_small: Optional[I18nObject] = None
icon_large: Optional[I18nObject] = None
supported_model_types: list[ModelType]
models: list[AIModelEntity] = []
supported_model_types: Sequence[ModelType]
models: list[ProviderModel] = []
class ProviderHelpEntity(BaseModel):
@ -116,7 +117,7 @@ class ProviderEntity(BaseModel):
icon_large: Optional[I18nObject] = None
background: Optional[str] = None
help: Optional[ProviderHelpEntity] = None
supported_model_types: list[ModelType]
supported_model_types: Sequence[ModelType]
configurate_methods: list[ConfigurateMethod]
models: list[ProviderModel] = []
provider_credential_schema: Optional[ProviderCredentialSchema] = None

View File

@ -1,6 +1,7 @@
import decimal
import os
from abc import ABC, abstractmethod
from collections.abc import Mapping
from typing import Optional
from pydantic import ConfigDict
@ -26,15 +27,16 @@ class AIModel(ABC):
"""
Base class for all models.
"""
model_type: ModelType
model_schemas: list[AIModelEntity] = None
model_schemas: Optional[list[AIModelEntity]] = None
started_at: float = 0
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
@abstractmethod
def validate_credentials(self, model: str, credentials: dict) -> None:
def validate_credentials(self, model: str, credentials: Mapping) -> None:
"""
Validate model credentials
@ -90,8 +92,8 @@ class AIModel(ABC):
# get price info from predefined model schema
price_config: Optional[PriceConfig] = None
if model_schema:
price_config: PriceConfig = model_schema.pricing
if model_schema and model_schema.pricing:
price_config = model_schema.pricing
# get unit price
unit_price = None
@ -103,13 +105,15 @@ class AIModel(ABC):
if unit_price is None:
return PriceInfo(
unit_price=decimal.Decimal('0.0'),
unit=decimal.Decimal('0.0'),
total_amount=decimal.Decimal('0.0'),
unit_price=decimal.Decimal("0.0"),
unit=decimal.Decimal("0.0"),
total_amount=decimal.Decimal("0.0"),
currency="USD",
)
# calculate total amount
if not price_config:
raise ValueError(f"Price config not found for model {model}")
total_amount = tokens * unit_price * price_config.unit
total_amount = total_amount.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
@ -209,7 +213,7 @@ class AIModel(ABC):
return model_schemas
def get_model_schema(self, model: str, credentials: Optional[dict] = None) -> Optional[AIModelEntity]:
def get_model_schema(self, model: str, credentials: Optional[Mapping] = None) -> Optional[AIModelEntity]:
"""
Get model schema by model name and credentials
@ -231,7 +235,7 @@ class AIModel(ABC):
return None
def get_customizable_model_schema_from_credentials(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
def get_customizable_model_schema_from_credentials(self, model: str, credentials: Mapping) -> Optional[AIModelEntity]:
"""
Get customizable model schema from credentials
@ -240,8 +244,8 @@ class AIModel(ABC):
:return: model schema
"""
return self._get_customizable_model_schema(model, credentials)
def _get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
def _get_customizable_model_schema(self, model: str, credentials: Mapping) -> Optional[AIModelEntity]:
"""
Get customizable model schema and fill in the template
"""
@ -249,7 +253,7 @@ class AIModel(ABC):
if not schema:
return None
# fill in the template
new_parameter_rules = []
for parameter_rule in schema.parameter_rules:
@ -271,10 +275,20 @@ class AIModel(ABC):
parameter_rule.help = I18nObject(
en_US=default_parameter_rule['help']['en_US'],
)
if not parameter_rule.help.en_US and ('help' in default_parameter_rule and 'en_US' in default_parameter_rule['help']):
parameter_rule.help.en_US = default_parameter_rule['help']['en_US']
if not parameter_rule.help.zh_Hans and ('help' in default_parameter_rule and 'zh_Hans' in default_parameter_rule['help']):
parameter_rule.help.zh_Hans = default_parameter_rule['help'].get('zh_Hans', default_parameter_rule['help']['en_US'])
if (
parameter_rule.help
and not parameter_rule.help.en_US
and ("help" in default_parameter_rule and "en_US" in default_parameter_rule["help"])
):
parameter_rule.help.en_US = default_parameter_rule["help"]["en_US"]
if (
parameter_rule.help
and not parameter_rule.help.zh_Hans
and ("help" in default_parameter_rule and "zh_Hans" in default_parameter_rule["help"])
):
parameter_rule.help.zh_Hans = default_parameter_rule["help"].get(
"zh_Hans", default_parameter_rule["help"]["en_US"]
)
except ValueError:
pass
@ -284,7 +298,7 @@ class AIModel(ABC):
return schema
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
def get_customizable_model_schema(self, model: str, credentials: Mapping) -> Optional[AIModelEntity]:
"""
Get customizable model schema
@ -304,7 +318,7 @@ class AIModel(ABC):
default_parameter_rule = PARAMETER_RULE_TEMPLATE.get(name)
if not default_parameter_rule:
raise Exception(f'Invalid model parameter rule name {name}')
raise Exception(f"Invalid model parameter rule name {name}")
return default_parameter_rule
@ -318,4 +332,4 @@ class AIModel(ABC):
:param text: plain text of prompt. You need to convert the original message to plain text
:return: number of tokens
"""
return GPT2Tokenizer.get_num_tokens(text)
return GPT2Tokenizer.get_num_tokens(text)

View File

@ -3,7 +3,7 @@ import os
import re
import time
from abc import abstractmethod
from collections.abc import Generator
from collections.abc import Generator, Mapping
from typing import Optional, Union
from pydantic import ConfigDict
@ -43,7 +43,7 @@ class LargeLanguageModel(AIModel):
def invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None, callbacks: list[Callback] = None) \
stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) \
-> Union[LLMResult, Generator]:
"""
Invoke large language model
@ -129,7 +129,7 @@ class LargeLanguageModel(AIModel):
user=user,
callbacks=callbacks
)
else:
elif isinstance(result, LLMResult):
self._trigger_after_invoke_callbacks(
model=model,
result=result,
@ -148,7 +148,7 @@ class LargeLanguageModel(AIModel):
def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None,
callbacks: list[Callback] = None) -> Union[LLMResult, Generator]:
callbacks: Optional[list[Callback]] = None) -> Union[LLMResult, Generator]:
"""
Code block mode wrapper, ensure the response is a code block with output markdown quote
@ -196,7 +196,7 @@ if you are not sure about the structure.
# override the system message
prompt_messages[0] = SystemPromptMessage(
content=block_prompts
.replace("{{instructions}}", prompt_messages[0].content)
.replace("{{instructions}}", str(prompt_messages[0].content))
)
else:
# insert the system message
@ -274,8 +274,9 @@ if you are not sure about the structure.
else:
yield piece
continue
new_piece = ""
new_piece: str = ""
for char in piece:
char = str(char)
if state == "normal":
if char == "`":
state = "in_backticks"
@ -340,7 +341,7 @@ if you are not sure about the structure.
if state == "done":
continue
new_piece = ""
new_piece: str = ""
for char in piece:
if state == "search_start":
if char == "`":
@ -365,7 +366,7 @@ if you are not sure about the structure.
# If backticks were counted but we're still collecting content, it was a false start
new_piece += "`" * backtick_count
backtick_count = 0
new_piece += char
new_piece += str(char)
elif state == "done":
break
@ -388,13 +389,14 @@ if you are not sure about the structure.
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stream: bool = True,
user: Optional[str] = None, callbacks: list[Callback] = None) -> Generator:
user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) -> Generator:
"""
Invoke result generator
:param result: result generator
:return: result generator
"""
callbacks = callbacks or []
prompt_message = AssistantPromptMessage(
content=""
)
@ -489,6 +491,7 @@ if you are not sure about the structure.
def _llm_result_to_stream(self, result: LLMResult) -> Generator:
"""
from typing_extensions import deprecated
Transform llm result to stream
:param result: llm result
@ -531,7 +534,7 @@ if you are not sure about the structure.
return []
def get_model_mode(self, model: str, credentials: Optional[dict] = None) -> LLMMode:
def get_model_mode(self, model: str, credentials: Optional[Mapping] = None) -> LLMMode:
"""
Get model mode
@ -595,7 +598,7 @@ if you are not sure about the structure.
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stream: bool = True,
user: Optional[str] = None, callbacks: list[Callback] = None) -> None:
user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) -> None:
"""
Trigger before invoke callbacks
@ -633,7 +636,7 @@ if you are not sure about the structure.
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stream: bool = True,
user: Optional[str] = None, callbacks: list[Callback] = None) -> None:
user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) -> None:
"""
Trigger new chunk callbacks
@ -672,7 +675,7 @@ if you are not sure about the structure.
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stream: bool = True,
user: Optional[str] = None, callbacks: list[Callback] = None) -> None:
user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) -> None:
"""
Trigger after invoke callbacks
@ -712,7 +715,7 @@ if you are not sure about the structure.
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stream: bool = True,
user: Optional[str] = None, callbacks: list[Callback] = None) -> None:
user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) -> None:
"""
Trigger invoke error callbacks

View File

@ -1,5 +1,6 @@
import os
from abc import ABC, abstractmethod
from typing import Optional
from core.helper.module_import_helper import get_subclasses_from_module, import_module_from_source
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
@ -9,7 +10,7 @@ from core.tools.utils.yaml_utils import load_yaml_file
class ModelProvider(ABC):
provider_schema: ProviderEntity = None
provider_schema: Optional[ProviderEntity] = None
model_instance_map: dict[str, AIModel] = {}
@abstractmethod
@ -28,23 +29,23 @@ class ModelProvider(ABC):
def get_provider_schema(self) -> ProviderEntity:
"""
Get provider schema
:return: provider schema
"""
if self.provider_schema:
return self.provider_schema
# get dirname of the current path
provider_name = self.__class__.__module__.split('.')[-1]
# get the path of the model_provider classes
base_path = os.path.abspath(__file__)
current_path = os.path.join(os.path.dirname(os.path.dirname(base_path)), provider_name)
# read provider schema from yaml file
yaml_path = os.path.join(current_path, f'{provider_name}.yaml')
yaml_data = load_yaml_file(yaml_path, ignore_error=True)
try:
# yaml_data to entity
provider_schema = ProviderEntity(**yaml_data)
@ -53,7 +54,7 @@ class ModelProvider(ABC):
# cache schema
self.provider_schema = provider_schema
return provider_schema
def models(self, model_type: ModelType) -> list[AIModelEntity]:
@ -84,7 +85,7 @@ class ModelProvider(ABC):
:return:
"""
# get dirname of the current path
provider_name = self.__class__.__module__.split('.')[-1]
provider_name = self.__class__.__module__.split(".")[-1]
if f"{provider_name}.{model_type.value}" in self.model_instance_map:
return self.model_instance_map[f"{provider_name}.{model_type.value}"]
@ -101,11 +102,17 @@ class ModelProvider(ABC):
# Dynamic loading {model_type_name}.py file and find the subclass of AIModel
parent_module = '.'.join(self.__class__.__module__.split('.')[:-1])
mod = import_module_from_source(
f'{parent_module}.{model_type_name}.{model_type_name}', model_type_py_path)
model_class = next(filter(lambda x: x.__module__ == mod.__name__ and not x.__abstractmethods__,
get_subclasses_from_module(mod, AIModel)), None)
module_name=f"{parent_module}.{model_type_name}.{model_type_name}", py_file_path=model_type_py_path
)
model_class = next(
filter(
lambda x: x.__module__ == mod.__name__ and not x.__abstractmethods__,
get_subclasses_from_module(mod, AIModel),
),
None,
)
if not model_class:
raise Exception(f'Missing AIModel Class for model type {model_type} in {model_type_py_path}')
raise Exception(f"Missing AIModel Class for model type {model_type} in {model_type_py_path}")
model_instance_map = model_class()
self.model_instance_map[f"{provider_name}.{model_type.value}"] = model_instance_map

View File

@ -1,5 +1,6 @@
import logging
import os
from collections.abc import Sequence
from typing import Optional
from pydantic import BaseModel, ConfigDict
@ -16,20 +17,21 @@ logger = logging.getLogger(__name__)
class ModelProviderExtension(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
provider_instance: ModelProvider
name: str
position: Optional[int] = None
model_config = ConfigDict(arbitrary_types_allowed=True)
class ModelProviderFactory:
model_provider_extensions: dict[str, ModelProviderExtension] = None
model_provider_extensions: Optional[dict[str, ModelProviderExtension]] = None
def __init__(self) -> None:
# for cache in memory
self.get_providers()
def get_providers(self) -> list[ProviderEntity]:
def get_providers(self) -> Sequence[ProviderEntity]:
"""
Get all providers
:return: list of providers
@ -39,7 +41,7 @@ class ModelProviderFactory:
# traverse all model_provider_extensions
providers = []
for name, model_provider_extension in model_provider_extensions.items():
for model_provider_extension in model_provider_extensions.values():
# get model_provider instance
model_provider_instance = model_provider_extension.provider_instance
@ -57,7 +59,7 @@ class ModelProviderFactory:
# return providers
return providers
def provider_credentials_validate(self, provider: str, credentials: dict) -> dict:
def provider_credentials_validate(self, *, provider: str, credentials: dict) -> dict:
"""
Validate provider credentials
@ -74,6 +76,9 @@ class ModelProviderFactory:
# get provider_credential_schema and validate credentials according to the rules
provider_credential_schema = provider_schema.provider_credential_schema
if not provider_credential_schema:
raise ValueError(f"Provider {provider} does not have provider_credential_schema")
# validate provider credential schema
validator = ProviderCredentialSchemaValidator(provider_credential_schema)
filtered_credentials = validator.validate_and_filter(credentials)
@ -83,8 +88,9 @@ class ModelProviderFactory:
return filtered_credentials
def model_credentials_validate(self, provider: str, model_type: ModelType,
model: str, credentials: dict) -> dict:
def model_credentials_validate(
self, *, provider: str, model_type: ModelType, model: str, credentials: dict
) -> dict:
"""
Validate model credentials
@ -103,6 +109,9 @@ class ModelProviderFactory:
# get model_credential_schema and validate credentials according to the rules
model_credential_schema = provider_schema.model_credential_schema
if not model_credential_schema:
raise ValueError(f"Provider {provider} does not have model_credential_schema")
# validate model credential schema
validator = ModelCredentialSchemaValidator(model_type, model_credential_schema)
filtered_credentials = validator.validate_and_filter(credentials)
@ -115,11 +124,13 @@ class ModelProviderFactory:
return filtered_credentials
def get_models(self,
provider: Optional[str] = None,
model_type: Optional[ModelType] = None,
provider_configs: Optional[list[ProviderConfig]] = None) \
-> list[SimpleProviderEntity]:
def get_models(
self,
*,
provider: Optional[str] = None,
model_type: Optional[ModelType] = None,
provider_configs: Optional[list[ProviderConfig]] = None,
) -> list[SimpleProviderEntity]:
"""
Get all models for given model type
@ -128,6 +139,8 @@ class ModelProviderFactory:
:param provider_configs: list of provider configs
:return: list of models
"""
provider_configs = provider_configs or []
# scan all providers
model_provider_extensions = self._get_model_provider_map()
@ -184,7 +197,7 @@ class ModelProviderFactory:
# get the provider extension
model_provider_extension = model_provider_extensions.get(provider)
if not model_provider_extension:
raise Exception(f'Invalid provider: {provider}')
raise Exception(f"Invalid provider: {provider}")
# get the provider instance
model_provider_instance = model_provider_extension.provider_instance
@ -192,10 +205,22 @@ class ModelProviderFactory:
return model_provider_instance
def _get_model_provider_map(self) -> dict[str, ModelProviderExtension]:
"""
Retrieves the model provider map.
This method retrieves the model provider map, which is a dictionary containing the model provider names as keys
and instances of `ModelProviderExtension` as values. The model provider map is used to store information about
available model providers.
Returns:
A dictionary containing the model provider map.
Raises:
None.
"""
if self.model_provider_extensions:
return self.model_provider_extensions
# get the path of current classes
current_path = os.path.abspath(__file__)
model_providers_path = os.path.dirname(current_path)
@ -204,8 +229,8 @@ class ModelProviderFactory:
model_provider_dir_paths = [
os.path.join(model_providers_path, model_provider_dir)
for model_provider_dir in os.listdir(model_providers_path)
if not model_provider_dir.startswith('__')
and os.path.isdir(os.path.join(model_providers_path, model_provider_dir))
if not model_provider_dir.startswith("__")
and os.path.isdir(os.path.join(model_providers_path, model_provider_dir))
]
# get _position.yaml file path
@ -219,30 +244,33 @@ class ModelProviderFactory:
file_names = os.listdir(model_provider_dir_path)
if (model_provider_name + '.py') not in file_names:
if (model_provider_name + ".py") not in file_names:
logger.warning(f"Missing {model_provider_name}.py file in {model_provider_dir_path}, Skip.")
continue
# Dynamic loading {model_provider_name}.py file and find the subclass of ModelProvider
py_path = os.path.join(model_provider_dir_path, model_provider_name + '.py')
py_path = os.path.join(model_provider_dir_path, model_provider_name + ".py")
model_provider_class = load_single_subclass_from_source(
module_name=f'core.model_runtime.model_providers.{model_provider_name}.{model_provider_name}',
module_name=f"core.model_runtime.model_providers.{model_provider_name}.{model_provider_name}",
script_path=py_path,
parent_type=ModelProvider)
parent_type=ModelProvider,
)
if not model_provider_class:
logger.warning(f"Missing Model Provider Class that extends ModelProvider in {py_path}, Skip.")
continue
if f'{model_provider_name}.yaml' not in file_names:
if f"{model_provider_name}.yaml" not in file_names:
logger.warning(f"Missing {model_provider_name}.yaml file in {model_provider_dir_path}, Skip.")
continue
model_providers.append(ModelProviderExtension(
name=model_provider_name,
provider_instance=model_provider_class(),
position=position_map.get(model_provider_name)
))
model_providers.append(
ModelProviderExtension(
name=model_provider_name,
provider_instance=model_provider_class(),
position=position_map.get(model_provider_name),
)
)
sorted_extensions = sort_to_dict_by_position_map(position_map, model_providers, lambda x: x.name)

View File

@ -1,3 +1,5 @@
from collections.abc import Mapping
import openai
from httpx import Timeout
@ -12,7 +14,7 @@ from core.model_runtime.errors.invoke import (
class _CommonOpenAI:
def _to_credential_kwargs(self, credentials: dict) -> dict:
def _to_credential_kwargs(self, credentials: Mapping) -> dict:
"""
Transform credentials to kwargs for model instance
@ -25,9 +27,9 @@ class _CommonOpenAI:
"max_retries": 1,
}
if credentials.get('openai_api_base'):
credentials['openai_api_base'] = credentials['openai_api_base'].rstrip('/')
credentials_kwargs['base_url'] = credentials['openai_api_base'] + '/v1'
if credentials.get("openai_api_base"):
openai_api_base = credentials["openai_api_base"].rstrip("/")
credentials_kwargs["base_url"] = openai_api_base + "/v1"
if 'openai_organization' in credentials:
credentials_kwargs['organization'] = credentials['openai_organization']
@ -45,24 +47,14 @@ class _CommonOpenAI:
:return: Invoke error mapping
"""
return {
InvokeConnectionError: [
openai.APIConnectionError,
openai.APITimeoutError
],
InvokeServerUnavailableError: [
openai.InternalServerError
],
InvokeRateLimitError: [
openai.RateLimitError
],
InvokeAuthorizationError: [
openai.AuthenticationError,
openai.PermissionDeniedError
],
InvokeConnectionError: [openai.APIConnectionError, openai.APITimeoutError],
InvokeServerUnavailableError: [openai.InternalServerError],
InvokeRateLimitError: [openai.RateLimitError],
InvokeAuthorizationError: [openai.AuthenticationError, openai.PermissionDeniedError],
InvokeBadRequestError: [
openai.BadRequestError,
openai.NotFoundError,
openai.UnprocessableEntityError,
openai.APIError
]
openai.APIError,
],
}

View File

@ -1,4 +1,5 @@
import logging
from collections.abc import Mapping
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
@ -9,7 +10,7 @@ logger = logging.getLogger(__name__)
class OpenAIProvider(ModelProvider):
def validate_provider_credentials(self, credentials: dict) -> None:
def validate_provider_credentials(self, credentials: Mapping) -> None:
"""
Validate provider credentials
if validate failed, raise exception