mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-16 14:25:54 +08:00
fix(core): Fix incorrect type hints. (#5427)
This commit is contained in:
parent
e4259a8f13
commit
23fa3dedc4
@ -1,5 +1,5 @@
|
||||
import enum
|
||||
import importlib
|
||||
import importlib.util
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
@ -74,6 +74,8 @@ class Extensible:
|
||||
# Dynamic loading {subdir_name}.py file and find the subclass of Extensible
|
||||
py_path = os.path.join(subdir_path, extension_name + '.py')
|
||||
spec = importlib.util.spec_from_file_location(extension_name, py_path)
|
||||
if not spec or not spec.loader:
|
||||
raise Exception(f"Failed to load module {extension_name} from {py_path}")
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(mod)
|
||||
|
||||
@ -108,6 +110,6 @@ class Extensible:
|
||||
position=position
|
||||
))
|
||||
|
||||
sorted_extensions = sort_to_dict_by_position_map(position_map, extensions, lambda x: x.name)
|
||||
sorted_extensions = sort_to_dict_by_position_map(position_map=position_map, data=extensions, name_func=lambda x: x.name)
|
||||
|
||||
return sorted_extensions
|
||||
|
@ -5,11 +5,7 @@ from types import ModuleType
|
||||
from typing import AnyStr
|
||||
|
||||
|
||||
def import_module_from_source(
|
||||
module_name: str,
|
||||
py_file_path: AnyStr,
|
||||
use_lazy_loader: bool = False
|
||||
) -> ModuleType:
|
||||
def import_module_from_source(*, module_name: str, py_file_path: AnyStr, use_lazy_loader: bool = False) -> ModuleType:
|
||||
"""
|
||||
Importing a module from the source file directly
|
||||
"""
|
||||
@ -17,9 +13,13 @@ def import_module_from_source(
|
||||
existed_spec = importlib.util.find_spec(module_name)
|
||||
if existed_spec:
|
||||
spec = existed_spec
|
||||
if not spec.loader:
|
||||
raise Exception(f"Failed to load module {module_name} from {py_file_path}")
|
||||
else:
|
||||
# Refer to: https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly
|
||||
spec = importlib.util.spec_from_file_location(module_name, py_file_path)
|
||||
if not spec or not spec.loader:
|
||||
raise Exception(f"Failed to load module {module_name} from {py_file_path}")
|
||||
if use_lazy_loader:
|
||||
# Refer to: https://docs.python.org/3/library/importlib.html#implementing-lazy-imports
|
||||
spec.loader = importlib.util.LazyLoader(spec.loader)
|
||||
@ -29,7 +29,7 @@ def import_module_from_source(
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
except Exception as e:
|
||||
logging.exception(f'Failed to load module {module_name} from {py_file_path}: {str(e)}')
|
||||
logging.exception(f"Failed to load module {module_name} from {py_file_path}: {str(e)}")
|
||||
raise e
|
||||
|
||||
|
||||
@ -43,15 +43,14 @@ def get_subclasses_from_module(mod: ModuleType, parent_type: type) -> list[type]
|
||||
|
||||
|
||||
def load_single_subclass_from_source(
|
||||
module_name: str,
|
||||
script_path: AnyStr,
|
||||
parent_type: type,
|
||||
use_lazy_loader: bool = False,
|
||||
*, module_name: str, script_path: AnyStr, parent_type: type, use_lazy_loader: bool = False
|
||||
) -> type:
|
||||
"""
|
||||
Load a single subclass from the source
|
||||
"""
|
||||
module = import_module_from_source(module_name, script_path, use_lazy_loader)
|
||||
module = import_module_from_source(
|
||||
module_name=module_name, py_file_path=script_path, use_lazy_loader=use_lazy_loader
|
||||
)
|
||||
subclasses = get_subclasses_from_module(module, parent_type)
|
||||
match len(subclasses):
|
||||
case 1:
|
||||
|
@ -1,15 +1,12 @@
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Callable
|
||||
from typing import Any, AnyStr
|
||||
from typing import Any
|
||||
|
||||
from core.tools.utils.yaml_utils import load_yaml_file
|
||||
|
||||
|
||||
def get_position_map(
|
||||
folder_path: AnyStr,
|
||||
file_name: str = '_position.yaml',
|
||||
) -> dict[str, int]:
|
||||
def get_position_map(folder_path: str, *, file_name: str = "_position.yaml") -> dict[str, int]:
|
||||
"""
|
||||
Get the mapping from name to index from a YAML file
|
||||
:param folder_path:
|
||||
|
@ -1,6 +1,6 @@
|
||||
import logging
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Callable, Generator
|
||||
from typing import IO, Optional, Union, cast
|
||||
|
||||
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
|
||||
@ -102,7 +102,7 @@ class ModelInstance:
|
||||
|
||||
def invoke_llm(self, prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None, callbacks: list[Callback] = None) \
|
||||
stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke large language model
|
||||
@ -291,7 +291,7 @@ class ModelInstance:
|
||||
streaming=streaming
|
||||
)
|
||||
|
||||
def _round_robin_invoke(self, function: callable, *args, **kwargs):
|
||||
def _round_robin_invoke(self, function: Callable, *args, **kwargs):
|
||||
"""
|
||||
Round-robin invoke
|
||||
:param function: function to invoke
|
||||
@ -437,6 +437,7 @@ class LBModelManager:
|
||||
|
||||
while True:
|
||||
current_index = redis_client.incr(cache_key)
|
||||
current_index = cast(int, current_index)
|
||||
if current_index >= 10000000:
|
||||
current_index = 1
|
||||
redis_client.set(cache_key, current_index)
|
||||
@ -499,7 +500,10 @@ class LBModelManager:
|
||||
config.id
|
||||
)
|
||||
|
||||
return redis_client.exists(cooldown_cache_key)
|
||||
|
||||
res = redis_client.exists(cooldown_cache_key)
|
||||
res = cast(bool, res)
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
def get_config_in_cooldown_and_ttl(cls, tenant_id: str,
|
||||
@ -528,4 +532,5 @@ class LBModelManager:
|
||||
if ttl == -2:
|
||||
return False, 0
|
||||
|
||||
ttl = cast(int, ttl)
|
||||
return True, ttl
|
||||
|
@ -1,10 +1,11 @@
|
||||
from collections.abc import Sequence
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType, ProviderModel
|
||||
from core.model_runtime.entities.model_entities import ModelType, ProviderModel
|
||||
|
||||
|
||||
class ConfigurateMethod(Enum):
|
||||
@ -93,8 +94,8 @@ class SimpleProviderEntity(BaseModel):
|
||||
label: I18nObject
|
||||
icon_small: Optional[I18nObject] = None
|
||||
icon_large: Optional[I18nObject] = None
|
||||
supported_model_types: list[ModelType]
|
||||
models: list[AIModelEntity] = []
|
||||
supported_model_types: Sequence[ModelType]
|
||||
models: list[ProviderModel] = []
|
||||
|
||||
|
||||
class ProviderHelpEntity(BaseModel):
|
||||
@ -116,7 +117,7 @@ class ProviderEntity(BaseModel):
|
||||
icon_large: Optional[I18nObject] = None
|
||||
background: Optional[str] = None
|
||||
help: Optional[ProviderHelpEntity] = None
|
||||
supported_model_types: list[ModelType]
|
||||
supported_model_types: Sequence[ModelType]
|
||||
configurate_methods: list[ConfigurateMethod]
|
||||
models: list[ProviderModel] = []
|
||||
provider_credential_schema: Optional[ProviderCredentialSchema] = None
|
||||
|
@ -1,6 +1,7 @@
|
||||
import decimal
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Mapping
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import ConfigDict
|
||||
@ -26,15 +27,16 @@ class AIModel(ABC):
|
||||
"""
|
||||
Base class for all models.
|
||||
"""
|
||||
|
||||
model_type: ModelType
|
||||
model_schemas: list[AIModelEntity] = None
|
||||
model_schemas: Optional[list[AIModelEntity]] = None
|
||||
started_at: float = 0
|
||||
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
@abstractmethod
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
def validate_credentials(self, model: str, credentials: Mapping) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
@ -90,8 +92,8 @@ class AIModel(ABC):
|
||||
|
||||
# get price info from predefined model schema
|
||||
price_config: Optional[PriceConfig] = None
|
||||
if model_schema:
|
||||
price_config: PriceConfig = model_schema.pricing
|
||||
if model_schema and model_schema.pricing:
|
||||
price_config = model_schema.pricing
|
||||
|
||||
# get unit price
|
||||
unit_price = None
|
||||
@ -103,13 +105,15 @@ class AIModel(ABC):
|
||||
|
||||
if unit_price is None:
|
||||
return PriceInfo(
|
||||
unit_price=decimal.Decimal('0.0'),
|
||||
unit=decimal.Decimal('0.0'),
|
||||
total_amount=decimal.Decimal('0.0'),
|
||||
unit_price=decimal.Decimal("0.0"),
|
||||
unit=decimal.Decimal("0.0"),
|
||||
total_amount=decimal.Decimal("0.0"),
|
||||
currency="USD",
|
||||
)
|
||||
|
||||
# calculate total amount
|
||||
if not price_config:
|
||||
raise ValueError(f"Price config not found for model {model}")
|
||||
total_amount = tokens * unit_price * price_config.unit
|
||||
total_amount = total_amount.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
@ -209,7 +213,7 @@ class AIModel(ABC):
|
||||
|
||||
return model_schemas
|
||||
|
||||
def get_model_schema(self, model: str, credentials: Optional[dict] = None) -> Optional[AIModelEntity]:
|
||||
def get_model_schema(self, model: str, credentials: Optional[Mapping] = None) -> Optional[AIModelEntity]:
|
||||
"""
|
||||
Get model schema by model name and credentials
|
||||
|
||||
@ -231,7 +235,7 @@ class AIModel(ABC):
|
||||
|
||||
return None
|
||||
|
||||
def get_customizable_model_schema_from_credentials(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
def get_customizable_model_schema_from_credentials(self, model: str, credentials: Mapping) -> Optional[AIModelEntity]:
|
||||
"""
|
||||
Get customizable model schema from credentials
|
||||
|
||||
@ -240,8 +244,8 @@ class AIModel(ABC):
|
||||
:return: model schema
|
||||
"""
|
||||
return self._get_customizable_model_schema(model, credentials)
|
||||
|
||||
def _get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
|
||||
def _get_customizable_model_schema(self, model: str, credentials: Mapping) -> Optional[AIModelEntity]:
|
||||
"""
|
||||
Get customizable model schema and fill in the template
|
||||
"""
|
||||
@ -249,7 +253,7 @@ class AIModel(ABC):
|
||||
|
||||
if not schema:
|
||||
return None
|
||||
|
||||
|
||||
# fill in the template
|
||||
new_parameter_rules = []
|
||||
for parameter_rule in schema.parameter_rules:
|
||||
@ -271,10 +275,20 @@ class AIModel(ABC):
|
||||
parameter_rule.help = I18nObject(
|
||||
en_US=default_parameter_rule['help']['en_US'],
|
||||
)
|
||||
if not parameter_rule.help.en_US and ('help' in default_parameter_rule and 'en_US' in default_parameter_rule['help']):
|
||||
parameter_rule.help.en_US = default_parameter_rule['help']['en_US']
|
||||
if not parameter_rule.help.zh_Hans and ('help' in default_parameter_rule and 'zh_Hans' in default_parameter_rule['help']):
|
||||
parameter_rule.help.zh_Hans = default_parameter_rule['help'].get('zh_Hans', default_parameter_rule['help']['en_US'])
|
||||
if (
|
||||
parameter_rule.help
|
||||
and not parameter_rule.help.en_US
|
||||
and ("help" in default_parameter_rule and "en_US" in default_parameter_rule["help"])
|
||||
):
|
||||
parameter_rule.help.en_US = default_parameter_rule["help"]["en_US"]
|
||||
if (
|
||||
parameter_rule.help
|
||||
and not parameter_rule.help.zh_Hans
|
||||
and ("help" in default_parameter_rule and "zh_Hans" in default_parameter_rule["help"])
|
||||
):
|
||||
parameter_rule.help.zh_Hans = default_parameter_rule["help"].get(
|
||||
"zh_Hans", default_parameter_rule["help"]["en_US"]
|
||||
)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
@ -284,7 +298,7 @@ class AIModel(ABC):
|
||||
|
||||
return schema
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
def get_customizable_model_schema(self, model: str, credentials: Mapping) -> Optional[AIModelEntity]:
|
||||
"""
|
||||
Get customizable model schema
|
||||
|
||||
@ -304,7 +318,7 @@ class AIModel(ABC):
|
||||
default_parameter_rule = PARAMETER_RULE_TEMPLATE.get(name)
|
||||
|
||||
if not default_parameter_rule:
|
||||
raise Exception(f'Invalid model parameter rule name {name}')
|
||||
raise Exception(f"Invalid model parameter rule name {name}")
|
||||
|
||||
return default_parameter_rule
|
||||
|
||||
@ -318,4 +332,4 @@ class AIModel(ABC):
|
||||
:param text: plain text of prompt. You need to convert the original message to plain text
|
||||
:return: number of tokens
|
||||
"""
|
||||
return GPT2Tokenizer.get_num_tokens(text)
|
||||
return GPT2Tokenizer.get_num_tokens(text)
|
||||
|
@ -3,7 +3,7 @@ import os
|
||||
import re
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Optional, Union
|
||||
|
||||
from pydantic import ConfigDict
|
||||
@ -43,7 +43,7 @@ class LargeLanguageModel(AIModel):
|
||||
def invoke(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None, callbacks: list[Callback] = None) \
|
||||
stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke large language model
|
||||
@ -129,7 +129,7 @@ class LargeLanguageModel(AIModel):
|
||||
user=user,
|
||||
callbacks=callbacks
|
||||
)
|
||||
else:
|
||||
elif isinstance(result, LLMResult):
|
||||
self._trigger_after_invoke_callbacks(
|
||||
model=model,
|
||||
result=result,
|
||||
@ -148,7 +148,7 @@ class LargeLanguageModel(AIModel):
|
||||
def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None,
|
||||
callbacks: list[Callback] = None) -> Union[LLMResult, Generator]:
|
||||
callbacks: Optional[list[Callback]] = None) -> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Code block mode wrapper, ensure the response is a code block with output markdown quote
|
||||
|
||||
@ -196,7 +196,7 @@ if you are not sure about the structure.
|
||||
# override the system message
|
||||
prompt_messages[0] = SystemPromptMessage(
|
||||
content=block_prompts
|
||||
.replace("{{instructions}}", prompt_messages[0].content)
|
||||
.replace("{{instructions}}", str(prompt_messages[0].content))
|
||||
)
|
||||
else:
|
||||
# insert the system message
|
||||
@ -274,8 +274,9 @@ if you are not sure about the structure.
|
||||
else:
|
||||
yield piece
|
||||
continue
|
||||
new_piece = ""
|
||||
new_piece: str = ""
|
||||
for char in piece:
|
||||
char = str(char)
|
||||
if state == "normal":
|
||||
if char == "`":
|
||||
state = "in_backticks"
|
||||
@ -340,7 +341,7 @@ if you are not sure about the structure.
|
||||
if state == "done":
|
||||
continue
|
||||
|
||||
new_piece = ""
|
||||
new_piece: str = ""
|
||||
for char in piece:
|
||||
if state == "search_start":
|
||||
if char == "`":
|
||||
@ -365,7 +366,7 @@ if you are not sure about the structure.
|
||||
# If backticks were counted but we're still collecting content, it was a false start
|
||||
new_piece += "`" * backtick_count
|
||||
backtick_count = 0
|
||||
new_piece += char
|
||||
new_piece += str(char)
|
||||
|
||||
elif state == "done":
|
||||
break
|
||||
@ -388,13 +389,14 @@ if you are not sure about the structure.
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None, stream: bool = True,
|
||||
user: Optional[str] = None, callbacks: list[Callback] = None) -> Generator:
|
||||
user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) -> Generator:
|
||||
"""
|
||||
Invoke result generator
|
||||
|
||||
:param result: result generator
|
||||
:return: result generator
|
||||
"""
|
||||
callbacks = callbacks or []
|
||||
prompt_message = AssistantPromptMessage(
|
||||
content=""
|
||||
)
|
||||
@ -489,6 +491,7 @@ if you are not sure about the structure.
|
||||
|
||||
def _llm_result_to_stream(self, result: LLMResult) -> Generator:
|
||||
"""
|
||||
from typing_extensions import deprecated
|
||||
Transform llm result to stream
|
||||
|
||||
:param result: llm result
|
||||
@ -531,7 +534,7 @@ if you are not sure about the structure.
|
||||
|
||||
return []
|
||||
|
||||
def get_model_mode(self, model: str, credentials: Optional[dict] = None) -> LLMMode:
|
||||
def get_model_mode(self, model: str, credentials: Optional[Mapping] = None) -> LLMMode:
|
||||
"""
|
||||
Get model mode
|
||||
|
||||
@ -595,7 +598,7 @@ if you are not sure about the structure.
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None, stream: bool = True,
|
||||
user: Optional[str] = None, callbacks: list[Callback] = None) -> None:
|
||||
user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) -> None:
|
||||
"""
|
||||
Trigger before invoke callbacks
|
||||
|
||||
@ -633,7 +636,7 @@ if you are not sure about the structure.
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None, stream: bool = True,
|
||||
user: Optional[str] = None, callbacks: list[Callback] = None) -> None:
|
||||
user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) -> None:
|
||||
"""
|
||||
Trigger new chunk callbacks
|
||||
|
||||
@ -672,7 +675,7 @@ if you are not sure about the structure.
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None, stream: bool = True,
|
||||
user: Optional[str] = None, callbacks: list[Callback] = None) -> None:
|
||||
user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) -> None:
|
||||
"""
|
||||
Trigger after invoke callbacks
|
||||
|
||||
@ -712,7 +715,7 @@ if you are not sure about the structure.
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None, stream: bool = True,
|
||||
user: Optional[str] = None, callbacks: list[Callback] = None) -> None:
|
||||
user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) -> None:
|
||||
"""
|
||||
Trigger invoke error callbacks
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from core.helper.module_import_helper import get_subclasses_from_module, import_module_from_source
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||
@ -9,7 +10,7 @@ from core.tools.utils.yaml_utils import load_yaml_file
|
||||
|
||||
|
||||
class ModelProvider(ABC):
|
||||
provider_schema: ProviderEntity = None
|
||||
provider_schema: Optional[ProviderEntity] = None
|
||||
model_instance_map: dict[str, AIModel] = {}
|
||||
|
||||
@abstractmethod
|
||||
@ -28,23 +29,23 @@ class ModelProvider(ABC):
|
||||
def get_provider_schema(self) -> ProviderEntity:
|
||||
"""
|
||||
Get provider schema
|
||||
|
||||
|
||||
:return: provider schema
|
||||
"""
|
||||
if self.provider_schema:
|
||||
return self.provider_schema
|
||||
|
||||
|
||||
# get dirname of the current path
|
||||
provider_name = self.__class__.__module__.split('.')[-1]
|
||||
|
||||
# get the path of the model_provider classes
|
||||
base_path = os.path.abspath(__file__)
|
||||
current_path = os.path.join(os.path.dirname(os.path.dirname(base_path)), provider_name)
|
||||
|
||||
|
||||
# read provider schema from yaml file
|
||||
yaml_path = os.path.join(current_path, f'{provider_name}.yaml')
|
||||
yaml_data = load_yaml_file(yaml_path, ignore_error=True)
|
||||
|
||||
|
||||
try:
|
||||
# yaml_data to entity
|
||||
provider_schema = ProviderEntity(**yaml_data)
|
||||
@ -53,7 +54,7 @@ class ModelProvider(ABC):
|
||||
|
||||
# cache schema
|
||||
self.provider_schema = provider_schema
|
||||
|
||||
|
||||
return provider_schema
|
||||
|
||||
def models(self, model_type: ModelType) -> list[AIModelEntity]:
|
||||
@ -84,7 +85,7 @@ class ModelProvider(ABC):
|
||||
:return:
|
||||
"""
|
||||
# get dirname of the current path
|
||||
provider_name = self.__class__.__module__.split('.')[-1]
|
||||
provider_name = self.__class__.__module__.split(".")[-1]
|
||||
|
||||
if f"{provider_name}.{model_type.value}" in self.model_instance_map:
|
||||
return self.model_instance_map[f"{provider_name}.{model_type.value}"]
|
||||
@ -101,11 +102,17 @@ class ModelProvider(ABC):
|
||||
# Dynamic loading {model_type_name}.py file and find the subclass of AIModel
|
||||
parent_module = '.'.join(self.__class__.__module__.split('.')[:-1])
|
||||
mod = import_module_from_source(
|
||||
f'{parent_module}.{model_type_name}.{model_type_name}', model_type_py_path)
|
||||
model_class = next(filter(lambda x: x.__module__ == mod.__name__ and not x.__abstractmethods__,
|
||||
get_subclasses_from_module(mod, AIModel)), None)
|
||||
module_name=f"{parent_module}.{model_type_name}.{model_type_name}", py_file_path=model_type_py_path
|
||||
)
|
||||
model_class = next(
|
||||
filter(
|
||||
lambda x: x.__module__ == mod.__name__ and not x.__abstractmethods__,
|
||||
get_subclasses_from_module(mod, AIModel),
|
||||
),
|
||||
None,
|
||||
)
|
||||
if not model_class:
|
||||
raise Exception(f'Missing AIModel Class for model type {model_type} in {model_type_py_path}')
|
||||
raise Exception(f"Missing AIModel Class for model type {model_type} in {model_type_py_path}")
|
||||
|
||||
model_instance_map = model_class()
|
||||
self.model_instance_map[f"{provider_name}.{model_type.value}"] = model_instance_map
|
||||
|
@ -1,5 +1,6 @@
|
||||
import logging
|
||||
import os
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
@ -16,20 +17,21 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModelProviderExtension(BaseModel):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
provider_instance: ModelProvider
|
||||
name: str
|
||||
position: Optional[int] = None
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
class ModelProviderFactory:
|
||||
model_provider_extensions: dict[str, ModelProviderExtension] = None
|
||||
model_provider_extensions: Optional[dict[str, ModelProviderExtension]] = None
|
||||
|
||||
def __init__(self) -> None:
|
||||
# for cache in memory
|
||||
self.get_providers()
|
||||
|
||||
def get_providers(self) -> list[ProviderEntity]:
|
||||
def get_providers(self) -> Sequence[ProviderEntity]:
|
||||
"""
|
||||
Get all providers
|
||||
:return: list of providers
|
||||
@ -39,7 +41,7 @@ class ModelProviderFactory:
|
||||
|
||||
# traverse all model_provider_extensions
|
||||
providers = []
|
||||
for name, model_provider_extension in model_provider_extensions.items():
|
||||
for model_provider_extension in model_provider_extensions.values():
|
||||
# get model_provider instance
|
||||
model_provider_instance = model_provider_extension.provider_instance
|
||||
|
||||
@ -57,7 +59,7 @@ class ModelProviderFactory:
|
||||
# return providers
|
||||
return providers
|
||||
|
||||
def provider_credentials_validate(self, provider: str, credentials: dict) -> dict:
|
||||
def provider_credentials_validate(self, *, provider: str, credentials: dict) -> dict:
|
||||
"""
|
||||
Validate provider credentials
|
||||
|
||||
@ -74,6 +76,9 @@ class ModelProviderFactory:
|
||||
# get provider_credential_schema and validate credentials according to the rules
|
||||
provider_credential_schema = provider_schema.provider_credential_schema
|
||||
|
||||
if not provider_credential_schema:
|
||||
raise ValueError(f"Provider {provider} does not have provider_credential_schema")
|
||||
|
||||
# validate provider credential schema
|
||||
validator = ProviderCredentialSchemaValidator(provider_credential_schema)
|
||||
filtered_credentials = validator.validate_and_filter(credentials)
|
||||
@ -83,8 +88,9 @@ class ModelProviderFactory:
|
||||
|
||||
return filtered_credentials
|
||||
|
||||
def model_credentials_validate(self, provider: str, model_type: ModelType,
|
||||
model: str, credentials: dict) -> dict:
|
||||
def model_credentials_validate(
|
||||
self, *, provider: str, model_type: ModelType, model: str, credentials: dict
|
||||
) -> dict:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
@ -103,6 +109,9 @@ class ModelProviderFactory:
|
||||
# get model_credential_schema and validate credentials according to the rules
|
||||
model_credential_schema = provider_schema.model_credential_schema
|
||||
|
||||
if not model_credential_schema:
|
||||
raise ValueError(f"Provider {provider} does not have model_credential_schema")
|
||||
|
||||
# validate model credential schema
|
||||
validator = ModelCredentialSchemaValidator(model_type, model_credential_schema)
|
||||
filtered_credentials = validator.validate_and_filter(credentials)
|
||||
@ -115,11 +124,13 @@ class ModelProviderFactory:
|
||||
|
||||
return filtered_credentials
|
||||
|
||||
def get_models(self,
|
||||
provider: Optional[str] = None,
|
||||
model_type: Optional[ModelType] = None,
|
||||
provider_configs: Optional[list[ProviderConfig]] = None) \
|
||||
-> list[SimpleProviderEntity]:
|
||||
def get_models(
|
||||
self,
|
||||
*,
|
||||
provider: Optional[str] = None,
|
||||
model_type: Optional[ModelType] = None,
|
||||
provider_configs: Optional[list[ProviderConfig]] = None,
|
||||
) -> list[SimpleProviderEntity]:
|
||||
"""
|
||||
Get all models for given model type
|
||||
|
||||
@ -128,6 +139,8 @@ class ModelProviderFactory:
|
||||
:param provider_configs: list of provider configs
|
||||
:return: list of models
|
||||
"""
|
||||
provider_configs = provider_configs or []
|
||||
|
||||
# scan all providers
|
||||
model_provider_extensions = self._get_model_provider_map()
|
||||
|
||||
@ -184,7 +197,7 @@ class ModelProviderFactory:
|
||||
# get the provider extension
|
||||
model_provider_extension = model_provider_extensions.get(provider)
|
||||
if not model_provider_extension:
|
||||
raise Exception(f'Invalid provider: {provider}')
|
||||
raise Exception(f"Invalid provider: {provider}")
|
||||
|
||||
# get the provider instance
|
||||
model_provider_instance = model_provider_extension.provider_instance
|
||||
@ -192,10 +205,22 @@ class ModelProviderFactory:
|
||||
return model_provider_instance
|
||||
|
||||
def _get_model_provider_map(self) -> dict[str, ModelProviderExtension]:
|
||||
"""
|
||||
Retrieves the model provider map.
|
||||
|
||||
This method retrieves the model provider map, which is a dictionary containing the model provider names as keys
|
||||
and instances of `ModelProviderExtension` as values. The model provider map is used to store information about
|
||||
available model providers.
|
||||
|
||||
Returns:
|
||||
A dictionary containing the model provider map.
|
||||
|
||||
Raises:
|
||||
None.
|
||||
"""
|
||||
if self.model_provider_extensions:
|
||||
return self.model_provider_extensions
|
||||
|
||||
|
||||
# get the path of current classes
|
||||
current_path = os.path.abspath(__file__)
|
||||
model_providers_path = os.path.dirname(current_path)
|
||||
@ -204,8 +229,8 @@ class ModelProviderFactory:
|
||||
model_provider_dir_paths = [
|
||||
os.path.join(model_providers_path, model_provider_dir)
|
||||
for model_provider_dir in os.listdir(model_providers_path)
|
||||
if not model_provider_dir.startswith('__')
|
||||
and os.path.isdir(os.path.join(model_providers_path, model_provider_dir))
|
||||
if not model_provider_dir.startswith("__")
|
||||
and os.path.isdir(os.path.join(model_providers_path, model_provider_dir))
|
||||
]
|
||||
|
||||
# get _position.yaml file path
|
||||
@ -219,30 +244,33 @@ class ModelProviderFactory:
|
||||
|
||||
file_names = os.listdir(model_provider_dir_path)
|
||||
|
||||
if (model_provider_name + '.py') not in file_names:
|
||||
if (model_provider_name + ".py") not in file_names:
|
||||
logger.warning(f"Missing {model_provider_name}.py file in {model_provider_dir_path}, Skip.")
|
||||
continue
|
||||
|
||||
# Dynamic loading {model_provider_name}.py file and find the subclass of ModelProvider
|
||||
py_path = os.path.join(model_provider_dir_path, model_provider_name + '.py')
|
||||
py_path = os.path.join(model_provider_dir_path, model_provider_name + ".py")
|
||||
model_provider_class = load_single_subclass_from_source(
|
||||
module_name=f'core.model_runtime.model_providers.{model_provider_name}.{model_provider_name}',
|
||||
module_name=f"core.model_runtime.model_providers.{model_provider_name}.{model_provider_name}",
|
||||
script_path=py_path,
|
||||
parent_type=ModelProvider)
|
||||
parent_type=ModelProvider,
|
||||
)
|
||||
|
||||
if not model_provider_class:
|
||||
logger.warning(f"Missing Model Provider Class that extends ModelProvider in {py_path}, Skip.")
|
||||
continue
|
||||
|
||||
if f'{model_provider_name}.yaml' not in file_names:
|
||||
if f"{model_provider_name}.yaml" not in file_names:
|
||||
logger.warning(f"Missing {model_provider_name}.yaml file in {model_provider_dir_path}, Skip.")
|
||||
continue
|
||||
|
||||
model_providers.append(ModelProviderExtension(
|
||||
name=model_provider_name,
|
||||
provider_instance=model_provider_class(),
|
||||
position=position_map.get(model_provider_name)
|
||||
))
|
||||
model_providers.append(
|
||||
ModelProviderExtension(
|
||||
name=model_provider_name,
|
||||
provider_instance=model_provider_class(),
|
||||
position=position_map.get(model_provider_name),
|
||||
)
|
||||
)
|
||||
|
||||
sorted_extensions = sort_to_dict_by_position_map(position_map, model_providers, lambda x: x.name)
|
||||
|
||||
|
@ -1,3 +1,5 @@
|
||||
from collections.abc import Mapping
|
||||
|
||||
import openai
|
||||
from httpx import Timeout
|
||||
|
||||
@ -12,7 +14,7 @@ from core.model_runtime.errors.invoke import (
|
||||
|
||||
|
||||
class _CommonOpenAI:
|
||||
def _to_credential_kwargs(self, credentials: dict) -> dict:
|
||||
def _to_credential_kwargs(self, credentials: Mapping) -> dict:
|
||||
"""
|
||||
Transform credentials to kwargs for model instance
|
||||
|
||||
@ -25,9 +27,9 @@ class _CommonOpenAI:
|
||||
"max_retries": 1,
|
||||
}
|
||||
|
||||
if credentials.get('openai_api_base'):
|
||||
credentials['openai_api_base'] = credentials['openai_api_base'].rstrip('/')
|
||||
credentials_kwargs['base_url'] = credentials['openai_api_base'] + '/v1'
|
||||
if credentials.get("openai_api_base"):
|
||||
openai_api_base = credentials["openai_api_base"].rstrip("/")
|
||||
credentials_kwargs["base_url"] = openai_api_base + "/v1"
|
||||
|
||||
if 'openai_organization' in credentials:
|
||||
credentials_kwargs['organization'] = credentials['openai_organization']
|
||||
@ -45,24 +47,14 @@ class _CommonOpenAI:
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
return {
|
||||
InvokeConnectionError: [
|
||||
openai.APIConnectionError,
|
||||
openai.APITimeoutError
|
||||
],
|
||||
InvokeServerUnavailableError: [
|
||||
openai.InternalServerError
|
||||
],
|
||||
InvokeRateLimitError: [
|
||||
openai.RateLimitError
|
||||
],
|
||||
InvokeAuthorizationError: [
|
||||
openai.AuthenticationError,
|
||||
openai.PermissionDeniedError
|
||||
],
|
||||
InvokeConnectionError: [openai.APIConnectionError, openai.APITimeoutError],
|
||||
InvokeServerUnavailableError: [openai.InternalServerError],
|
||||
InvokeRateLimitError: [openai.RateLimitError],
|
||||
InvokeAuthorizationError: [openai.AuthenticationError, openai.PermissionDeniedError],
|
||||
InvokeBadRequestError: [
|
||||
openai.BadRequestError,
|
||||
openai.NotFoundError,
|
||||
openai.UnprocessableEntityError,
|
||||
openai.APIError
|
||||
]
|
||||
openai.APIError,
|
||||
],
|
||||
}
|
||||
|
@ -1,4 +1,5 @@
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
@ -9,7 +10,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class OpenAIProvider(ModelProvider):
|
||||
|
||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
def validate_provider_credentials(self, credentials: Mapping) -> None:
|
||||
"""
|
||||
Validate provider credentials
|
||||
if validate failed, raise exception
|
||||
|
Loading…
x
Reference in New Issue
Block a user