Garfield Dai 42a5b3ec17
feat: advanced prompt backend (#1301)
Co-authored-by: takatost <takatost@gmail.com>
2023-10-12 10:13:10 -05:00

303 lines
8.6 KiB
Python

from abc import ABC, abstractmethod
from datetime import datetime
from typing import Type, Optional
from flask import current_app
from pydantic import BaseModel
from core.model_providers.error import QuotaExceededError, LLMBadRequestError
from extensions.ext_database import db
from core.model_providers.models.entity.model_params import ModelType, ModelKwargsRules
from core.model_providers.models.entity.provider import ProviderQuotaUnit
from core.model_providers.rules import provider_rules
from models.provider import Provider, ProviderType, ProviderModel
class BaseModelProvider(BaseModel, ABC):
provider: Provider
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
@property
@abstractmethod
def provider_name(self):
"""
Returns the name of a provider.
"""
raise NotImplementedError
def get_rules(self):
"""
Returns the rules of a provider.
"""
return provider_rules[self.provider_name]
def get_supported_model_list(self, model_type: ModelType) -> list[dict]:
"""
get supported model object list for use.
:param model_type:
:return:
"""
rules = self.get_rules()
if 'custom' not in rules['support_provider_types']:
return self._get_fixed_model_list(model_type)
if 'model_flexibility' not in rules:
return self._get_fixed_model_list(model_type)
if rules['model_flexibility'] == 'fixed':
return self._get_fixed_model_list(model_type)
# get configurable provider models
provider_models = db.session.query(ProviderModel).filter(
ProviderModel.tenant_id == self.provider.tenant_id,
ProviderModel.provider_name == self.provider.provider_name,
ProviderModel.model_type == model_type.value,
ProviderModel.is_valid == True
).order_by(ProviderModel.created_at.asc()).all()
provider_model_list = []
for provider_model in provider_models:
provider_model_dict = {
'id': provider_model.model_name,
'name': provider_model.model_name
}
if model_type == ModelType.TEXT_GENERATION:
provider_model_dict['mode'] = self._get_text_generation_model_mode(provider_model.model_name)
provider_model_list.append(provider_model_dict)
return provider_model_list
@abstractmethod
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
"""
get supported model object list for use.
:param model_type:
:return:
"""
raise NotImplementedError
@abstractmethod
def _get_text_generation_model_mode(self, model_name) -> str:
"""
get text generation model mode.
:param model_name:
:return:
"""
raise NotImplementedError
@abstractmethod
def get_model_class(self, model_type: ModelType) -> Type:
"""
get specific model class.
:param model_type:
:return:
"""
raise NotImplementedError
@classmethod
@abstractmethod
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
"""
check provider credentials valid.
:param credentials:
"""
raise NotImplementedError
@classmethod
@abstractmethod
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
"""
encrypt provider credentials for save.
:param tenant_id:
:param credentials:
:return:
"""
raise NotImplementedError
@abstractmethod
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
"""
get credentials for llm use.
:param obfuscated:
:return:
"""
raise NotImplementedError
@classmethod
@abstractmethod
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
"""
check model credentials valid.
:param model_name:
:param model_type:
:param credentials:
"""
raise NotImplementedError
@classmethod
@abstractmethod
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
credentials: dict) -> dict:
"""
encrypt model credentials for save.
:param tenant_id:
:param model_name:
:param model_type:
:param credentials:
:return:
"""
raise NotImplementedError
@abstractmethod
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
"""
get model parameter rules.
:param model_name:
:param model_type:
:return:
"""
raise NotImplementedError
@abstractmethod
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
"""
get credentials for llm use.
:param model_name:
:param model_type:
:param obfuscated:
:return:
"""
raise NotImplementedError
@classmethod
def is_provider_type_system_supported(cls) -> bool:
return current_app.config['EDITION'] == 'CLOUD'
def check_quota_over_limit(self):
"""
check provider quota over limit.
:return:
"""
if self.provider.provider_type != ProviderType.SYSTEM.value:
return
rules = self.get_rules()
if 'system' not in rules['support_provider_types']:
return
provider = db.session.query(Provider).filter(
db.and_(
Provider.id == self.provider.id,
Provider.is_valid == True,
Provider.quota_limit > Provider.quota_used
)
).first()
if not provider:
raise QuotaExceededError()
def deduct_quota(self, used_tokens: int = 0) -> None:
"""
deduct available quota when provider type is system or paid.
:return:
"""
if self.provider.provider_type != ProviderType.SYSTEM.value:
return
rules = self.get_rules()
if 'system' not in rules['support_provider_types']:
return
if not self.should_deduct_quota():
return
if 'system_config' not in rules:
quota_unit = ProviderQuotaUnit.TIMES.value
elif 'quota_unit' not in rules['system_config']:
quota_unit = ProviderQuotaUnit.TIMES.value
else:
quota_unit = rules['system_config']['quota_unit']
if quota_unit == ProviderQuotaUnit.TOKENS.value:
used_quota = used_tokens
else:
used_quota = 1
db.session.query(Provider).filter(
Provider.tenant_id == self.provider.tenant_id,
Provider.provider_name == self.provider.provider_name,
Provider.provider_type == self.provider.provider_type,
Provider.quota_type == self.provider.quota_type,
Provider.quota_limit > Provider.quota_used
).update({'quota_used': Provider.quota_used + used_quota})
db.session.commit()
def should_deduct_quota(self):
return False
def update_last_used(self) -> None:
"""
update last used time.
:return:
"""
db.session.query(Provider).filter(
Provider.tenant_id == self.provider.tenant_id,
Provider.provider_name == self.provider.provider_name
).update({'last_used': datetime.utcnow()})
db.session.commit()
def get_payment_info(self) -> Optional[dict]:
"""
get product info if it payable.
:return:
"""
return None
def _get_provider_model(self, model_name: str, model_type: ModelType) -> ProviderModel:
"""
get provider model.
:param model_name:
:param model_type:
:return:
"""
provider_model = db.session.query(ProviderModel).filter(
ProviderModel.tenant_id == self.provider.tenant_id,
ProviderModel.provider_name == self.provider.provider_name,
ProviderModel.model_name == model_name,
ProviderModel.model_type == model_type.value,
ProviderModel.is_valid == True
).first()
if not provider_model:
raise LLMBadRequestError(f"The model {model_name} does not exist. "
f"Please check the configuration.")
return provider_model
class CredentialsValidateFailedError(Exception):
pass