mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-18 03:45:55 +08:00
fix(core): Fix incorrect type hints. (#5427)
This commit is contained in:
parent
e4259a8f13
commit
23fa3dedc4
@ -1,5 +1,5 @@
|
|||||||
import enum
|
import enum
|
||||||
import importlib
|
import importlib.util
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
@ -74,6 +74,8 @@ class Extensible:
|
|||||||
# Dynamic loading {subdir_name}.py file and find the subclass of Extensible
|
# Dynamic loading {subdir_name}.py file and find the subclass of Extensible
|
||||||
py_path = os.path.join(subdir_path, extension_name + '.py')
|
py_path = os.path.join(subdir_path, extension_name + '.py')
|
||||||
spec = importlib.util.spec_from_file_location(extension_name, py_path)
|
spec = importlib.util.spec_from_file_location(extension_name, py_path)
|
||||||
|
if not spec or not spec.loader:
|
||||||
|
raise Exception(f"Failed to load module {extension_name} from {py_path}")
|
||||||
mod = importlib.util.module_from_spec(spec)
|
mod = importlib.util.module_from_spec(spec)
|
||||||
spec.loader.exec_module(mod)
|
spec.loader.exec_module(mod)
|
||||||
|
|
||||||
@ -108,6 +110,6 @@ class Extensible:
|
|||||||
position=position
|
position=position
|
||||||
))
|
))
|
||||||
|
|
||||||
sorted_extensions = sort_to_dict_by_position_map(position_map, extensions, lambda x: x.name)
|
sorted_extensions = sort_to_dict_by_position_map(position_map=position_map, data=extensions, name_func=lambda x: x.name)
|
||||||
|
|
||||||
return sorted_extensions
|
return sorted_extensions
|
||||||
|
@ -5,11 +5,7 @@ from types import ModuleType
|
|||||||
from typing import AnyStr
|
from typing import AnyStr
|
||||||
|
|
||||||
|
|
||||||
def import_module_from_source(
|
def import_module_from_source(*, module_name: str, py_file_path: AnyStr, use_lazy_loader: bool = False) -> ModuleType:
|
||||||
module_name: str,
|
|
||||||
py_file_path: AnyStr,
|
|
||||||
use_lazy_loader: bool = False
|
|
||||||
) -> ModuleType:
|
|
||||||
"""
|
"""
|
||||||
Importing a module from the source file directly
|
Importing a module from the source file directly
|
||||||
"""
|
"""
|
||||||
@ -17,9 +13,13 @@ def import_module_from_source(
|
|||||||
existed_spec = importlib.util.find_spec(module_name)
|
existed_spec = importlib.util.find_spec(module_name)
|
||||||
if existed_spec:
|
if existed_spec:
|
||||||
spec = existed_spec
|
spec = existed_spec
|
||||||
|
if not spec.loader:
|
||||||
|
raise Exception(f"Failed to load module {module_name} from {py_file_path}")
|
||||||
else:
|
else:
|
||||||
# Refer to: https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly
|
# Refer to: https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly
|
||||||
spec = importlib.util.spec_from_file_location(module_name, py_file_path)
|
spec = importlib.util.spec_from_file_location(module_name, py_file_path)
|
||||||
|
if not spec or not spec.loader:
|
||||||
|
raise Exception(f"Failed to load module {module_name} from {py_file_path}")
|
||||||
if use_lazy_loader:
|
if use_lazy_loader:
|
||||||
# Refer to: https://docs.python.org/3/library/importlib.html#implementing-lazy-imports
|
# Refer to: https://docs.python.org/3/library/importlib.html#implementing-lazy-imports
|
||||||
spec.loader = importlib.util.LazyLoader(spec.loader)
|
spec.loader = importlib.util.LazyLoader(spec.loader)
|
||||||
@ -29,7 +29,7 @@ def import_module_from_source(
|
|||||||
spec.loader.exec_module(module)
|
spec.loader.exec_module(module)
|
||||||
return module
|
return module
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.exception(f'Failed to load module {module_name} from {py_file_path}: {str(e)}')
|
logging.exception(f"Failed to load module {module_name} from {py_file_path}: {str(e)}")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
@ -43,15 +43,14 @@ def get_subclasses_from_module(mod: ModuleType, parent_type: type) -> list[type]
|
|||||||
|
|
||||||
|
|
||||||
def load_single_subclass_from_source(
|
def load_single_subclass_from_source(
|
||||||
module_name: str,
|
*, module_name: str, script_path: AnyStr, parent_type: type, use_lazy_loader: bool = False
|
||||||
script_path: AnyStr,
|
|
||||||
parent_type: type,
|
|
||||||
use_lazy_loader: bool = False,
|
|
||||||
) -> type:
|
) -> type:
|
||||||
"""
|
"""
|
||||||
Load a single subclass from the source
|
Load a single subclass from the source
|
||||||
"""
|
"""
|
||||||
module = import_module_from_source(module_name, script_path, use_lazy_loader)
|
module = import_module_from_source(
|
||||||
|
module_name=module_name, py_file_path=script_path, use_lazy_loader=use_lazy_loader
|
||||||
|
)
|
||||||
subclasses = get_subclasses_from_module(module, parent_type)
|
subclasses = get_subclasses_from_module(module, parent_type)
|
||||||
match len(subclasses):
|
match len(subclasses):
|
||||||
case 1:
|
case 1:
|
||||||
|
@ -1,15 +1,12 @@
|
|||||||
import os
|
import os
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from typing import Any, AnyStr
|
from typing import Any
|
||||||
|
|
||||||
from core.tools.utils.yaml_utils import load_yaml_file
|
from core.tools.utils.yaml_utils import load_yaml_file
|
||||||
|
|
||||||
|
|
||||||
def get_position_map(
|
def get_position_map(folder_path: str, *, file_name: str = "_position.yaml") -> dict[str, int]:
|
||||||
folder_path: AnyStr,
|
|
||||||
file_name: str = '_position.yaml',
|
|
||||||
) -> dict[str, int]:
|
|
||||||
"""
|
"""
|
||||||
Get the mapping from name to index from a YAML file
|
Get the mapping from name to index from a YAML file
|
||||||
:param folder_path:
|
:param folder_path:
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from collections.abc import Generator
|
from collections.abc import Callable, Generator
|
||||||
from typing import IO, Optional, Union, cast
|
from typing import IO, Optional, Union, cast
|
||||||
|
|
||||||
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
|
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
|
||||||
@ -102,7 +102,7 @@ class ModelInstance:
|
|||||||
|
|
||||||
def invoke_llm(self, prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None,
|
def invoke_llm(self, prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None,
|
||||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||||
stream: bool = True, user: Optional[str] = None, callbacks: list[Callback] = None) \
|
stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) \
|
||||||
-> Union[LLMResult, Generator]:
|
-> Union[LLMResult, Generator]:
|
||||||
"""
|
"""
|
||||||
Invoke large language model
|
Invoke large language model
|
||||||
@ -291,7 +291,7 @@ class ModelInstance:
|
|||||||
streaming=streaming
|
streaming=streaming
|
||||||
)
|
)
|
||||||
|
|
||||||
def _round_robin_invoke(self, function: callable, *args, **kwargs):
|
def _round_robin_invoke(self, function: Callable, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
Round-robin invoke
|
Round-robin invoke
|
||||||
:param function: function to invoke
|
:param function: function to invoke
|
||||||
@ -437,6 +437,7 @@ class LBModelManager:
|
|||||||
|
|
||||||
while True:
|
while True:
|
||||||
current_index = redis_client.incr(cache_key)
|
current_index = redis_client.incr(cache_key)
|
||||||
|
current_index = cast(int, current_index)
|
||||||
if current_index >= 10000000:
|
if current_index >= 10000000:
|
||||||
current_index = 1
|
current_index = 1
|
||||||
redis_client.set(cache_key, current_index)
|
redis_client.set(cache_key, current_index)
|
||||||
@ -499,7 +500,10 @@ class LBModelManager:
|
|||||||
config.id
|
config.id
|
||||||
)
|
)
|
||||||
|
|
||||||
return redis_client.exists(cooldown_cache_key)
|
|
||||||
|
res = redis_client.exists(cooldown_cache_key)
|
||||||
|
res = cast(bool, res)
|
||||||
|
return res
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_config_in_cooldown_and_ttl(cls, tenant_id: str,
|
def get_config_in_cooldown_and_ttl(cls, tenant_id: str,
|
||||||
@ -528,4 +532,5 @@ class LBModelManager:
|
|||||||
if ttl == -2:
|
if ttl == -2:
|
||||||
return False, 0
|
return False, 0
|
||||||
|
|
||||||
|
ttl = cast(int, ttl)
|
||||||
return True, ttl
|
return True, ttl
|
||||||
|
@ -1,10 +1,11 @@
|
|||||||
|
from collections.abc import Sequence
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
|
||||||
from core.model_runtime.entities.common_entities import I18nObject
|
from core.model_runtime.entities.common_entities import I18nObject
|
||||||
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType, ProviderModel
|
from core.model_runtime.entities.model_entities import ModelType, ProviderModel
|
||||||
|
|
||||||
|
|
||||||
class ConfigurateMethod(Enum):
|
class ConfigurateMethod(Enum):
|
||||||
@ -93,8 +94,8 @@ class SimpleProviderEntity(BaseModel):
|
|||||||
label: I18nObject
|
label: I18nObject
|
||||||
icon_small: Optional[I18nObject] = None
|
icon_small: Optional[I18nObject] = None
|
||||||
icon_large: Optional[I18nObject] = None
|
icon_large: Optional[I18nObject] = None
|
||||||
supported_model_types: list[ModelType]
|
supported_model_types: Sequence[ModelType]
|
||||||
models: list[AIModelEntity] = []
|
models: list[ProviderModel] = []
|
||||||
|
|
||||||
|
|
||||||
class ProviderHelpEntity(BaseModel):
|
class ProviderHelpEntity(BaseModel):
|
||||||
@ -116,7 +117,7 @@ class ProviderEntity(BaseModel):
|
|||||||
icon_large: Optional[I18nObject] = None
|
icon_large: Optional[I18nObject] = None
|
||||||
background: Optional[str] = None
|
background: Optional[str] = None
|
||||||
help: Optional[ProviderHelpEntity] = None
|
help: Optional[ProviderHelpEntity] = None
|
||||||
supported_model_types: list[ModelType]
|
supported_model_types: Sequence[ModelType]
|
||||||
configurate_methods: list[ConfigurateMethod]
|
configurate_methods: list[ConfigurateMethod]
|
||||||
models: list[ProviderModel] = []
|
models: list[ProviderModel] = []
|
||||||
provider_credential_schema: Optional[ProviderCredentialSchema] = None
|
provider_credential_schema: Optional[ProviderCredentialSchema] = None
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import decimal
|
import decimal
|
||||||
import os
|
import os
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from collections.abc import Mapping
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from pydantic import ConfigDict
|
from pydantic import ConfigDict
|
||||||
@ -26,15 +27,16 @@ class AIModel(ABC):
|
|||||||
"""
|
"""
|
||||||
Base class for all models.
|
Base class for all models.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_type: ModelType
|
model_type: ModelType
|
||||||
model_schemas: list[AIModelEntity] = None
|
model_schemas: Optional[list[AIModelEntity]] = None
|
||||||
started_at: float = 0
|
started_at: float = 0
|
||||||
|
|
||||||
# pydantic configs
|
# pydantic configs
|
||||||
model_config = ConfigDict(protected_namespaces=())
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
def validate_credentials(self, model: str, credentials: Mapping) -> None:
|
||||||
"""
|
"""
|
||||||
Validate model credentials
|
Validate model credentials
|
||||||
|
|
||||||
@ -90,8 +92,8 @@ class AIModel(ABC):
|
|||||||
|
|
||||||
# get price info from predefined model schema
|
# get price info from predefined model schema
|
||||||
price_config: Optional[PriceConfig] = None
|
price_config: Optional[PriceConfig] = None
|
||||||
if model_schema:
|
if model_schema and model_schema.pricing:
|
||||||
price_config: PriceConfig = model_schema.pricing
|
price_config = model_schema.pricing
|
||||||
|
|
||||||
# get unit price
|
# get unit price
|
||||||
unit_price = None
|
unit_price = None
|
||||||
@ -103,13 +105,15 @@ class AIModel(ABC):
|
|||||||
|
|
||||||
if unit_price is None:
|
if unit_price is None:
|
||||||
return PriceInfo(
|
return PriceInfo(
|
||||||
unit_price=decimal.Decimal('0.0'),
|
unit_price=decimal.Decimal("0.0"),
|
||||||
unit=decimal.Decimal('0.0'),
|
unit=decimal.Decimal("0.0"),
|
||||||
total_amount=decimal.Decimal('0.0'),
|
total_amount=decimal.Decimal("0.0"),
|
||||||
currency="USD",
|
currency="USD",
|
||||||
)
|
)
|
||||||
|
|
||||||
# calculate total amount
|
# calculate total amount
|
||||||
|
if not price_config:
|
||||||
|
raise ValueError(f"Price config not found for model {model}")
|
||||||
total_amount = tokens * unit_price * price_config.unit
|
total_amount = tokens * unit_price * price_config.unit
|
||||||
total_amount = total_amount.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
|
total_amount = total_amount.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
|
||||||
|
|
||||||
@ -209,7 +213,7 @@ class AIModel(ABC):
|
|||||||
|
|
||||||
return model_schemas
|
return model_schemas
|
||||||
|
|
||||||
def get_model_schema(self, model: str, credentials: Optional[dict] = None) -> Optional[AIModelEntity]:
|
def get_model_schema(self, model: str, credentials: Optional[Mapping] = None) -> Optional[AIModelEntity]:
|
||||||
"""
|
"""
|
||||||
Get model schema by model name and credentials
|
Get model schema by model name and credentials
|
||||||
|
|
||||||
@ -231,7 +235,7 @@ class AIModel(ABC):
|
|||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_customizable_model_schema_from_credentials(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
def get_customizable_model_schema_from_credentials(self, model: str, credentials: Mapping) -> Optional[AIModelEntity]:
|
||||||
"""
|
"""
|
||||||
Get customizable model schema from credentials
|
Get customizable model schema from credentials
|
||||||
|
|
||||||
@ -240,8 +244,8 @@ class AIModel(ABC):
|
|||||||
:return: model schema
|
:return: model schema
|
||||||
"""
|
"""
|
||||||
return self._get_customizable_model_schema(model, credentials)
|
return self._get_customizable_model_schema(model, credentials)
|
||||||
|
|
||||||
def _get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
def _get_customizable_model_schema(self, model: str, credentials: Mapping) -> Optional[AIModelEntity]:
|
||||||
"""
|
"""
|
||||||
Get customizable model schema and fill in the template
|
Get customizable model schema and fill in the template
|
||||||
"""
|
"""
|
||||||
@ -249,7 +253,7 @@ class AIModel(ABC):
|
|||||||
|
|
||||||
if not schema:
|
if not schema:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# fill in the template
|
# fill in the template
|
||||||
new_parameter_rules = []
|
new_parameter_rules = []
|
||||||
for parameter_rule in schema.parameter_rules:
|
for parameter_rule in schema.parameter_rules:
|
||||||
@ -271,10 +275,20 @@ class AIModel(ABC):
|
|||||||
parameter_rule.help = I18nObject(
|
parameter_rule.help = I18nObject(
|
||||||
en_US=default_parameter_rule['help']['en_US'],
|
en_US=default_parameter_rule['help']['en_US'],
|
||||||
)
|
)
|
||||||
if not parameter_rule.help.en_US and ('help' in default_parameter_rule and 'en_US' in default_parameter_rule['help']):
|
if (
|
||||||
parameter_rule.help.en_US = default_parameter_rule['help']['en_US']
|
parameter_rule.help
|
||||||
if not parameter_rule.help.zh_Hans and ('help' in default_parameter_rule and 'zh_Hans' in default_parameter_rule['help']):
|
and not parameter_rule.help.en_US
|
||||||
parameter_rule.help.zh_Hans = default_parameter_rule['help'].get('zh_Hans', default_parameter_rule['help']['en_US'])
|
and ("help" in default_parameter_rule and "en_US" in default_parameter_rule["help"])
|
||||||
|
):
|
||||||
|
parameter_rule.help.en_US = default_parameter_rule["help"]["en_US"]
|
||||||
|
if (
|
||||||
|
parameter_rule.help
|
||||||
|
and not parameter_rule.help.zh_Hans
|
||||||
|
and ("help" in default_parameter_rule and "zh_Hans" in default_parameter_rule["help"])
|
||||||
|
):
|
||||||
|
parameter_rule.help.zh_Hans = default_parameter_rule["help"].get(
|
||||||
|
"zh_Hans", default_parameter_rule["help"]["en_US"]
|
||||||
|
)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -284,7 +298,7 @@ class AIModel(ABC):
|
|||||||
|
|
||||||
return schema
|
return schema
|
||||||
|
|
||||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
def get_customizable_model_schema(self, model: str, credentials: Mapping) -> Optional[AIModelEntity]:
|
||||||
"""
|
"""
|
||||||
Get customizable model schema
|
Get customizable model schema
|
||||||
|
|
||||||
@ -304,7 +318,7 @@ class AIModel(ABC):
|
|||||||
default_parameter_rule = PARAMETER_RULE_TEMPLATE.get(name)
|
default_parameter_rule = PARAMETER_RULE_TEMPLATE.get(name)
|
||||||
|
|
||||||
if not default_parameter_rule:
|
if not default_parameter_rule:
|
||||||
raise Exception(f'Invalid model parameter rule name {name}')
|
raise Exception(f"Invalid model parameter rule name {name}")
|
||||||
|
|
||||||
return default_parameter_rule
|
return default_parameter_rule
|
||||||
|
|
||||||
@ -318,4 +332,4 @@ class AIModel(ABC):
|
|||||||
:param text: plain text of prompt. You need to convert the original message to plain text
|
:param text: plain text of prompt. You need to convert the original message to plain text
|
||||||
:return: number of tokens
|
:return: number of tokens
|
||||||
"""
|
"""
|
||||||
return GPT2Tokenizer.get_num_tokens(text)
|
return GPT2Tokenizer.get_num_tokens(text)
|
||||||
|
@ -3,7 +3,7 @@ import os
|
|||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator, Mapping
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
from pydantic import ConfigDict
|
from pydantic import ConfigDict
|
||||||
@ -43,7 +43,7 @@ class LargeLanguageModel(AIModel):
|
|||||||
def invoke(self, model: str, credentials: dict,
|
def invoke(self, model: str, credentials: dict,
|
||||||
prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None,
|
prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None,
|
||||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||||
stream: bool = True, user: Optional[str] = None, callbacks: list[Callback] = None) \
|
stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) \
|
||||||
-> Union[LLMResult, Generator]:
|
-> Union[LLMResult, Generator]:
|
||||||
"""
|
"""
|
||||||
Invoke large language model
|
Invoke large language model
|
||||||
@ -129,7 +129,7 @@ class LargeLanguageModel(AIModel):
|
|||||||
user=user,
|
user=user,
|
||||||
callbacks=callbacks
|
callbacks=callbacks
|
||||||
)
|
)
|
||||||
else:
|
elif isinstance(result, LLMResult):
|
||||||
self._trigger_after_invoke_callbacks(
|
self._trigger_after_invoke_callbacks(
|
||||||
model=model,
|
model=model,
|
||||||
result=result,
|
result=result,
|
||||||
@ -148,7 +148,7 @@ class LargeLanguageModel(AIModel):
|
|||||||
def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||||
model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None,
|
model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None,
|
||||||
stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None,
|
stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None,
|
||||||
callbacks: list[Callback] = None) -> Union[LLMResult, Generator]:
|
callbacks: Optional[list[Callback]] = None) -> Union[LLMResult, Generator]:
|
||||||
"""
|
"""
|
||||||
Code block mode wrapper, ensure the response is a code block with output markdown quote
|
Code block mode wrapper, ensure the response is a code block with output markdown quote
|
||||||
|
|
||||||
@ -196,7 +196,7 @@ if you are not sure about the structure.
|
|||||||
# override the system message
|
# override the system message
|
||||||
prompt_messages[0] = SystemPromptMessage(
|
prompt_messages[0] = SystemPromptMessage(
|
||||||
content=block_prompts
|
content=block_prompts
|
||||||
.replace("{{instructions}}", prompt_messages[0].content)
|
.replace("{{instructions}}", str(prompt_messages[0].content))
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# insert the system message
|
# insert the system message
|
||||||
@ -274,8 +274,9 @@ if you are not sure about the structure.
|
|||||||
else:
|
else:
|
||||||
yield piece
|
yield piece
|
||||||
continue
|
continue
|
||||||
new_piece = ""
|
new_piece: str = ""
|
||||||
for char in piece:
|
for char in piece:
|
||||||
|
char = str(char)
|
||||||
if state == "normal":
|
if state == "normal":
|
||||||
if char == "`":
|
if char == "`":
|
||||||
state = "in_backticks"
|
state = "in_backticks"
|
||||||
@ -340,7 +341,7 @@ if you are not sure about the structure.
|
|||||||
if state == "done":
|
if state == "done":
|
||||||
continue
|
continue
|
||||||
|
|
||||||
new_piece = ""
|
new_piece: str = ""
|
||||||
for char in piece:
|
for char in piece:
|
||||||
if state == "search_start":
|
if state == "search_start":
|
||||||
if char == "`":
|
if char == "`":
|
||||||
@ -365,7 +366,7 @@ if you are not sure about the structure.
|
|||||||
# If backticks were counted but we're still collecting content, it was a false start
|
# If backticks were counted but we're still collecting content, it was a false start
|
||||||
new_piece += "`" * backtick_count
|
new_piece += "`" * backtick_count
|
||||||
backtick_count = 0
|
backtick_count = 0
|
||||||
new_piece += char
|
new_piece += str(char)
|
||||||
|
|
||||||
elif state == "done":
|
elif state == "done":
|
||||||
break
|
break
|
||||||
@ -388,13 +389,14 @@ if you are not sure about the structure.
|
|||||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||||
tools: Optional[list[PromptMessageTool]] = None,
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
stop: Optional[list[str]] = None, stream: bool = True,
|
stop: Optional[list[str]] = None, stream: bool = True,
|
||||||
user: Optional[str] = None, callbacks: list[Callback] = None) -> Generator:
|
user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) -> Generator:
|
||||||
"""
|
"""
|
||||||
Invoke result generator
|
Invoke result generator
|
||||||
|
|
||||||
:param result: result generator
|
:param result: result generator
|
||||||
:return: result generator
|
:return: result generator
|
||||||
"""
|
"""
|
||||||
|
callbacks = callbacks or []
|
||||||
prompt_message = AssistantPromptMessage(
|
prompt_message = AssistantPromptMessage(
|
||||||
content=""
|
content=""
|
||||||
)
|
)
|
||||||
@ -489,6 +491,7 @@ if you are not sure about the structure.
|
|||||||
|
|
||||||
def _llm_result_to_stream(self, result: LLMResult) -> Generator:
|
def _llm_result_to_stream(self, result: LLMResult) -> Generator:
|
||||||
"""
|
"""
|
||||||
|
from typing_extensions import deprecated
|
||||||
Transform llm result to stream
|
Transform llm result to stream
|
||||||
|
|
||||||
:param result: llm result
|
:param result: llm result
|
||||||
@ -531,7 +534,7 @@ if you are not sure about the structure.
|
|||||||
|
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def get_model_mode(self, model: str, credentials: Optional[dict] = None) -> LLMMode:
|
def get_model_mode(self, model: str, credentials: Optional[Mapping] = None) -> LLMMode:
|
||||||
"""
|
"""
|
||||||
Get model mode
|
Get model mode
|
||||||
|
|
||||||
@ -595,7 +598,7 @@ if you are not sure about the structure.
|
|||||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||||
tools: Optional[list[PromptMessageTool]] = None,
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
stop: Optional[list[str]] = None, stream: bool = True,
|
stop: Optional[list[str]] = None, stream: bool = True,
|
||||||
user: Optional[str] = None, callbacks: list[Callback] = None) -> None:
|
user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) -> None:
|
||||||
"""
|
"""
|
||||||
Trigger before invoke callbacks
|
Trigger before invoke callbacks
|
||||||
|
|
||||||
@ -633,7 +636,7 @@ if you are not sure about the structure.
|
|||||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||||
tools: Optional[list[PromptMessageTool]] = None,
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
stop: Optional[list[str]] = None, stream: bool = True,
|
stop: Optional[list[str]] = None, stream: bool = True,
|
||||||
user: Optional[str] = None, callbacks: list[Callback] = None) -> None:
|
user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) -> None:
|
||||||
"""
|
"""
|
||||||
Trigger new chunk callbacks
|
Trigger new chunk callbacks
|
||||||
|
|
||||||
@ -672,7 +675,7 @@ if you are not sure about the structure.
|
|||||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||||
tools: Optional[list[PromptMessageTool]] = None,
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
stop: Optional[list[str]] = None, stream: bool = True,
|
stop: Optional[list[str]] = None, stream: bool = True,
|
||||||
user: Optional[str] = None, callbacks: list[Callback] = None) -> None:
|
user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) -> None:
|
||||||
"""
|
"""
|
||||||
Trigger after invoke callbacks
|
Trigger after invoke callbacks
|
||||||
|
|
||||||
@ -712,7 +715,7 @@ if you are not sure about the structure.
|
|||||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||||
tools: Optional[list[PromptMessageTool]] = None,
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
stop: Optional[list[str]] = None, stream: bool = True,
|
stop: Optional[list[str]] = None, stream: bool = True,
|
||||||
user: Optional[str] = None, callbacks: list[Callback] = None) -> None:
|
user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) -> None:
|
||||||
"""
|
"""
|
||||||
Trigger invoke error callbacks
|
Trigger invoke error callbacks
|
||||||
|
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from core.helper.module_import_helper import get_subclasses_from_module, import_module_from_source
|
from core.helper.module_import_helper import get_subclasses_from_module, import_module_from_source
|
||||||
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||||
@ -9,7 +10,7 @@ from core.tools.utils.yaml_utils import load_yaml_file
|
|||||||
|
|
||||||
|
|
||||||
class ModelProvider(ABC):
|
class ModelProvider(ABC):
|
||||||
provider_schema: ProviderEntity = None
|
provider_schema: Optional[ProviderEntity] = None
|
||||||
model_instance_map: dict[str, AIModel] = {}
|
model_instance_map: dict[str, AIModel] = {}
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -28,23 +29,23 @@ class ModelProvider(ABC):
|
|||||||
def get_provider_schema(self) -> ProviderEntity:
|
def get_provider_schema(self) -> ProviderEntity:
|
||||||
"""
|
"""
|
||||||
Get provider schema
|
Get provider schema
|
||||||
|
|
||||||
:return: provider schema
|
:return: provider schema
|
||||||
"""
|
"""
|
||||||
if self.provider_schema:
|
if self.provider_schema:
|
||||||
return self.provider_schema
|
return self.provider_schema
|
||||||
|
|
||||||
# get dirname of the current path
|
# get dirname of the current path
|
||||||
provider_name = self.__class__.__module__.split('.')[-1]
|
provider_name = self.__class__.__module__.split('.')[-1]
|
||||||
|
|
||||||
# get the path of the model_provider classes
|
# get the path of the model_provider classes
|
||||||
base_path = os.path.abspath(__file__)
|
base_path = os.path.abspath(__file__)
|
||||||
current_path = os.path.join(os.path.dirname(os.path.dirname(base_path)), provider_name)
|
current_path = os.path.join(os.path.dirname(os.path.dirname(base_path)), provider_name)
|
||||||
|
|
||||||
# read provider schema from yaml file
|
# read provider schema from yaml file
|
||||||
yaml_path = os.path.join(current_path, f'{provider_name}.yaml')
|
yaml_path = os.path.join(current_path, f'{provider_name}.yaml')
|
||||||
yaml_data = load_yaml_file(yaml_path, ignore_error=True)
|
yaml_data = load_yaml_file(yaml_path, ignore_error=True)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# yaml_data to entity
|
# yaml_data to entity
|
||||||
provider_schema = ProviderEntity(**yaml_data)
|
provider_schema = ProviderEntity(**yaml_data)
|
||||||
@ -53,7 +54,7 @@ class ModelProvider(ABC):
|
|||||||
|
|
||||||
# cache schema
|
# cache schema
|
||||||
self.provider_schema = provider_schema
|
self.provider_schema = provider_schema
|
||||||
|
|
||||||
return provider_schema
|
return provider_schema
|
||||||
|
|
||||||
def models(self, model_type: ModelType) -> list[AIModelEntity]:
|
def models(self, model_type: ModelType) -> list[AIModelEntity]:
|
||||||
@ -84,7 +85,7 @@ class ModelProvider(ABC):
|
|||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
# get dirname of the current path
|
# get dirname of the current path
|
||||||
provider_name = self.__class__.__module__.split('.')[-1]
|
provider_name = self.__class__.__module__.split(".")[-1]
|
||||||
|
|
||||||
if f"{provider_name}.{model_type.value}" in self.model_instance_map:
|
if f"{provider_name}.{model_type.value}" in self.model_instance_map:
|
||||||
return self.model_instance_map[f"{provider_name}.{model_type.value}"]
|
return self.model_instance_map[f"{provider_name}.{model_type.value}"]
|
||||||
@ -101,11 +102,17 @@ class ModelProvider(ABC):
|
|||||||
# Dynamic loading {model_type_name}.py file and find the subclass of AIModel
|
# Dynamic loading {model_type_name}.py file and find the subclass of AIModel
|
||||||
parent_module = '.'.join(self.__class__.__module__.split('.')[:-1])
|
parent_module = '.'.join(self.__class__.__module__.split('.')[:-1])
|
||||||
mod = import_module_from_source(
|
mod = import_module_from_source(
|
||||||
f'{parent_module}.{model_type_name}.{model_type_name}', model_type_py_path)
|
module_name=f"{parent_module}.{model_type_name}.{model_type_name}", py_file_path=model_type_py_path
|
||||||
model_class = next(filter(lambda x: x.__module__ == mod.__name__ and not x.__abstractmethods__,
|
)
|
||||||
get_subclasses_from_module(mod, AIModel)), None)
|
model_class = next(
|
||||||
|
filter(
|
||||||
|
lambda x: x.__module__ == mod.__name__ and not x.__abstractmethods__,
|
||||||
|
get_subclasses_from_module(mod, AIModel),
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
if not model_class:
|
if not model_class:
|
||||||
raise Exception(f'Missing AIModel Class for model type {model_type} in {model_type_py_path}')
|
raise Exception(f"Missing AIModel Class for model type {model_type} in {model_type_py_path}")
|
||||||
|
|
||||||
model_instance_map = model_class()
|
model_instance_map = model_class()
|
||||||
self.model_instance_map[f"{provider_name}.{model_type.value}"] = model_instance_map
|
self.model_instance_map[f"{provider_name}.{model_type.value}"] = model_instance_map
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
from collections.abc import Sequence
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
@ -16,20 +17,21 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class ModelProviderExtension(BaseModel):
|
class ModelProviderExtension(BaseModel):
|
||||||
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||||
|
|
||||||
provider_instance: ModelProvider
|
provider_instance: ModelProvider
|
||||||
name: str
|
name: str
|
||||||
position: Optional[int] = None
|
position: Optional[int] = None
|
||||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
||||||
|
|
||||||
|
|
||||||
class ModelProviderFactory:
|
class ModelProviderFactory:
|
||||||
model_provider_extensions: dict[str, ModelProviderExtension] = None
|
model_provider_extensions: Optional[dict[str, ModelProviderExtension]] = None
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
# for cache in memory
|
# for cache in memory
|
||||||
self.get_providers()
|
self.get_providers()
|
||||||
|
|
||||||
def get_providers(self) -> list[ProviderEntity]:
|
def get_providers(self) -> Sequence[ProviderEntity]:
|
||||||
"""
|
"""
|
||||||
Get all providers
|
Get all providers
|
||||||
:return: list of providers
|
:return: list of providers
|
||||||
@ -39,7 +41,7 @@ class ModelProviderFactory:
|
|||||||
|
|
||||||
# traverse all model_provider_extensions
|
# traverse all model_provider_extensions
|
||||||
providers = []
|
providers = []
|
||||||
for name, model_provider_extension in model_provider_extensions.items():
|
for model_provider_extension in model_provider_extensions.values():
|
||||||
# get model_provider instance
|
# get model_provider instance
|
||||||
model_provider_instance = model_provider_extension.provider_instance
|
model_provider_instance = model_provider_extension.provider_instance
|
||||||
|
|
||||||
@ -57,7 +59,7 @@ class ModelProviderFactory:
|
|||||||
# return providers
|
# return providers
|
||||||
return providers
|
return providers
|
||||||
|
|
||||||
def provider_credentials_validate(self, provider: str, credentials: dict) -> dict:
|
def provider_credentials_validate(self, *, provider: str, credentials: dict) -> dict:
|
||||||
"""
|
"""
|
||||||
Validate provider credentials
|
Validate provider credentials
|
||||||
|
|
||||||
@ -74,6 +76,9 @@ class ModelProviderFactory:
|
|||||||
# get provider_credential_schema and validate credentials according to the rules
|
# get provider_credential_schema and validate credentials according to the rules
|
||||||
provider_credential_schema = provider_schema.provider_credential_schema
|
provider_credential_schema = provider_schema.provider_credential_schema
|
||||||
|
|
||||||
|
if not provider_credential_schema:
|
||||||
|
raise ValueError(f"Provider {provider} does not have provider_credential_schema")
|
||||||
|
|
||||||
# validate provider credential schema
|
# validate provider credential schema
|
||||||
validator = ProviderCredentialSchemaValidator(provider_credential_schema)
|
validator = ProviderCredentialSchemaValidator(provider_credential_schema)
|
||||||
filtered_credentials = validator.validate_and_filter(credentials)
|
filtered_credentials = validator.validate_and_filter(credentials)
|
||||||
@ -83,8 +88,9 @@ class ModelProviderFactory:
|
|||||||
|
|
||||||
return filtered_credentials
|
return filtered_credentials
|
||||||
|
|
||||||
def model_credentials_validate(self, provider: str, model_type: ModelType,
|
def model_credentials_validate(
|
||||||
model: str, credentials: dict) -> dict:
|
self, *, provider: str, model_type: ModelType, model: str, credentials: dict
|
||||||
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
Validate model credentials
|
Validate model credentials
|
||||||
|
|
||||||
@ -103,6 +109,9 @@ class ModelProviderFactory:
|
|||||||
# get model_credential_schema and validate credentials according to the rules
|
# get model_credential_schema and validate credentials according to the rules
|
||||||
model_credential_schema = provider_schema.model_credential_schema
|
model_credential_schema = provider_schema.model_credential_schema
|
||||||
|
|
||||||
|
if not model_credential_schema:
|
||||||
|
raise ValueError(f"Provider {provider} does not have model_credential_schema")
|
||||||
|
|
||||||
# validate model credential schema
|
# validate model credential schema
|
||||||
validator = ModelCredentialSchemaValidator(model_type, model_credential_schema)
|
validator = ModelCredentialSchemaValidator(model_type, model_credential_schema)
|
||||||
filtered_credentials = validator.validate_and_filter(credentials)
|
filtered_credentials = validator.validate_and_filter(credentials)
|
||||||
@ -115,11 +124,13 @@ class ModelProviderFactory:
|
|||||||
|
|
||||||
return filtered_credentials
|
return filtered_credentials
|
||||||
|
|
||||||
def get_models(self,
|
def get_models(
|
||||||
provider: Optional[str] = None,
|
self,
|
||||||
model_type: Optional[ModelType] = None,
|
*,
|
||||||
provider_configs: Optional[list[ProviderConfig]] = None) \
|
provider: Optional[str] = None,
|
||||||
-> list[SimpleProviderEntity]:
|
model_type: Optional[ModelType] = None,
|
||||||
|
provider_configs: Optional[list[ProviderConfig]] = None,
|
||||||
|
) -> list[SimpleProviderEntity]:
|
||||||
"""
|
"""
|
||||||
Get all models for given model type
|
Get all models for given model type
|
||||||
|
|
||||||
@ -128,6 +139,8 @@ class ModelProviderFactory:
|
|||||||
:param provider_configs: list of provider configs
|
:param provider_configs: list of provider configs
|
||||||
:return: list of models
|
:return: list of models
|
||||||
"""
|
"""
|
||||||
|
provider_configs = provider_configs or []
|
||||||
|
|
||||||
# scan all providers
|
# scan all providers
|
||||||
model_provider_extensions = self._get_model_provider_map()
|
model_provider_extensions = self._get_model_provider_map()
|
||||||
|
|
||||||
@ -184,7 +197,7 @@ class ModelProviderFactory:
|
|||||||
# get the provider extension
|
# get the provider extension
|
||||||
model_provider_extension = model_provider_extensions.get(provider)
|
model_provider_extension = model_provider_extensions.get(provider)
|
||||||
if not model_provider_extension:
|
if not model_provider_extension:
|
||||||
raise Exception(f'Invalid provider: {provider}')
|
raise Exception(f"Invalid provider: {provider}")
|
||||||
|
|
||||||
# get the provider instance
|
# get the provider instance
|
||||||
model_provider_instance = model_provider_extension.provider_instance
|
model_provider_instance = model_provider_extension.provider_instance
|
||||||
@ -192,10 +205,22 @@ class ModelProviderFactory:
|
|||||||
return model_provider_instance
|
return model_provider_instance
|
||||||
|
|
||||||
def _get_model_provider_map(self) -> dict[str, ModelProviderExtension]:
|
def _get_model_provider_map(self) -> dict[str, ModelProviderExtension]:
|
||||||
|
"""
|
||||||
|
Retrieves the model provider map.
|
||||||
|
|
||||||
|
This method retrieves the model provider map, which is a dictionary containing the model provider names as keys
|
||||||
|
and instances of `ModelProviderExtension` as values. The model provider map is used to store information about
|
||||||
|
available model providers.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary containing the model provider map.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
None.
|
||||||
|
"""
|
||||||
if self.model_provider_extensions:
|
if self.model_provider_extensions:
|
||||||
return self.model_provider_extensions
|
return self.model_provider_extensions
|
||||||
|
|
||||||
|
|
||||||
# get the path of current classes
|
# get the path of current classes
|
||||||
current_path = os.path.abspath(__file__)
|
current_path = os.path.abspath(__file__)
|
||||||
model_providers_path = os.path.dirname(current_path)
|
model_providers_path = os.path.dirname(current_path)
|
||||||
@ -204,8 +229,8 @@ class ModelProviderFactory:
|
|||||||
model_provider_dir_paths = [
|
model_provider_dir_paths = [
|
||||||
os.path.join(model_providers_path, model_provider_dir)
|
os.path.join(model_providers_path, model_provider_dir)
|
||||||
for model_provider_dir in os.listdir(model_providers_path)
|
for model_provider_dir in os.listdir(model_providers_path)
|
||||||
if not model_provider_dir.startswith('__')
|
if not model_provider_dir.startswith("__")
|
||||||
and os.path.isdir(os.path.join(model_providers_path, model_provider_dir))
|
and os.path.isdir(os.path.join(model_providers_path, model_provider_dir))
|
||||||
]
|
]
|
||||||
|
|
||||||
# get _position.yaml file path
|
# get _position.yaml file path
|
||||||
@ -219,30 +244,33 @@ class ModelProviderFactory:
|
|||||||
|
|
||||||
file_names = os.listdir(model_provider_dir_path)
|
file_names = os.listdir(model_provider_dir_path)
|
||||||
|
|
||||||
if (model_provider_name + '.py') not in file_names:
|
if (model_provider_name + ".py") not in file_names:
|
||||||
logger.warning(f"Missing {model_provider_name}.py file in {model_provider_dir_path}, Skip.")
|
logger.warning(f"Missing {model_provider_name}.py file in {model_provider_dir_path}, Skip.")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Dynamic loading {model_provider_name}.py file and find the subclass of ModelProvider
|
# Dynamic loading {model_provider_name}.py file and find the subclass of ModelProvider
|
||||||
py_path = os.path.join(model_provider_dir_path, model_provider_name + '.py')
|
py_path = os.path.join(model_provider_dir_path, model_provider_name + ".py")
|
||||||
model_provider_class = load_single_subclass_from_source(
|
model_provider_class = load_single_subclass_from_source(
|
||||||
module_name=f'core.model_runtime.model_providers.{model_provider_name}.{model_provider_name}',
|
module_name=f"core.model_runtime.model_providers.{model_provider_name}.{model_provider_name}",
|
||||||
script_path=py_path,
|
script_path=py_path,
|
||||||
parent_type=ModelProvider)
|
parent_type=ModelProvider,
|
||||||
|
)
|
||||||
|
|
||||||
if not model_provider_class:
|
if not model_provider_class:
|
||||||
logger.warning(f"Missing Model Provider Class that extends ModelProvider in {py_path}, Skip.")
|
logger.warning(f"Missing Model Provider Class that extends ModelProvider in {py_path}, Skip.")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if f'{model_provider_name}.yaml' not in file_names:
|
if f"{model_provider_name}.yaml" not in file_names:
|
||||||
logger.warning(f"Missing {model_provider_name}.yaml file in {model_provider_dir_path}, Skip.")
|
logger.warning(f"Missing {model_provider_name}.yaml file in {model_provider_dir_path}, Skip.")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
model_providers.append(ModelProviderExtension(
|
model_providers.append(
|
||||||
name=model_provider_name,
|
ModelProviderExtension(
|
||||||
provider_instance=model_provider_class(),
|
name=model_provider_name,
|
||||||
position=position_map.get(model_provider_name)
|
provider_instance=model_provider_class(),
|
||||||
))
|
position=position_map.get(model_provider_name),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
sorted_extensions = sort_to_dict_by_position_map(position_map, model_providers, lambda x: x.name)
|
sorted_extensions = sort_to_dict_by_position_map(position_map, model_providers, lambda x: x.name)
|
||||||
|
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
from collections.abc import Mapping
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
from httpx import Timeout
|
from httpx import Timeout
|
||||||
|
|
||||||
@ -12,7 +14,7 @@ from core.model_runtime.errors.invoke import (
|
|||||||
|
|
||||||
|
|
||||||
class _CommonOpenAI:
|
class _CommonOpenAI:
|
||||||
def _to_credential_kwargs(self, credentials: dict) -> dict:
|
def _to_credential_kwargs(self, credentials: Mapping) -> dict:
|
||||||
"""
|
"""
|
||||||
Transform credentials to kwargs for model instance
|
Transform credentials to kwargs for model instance
|
||||||
|
|
||||||
@ -25,9 +27,9 @@ class _CommonOpenAI:
|
|||||||
"max_retries": 1,
|
"max_retries": 1,
|
||||||
}
|
}
|
||||||
|
|
||||||
if credentials.get('openai_api_base'):
|
if credentials.get("openai_api_base"):
|
||||||
credentials['openai_api_base'] = credentials['openai_api_base'].rstrip('/')
|
openai_api_base = credentials["openai_api_base"].rstrip("/")
|
||||||
credentials_kwargs['base_url'] = credentials['openai_api_base'] + '/v1'
|
credentials_kwargs["base_url"] = openai_api_base + "/v1"
|
||||||
|
|
||||||
if 'openai_organization' in credentials:
|
if 'openai_organization' in credentials:
|
||||||
credentials_kwargs['organization'] = credentials['openai_organization']
|
credentials_kwargs['organization'] = credentials['openai_organization']
|
||||||
@ -45,24 +47,14 @@ class _CommonOpenAI:
|
|||||||
:return: Invoke error mapping
|
:return: Invoke error mapping
|
||||||
"""
|
"""
|
||||||
return {
|
return {
|
||||||
InvokeConnectionError: [
|
InvokeConnectionError: [openai.APIConnectionError, openai.APITimeoutError],
|
||||||
openai.APIConnectionError,
|
InvokeServerUnavailableError: [openai.InternalServerError],
|
||||||
openai.APITimeoutError
|
InvokeRateLimitError: [openai.RateLimitError],
|
||||||
],
|
InvokeAuthorizationError: [openai.AuthenticationError, openai.PermissionDeniedError],
|
||||||
InvokeServerUnavailableError: [
|
|
||||||
openai.InternalServerError
|
|
||||||
],
|
|
||||||
InvokeRateLimitError: [
|
|
||||||
openai.RateLimitError
|
|
||||||
],
|
|
||||||
InvokeAuthorizationError: [
|
|
||||||
openai.AuthenticationError,
|
|
||||||
openai.PermissionDeniedError
|
|
||||||
],
|
|
||||||
InvokeBadRequestError: [
|
InvokeBadRequestError: [
|
||||||
openai.BadRequestError,
|
openai.BadRequestError,
|
||||||
openai.NotFoundError,
|
openai.NotFoundError,
|
||||||
openai.UnprocessableEntityError,
|
openai.UnprocessableEntityError,
|
||||||
openai.APIError
|
openai.APIError,
|
||||||
]
|
],
|
||||||
}
|
}
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
|
from collections.abc import Mapping
|
||||||
|
|
||||||
from core.model_runtime.entities.model_entities import ModelType
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
@ -9,7 +10,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
class OpenAIProvider(ModelProvider):
|
class OpenAIProvider(ModelProvider):
|
||||||
|
|
||||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
def validate_provider_credentials(self, credentials: Mapping) -> None:
|
||||||
"""
|
"""
|
||||||
Validate provider credentials
|
Validate provider credentials
|
||||||
if validate failed, raise exception
|
if validate failed, raise exception
|
||||||
|
Loading…
x
Reference in New Issue
Block a user