from flask_login import current_user from core.login.login import login_required from flask_restful import Resource, reqparse from controllers.console import api from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required from core.model_providers.model_provider_factory import ModelProviderFactory from core.model_providers.models.entity.model_params import ModelType from models.provider import ProviderType from services.provider_service import ProviderService class DefaultModelApi(Resource): @setup_required @login_required @account_initialization_required def get(self): parser = reqparse.RequestParser() parser.add_argument('model_type', type=str, required=True, nullable=False, choices=['text-generation', 'embeddings', 'speech2text'], location='args') args = parser.parse_args() tenant_id = current_user.current_tenant_id provider_service = ProviderService() default_model = provider_service.get_default_model_of_model_type( tenant_id=tenant_id, model_type=args['model_type'] ) if not default_model: return None model_provider = ModelProviderFactory.get_preferred_model_provider( tenant_id, default_model.provider_name ) if not model_provider: return { 'model_name': default_model.model_name, 'model_type': default_model.model_type, 'model_provider': { 'provider_name': default_model.provider_name } } provider = model_provider.provider rst = { 'model_name': default_model.model_name, 'model_type': default_model.model_type, 'model_provider': { 'provider_name': provider.provider_name, 'provider_type': provider.provider_type } } model_provider_rules = ModelProviderFactory.get_provider_rule(default_model.provider_name) if provider.provider_type == ProviderType.SYSTEM.value: rst['model_provider']['quota_type'] = provider.quota_type rst['model_provider']['quota_unit'] = model_provider_rules['system_config']['quota_unit'] rst['model_provider']['quota_limit'] = provider.quota_limit rst['model_provider']['quota_used'] = provider.quota_used return rst @setup_required @login_required @account_initialization_required def post(self): parser = reqparse.RequestParser() parser.add_argument('model_name', type=str, required=True, nullable=False, location='json') parser.add_argument('model_type', type=str, required=True, nullable=False, choices=['text-generation', 'embeddings', 'speech2text'], location='json') parser.add_argument('provider_name', type=str, required=True, nullable=False, location='json') args = parser.parse_args() provider_service = ProviderService() provider_service.update_default_model_of_model_type( tenant_id=current_user.current_tenant_id, model_type=args['model_type'], provider_name=args['provider_name'], model_name=args['model_name'] ) return {'result': 'success'} class ValidModelApi(Resource): @setup_required @login_required @account_initialization_required def get(self, model_type): ModelType.value_of(model_type) provider_service = ProviderService() valid_models = provider_service.get_valid_model_list( tenant_id=current_user.current_tenant_id, model_type=model_type ) return valid_models api.add_resource(DefaultModelApi, '/workspaces/current/default-model') api.add_resource(ValidModelApi, '/workspaces/current/models/model-type/')