diff --git a/api/config.py b/api/config.py index 6cabf814bf..2f0ec1cdb7 100644 --- a/api/config.py +++ b/api/config.py @@ -55,6 +55,8 @@ DEFAULTS = { 'OUTPUT_MODERATION_BUFFER_SIZE': 300, 'MULTIMODAL_SEND_IMAGE_FORMAT': 'base64', 'INVITE_EXPIRY_HOURS': 72, + 'BILLING_ENABLED': 'False', + 'CAN_REPLACE_LOGO': 'False', 'ETL_TYPE': 'dify', } @@ -279,6 +281,8 @@ class Config: self.ETL_TYPE = get_env('ETL_TYPE') self.UNSTRUCTURED_API_URL = get_env('UNSTRUCTURED_API_URL') + self.BILLING_ENABLED = get_bool_env('BILLING_ENABLED') + self.CAN_REPLACE_LOGO = get_bool_env('CAN_REPLACE_LOGO') class CloudEditionConfig(Config): diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index 21dcbd62be..6fa896c3e8 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -6,7 +6,7 @@ bp = Blueprint('console', __name__, url_prefix='/console/api') api = ExternalApi(bp) # Import other controllers -from . import extension, setup, version, apikey, admin +from . import extension, setup, version, apikey, admin, feature # Import app controllers from .app import advanced_prompt_template, app, site, completion, model_config, statistic, conversation, message, generator, audio, annotation diff --git a/api/controllers/console/billing/billing.py b/api/controllers/console/billing/billing.py index f3986bed9d..83741ce2ed 100644 --- a/api/controllers/console/billing/billing.py +++ b/api/controllers/console/billing/billing.py @@ -1,6 +1,5 @@ from flask_restful import Resource, reqparse from flask_login import current_user -from flask import current_app from controllers.console import api from controllers.console.setup import setup_required @@ -10,20 +9,6 @@ from libs.login import login_required from services.billing_service import BillingService -class BillingInfo(Resource): - - @setup_required - @login_required - @account_initialization_required - def get(self): - - edition = current_app.config['EDITION'] - if edition != 'CLOUD': - return {"enabled": False} - - return BillingService.get_info(current_user.current_tenant_id) - - class Subscription(Resource): @setup_required @@ -56,6 +41,5 @@ class Invoices(Resource): return BillingService.get_invoices(current_user.email) -api.add_resource(BillingInfo, '/billing/info') api.add_resource(Subscription, '/billing/subscription') api.add_resource(Invoices, '/billing/invoices') diff --git a/api/controllers/console/feature.py b/api/controllers/console/feature.py new file mode 100644 index 0000000000..0d1d61ad00 --- /dev/null +++ b/api/controllers/console/feature.py @@ -0,0 +1,14 @@ +from flask_restful import Resource +from flask_login import current_user + +from . import api +from services.feature_service import FeatureService + + +class FeatureApi(Resource): + + def get(self): + return FeatureService.get_features(current_user.current_tenant_id).dict() + + +api.add_resource(FeatureApi, '/features') diff --git a/api/controllers/console/wraps.py b/api/controllers/console/wraps.py index 19a5de69ed..c19ef14708 100644 --- a/api/controllers/console/wraps.py +++ b/api/controllers/console/wraps.py @@ -5,7 +5,7 @@ from flask import current_app, abort from flask_login import current_user from controllers.console.workspace.error import AccountNotInitializedError -from services.billing_service import BillingService +from services.feature_service import FeatureService def account_initialization_required(view): @@ -49,23 +49,23 @@ def cloud_edition_billing_resource_check(resource: str, def interceptor(view): @wraps(view) def decorated(*args, **kwargs): - if current_app.config['EDITION'] == 'CLOUD': - tenant_id = current_user.current_tenant_id - billing_info = BillingService.get_info(tenant_id) - members = billing_info['members'] - apps = billing_info['apps'] - vector_space = billing_info['vector_space'] - annotation_quota_limit = billing_info['annotation_quota_limit'] + features = FeatureService.get_features(current_user.current_tenant_id) - if resource == 'members' and 0 < members['limit'] <= members['size']: + if features.billing.enabled: + members = features.members + apps = features.apps + vector_space = features.vector_space + annotation_quota_limit = features.annotation_quota_limit + + if resource == 'members' and 0 < members.limit <= members.size: abort(403, error_msg) - elif resource == 'apps' and 0 < apps['limit'] <= apps['size']: + elif resource == 'apps' and 0 < apps.limit <= apps.size: abort(403, error_msg) - elif resource == 'vector_space' and 0 < vector_space['limit'] <= vector_space['size']: + elif resource == 'vector_space' and 0 < vector_space.limit <= vector_space.size: abort(403, error_msg) - elif resource == 'workspace_custom' and not billing_info['can_replace_logo']: + elif resource == 'workspace_custom' and not features.can_replace_logo: abort(403, error_msg) - elif resource == 'annotation' and 0 < annotation_quota_limit['limit'] < annotation_quota_limit['size']: + elif resource == 'annotation' and 0 < annotation_quota_limit.limit < annotation_quota_limit.size: abort(403, error_msg) else: return view(*args, **kwargs) diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py index 0b73297dd5..7320a6e614 100644 --- a/api/controllers/service_api/wraps.py +++ b/api/controllers/service_api/wraps.py @@ -11,8 +11,7 @@ from libs.login import _get_user from extensions.ext_database import db from models.account import Tenant, TenantAccountJoin, Account from models.model import ApiToken, App -from services.billing_service import BillingService - +from services.feature_service import FeatureService def validate_app_token(view=None): def decorator(view): @@ -46,19 +45,19 @@ def cloud_edition_billing_resource_check(resource: str, error_msg: str = "You have reached the limit of your subscription."): def interceptor(view): def decorated(*args, **kwargs): - if current_app.config['EDITION'] == 'CLOUD': - api_token = validate_and_get_api_token(api_token_type) - billing_info = BillingService.get_info(api_token.tenant_id) + api_token = validate_and_get_api_token(api_token_type) + features = FeatureService.get_features(api_token.tenant_id) - members = billing_info['members'] - apps = billing_info['apps'] - vector_space = billing_info['vector_space'] + if features.billing.enabled: + members = features.members + apps = features.apps + vector_space = features.vector_space - if resource == 'members' and 0 < members['limit'] <= members['size']: + if resource == 'members' and 0 < members.limit <= members.size: raise Unauthorized(error_msg) - elif resource == 'apps' and 0 < apps['limit'] <= apps['size']: + elif resource == 'apps' and 0 < apps.limit <= apps.size: raise Unauthorized(error_msg) - elif resource == 'vector_space' and 0 < vector_space['limit'] <= vector_space['size']: + elif resource == 'vector_space' and 0 < vector_space.limit <= vector_space.size: raise Unauthorized(error_msg) else: return view(*args, **kwargs) diff --git a/api/controllers/web/site.py b/api/controllers/web/site.py index e5c27c18c1..0f63e6087b 100644 --- a/api/controllers/web/site.py +++ b/api/controllers/web/site.py @@ -9,7 +9,7 @@ from controllers.web import api from controllers.web.wraps import WebApiResource from extensions.ext_database import db from models.model import Site -from services.billing_service import BillingService +from services.feature_service import FeatureService class AppSiteApi(WebApiResource): @@ -56,12 +56,7 @@ class AppSiteApi(WebApiResource): if not site: raise Forbidden() - edition = os.environ.get('EDITION') - can_replace_logo = False - - if edition == 'CLOUD': - info = BillingService.get_info(app_model.tenant_id) - can_replace_logo = info['can_replace_logo'] + can_replace_logo = FeatureService.get_features(app_model.tenant_id).can_replace_logo return AppSiteInfo(app_model.tenant, app_model, site, end_user.id, can_replace_logo) diff --git a/api/services/feature_service.py b/api/services/feature_service.py new file mode 100644 index 0000000000..bf7b378f38 --- /dev/null +++ b/api/services/feature_service.py @@ -0,0 +1,71 @@ +from pydantic import BaseModel +from flask import current_app + +from services.billing_service import BillingService + + +class SubscriptionModel(BaseModel): + plan: str = 'sandbox' + interval: str = '' + + +class BillingModel(BaseModel): + enabled: bool = False + subscription: SubscriptionModel = SubscriptionModel() + + +class LimitationModel(BaseModel): + size: int = 0 + limit: int = 0 + + +class FeatureModel(BaseModel): + billing: BillingModel = BillingModel() + members: LimitationModel = LimitationModel(size=0, limit=1) + apps: LimitationModel = LimitationModel(size=0, limit=10) + vector_space: LimitationModel = LimitationModel(size=0, limit=5) + annotation_quota_limit: LimitationModel = LimitationModel(size=0, limit=10) + docs_processing: str = 'standard' + can_replace_logo: bool = False + + +class FeatureService: + + @classmethod + def get_features(cls, tenant_id: str) -> FeatureModel: + features = FeatureModel() + + cls._fulfill_params_from_env(features) + + if current_app.config['BILLING_ENABLED']: + cls._fulfill_params_from_billing_api(features, tenant_id) + + return features + + @classmethod + def _fulfill_params_from_env(cls, features: FeatureModel): + features.can_replace_logo = current_app.config['CAN_REPLACE_LOGO'] + + @classmethod + def _fulfill_params_from_billing_api(cls, features: FeatureModel, tenant_id: str): + billing_info = BillingService.get_info(tenant_id) + + features.billing.enabled = billing_info['enabled'] + features.billing.subscription.plan = billing_info['subscription']['plan'] + features.billing.subscription.interval = billing_info['subscription']['interval'] + + features.members.size = billing_info['members']['size'] + features.members.limit = billing_info['members']['limit'] + + features.apps.size = billing_info['apps']['size'] + features.apps.limit = billing_info['apps']['limit'] + + features.vector_space.size = billing_info['vector_space']['size'] + features.vector_space.limit = billing_info['vector_space']['limit'] + + features.annotation_quota_limit.size = billing_info['annotation_quota_limit']['size'] + features.annotation_quota_limit.limit = billing_info['annotation_quota_limit']['limit'] + + features.docs_processing = billing_info['docs_processing'] + features.can_replace_logo = billing_info['can_replace_logo'] + diff --git a/api/services/workspace_service.py b/api/services/workspace_service.py index 6aa7198a04..73651faab1 100644 --- a/api/services/workspace_service.py +++ b/api/services/workspace_service.py @@ -4,7 +4,7 @@ from extensions.ext_database import db from models.account import Tenant, TenantAccountJoin, TenantAccountJoinRole from models.provider import Provider -from services.billing_service import BillingService +from services.feature_service import FeatureService from services.account_service import TenantService @@ -32,12 +32,10 @@ class WorkspaceService: ).first() tenant_info['role'] = tenant_account_join.role - edition = current_app.config['EDITION'] - if edition == 'CLOUD': - billing_info = BillingService.get_info(tenant_info['id']) + can_replace_logo = FeatureService.get_features(tenant_info['id']).can_replace_logo - if billing_info['can_replace_logo'] and TenantService.has_roles(tenant, [TenantAccountJoinRole.OWNER, TenantAccountJoinRole.ADMIN]): - tenant_info['custom_config'] = tenant.custom_config_dict + if can_replace_logo and TenantService.has_roles(tenant, [TenantAccountJoinRole.OWNER, TenantAccountJoinRole.ADMIN]): + tenant_info['custom_config'] = tenant.custom_config_dict # Get providers providers = db.session.query(Provider).filter( diff --git a/web/app/components/billing/type.ts b/web/app/components/billing/type.ts index 681ed1ee6f..1f25e909fc 100644 --- a/web/app/components/billing/type.ts +++ b/web/app/components/billing/type.ts @@ -32,9 +32,11 @@ export enum DocumentProcessingPriority { } export type CurrentPlanInfoBackend = { - enabled: boolean - subscription: { - plan: Plan + billing: { + enabled: boolean + subscription: { + plan: Plan + } } members: { size: number @@ -53,6 +55,7 @@ export type CurrentPlanInfoBackend = { limit: number // total. 0 means unlimited } docs_processing: DocumentProcessingPriority + can_replace_logo: boolean } export type SubscriptionItem = { diff --git a/web/app/components/billing/utils/index.ts b/web/app/components/billing/utils/index.ts index 405d8656a1..3e689e2544 100644 --- a/web/app/components/billing/utils/index.ts +++ b/web/app/components/billing/utils/index.ts @@ -10,7 +10,7 @@ const parseLimit = (limit: number) => { export const parseCurrentPlan = (data: CurrentPlanInfoBackend) => { return { - type: data.subscription.plan, + type: data.billing.subscription.plan, usage: { vectorSpace: data.vector_space.size, buildApps: data.apps?.size || 0, diff --git a/web/app/components/custom/custom-page/index.tsx b/web/app/components/custom/custom-page/index.tsx index 44b5d0486b..9128ba0dba 100644 --- a/web/app/components/custom/custom-page/index.tsx +++ b/web/app/components/custom/custom-page/index.tsx @@ -10,12 +10,16 @@ import { contactSalesUrl } from '@/app/components/billing/config' const CustomPage = () => { const { t } = useTranslation() - const { plan } = useProviderContext() + const { plan, enableBilling } = useProviderContext() + + const showBillingTip = enableBilling && plan.type === Plan.sandbox + const showCustomAppHeaderBrand = enableBilling && plan.type === Plan.sandbox + const showContact = enableBilling && (plan.type === Plan.professional || plan.type === Plan.team) return (