mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-13 14:48:57 +08:00
Azure openai init (#1929)
Co-authored-by: luowei <glpat-EjySCyNjWiLqAED-YmwM> Co-authored-by: crazywoola <427733928@qq.com> Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
This commit is contained in:
parent
b8592ad412
commit
5b24d7129e
@ -1,7 +1,7 @@
|
|||||||
import datetime
|
import datetime
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import time
|
|
||||||
from json import JSONDecodeError
|
from json import JSONDecodeError
|
||||||
from typing import Optional, List, Dict, Tuple, Iterator
|
from typing import Optional, List, Dict, Tuple, Iterator
|
||||||
|
|
||||||
@ -11,8 +11,9 @@ from core.entities.model_entities import ModelWithProviderEntity, ModelStatus, S
|
|||||||
from core.entities.provider_entities import SystemConfiguration, CustomConfiguration, SystemConfigurationStatus
|
from core.entities.provider_entities import SystemConfiguration, CustomConfiguration, SystemConfigurationStatus
|
||||||
from core.helper import encrypter
|
from core.helper import encrypter
|
||||||
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
|
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
|
||||||
from core.model_runtime.entities.model_entities import ModelType
|
from core.model_runtime.entities.model_entities import ModelType, FetchFrom
|
||||||
from core.model_runtime.entities.provider_entities import ProviderEntity, CredentialFormSchema, FormType
|
from core.model_runtime.entities.provider_entities import ProviderEntity, CredentialFormSchema, FormType, \
|
||||||
|
ConfigurateMethod
|
||||||
from core.model_runtime.model_providers import model_provider_factory
|
from core.model_runtime.model_providers import model_provider_factory
|
||||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||||
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
||||||
@ -22,6 +23,8 @@ from models.provider import ProviderType, Provider, ProviderModel, TenantPreferr
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
original_provider_configurate_methods = {}
|
||||||
|
|
||||||
|
|
||||||
class ProviderConfiguration(BaseModel):
|
class ProviderConfiguration(BaseModel):
|
||||||
"""
|
"""
|
||||||
@ -34,6 +37,20 @@ class ProviderConfiguration(BaseModel):
|
|||||||
system_configuration: SystemConfiguration
|
system_configuration: SystemConfiguration
|
||||||
custom_configuration: CustomConfiguration
|
custom_configuration: CustomConfiguration
|
||||||
|
|
||||||
|
def __init__(self, **data):
|
||||||
|
super().__init__(**data)
|
||||||
|
|
||||||
|
if self.provider.provider not in original_provider_configurate_methods:
|
||||||
|
original_provider_configurate_methods[self.provider.provider] = []
|
||||||
|
for configurate_method in self.provider.configurate_methods:
|
||||||
|
original_provider_configurate_methods[self.provider.provider].append(configurate_method)
|
||||||
|
|
||||||
|
if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
|
||||||
|
if (any([len(quota_configuration.restrict_models) > 0
|
||||||
|
for quota_configuration in self.system_configuration.quota_configurations])
|
||||||
|
and ConfigurateMethod.PREDEFINED_MODEL not in self.provider.configurate_methods):
|
||||||
|
self.provider.configurate_methods.append(ConfigurateMethod.PREDEFINED_MODEL)
|
||||||
|
|
||||||
def get_current_credentials(self, model_type: ModelType, model: str) -> Optional[dict]:
|
def get_current_credentials(self, model_type: ModelType, model: str) -> Optional[dict]:
|
||||||
"""
|
"""
|
||||||
Get current credentials.
|
Get current credentials.
|
||||||
@ -123,7 +140,8 @@ class ProviderConfiguration(BaseModel):
|
|||||||
|
|
||||||
if provider_record:
|
if provider_record:
|
||||||
try:
|
try:
|
||||||
original_credentials = json.loads(provider_record.encrypted_config) if provider_record.encrypted_config else {}
|
original_credentials = json.loads(
|
||||||
|
provider_record.encrypted_config) if provider_record.encrypted_config else {}
|
||||||
except JSONDecodeError:
|
except JSONDecodeError:
|
||||||
original_credentials = {}
|
original_credentials = {}
|
||||||
|
|
||||||
@ -265,7 +283,8 @@ class ProviderConfiguration(BaseModel):
|
|||||||
|
|
||||||
if provider_model_record:
|
if provider_model_record:
|
||||||
try:
|
try:
|
||||||
original_credentials = json.loads(provider_model_record.encrypted_config) if provider_model_record.encrypted_config else {}
|
original_credentials = json.loads(
|
||||||
|
provider_model_record.encrypted_config) if provider_model_record.encrypted_config else {}
|
||||||
except JSONDecodeError:
|
except JSONDecodeError:
|
||||||
original_credentials = {}
|
original_credentials = {}
|
||||||
|
|
||||||
@ -534,21 +553,70 @@ class ProviderConfiguration(BaseModel):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.provider.provider not in original_provider_configurate_methods:
|
||||||
|
original_provider_configurate_methods[self.provider.provider] = []
|
||||||
|
for configurate_method in provider_instance.get_provider_schema().configurate_methods:
|
||||||
|
original_provider_configurate_methods[self.provider.provider].append(configurate_method)
|
||||||
|
|
||||||
|
should_use_custom_model = False
|
||||||
|
if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
|
||||||
|
should_use_custom_model = True
|
||||||
|
|
||||||
for quota_configuration in self.system_configuration.quota_configurations:
|
for quota_configuration in self.system_configuration.quota_configurations:
|
||||||
if self.system_configuration.current_quota_type != quota_configuration.quota_type:
|
if self.system_configuration.current_quota_type != quota_configuration.quota_type:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
restrict_llms = quota_configuration.restrict_llms
|
restrict_models = quota_configuration.restrict_models
|
||||||
if not restrict_llms:
|
if len(restrict_models) == 0:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
if should_use_custom_model:
|
||||||
|
if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
|
||||||
|
# only customizable model
|
||||||
|
for restrict_model in restrict_models:
|
||||||
|
copy_credentials = self.system_configuration.credentials.copy()
|
||||||
|
if restrict_model.base_model_name:
|
||||||
|
copy_credentials['base_model_name'] = restrict_model.base_model_name
|
||||||
|
|
||||||
|
try:
|
||||||
|
custom_model_schema = (
|
||||||
|
provider_instance.get_model_instance(restrict_model.model_type)
|
||||||
|
.get_customizable_model_schema_from_credentials(
|
||||||
|
restrict_model.model,
|
||||||
|
copy_credentials
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception as ex:
|
||||||
|
logger.warning(f'get custom model schema failed, {ex}')
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not custom_model_schema:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if custom_model_schema.model_type not in model_types:
|
||||||
|
continue
|
||||||
|
|
||||||
|
provider_models.append(
|
||||||
|
ModelWithProviderEntity(
|
||||||
|
model=custom_model_schema.model,
|
||||||
|
label=custom_model_schema.label,
|
||||||
|
model_type=custom_model_schema.model_type,
|
||||||
|
features=custom_model_schema.features,
|
||||||
|
fetch_from=FetchFrom.PREDEFINED_MODEL,
|
||||||
|
model_properties=custom_model_schema.model_properties,
|
||||||
|
deprecated=custom_model_schema.deprecated,
|
||||||
|
provider=SimpleModelProviderEntity(self.provider),
|
||||||
|
status=ModelStatus.ACTIVE
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# if llm name not in restricted llm list, remove it
|
# if llm name not in restricted llm list, remove it
|
||||||
|
restrict_model_names = [rm.model for rm in restrict_models]
|
||||||
for m in provider_models:
|
for m in provider_models:
|
||||||
if m.model_type == ModelType.LLM and m.model not in restrict_llms:
|
if m.model_type == ModelType.LLM and m.model not in restrict_model_names:
|
||||||
m.status = ModelStatus.NO_PERMISSION
|
m.status = ModelStatus.NO_PERMISSION
|
||||||
elif not quota_configuration.is_valid:
|
elif not quota_configuration.is_valid:
|
||||||
m.status = ModelStatus.QUOTA_EXCEEDED
|
m.status = ModelStatus.QUOTA_EXCEEDED
|
||||||
|
|
||||||
return provider_models
|
return provider_models
|
||||||
|
|
||||||
def _get_custom_provider_models(self,
|
def _get_custom_provider_models(self,
|
||||||
|
@ -21,6 +21,12 @@ class SystemConfigurationStatus(Enum):
|
|||||||
UNSUPPORTED = 'unsupported'
|
UNSUPPORTED = 'unsupported'
|
||||||
|
|
||||||
|
|
||||||
|
class RestrictModel(BaseModel):
|
||||||
|
model: str
|
||||||
|
base_model_name: Optional[str] = None
|
||||||
|
model_type: ModelType
|
||||||
|
|
||||||
|
|
||||||
class QuotaConfiguration(BaseModel):
|
class QuotaConfiguration(BaseModel):
|
||||||
"""
|
"""
|
||||||
Model class for provider quota configuration.
|
Model class for provider quota configuration.
|
||||||
@ -30,7 +36,7 @@ class QuotaConfiguration(BaseModel):
|
|||||||
quota_limit: int
|
quota_limit: int
|
||||||
quota_used: int
|
quota_used: int
|
||||||
is_valid: bool
|
is_valid: bool
|
||||||
restrict_llms: list[str] = []
|
restrict_models: list[RestrictModel] = []
|
||||||
|
|
||||||
|
|
||||||
class SystemConfiguration(BaseModel):
|
class SystemConfiguration(BaseModel):
|
||||||
|
@ -4,13 +4,14 @@ from typing import Optional
|
|||||||
from flask import Flask
|
from flask import Flask
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from core.entities.provider_entities import QuotaUnit
|
from core.entities.provider_entities import QuotaUnit, RestrictModel
|
||||||
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
from models.provider import ProviderQuotaType
|
from models.provider import ProviderQuotaType
|
||||||
|
|
||||||
|
|
||||||
class HostingQuota(BaseModel):
|
class HostingQuota(BaseModel):
|
||||||
quota_type: ProviderQuotaType
|
quota_type: ProviderQuotaType
|
||||||
restrict_llms: list[str] = []
|
restrict_models: list[RestrictModel] = []
|
||||||
|
|
||||||
|
|
||||||
class TrialHostingQuota(HostingQuota):
|
class TrialHostingQuota(HostingQuota):
|
||||||
@ -47,10 +48,9 @@ class HostingConfiguration:
|
|||||||
provider_map: dict[str, HostingProvider] = {}
|
provider_map: dict[str, HostingProvider] = {}
|
||||||
moderation_config: HostedModerationConfig = None
|
moderation_config: HostedModerationConfig = None
|
||||||
|
|
||||||
def init_app(self, app: Flask):
|
def init_app(self, app: Flask) -> None:
|
||||||
if app.config.get('EDITION') != 'CLOUD':
|
|
||||||
return
|
|
||||||
|
|
||||||
|
self.provider_map["azure_openai"] = self.init_azure_openai()
|
||||||
self.provider_map["openai"] = self.init_openai()
|
self.provider_map["openai"] = self.init_openai()
|
||||||
self.provider_map["anthropic"] = self.init_anthropic()
|
self.provider_map["anthropic"] = self.init_anthropic()
|
||||||
self.provider_map["minimax"] = self.init_minimax()
|
self.provider_map["minimax"] = self.init_minimax()
|
||||||
@ -59,6 +59,47 @@ class HostingConfiguration:
|
|||||||
|
|
||||||
self.moderation_config = self.init_moderation_config()
|
self.moderation_config = self.init_moderation_config()
|
||||||
|
|
||||||
|
def init_azure_openai(self) -> HostingProvider:
|
||||||
|
quota_unit = QuotaUnit.TIMES
|
||||||
|
if os.environ.get("HOSTED_AZURE_OPENAI_ENABLED") and os.environ.get("HOSTED_AZURE_OPENAI_ENABLED").lower() == 'true':
|
||||||
|
credentials = {
|
||||||
|
"openai_api_key": os.environ.get("HOSTED_AZURE_OPENAI_API_KEY"),
|
||||||
|
"openai_api_base": os.environ.get("HOSTED_AZURE_OPENAI_API_BASE"),
|
||||||
|
"base_model_name": "gpt-35-turbo"
|
||||||
|
}
|
||||||
|
|
||||||
|
quotas = []
|
||||||
|
hosted_quota_limit = int(os.environ.get("HOSTED_AZURE_OPENAI_QUOTA_LIMIT", "1000"))
|
||||||
|
if hosted_quota_limit != -1 or hosted_quota_limit > 0:
|
||||||
|
trial_quota = TrialHostingQuota(
|
||||||
|
quota_limit=hosted_quota_limit,
|
||||||
|
restrict_models=[
|
||||||
|
RestrictModel(model="gpt-4", base_model_name="gpt-4", model_type=ModelType.LLM),
|
||||||
|
RestrictModel(model="gpt-4-32k", base_model_name="gpt-4-32k", model_type=ModelType.LLM),
|
||||||
|
RestrictModel(model="gpt-4-1106-preview", base_model_name="gpt-4-1106-preview", model_type=ModelType.LLM),
|
||||||
|
RestrictModel(model="gpt-4-vision-preview", base_model_name="gpt-4-vision-preview", model_type=ModelType.LLM),
|
||||||
|
RestrictModel(model="gpt-35-turbo", base_model_name="gpt-35-turbo", model_type=ModelType.LLM),
|
||||||
|
RestrictModel(model="gpt-35-turbo-1106", base_model_name="gpt-35-turbo-1106", model_type=ModelType.LLM),
|
||||||
|
RestrictModel(model="gpt-35-turbo-instruct", base_model_name="gpt-35-turbo-instruct", model_type=ModelType.LLM),
|
||||||
|
RestrictModel(model="gpt-35-turbo-16k", base_model_name="gpt-35-turbo-16k", model_type=ModelType.LLM),
|
||||||
|
RestrictModel(model="text-davinci-003", base_model_name="text-davinci-003", model_type=ModelType.LLM),
|
||||||
|
RestrictModel(model="text-embedding-ada-002", base_model_name="text-embedding-ada-002", model_type=ModelType.TEXT_EMBEDDING),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
quotas.append(trial_quota)
|
||||||
|
|
||||||
|
return HostingProvider(
|
||||||
|
enabled=True,
|
||||||
|
credentials=credentials,
|
||||||
|
quota_unit=quota_unit,
|
||||||
|
quotas=quotas
|
||||||
|
)
|
||||||
|
|
||||||
|
return HostingProvider(
|
||||||
|
enabled=False,
|
||||||
|
quota_unit=quota_unit,
|
||||||
|
)
|
||||||
|
|
||||||
def init_openai(self) -> HostingProvider:
|
def init_openai(self) -> HostingProvider:
|
||||||
quota_unit = QuotaUnit.TIMES
|
quota_unit = QuotaUnit.TIMES
|
||||||
if os.environ.get("HOSTED_OPENAI_ENABLED") and os.environ.get("HOSTED_OPENAI_ENABLED").lower() == 'true':
|
if os.environ.get("HOSTED_OPENAI_ENABLED") and os.environ.get("HOSTED_OPENAI_ENABLED").lower() == 'true':
|
||||||
@ -77,12 +118,12 @@ class HostingConfiguration:
|
|||||||
if hosted_quota_limit != -1 or hosted_quota_limit > 0:
|
if hosted_quota_limit != -1 or hosted_quota_limit > 0:
|
||||||
trial_quota = TrialHostingQuota(
|
trial_quota = TrialHostingQuota(
|
||||||
quota_limit=hosted_quota_limit,
|
quota_limit=hosted_quota_limit,
|
||||||
restrict_llms=[
|
restrict_models=[
|
||||||
"gpt-3.5-turbo",
|
RestrictModel(model="gpt-3.5-turbo", model_type=ModelType.LLM),
|
||||||
"gpt-3.5-turbo-1106",
|
RestrictModel(model="gpt-3.5-turbo-1106", model_type=ModelType.LLM),
|
||||||
"gpt-3.5-turbo-instruct",
|
RestrictModel(model="gpt-3.5-turbo-instruct", model_type=ModelType.LLM),
|
||||||
"gpt-3.5-turbo-16k",
|
RestrictModel(model="gpt-3.5-turbo-16k", model_type=ModelType.LLM),
|
||||||
"text-davinci-003"
|
RestrictModel(model="text-davinci-003", model_type=ModelType.LLM),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
quotas.append(trial_quota)
|
quotas.append(trial_quota)
|
||||||
|
@ -144,7 +144,7 @@ class ModelInstance:
|
|||||||
user=user
|
user=user
|
||||||
)
|
)
|
||||||
|
|
||||||
def invoke_speech2text(self, file: IO[bytes], user: Optional[str] = None) \
|
def invoke_speech2text(self, file: IO[bytes], user: Optional[str] = None, **params) \
|
||||||
-> str:
|
-> str:
|
||||||
"""
|
"""
|
||||||
Invoke large language model
|
Invoke large language model
|
||||||
@ -161,7 +161,8 @@ class ModelInstance:
|
|||||||
model=self.model,
|
model=self.model,
|
||||||
credentials=self.credentials,
|
credentials=self.credentials,
|
||||||
file=file,
|
file=file,
|
||||||
user=user
|
user=user,
|
||||||
|
**params
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -32,7 +32,7 @@ class ModelType(Enum):
|
|||||||
return cls.TEXT_EMBEDDING
|
return cls.TEXT_EMBEDDING
|
||||||
elif origin_model_type == 'reranking' or origin_model_type == cls.RERANK.value:
|
elif origin_model_type == 'reranking' or origin_model_type == cls.RERANK.value:
|
||||||
return cls.RERANK
|
return cls.RERANK
|
||||||
elif origin_model_type == cls.SPEECH2TEXT.value:
|
elif origin_model_type == 'speech2text' or origin_model_type == cls.SPEECH2TEXT.value:
|
||||||
return cls.SPEECH2TEXT
|
return cls.SPEECH2TEXT
|
||||||
elif origin_model_type == cls.MODERATION.value:
|
elif origin_model_type == cls.MODERATION.value:
|
||||||
return cls.MODERATION
|
return cls.MODERATION
|
||||||
|
@ -2,7 +2,7 @@ from pydantic import BaseModel
|
|||||||
|
|
||||||
from core.model_runtime.entities.llm_entities import LLMMode
|
from core.model_runtime.entities.llm_entities import LLMMode
|
||||||
from core.model_runtime.entities.model_entities import ModelFeature, ModelType, FetchFrom, ParameterRule, \
|
from core.model_runtime.entities.model_entities import ModelFeature, ModelType, FetchFrom, ParameterRule, \
|
||||||
DefaultParameterName, PriceConfig
|
DefaultParameterName, PriceConfig, ModelPropertyKey
|
||||||
from core.model_runtime.entities.model_entities import AIModelEntity, I18nObject
|
from core.model_runtime.entities.model_entities import AIModelEntity, I18nObject
|
||||||
from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE
|
from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE
|
||||||
|
|
||||||
@ -502,8 +502,8 @@ EMBEDDING_BASE_MODELS = [
|
|||||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||||
model_type=ModelType.TEXT_EMBEDDING,
|
model_type=ModelType.TEXT_EMBEDDING,
|
||||||
model_properties={
|
model_properties={
|
||||||
'context_size': 8097,
|
ModelPropertyKey.CONTEXT_SIZE: 8097,
|
||||||
'max_chunks': 32,
|
ModelPropertyKey.MAX_CHUNKS: 32,
|
||||||
},
|
},
|
||||||
pricing=PriceConfig(
|
pricing=PriceConfig(
|
||||||
input=0.0001,
|
input=0.0001,
|
||||||
|
@ -30,7 +30,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|||||||
stream: bool = True, user: Optional[str] = None) \
|
stream: bool = True, user: Optional[str] = None) \
|
||||||
-> Union[LLMResult, Generator]:
|
-> Union[LLMResult, Generator]:
|
||||||
|
|
||||||
ai_model_entity = self._get_ai_model_entity(credentials['base_model_name'], model)
|
ai_model_entity = self._get_ai_model_entity(credentials.get('base_model_name'), model)
|
||||||
|
|
||||||
if ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value:
|
if ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value:
|
||||||
# chat model
|
# chat model
|
||||||
@ -59,7 +59,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|||||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||||
|
|
||||||
model_mode = self._get_ai_model_entity(credentials['base_model_name'], model).entity.model_properties.get(
|
model_mode = self._get_ai_model_entity(credentials.get('base_model_name'), model).entity.model_properties.get(
|
||||||
ModelPropertyKey.MODE)
|
ModelPropertyKey.MODE)
|
||||||
|
|
||||||
if model_mode == LLMMode.CHAT.value:
|
if model_mode == LLMMode.CHAT.value:
|
||||||
@ -79,7 +79,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|||||||
if 'base_model_name' not in credentials:
|
if 'base_model_name' not in credentials:
|
||||||
raise CredentialsValidateFailedError('Base Model Name is required')
|
raise CredentialsValidateFailedError('Base Model Name is required')
|
||||||
|
|
||||||
ai_model_entity = self._get_ai_model_entity(credentials['base_model_name'], model)
|
ai_model_entity = self._get_ai_model_entity(credentials.get('base_model_name'), model)
|
||||||
|
|
||||||
if not ai_model_entity:
|
if not ai_model_entity:
|
||||||
raise CredentialsValidateFailedError(f'Base Model Name {credentials["base_model_name"]} is invalid')
|
raise CredentialsValidateFailedError(f'Base Model Name {credentials["base_model_name"]} is invalid')
|
||||||
@ -109,8 +109,8 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|||||||
raise CredentialsValidateFailedError(str(ex))
|
raise CredentialsValidateFailedError(str(ex))
|
||||||
|
|
||||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||||
ai_model_entity = self._get_ai_model_entity(credentials['base_model_name'], model)
|
ai_model_entity = self._get_ai_model_entity(credentials.get('base_model_name'), model)
|
||||||
return ai_model_entity.entity
|
return ai_model_entity.entity if ai_model_entity else None
|
||||||
|
|
||||||
def _generate(self, model: str, credentials: dict,
|
def _generate(self, model: str, credentials: dict,
|
||||||
prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[List[str]] = None,
|
prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[List[str]] = None,
|
||||||
|
@ -12,7 +12,8 @@ from core.entities.provider_entities import CustomConfiguration, CustomProviderC
|
|||||||
from core.helper import encrypter
|
from core.helper import encrypter
|
||||||
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
|
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
|
||||||
from core.model_runtime.entities.model_entities import ModelType
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
from core.model_runtime.entities.provider_entities import ProviderEntity, CredentialFormSchema, FormType
|
from core.model_runtime.entities.provider_entities import ProviderEntity, CredentialFormSchema, FormType, \
|
||||||
|
ConfigurateMethod
|
||||||
from core.model_runtime.model_providers import model_provider_factory
|
from core.model_runtime.model_providers import model_provider_factory
|
||||||
from extensions import ext_hosting_provider
|
from extensions import ext_hosting_provider
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
@ -607,7 +608,7 @@ class ProviderManager:
|
|||||||
quota_used=provider_record.quota_used,
|
quota_used=provider_record.quota_used,
|
||||||
quota_limit=provider_record.quota_limit,
|
quota_limit=provider_record.quota_limit,
|
||||||
is_valid=provider_record.quota_limit > provider_record.quota_used or provider_record.quota_limit == -1,
|
is_valid=provider_record.quota_limit > provider_record.quota_used or provider_record.quota_limit == -1,
|
||||||
restrict_llms=provider_quota.restrict_llms
|
restrict_models=provider_quota.restrict_models
|
||||||
)
|
)
|
||||||
|
|
||||||
quota_configurations.append(quota_configuration)
|
quota_configurations.append(quota_configuration)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user