mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-06-04 11:14:10 +08:00
303 lines
8.6 KiB
Python
303 lines
8.6 KiB
Python
from abc import ABC, abstractmethod
|
|
from datetime import datetime
|
|
from typing import Type, Optional
|
|
|
|
from flask import current_app
|
|
from pydantic import BaseModel
|
|
|
|
from core.model_providers.error import QuotaExceededError, LLMBadRequestError
|
|
from extensions.ext_database import db
|
|
from core.model_providers.models.entity.model_params import ModelType, ModelKwargsRules
|
|
from core.model_providers.models.entity.provider import ProviderQuotaUnit
|
|
from core.model_providers.rules import provider_rules
|
|
from models.provider import Provider, ProviderType, ProviderModel
|
|
|
|
|
|
class BaseModelProvider(BaseModel, ABC):
|
|
|
|
provider: Provider
|
|
|
|
class Config:
|
|
"""Configuration for this pydantic object."""
|
|
|
|
arbitrary_types_allowed = True
|
|
|
|
@property
|
|
@abstractmethod
|
|
def provider_name(self):
|
|
"""
|
|
Returns the name of a provider.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def get_rules(self):
|
|
"""
|
|
Returns the rules of a provider.
|
|
"""
|
|
return provider_rules[self.provider_name]
|
|
|
|
def get_supported_model_list(self, model_type: ModelType) -> list[dict]:
|
|
"""
|
|
get supported model object list for use.
|
|
|
|
:param model_type:
|
|
:return:
|
|
"""
|
|
rules = self.get_rules()
|
|
if 'custom' not in rules['support_provider_types']:
|
|
return self._get_fixed_model_list(model_type)
|
|
|
|
if 'model_flexibility' not in rules:
|
|
return self._get_fixed_model_list(model_type)
|
|
|
|
if rules['model_flexibility'] == 'fixed':
|
|
return self._get_fixed_model_list(model_type)
|
|
|
|
# get configurable provider models
|
|
provider_models = db.session.query(ProviderModel).filter(
|
|
ProviderModel.tenant_id == self.provider.tenant_id,
|
|
ProviderModel.provider_name == self.provider.provider_name,
|
|
ProviderModel.model_type == model_type.value,
|
|
ProviderModel.is_valid == True
|
|
).order_by(ProviderModel.created_at.asc()).all()
|
|
|
|
provider_model_list = []
|
|
for provider_model in provider_models:
|
|
provider_model_dict = {
|
|
'id': provider_model.model_name,
|
|
'name': provider_model.model_name
|
|
}
|
|
|
|
if model_type == ModelType.TEXT_GENERATION:
|
|
provider_model_dict['mode'] = self._get_text_generation_model_mode(provider_model.model_name)
|
|
|
|
provider_model_list.append(provider_model_dict)
|
|
|
|
return provider_model_list
|
|
|
|
@abstractmethod
|
|
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
|
|
"""
|
|
get supported model object list for use.
|
|
|
|
:param model_type:
|
|
:return:
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def _get_text_generation_model_mode(self, model_name) -> str:
|
|
"""
|
|
get text generation model mode.
|
|
|
|
:param model_name:
|
|
:return:
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def get_model_class(self, model_type: ModelType) -> Type:
|
|
"""
|
|
get specific model class.
|
|
|
|
:param model_type:
|
|
:return:
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@classmethod
|
|
@abstractmethod
|
|
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
|
|
"""
|
|
check provider credentials valid.
|
|
|
|
:param credentials:
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@classmethod
|
|
@abstractmethod
|
|
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
|
|
"""
|
|
encrypt provider credentials for save.
|
|
|
|
:param tenant_id:
|
|
:param credentials:
|
|
:return:
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
|
|
"""
|
|
get credentials for llm use.
|
|
|
|
:param obfuscated:
|
|
:return:
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@classmethod
|
|
@abstractmethod
|
|
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
|
|
"""
|
|
check model credentials valid.
|
|
|
|
:param model_name:
|
|
:param model_type:
|
|
:param credentials:
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@classmethod
|
|
@abstractmethod
|
|
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
|
|
credentials: dict) -> dict:
|
|
"""
|
|
encrypt model credentials for save.
|
|
|
|
:param tenant_id:
|
|
:param model_name:
|
|
:param model_type:
|
|
:param credentials:
|
|
:return:
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
|
|
"""
|
|
get model parameter rules.
|
|
|
|
:param model_name:
|
|
:param model_type:
|
|
:return:
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
|
|
"""
|
|
get credentials for llm use.
|
|
|
|
:param model_name:
|
|
:param model_type:
|
|
:param obfuscated:
|
|
:return:
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@classmethod
|
|
def is_provider_type_system_supported(cls) -> bool:
|
|
return current_app.config['EDITION'] == 'CLOUD'
|
|
|
|
def check_quota_over_limit(self):
|
|
"""
|
|
check provider quota over limit.
|
|
|
|
:return:
|
|
"""
|
|
if self.provider.provider_type != ProviderType.SYSTEM.value:
|
|
return
|
|
|
|
rules = self.get_rules()
|
|
if 'system' not in rules['support_provider_types']:
|
|
return
|
|
|
|
provider = db.session.query(Provider).filter(
|
|
db.and_(
|
|
Provider.id == self.provider.id,
|
|
Provider.is_valid == True,
|
|
Provider.quota_limit > Provider.quota_used
|
|
)
|
|
).first()
|
|
|
|
if not provider:
|
|
raise QuotaExceededError()
|
|
|
|
def deduct_quota(self, used_tokens: int = 0) -> None:
|
|
"""
|
|
deduct available quota when provider type is system or paid.
|
|
|
|
:return:
|
|
"""
|
|
if self.provider.provider_type != ProviderType.SYSTEM.value:
|
|
return
|
|
|
|
rules = self.get_rules()
|
|
if 'system' not in rules['support_provider_types']:
|
|
return
|
|
|
|
if not self.should_deduct_quota():
|
|
return
|
|
|
|
if 'system_config' not in rules:
|
|
quota_unit = ProviderQuotaUnit.TIMES.value
|
|
elif 'quota_unit' not in rules['system_config']:
|
|
quota_unit = ProviderQuotaUnit.TIMES.value
|
|
else:
|
|
quota_unit = rules['system_config']['quota_unit']
|
|
|
|
if quota_unit == ProviderQuotaUnit.TOKENS.value:
|
|
used_quota = used_tokens
|
|
else:
|
|
used_quota = 1
|
|
|
|
db.session.query(Provider).filter(
|
|
Provider.tenant_id == self.provider.tenant_id,
|
|
Provider.provider_name == self.provider.provider_name,
|
|
Provider.provider_type == self.provider.provider_type,
|
|
Provider.quota_type == self.provider.quota_type,
|
|
Provider.quota_limit > Provider.quota_used
|
|
).update({'quota_used': Provider.quota_used + used_quota})
|
|
db.session.commit()
|
|
|
|
def should_deduct_quota(self):
|
|
return False
|
|
|
|
def update_last_used(self) -> None:
|
|
"""
|
|
update last used time.
|
|
|
|
:return:
|
|
"""
|
|
db.session.query(Provider).filter(
|
|
Provider.tenant_id == self.provider.tenant_id,
|
|
Provider.provider_name == self.provider.provider_name
|
|
).update({'last_used': datetime.utcnow()})
|
|
db.session.commit()
|
|
|
|
def get_payment_info(self) -> Optional[dict]:
|
|
"""
|
|
get product info if it payable.
|
|
|
|
:return:
|
|
"""
|
|
return None
|
|
|
|
def _get_provider_model(self, model_name: str, model_type: ModelType) -> ProviderModel:
|
|
"""
|
|
get provider model.
|
|
|
|
:param model_name:
|
|
:param model_type:
|
|
:return:
|
|
"""
|
|
provider_model = db.session.query(ProviderModel).filter(
|
|
ProviderModel.tenant_id == self.provider.tenant_id,
|
|
ProviderModel.provider_name == self.provider.provider_name,
|
|
ProviderModel.model_name == model_name,
|
|
ProviderModel.model_type == model_type.value,
|
|
ProviderModel.is_valid == True
|
|
).first()
|
|
|
|
if not provider_model:
|
|
raise LLMBadRequestError(f"The model {model_name} does not exist. "
|
|
f"Please check the configuration.")
|
|
|
|
return provider_model
|
|
|
|
|
|
class CredentialsValidateFailedError(Exception):
|
|
pass
|