diff --git a/api/.env.example b/api/.env.example index 8a451dec17..012c8a5c65 100644 --- a/api/.env.example +++ b/api/.env.example @@ -124,5 +124,11 @@ HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA=1000000 HOSTED_ANTHROPIC_PAID_MIN_QUANTITY=20 HOSTED_ANTHROPIC_PAID_MAX_QUANTITY=100 +# Stripe configuration STRIPE_API_KEY= -STRIPE_WEBHOOK_SECRET= \ No newline at end of file +STRIPE_WEBHOOK_SECRET= + +# Billing configuration +BILLING_API_URL=http://127.0.0.1:8000/v1 +BILLING_API_SECRET_KEY= +STRIPE_WEBHOOK_BILLING_SECRET= \ No newline at end of file diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index ac881dc126..99d677970c 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -28,3 +28,5 @@ from .universal_chat import chat, conversation, message, parameter, audio # Import webhook controllers from .webhook import stripe + +from .billing import billing diff --git a/api/controllers/console/billing/__init__.py b/api/controllers/console/billing/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/controllers/console/billing/billing.py b/api/controllers/console/billing/billing.py new file mode 100644 index 0000000000..6bad91f411 --- /dev/null +++ b/api/controllers/console/billing/billing.py @@ -0,0 +1,85 @@ +import stripe +import os + +from flask_restful import Resource, reqparse +from flask_login import current_user +from flask import current_app, request + +from controllers.console import api +from controllers.console.setup import setup_required +from controllers.console.wraps import account_initialization_required +from controllers.console.wraps import only_edition_cloud +from libs.login import login_required +from services.billing_service import BillingService + + +class BillingInfo(Resource): + + @setup_required + @login_required + @account_initialization_required + def get(self): + + edition = current_app.config['EDITION'] + if edition != 'CLOUD': + return {"enabled": False} + + return BillingService.get_info(current_user.current_tenant_id) + + +class Subscription(Resource): + + @setup_required + @login_required + @account_initialization_required + @only_edition_cloud + def get(self): + + parser = reqparse.RequestParser() + parser.add_argument('plan', type=str, required=True, location='args', choices=['professional', 'team']) + parser.add_argument('interval', type=str, required=True, location='args', choices=['month', 'year']) + args = parser.parse_args() + + return BillingService.get_subscription(args['plan'], args['interval'], current_user.email, current_user.name, current_user.current_tenant_id) + + +class Invoices(Resource): + + @setup_required + @login_required + @account_initialization_required + @only_edition_cloud + def get(self): + + return BillingService.get_invoices(current_user.email) + + +class StripeBillingWebhook(Resource): + + @setup_required + @only_edition_cloud + def post(self): + payload = request.data + sig_header = request.headers.get('STRIPE_SIGNATURE') + webhook_secret = os.environ.get('STRIPE_WEBHOOK_BILLING_SECRET', 'STRIPE_WEBHOOK_BILLING_SECRET') + + try: + event = stripe.Webhook.construct_event( + payload, sig_header, webhook_secret + ) + except ValueError as e: + # Invalid payload + return 'Invalid payload', 400 + except stripe.error.SignatureVerificationError as e: + # Invalid signature + return 'Invalid signature', 400 + + BillingService.process_event(event) + + return 'success', 200 + + +api.add_resource(BillingInfo, '/billing/info') +api.add_resource(Subscription, '/billing/subscription') +api.add_resource(Invoices, '/billing/invoices') +api.add_resource(StripeBillingWebhook, '/billing/webhook/stripe') diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 9a417b3660..b4323f18ec 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -493,3 +493,4 @@ api.add_resource(DatasetApiDeleteApi, '/datasets/api-keys/') api.add_resource(DatasetApiBaseUrlApi, '/datasets/api-base-info') api.add_resource(DatasetRetrievalSettingApi, '/datasets/retrieval-setting') api.add_resource(DatasetRetrievalSettingMockApi, '/datasets/retrieval-setting/') + diff --git a/api/services/billing_service.py b/api/services/billing_service.py new file mode 100644 index 0000000000..10d991ad18 --- /dev/null +++ b/api/services/billing_service.py @@ -0,0 +1,55 @@ +import os +import requests + +from services.dataset_service import DatasetService + + +class BillingService: + base_url = os.environ.get('BILLING_API_URL', 'BILLING_API_URL') + secret_key = os.environ.get('BILLING_API_SECRET_KEY', 'BILLING_API_SECRET_KEY') + + @classmethod + def get_info(cls, tenant_id: str): + params = {'tenant_id': tenant_id} + + billing_info = cls._send_request('GET', '/info', params=params) + + vector_size = DatasetService.get_tenant_datasets_usage(tenant_id) / 1024 + billing_info['vector_space']['size'] = int(vector_size) + + return billing_info + + @classmethod + def get_subscription(cls, plan: str, interval: str, prefilled_email: str = '', user_name: str = '', tenant_id: str = ''): + params = { + 'plan': plan, + 'interval': interval, + 'prefilled_email': prefilled_email, + 'user_name': user_name, + 'tenant_id': tenant_id + } + return cls._send_request('GET', '/subscription', params=params) + + @classmethod + def get_invoices(cls, prefilled_email: str = ''): + params = {'prefilled_email': prefilled_email} + return cls._send_request('GET', '/invoices', params=params) + + @classmethod + def _send_request(cls, method, endpoint, json=None, params=None): + headers = { + "Content-Type": "application/json", + "Billing-Api-Secret-Key": cls.secret_key + } + + url = f"{cls.base_url}{endpoint}" + response = requests.request(method, url, json=json, params=params, headers=headers) + + return response.json() + + @classmethod + def process_event(cls, event: dict): + json = { + "content": event, + } + return cls._send_request('POST', '/webhook/stripe', json=json) diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 4093a8f72f..47ecca02c7 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -227,6 +227,36 @@ class DatasetService: return AppDatasetJoin.query.filter(AppDatasetJoin.dataset_id == dataset_id) \ .order_by(db.desc(AppDatasetJoin.created_at)).all() + @staticmethod + def get_tenant_datasets_usage(tenant_id): + # get the high_quality datasets + dataset_ids = db.session.query(Dataset.id).filter(Dataset.indexing_technique == 'high_quality', + Dataset.tenant_id == tenant_id).all() + if not dataset_ids: + return 0 + dataset_ids = [result[0] for result in dataset_ids] + document_ids = db.session.query(Document.id).filter(Document.dataset_id.in_(dataset_ids), + Document.tenant_id == tenant_id, + Document.completed_at.isnot(None), + Document.enabled == True, + Document.archived == False + ).all() + if not document_ids: + return 0 + document_ids = [result[0] for result in document_ids] + document_segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id.in_(document_ids), + DocumentSegment.tenant_id == tenant_id, + DocumentSegment.completed_at.isnot(None), + DocumentSegment.enabled == True, + ).all() + if not document_segments: + return 0 + + total_words_size = sum(document_segment.word_count * 3 for document_segment in document_segments) + total_vector_size = 1536 * 4 * len(document_segments) + + return total_words_size + total_vector_size + class DocumentService: DEFAULT_RULES = { @@ -488,7 +518,8 @@ class DocumentService: 'score_threshold_enabled': False } - dataset.retrieval_model = document_data.get('retrieval_model') if document_data.get('retrieval_model') else default_retrieval_model + dataset.retrieval_model = document_data.get('retrieval_model') if document_data.get( + 'retrieval_model') else default_retrieval_model documents = [] batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999))