mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-14 22:55:54 +08:00
remove stripe and anthropic. (#1746)
This commit is contained in:
parent
4c639961f5
commit
6b499b9a16
@ -106,8 +106,6 @@ HOSTED_OPENAI_API_BASE=
|
|||||||
HOSTED_OPENAI_API_ORGANIZATION=
|
HOSTED_OPENAI_API_ORGANIZATION=
|
||||||
HOSTED_OPENAI_QUOTA_LIMIT=200
|
HOSTED_OPENAI_QUOTA_LIMIT=200
|
||||||
HOSTED_OPENAI_PAID_ENABLED=false
|
HOSTED_OPENAI_PAID_ENABLED=false
|
||||||
HOSTED_OPENAI_PAID_STRIPE_PRICE_ID=
|
|
||||||
HOSTED_OPENAI_PAID_INCREASE_QUOTA=1
|
|
||||||
|
|
||||||
HOSTED_AZURE_OPENAI_ENABLED=false
|
HOSTED_AZURE_OPENAI_ENABLED=false
|
||||||
HOSTED_AZURE_OPENAI_API_KEY=
|
HOSTED_AZURE_OPENAI_API_KEY=
|
||||||
@ -119,16 +117,7 @@ HOSTED_ANTHROPIC_API_BASE=
|
|||||||
HOSTED_ANTHROPIC_API_KEY=
|
HOSTED_ANTHROPIC_API_KEY=
|
||||||
HOSTED_ANTHROPIC_QUOTA_LIMIT=600000
|
HOSTED_ANTHROPIC_QUOTA_LIMIT=600000
|
||||||
HOSTED_ANTHROPIC_PAID_ENABLED=false
|
HOSTED_ANTHROPIC_PAID_ENABLED=false
|
||||||
HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID=
|
|
||||||
HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA=1000000
|
|
||||||
HOSTED_ANTHROPIC_PAID_MIN_QUANTITY=20
|
|
||||||
HOSTED_ANTHROPIC_PAID_MAX_QUANTITY=100
|
|
||||||
|
|
||||||
# Stripe configuration
|
|
||||||
STRIPE_API_KEY=
|
|
||||||
STRIPE_WEBHOOK_SECRET=
|
|
||||||
|
|
||||||
# Billing configuration
|
# Billing configuration
|
||||||
BILLING_API_URL=http://127.0.0.1:8000/v1
|
BILLING_API_URL=http://127.0.0.1:8000/v1
|
||||||
BILLING_API_SECRET_KEY=
|
BILLING_API_SECRET_KEY=
|
||||||
STRIPE_WEBHOOK_BILLING_SECRET=
|
|
@ -20,7 +20,7 @@ from flask_cors import CORS
|
|||||||
|
|
||||||
from core.model_providers.providers import hosted
|
from core.model_providers.providers import hosted
|
||||||
from extensions import ext_celery, ext_sentry, ext_redis, ext_login, ext_migrate, \
|
from extensions import ext_celery, ext_sentry, ext_redis, ext_login, ext_migrate, \
|
||||||
ext_database, ext_storage, ext_mail, ext_stripe, ext_code_based_extension
|
ext_database, ext_storage, ext_mail, ext_code_based_extension
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from extensions.ext_login import login_manager
|
from extensions.ext_login import login_manager
|
||||||
|
|
||||||
@ -96,7 +96,6 @@ def initialize_extensions(app):
|
|||||||
ext_login.init_app(app)
|
ext_login.init_app(app)
|
||||||
ext_mail.init_app(app)
|
ext_mail.init_app(app)
|
||||||
ext_sentry.init_app(app)
|
ext_sentry.init_app(app)
|
||||||
ext_stripe.init_app(app)
|
|
||||||
|
|
||||||
|
|
||||||
# Flask-Login configuration
|
# Flask-Login configuration
|
||||||
|
@ -1,11 +1,8 @@
|
|||||||
# -*- coding:utf-8 -*-
|
# -*- coding:utf-8 -*-
|
||||||
import os
|
import os
|
||||||
from datetime import timedelta
|
|
||||||
|
|
||||||
import dotenv
|
import dotenv
|
||||||
|
|
||||||
from extensions.ext_database import db
|
|
||||||
from extensions.ext_redis import redis_client
|
|
||||||
|
|
||||||
dotenv.load_dotenv()
|
dotenv.load_dotenv()
|
||||||
|
|
||||||
@ -44,15 +41,11 @@ DEFAULTS = {
|
|||||||
'HOSTED_OPENAI_QUOTA_LIMIT': 200,
|
'HOSTED_OPENAI_QUOTA_LIMIT': 200,
|
||||||
'HOSTED_OPENAI_ENABLED': 'False',
|
'HOSTED_OPENAI_ENABLED': 'False',
|
||||||
'HOSTED_OPENAI_PAID_ENABLED': 'False',
|
'HOSTED_OPENAI_PAID_ENABLED': 'False',
|
||||||
'HOSTED_OPENAI_PAID_INCREASE_QUOTA': 1,
|
|
||||||
'HOSTED_AZURE_OPENAI_ENABLED': 'False',
|
'HOSTED_AZURE_OPENAI_ENABLED': 'False',
|
||||||
'HOSTED_AZURE_OPENAI_QUOTA_LIMIT': 200,
|
'HOSTED_AZURE_OPENAI_QUOTA_LIMIT': 200,
|
||||||
'HOSTED_ANTHROPIC_QUOTA_LIMIT': 600000,
|
'HOSTED_ANTHROPIC_QUOTA_LIMIT': 600000,
|
||||||
'HOSTED_ANTHROPIC_ENABLED': 'False',
|
'HOSTED_ANTHROPIC_ENABLED': 'False',
|
||||||
'HOSTED_ANTHROPIC_PAID_ENABLED': 'False',
|
'HOSTED_ANTHROPIC_PAID_ENABLED': 'False',
|
||||||
'HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA': 1000000,
|
|
||||||
'HOSTED_ANTHROPIC_PAID_MIN_QUANTITY': 20,
|
|
||||||
'HOSTED_ANTHROPIC_PAID_MAX_QUANTITY': 100,
|
|
||||||
'HOSTED_MODERATION_ENABLED': 'False',
|
'HOSTED_MODERATION_ENABLED': 'False',
|
||||||
'HOSTED_MODERATION_PROVIDERS': '',
|
'HOSTED_MODERATION_PROVIDERS': '',
|
||||||
'CLEAN_DAY_SETTING': 30,
|
'CLEAN_DAY_SETTING': 30,
|
||||||
@ -268,8 +261,6 @@ class Config:
|
|||||||
self.HOSTED_OPENAI_API_ORGANIZATION = get_env('HOSTED_OPENAI_API_ORGANIZATION')
|
self.HOSTED_OPENAI_API_ORGANIZATION = get_env('HOSTED_OPENAI_API_ORGANIZATION')
|
||||||
self.HOSTED_OPENAI_QUOTA_LIMIT = int(get_env('HOSTED_OPENAI_QUOTA_LIMIT'))
|
self.HOSTED_OPENAI_QUOTA_LIMIT = int(get_env('HOSTED_OPENAI_QUOTA_LIMIT'))
|
||||||
self.HOSTED_OPENAI_PAID_ENABLED = get_bool_env('HOSTED_OPENAI_PAID_ENABLED')
|
self.HOSTED_OPENAI_PAID_ENABLED = get_bool_env('HOSTED_OPENAI_PAID_ENABLED')
|
||||||
self.HOSTED_OPENAI_PAID_STRIPE_PRICE_ID = get_env('HOSTED_OPENAI_PAID_STRIPE_PRICE_ID')
|
|
||||||
self.HOSTED_OPENAI_PAID_INCREASE_QUOTA = int(get_env('HOSTED_OPENAI_PAID_INCREASE_QUOTA'))
|
|
||||||
|
|
||||||
self.HOSTED_AZURE_OPENAI_ENABLED = get_bool_env('HOSTED_AZURE_OPENAI_ENABLED')
|
self.HOSTED_AZURE_OPENAI_ENABLED = get_bool_env('HOSTED_AZURE_OPENAI_ENABLED')
|
||||||
self.HOSTED_AZURE_OPENAI_API_KEY = get_env('HOSTED_AZURE_OPENAI_API_KEY')
|
self.HOSTED_AZURE_OPENAI_API_KEY = get_env('HOSTED_AZURE_OPENAI_API_KEY')
|
||||||
@ -281,10 +272,6 @@ class Config:
|
|||||||
self.HOSTED_ANTHROPIC_API_KEY = get_env('HOSTED_ANTHROPIC_API_KEY')
|
self.HOSTED_ANTHROPIC_API_KEY = get_env('HOSTED_ANTHROPIC_API_KEY')
|
||||||
self.HOSTED_ANTHROPIC_QUOTA_LIMIT = int(get_env('HOSTED_ANTHROPIC_QUOTA_LIMIT'))
|
self.HOSTED_ANTHROPIC_QUOTA_LIMIT = int(get_env('HOSTED_ANTHROPIC_QUOTA_LIMIT'))
|
||||||
self.HOSTED_ANTHROPIC_PAID_ENABLED = get_bool_env('HOSTED_ANTHROPIC_PAID_ENABLED')
|
self.HOSTED_ANTHROPIC_PAID_ENABLED = get_bool_env('HOSTED_ANTHROPIC_PAID_ENABLED')
|
||||||
self.HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID = get_env('HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID')
|
|
||||||
self.HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA = int(get_env('HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA'))
|
|
||||||
self.HOSTED_ANTHROPIC_PAID_MIN_QUANTITY = int(get_env('HOSTED_ANTHROPIC_PAID_MIN_QUANTITY'))
|
|
||||||
self.HOSTED_ANTHROPIC_PAID_MAX_QUANTITY = int(get_env('HOSTED_ANTHROPIC_PAID_MAX_QUANTITY'))
|
|
||||||
|
|
||||||
self.HOSTED_MODERATION_ENABLED = get_bool_env('HOSTED_MODERATION_ENABLED')
|
self.HOSTED_MODERATION_ENABLED = get_bool_env('HOSTED_MODERATION_ENABLED')
|
||||||
self.HOSTED_MODERATION_PROVIDERS = get_env('HOSTED_MODERATION_PROVIDERS')
|
self.HOSTED_MODERATION_PROVIDERS = get_env('HOSTED_MODERATION_PROVIDERS')
|
||||||
@ -302,6 +289,3 @@ class CloudEditionConfig(Config):
|
|||||||
self.GOOGLE_CLIENT_ID = get_env('GOOGLE_CLIENT_ID')
|
self.GOOGLE_CLIENT_ID = get_env('GOOGLE_CLIENT_ID')
|
||||||
self.GOOGLE_CLIENT_SECRET = get_env('GOOGLE_CLIENT_SECRET')
|
self.GOOGLE_CLIENT_SECRET = get_env('GOOGLE_CLIENT_SECRET')
|
||||||
self.OAUTH_REDIRECT_PATH = get_env('OAUTH_REDIRECT_PATH')
|
self.OAUTH_REDIRECT_PATH = get_env('OAUTH_REDIRECT_PATH')
|
||||||
|
|
||||||
self.STRIPE_API_KEY = get_env('STRIPE_API_KEY')
|
|
||||||
self.STRIPE_WEBHOOK_SECRET = get_env('STRIPE_WEBHOOK_SECRET')
|
|
||||||
|
@ -26,7 +26,4 @@ from .explore import installed_app, recommended_app, completion, conversation, m
|
|||||||
# Import universal chat controllers
|
# Import universal chat controllers
|
||||||
from .universal_chat import chat, conversation, message, parameter, audio
|
from .universal_chat import chat, conversation, message, parameter, audio
|
||||||
|
|
||||||
# Import webhook controllers
|
|
||||||
from .webhook import stripe
|
|
||||||
|
|
||||||
from .billing import billing
|
from .billing import billing
|
||||||
|
@ -1,9 +1,6 @@
|
|||||||
import stripe
|
|
||||||
import os
|
|
||||||
|
|
||||||
from flask_restful import Resource, reqparse
|
from flask_restful import Resource, reqparse
|
||||||
from flask_login import current_user
|
from flask_login import current_user
|
||||||
from flask import current_app, request
|
from flask import current_app
|
||||||
|
|
||||||
from controllers.console import api
|
from controllers.console import api
|
||||||
from controllers.console.setup import setup_required
|
from controllers.console.setup import setup_required
|
||||||
@ -40,7 +37,10 @@ class Subscription(Resource):
|
|||||||
parser.add_argument('interval', type=str, required=True, location='args', choices=['month', 'year'])
|
parser.add_argument('interval', type=str, required=True, location='args', choices=['month', 'year'])
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
return BillingService.get_subscription(args['plan'], args['interval'], current_user.email, current_user.name, current_user.current_tenant_id)
|
return BillingService.get_subscription(args['plan'],
|
||||||
|
args['interval'],
|
||||||
|
current_user.email,
|
||||||
|
current_user.current_tenant_id)
|
||||||
|
|
||||||
|
|
||||||
class Invoices(Resource):
|
class Invoices(Resource):
|
||||||
@ -54,32 +54,6 @@ class Invoices(Resource):
|
|||||||
return BillingService.get_invoices(current_user.email)
|
return BillingService.get_invoices(current_user.email)
|
||||||
|
|
||||||
|
|
||||||
class StripeBillingWebhook(Resource):
|
|
||||||
|
|
||||||
@setup_required
|
|
||||||
@only_edition_cloud
|
|
||||||
def post(self):
|
|
||||||
payload = request.data
|
|
||||||
sig_header = request.headers.get('STRIPE_SIGNATURE')
|
|
||||||
webhook_secret = os.environ.get('STRIPE_WEBHOOK_BILLING_SECRET', 'STRIPE_WEBHOOK_BILLING_SECRET')
|
|
||||||
|
|
||||||
try:
|
|
||||||
event = stripe.Webhook.construct_event(
|
|
||||||
payload, sig_header, webhook_secret
|
|
||||||
)
|
|
||||||
except ValueError as e:
|
|
||||||
# Invalid payload
|
|
||||||
return 'Invalid payload', 400
|
|
||||||
except stripe.error.SignatureVerificationError as e:
|
|
||||||
# Invalid signature
|
|
||||||
return 'Invalid signature', 400
|
|
||||||
|
|
||||||
BillingService.process_event(event)
|
|
||||||
|
|
||||||
return 'success', 200
|
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(BillingInfo, '/billing/info')
|
api.add_resource(BillingInfo, '/billing/info')
|
||||||
api.add_resource(Subscription, '/billing/subscription')
|
api.add_resource(Subscription, '/billing/subscription')
|
||||||
api.add_resource(Invoices, '/billing/invoices')
|
api.add_resource(Invoices, '/billing/invoices')
|
||||||
api.add_resource(StripeBillingWebhook, '/billing/webhook/stripe')
|
|
||||||
|
@ -1,61 +0,0 @@
|
|||||||
import logging
|
|
||||||
|
|
||||||
import stripe
|
|
||||||
from flask import request, current_app
|
|
||||||
from flask_restful import Resource
|
|
||||||
|
|
||||||
from controllers.console import api
|
|
||||||
from controllers.console.setup import setup_required
|
|
||||||
from controllers.console.wraps import only_edition_cloud
|
|
||||||
from services.provider_checkout_service import ProviderCheckoutService
|
|
||||||
|
|
||||||
|
|
||||||
class StripeWebhookApi(Resource):
|
|
||||||
@setup_required
|
|
||||||
@only_edition_cloud
|
|
||||||
def post(self):
|
|
||||||
payload = request.data
|
|
||||||
sig_header = request.headers.get('STRIPE_SIGNATURE')
|
|
||||||
webhook_secret = current_app.config.get('STRIPE_WEBHOOK_SECRET')
|
|
||||||
|
|
||||||
try:
|
|
||||||
event = stripe.Webhook.construct_event(
|
|
||||||
payload, sig_header, webhook_secret
|
|
||||||
)
|
|
||||||
except ValueError as e:
|
|
||||||
# Invalid payload
|
|
||||||
return 'Invalid payload', 400
|
|
||||||
except stripe.error.SignatureVerificationError as e:
|
|
||||||
# Invalid signature
|
|
||||||
return 'Invalid signature', 400
|
|
||||||
|
|
||||||
# Handle the checkout.session.completed event
|
|
||||||
if event['type'] == 'checkout.session.completed':
|
|
||||||
logging.debug(event['data']['object']['id'])
|
|
||||||
logging.debug(event['data']['object']['amount_subtotal'])
|
|
||||||
logging.debug(event['data']['object']['currency'])
|
|
||||||
logging.debug(event['data']['object']['payment_intent'])
|
|
||||||
logging.debug(event['data']['object']['payment_status'])
|
|
||||||
logging.debug(event['data']['object']['metadata'])
|
|
||||||
|
|
||||||
session = stripe.checkout.Session.retrieve(
|
|
||||||
event['data']['object']['id'],
|
|
||||||
expand=['line_items'],
|
|
||||||
)
|
|
||||||
|
|
||||||
logging.debug(session.line_items['data'][0]['quantity'])
|
|
||||||
|
|
||||||
# Fulfill the purchase...
|
|
||||||
provider_checkout_service = ProviderCheckoutService()
|
|
||||||
|
|
||||||
try:
|
|
||||||
provider_checkout_service.fulfill_provider_order(event, session.line_items)
|
|
||||||
except Exception as e:
|
|
||||||
|
|
||||||
logging.debug(str(e))
|
|
||||||
return 'success', 200
|
|
||||||
|
|
||||||
return 'success', 200
|
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(StripeWebhookApi, '/webhook/stripe')
|
|
@ -9,8 +9,8 @@ from controllers.console.setup import setup_required
|
|||||||
from controllers.console.wraps import account_initialization_required
|
from controllers.console.wraps import account_initialization_required
|
||||||
from core.model_providers.error import LLMBadRequestError
|
from core.model_providers.error import LLMBadRequestError
|
||||||
from core.model_providers.providers.base import CredentialsValidateFailedError
|
from core.model_providers.providers.base import CredentialsValidateFailedError
|
||||||
from services.provider_checkout_service import ProviderCheckoutService
|
|
||||||
from services.provider_service import ProviderService
|
from services.provider_service import ProviderService
|
||||||
|
from services.billing_service import BillingService
|
||||||
|
|
||||||
|
|
||||||
class ModelProviderListApi(Resource):
|
class ModelProviderListApi(Resource):
|
||||||
@ -264,16 +264,13 @@ class ModelProviderPaymentCheckoutUrlApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, provider_name: str):
|
def get(self, provider_name: str):
|
||||||
provider_service = ProviderCheckoutService()
|
if provider_name != 'anthropic':
|
||||||
provider_checkout = provider_service.create_checkout(
|
raise ValueError(f'provider name {provider_name} is invalid')
|
||||||
tenant_id=current_user.current_tenant_id,
|
|
||||||
provider_name=provider_name,
|
|
||||||
account=current_user
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
data = BillingService.get_model_provider_payment_link(provider_name=provider_name,
|
||||||
'url': provider_checkout.get_checkout_url()
|
tenant_id=current_user.current_tenant_id,
|
||||||
}
|
account_id=current_user.id)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
class ModelProviderFreeQuotaSubmitApi(Resource):
|
class ModelProviderFreeQuotaSubmitApi(Resource):
|
||||||
|
@ -191,23 +191,6 @@ class AnthropicProvider(BaseModelProvider):
|
|||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def get_payment_info(self) -> Optional[dict]:
|
|
||||||
"""
|
|
||||||
get product info if it payable.
|
|
||||||
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
if hosted_model_providers.anthropic \
|
|
||||||
and hosted_model_providers.anthropic.paid_enabled:
|
|
||||||
return {
|
|
||||||
'product_id': hosted_model_providers.anthropic.paid_stripe_price_id,
|
|
||||||
'increase_quota': hosted_model_providers.anthropic.paid_increase_quota,
|
|
||||||
'min_quantity': hosted_model_providers.anthropic.paid_min_quantity,
|
|
||||||
'max_quantity': hosted_model_providers.anthropic.paid_max_quantity,
|
|
||||||
}
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
|
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
|
||||||
"""
|
"""
|
||||||
|
@ -267,14 +267,6 @@ class BaseModelProvider(BaseModel, ABC):
|
|||||||
).update({'last_used': datetime.utcnow()})
|
).update({'last_used': datetime.utcnow()})
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
def get_payment_info(self) -> Optional[dict]:
|
|
||||||
"""
|
|
||||||
get product info if it payable.
|
|
||||||
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _get_provider_model(self, model_name: str, model_type: ModelType) -> ProviderModel:
|
def _get_provider_model(self, model_name: str, model_type: ModelType) -> ProviderModel:
|
||||||
"""
|
"""
|
||||||
get provider model.
|
get provider model.
|
||||||
|
@ -13,8 +13,6 @@ class HostedOpenAI(BaseModel):
|
|||||||
quota_limit: int = 0
|
quota_limit: int = 0
|
||||||
"""Quota limit for the openai hosted model. -1 means unlimited."""
|
"""Quota limit for the openai hosted model. -1 means unlimited."""
|
||||||
paid_enabled: bool = False
|
paid_enabled: bool = False
|
||||||
paid_stripe_price_id: str = None
|
|
||||||
paid_increase_quota: int = 1
|
|
||||||
|
|
||||||
|
|
||||||
class HostedAzureOpenAI(BaseModel):
|
class HostedAzureOpenAI(BaseModel):
|
||||||
@ -30,10 +28,6 @@ class HostedAnthropic(BaseModel):
|
|||||||
quota_limit: int = 0
|
quota_limit: int = 0
|
||||||
"""Quota limit for the anthropic hosted model. -1 means unlimited."""
|
"""Quota limit for the anthropic hosted model. -1 means unlimited."""
|
||||||
paid_enabled: bool = False
|
paid_enabled: bool = False
|
||||||
paid_stripe_price_id: str = None
|
|
||||||
paid_increase_quota: int = 1000000
|
|
||||||
paid_min_quantity: int = 20
|
|
||||||
paid_max_quantity: int = 100
|
|
||||||
|
|
||||||
|
|
||||||
class HostedModelProviders(BaseModel):
|
class HostedModelProviders(BaseModel):
|
||||||
@ -68,8 +62,6 @@ def init_app(app: Flask):
|
|||||||
api_key=app.config.get("HOSTED_OPENAI_API_KEY"),
|
api_key=app.config.get("HOSTED_OPENAI_API_KEY"),
|
||||||
quota_limit=app.config.get("HOSTED_OPENAI_QUOTA_LIMIT"),
|
quota_limit=app.config.get("HOSTED_OPENAI_QUOTA_LIMIT"),
|
||||||
paid_enabled=app.config.get("HOSTED_OPENAI_PAID_ENABLED"),
|
paid_enabled=app.config.get("HOSTED_OPENAI_PAID_ENABLED"),
|
||||||
paid_stripe_price_id=app.config.get("HOSTED_OPENAI_PAID_STRIPE_PRICE_ID"),
|
|
||||||
paid_increase_quota=app.config.get("HOSTED_OPENAI_PAID_INCREASE_QUOTA"),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if app.config.get("HOSTED_AZURE_OPENAI_ENABLED"):
|
if app.config.get("HOSTED_AZURE_OPENAI_ENABLED"):
|
||||||
@ -85,10 +77,6 @@ def init_app(app: Flask):
|
|||||||
api_key=app.config.get("HOSTED_ANTHROPIC_API_KEY"),
|
api_key=app.config.get("HOSTED_ANTHROPIC_API_KEY"),
|
||||||
quota_limit=app.config.get("HOSTED_ANTHROPIC_QUOTA_LIMIT"),
|
quota_limit=app.config.get("HOSTED_ANTHROPIC_QUOTA_LIMIT"),
|
||||||
paid_enabled=app.config.get("HOSTED_ANTHROPIC_PAID_ENABLED"),
|
paid_enabled=app.config.get("HOSTED_ANTHROPIC_PAID_ENABLED"),
|
||||||
paid_stripe_price_id=app.config.get("HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID"),
|
|
||||||
paid_increase_quota=app.config.get("HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA"),
|
|
||||||
paid_min_quantity=app.config.get("HOSTED_ANTHROPIC_PAID_MIN_QUANTITY"),
|
|
||||||
paid_max_quantity=app.config.get("HOSTED_ANTHROPIC_PAID_MAX_QUANTITY"),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if app.config.get("HOSTED_MODERATION_ENABLED") and app.config.get("HOSTED_MODERATION_PROVIDERS"):
|
if app.config.get("HOSTED_MODERATION_ENABLED") and app.config.get("HOSTED_MODERATION_PROVIDERS"):
|
||||||
|
@ -282,21 +282,6 @@ class OpenAIProvider(BaseModelProvider):
|
|||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def get_payment_info(self) -> Optional[dict]:
|
|
||||||
"""
|
|
||||||
get payment info if it payable.
|
|
||||||
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
if hosted_model_providers.openai \
|
|
||||||
and hosted_model_providers.openai.paid_enabled:
|
|
||||||
return {
|
|
||||||
'product_id': hosted_model_providers.openai.paid_stripe_price_id,
|
|
||||||
'increase_quota': hosted_model_providers.openai.paid_increase_quota,
|
|
||||||
}
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
|
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
|
||||||
"""
|
"""
|
||||||
|
@ -1,6 +0,0 @@
|
|||||||
import stripe
|
|
||||||
|
|
||||||
|
|
||||||
def init_app(app):
|
|
||||||
if app.config.get('STRIPE_API_KEY'):
|
|
||||||
stripe.api_key = app.config.get('STRIPE_API_KEY')
|
|
@ -135,21 +135,6 @@ class TenantPreferredModelProvider(db.Model):
|
|||||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||||
|
|
||||||
|
|
||||||
class ProviderOrderPaymentStatus(Enum):
|
|
||||||
WAIT_PAY = 'wait_pay'
|
|
||||||
PAID = 'paid'
|
|
||||||
PAY_FAILED = 'pay_failed'
|
|
||||||
REFUNDED = 'refunded'
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def value_of(value):
|
|
||||||
for member in ProviderOrderPaymentStatus:
|
|
||||||
if member.value == value:
|
|
||||||
return member
|
|
||||||
raise ValueError(f"No matching enum found for value '{value}'")
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ProviderOrder(db.Model):
|
class ProviderOrder(db.Model):
|
||||||
__tablename__ = 'provider_orders'
|
__tablename__ = 'provider_orders'
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
|
@ -46,7 +46,6 @@ websocket-client~=1.6.1
|
|||||||
dashscope~=1.11.0
|
dashscope~=1.11.0
|
||||||
huggingface_hub~=0.16.4
|
huggingface_hub~=0.16.4
|
||||||
transformers~=4.31.0
|
transformers~=4.31.0
|
||||||
stripe~=5.5.0
|
|
||||||
pandas==1.5.3
|
pandas==1.5.3
|
||||||
xinference-client~=0.6.4
|
xinference-client~=0.6.4
|
||||||
safetensors==0.3.2
|
safetensors==0.3.2
|
||||||
|
@ -10,7 +10,7 @@ class BillingService:
|
|||||||
def get_info(cls, tenant_id: str):
|
def get_info(cls, tenant_id: str):
|
||||||
params = {'tenant_id': tenant_id}
|
params = {'tenant_id': tenant_id}
|
||||||
|
|
||||||
billing_info = cls._send_request('GET', '/info', params=params)
|
billing_info = cls._send_request('GET', '/subscription/info', params=params)
|
||||||
|
|
||||||
return billing_info
|
return billing_info
|
||||||
|
|
||||||
@ -18,16 +18,26 @@ class BillingService:
|
|||||||
def get_subscription(cls, plan: str,
|
def get_subscription(cls, plan: str,
|
||||||
interval: str,
|
interval: str,
|
||||||
prefilled_email: str = '',
|
prefilled_email: str = '',
|
||||||
user_name: str = '',
|
|
||||||
tenant_id: str = ''):
|
tenant_id: str = ''):
|
||||||
params = {
|
params = {
|
||||||
'plan': plan,
|
'plan': plan,
|
||||||
'interval': interval,
|
'interval': interval,
|
||||||
'prefilled_email': prefilled_email,
|
'prefilled_email': prefilled_email,
|
||||||
'user_name': user_name,
|
|
||||||
'tenant_id': tenant_id
|
'tenant_id': tenant_id
|
||||||
}
|
}
|
||||||
return cls._send_request('GET', '/subscription', params=params)
|
return cls._send_request('GET', '/subscription/payment-link', params=params)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_model_provider_payment_link(cls,
|
||||||
|
provider_name: str,
|
||||||
|
tenant_id: str,
|
||||||
|
account_id: str):
|
||||||
|
params = {
|
||||||
|
'provider_name': provider_name,
|
||||||
|
'tenant_id': tenant_id,
|
||||||
|
'account_id': account_id
|
||||||
|
}
|
||||||
|
return cls._send_request('GET', '/model-provider/payment-link', params=params)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_invoices(cls, prefilled_email: str = ''):
|
def get_invoices(cls, prefilled_email: str = ''):
|
||||||
@ -45,10 +55,3 @@ class BillingService:
|
|||||||
response = requests.request(method, url, json=json, params=params, headers=headers)
|
response = requests.request(method, url, json=json, params=params, headers=headers)
|
||||||
|
|
||||||
return response.json()
|
return response.json()
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def process_event(cls, event: dict):
|
|
||||||
json = {
|
|
||||||
"content": event,
|
|
||||||
}
|
|
||||||
return cls._send_request('POST', '/webhook/stripe', json=json)
|
|
||||||
|
@ -1,174 +0,0 @@
|
|||||||
import datetime
|
|
||||||
import logging
|
|
||||||
|
|
||||||
import stripe
|
|
||||||
from flask import current_app
|
|
||||||
|
|
||||||
from core.model_providers.model_provider_factory import ModelProviderFactory
|
|
||||||
from extensions.ext_database import db
|
|
||||||
from models.account import Account
|
|
||||||
from models.provider import ProviderOrder, ProviderOrderPaymentStatus, ProviderType, Provider, ProviderQuotaType
|
|
||||||
|
|
||||||
|
|
||||||
class ProviderCheckout:
|
|
||||||
def __init__(self, stripe_checkout_session):
|
|
||||||
self.stripe_checkout_session = stripe_checkout_session
|
|
||||||
|
|
||||||
def get_checkout_url(self):
|
|
||||||
return self.stripe_checkout_session.url
|
|
||||||
|
|
||||||
|
|
||||||
class ProviderCheckoutService:
|
|
||||||
def create_checkout(self, tenant_id: str, provider_name: str, account: Account) -> ProviderCheckout:
|
|
||||||
# check provider name is valid
|
|
||||||
model_provider_rules = ModelProviderFactory.get_provider_rules()
|
|
||||||
if provider_name not in model_provider_rules:
|
|
||||||
raise ValueError(f'provider name {provider_name} is invalid')
|
|
||||||
|
|
||||||
model_provider_rule = model_provider_rules[provider_name]
|
|
||||||
|
|
||||||
# check provider name can be paid
|
|
||||||
self._check_provider_payable(provider_name, model_provider_rule)
|
|
||||||
|
|
||||||
# get stripe checkout product id
|
|
||||||
paid_provider = self._get_paid_provider(tenant_id, provider_name)
|
|
||||||
model_provider_class = ModelProviderFactory.get_model_provider_class(provider_name)
|
|
||||||
model_provider = model_provider_class(provider=paid_provider)
|
|
||||||
payment_info = model_provider.get_payment_info()
|
|
||||||
if not payment_info:
|
|
||||||
raise ValueError(f'provider name {provider_name} not support payment')
|
|
||||||
|
|
||||||
payment_product_id = payment_info['product_id']
|
|
||||||
payment_min_quantity = payment_info['min_quantity']
|
|
||||||
payment_max_quantity = payment_info['max_quantity']
|
|
||||||
|
|
||||||
# create provider order
|
|
||||||
provider_order = ProviderOrder(
|
|
||||||
tenant_id=tenant_id,
|
|
||||||
provider_name=provider_name,
|
|
||||||
account_id=account.id,
|
|
||||||
payment_product_id=payment_product_id,
|
|
||||||
quantity=1,
|
|
||||||
payment_status=ProviderOrderPaymentStatus.WAIT_PAY.value
|
|
||||||
)
|
|
||||||
|
|
||||||
db.session.add(provider_order)
|
|
||||||
db.session.flush()
|
|
||||||
|
|
||||||
line_item = {
|
|
||||||
'price': f'{payment_product_id}',
|
|
||||||
'quantity': payment_min_quantity
|
|
||||||
}
|
|
||||||
|
|
||||||
if payment_min_quantity > 1 and payment_max_quantity != payment_min_quantity:
|
|
||||||
line_item['adjustable_quantity'] = {
|
|
||||||
'enabled': True,
|
|
||||||
'minimum': payment_min_quantity,
|
|
||||||
'maximum': payment_max_quantity
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
|
||||||
# create stripe checkout session
|
|
||||||
checkout_session = stripe.checkout.Session.create(
|
|
||||||
line_items=[
|
|
||||||
line_item
|
|
||||||
],
|
|
||||||
mode='payment',
|
|
||||||
success_url=current_app.config.get("CONSOLE_WEB_URL")
|
|
||||||
+ f'?provider_name={provider_name}&payment_result=succeeded',
|
|
||||||
cancel_url=current_app.config.get("CONSOLE_WEB_URL")
|
|
||||||
+ f'?provider_name={provider_name}&payment_result=cancelled',
|
|
||||||
automatic_tax={'enabled': True},
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logging.exception(e)
|
|
||||||
raise ValueError(f'provider name {provider_name} create checkout session failed, please try again later')
|
|
||||||
|
|
||||||
provider_order.payment_id = checkout_session.id
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
return ProviderCheckout(checkout_session)
|
|
||||||
|
|
||||||
def fulfill_provider_order(self, event, line_items):
|
|
||||||
provider_order = db.session.query(ProviderOrder) \
|
|
||||||
.filter(ProviderOrder.payment_id == event['data']['object']['id']) \
|
|
||||||
.first()
|
|
||||||
|
|
||||||
if not provider_order:
|
|
||||||
raise ValueError(f'provider order not found, payment id: {event["data"]["object"]["id"]}')
|
|
||||||
|
|
||||||
if provider_order.payment_status != ProviderOrderPaymentStatus.WAIT_PAY.value:
|
|
||||||
raise ValueError(
|
|
||||||
f'provider order payment status is not wait pay, payment id: {event["data"]["object"]["id"]}')
|
|
||||||
|
|
||||||
provider_order.transaction_id = event['data']['object']['payment_intent']
|
|
||||||
provider_order.currency = event['data']['object']['currency']
|
|
||||||
provider_order.total_amount = event['data']['object']['amount_subtotal']
|
|
||||||
provider_order.payment_status = ProviderOrderPaymentStatus.PAID.value
|
|
||||||
provider_order.paid_at = datetime.datetime.utcnow()
|
|
||||||
provider_order.updated_at = provider_order.paid_at
|
|
||||||
|
|
||||||
# update provider quota
|
|
||||||
provider = db.session.query(Provider).filter(
|
|
||||||
Provider.tenant_id == provider_order.tenant_id,
|
|
||||||
Provider.provider_name == provider_order.provider_name,
|
|
||||||
Provider.provider_type == ProviderType.SYSTEM.value,
|
|
||||||
Provider.quota_type == ProviderQuotaType.PAID.value
|
|
||||||
).first()
|
|
||||||
|
|
||||||
if not provider:
|
|
||||||
raise ValueError(f'provider not found, tenant id: {provider_order.tenant_id}, '
|
|
||||||
f'provider name: {provider_order.provider_name}')
|
|
||||||
|
|
||||||
model_provider_class = ModelProviderFactory.get_model_provider_class(provider_order.provider_name)
|
|
||||||
model_provider = model_provider_class(provider=provider)
|
|
||||||
payment_info = model_provider.get_payment_info()
|
|
||||||
|
|
||||||
quantity = line_items['data'][0]['quantity']
|
|
||||||
|
|
||||||
if not payment_info:
|
|
||||||
increase_quota = 0
|
|
||||||
else:
|
|
||||||
increase_quota = int(payment_info['increase_quota']) * quantity
|
|
||||||
|
|
||||||
if increase_quota > 0:
|
|
||||||
provider.quota_limit += increase_quota
|
|
||||||
provider.is_valid = True
|
|
||||||
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
def _check_provider_payable(self, provider_name: str, model_provider_rule: dict):
|
|
||||||
if ProviderType.SYSTEM.value not in model_provider_rule['support_provider_types']:
|
|
||||||
raise ValueError(f'provider name {provider_name} not support payment')
|
|
||||||
|
|
||||||
if 'system_config' not in model_provider_rule:
|
|
||||||
raise ValueError(f'provider name {provider_name} not support payment')
|
|
||||||
|
|
||||||
if 'supported_quota_types' not in model_provider_rule['system_config']:
|
|
||||||
raise ValueError(f'provider name {provider_name} not support payment')
|
|
||||||
|
|
||||||
if 'paid' not in model_provider_rule['system_config']['supported_quota_types']:
|
|
||||||
raise ValueError(f'provider name {provider_name} not support payment')
|
|
||||||
|
|
||||||
def _get_paid_provider(self, tenant_id: str, provider_name: str):
|
|
||||||
paid_provider = db.session.query(Provider) \
|
|
||||||
.filter(
|
|
||||||
Provider.tenant_id == tenant_id,
|
|
||||||
Provider.provider_name == provider_name,
|
|
||||||
Provider.provider_type == ProviderType.SYSTEM.value,
|
|
||||||
Provider.quota_type == ProviderQuotaType.PAID.value,
|
|
||||||
).first()
|
|
||||||
|
|
||||||
if not paid_provider:
|
|
||||||
paid_provider = Provider(
|
|
||||||
tenant_id=tenant_id,
|
|
||||||
provider_name=provider_name,
|
|
||||||
provider_type=ProviderType.SYSTEM.value,
|
|
||||||
quota_type=ProviderQuotaType.PAID.value,
|
|
||||||
quota_limit=0,
|
|
||||||
quota_used=0,
|
|
||||||
)
|
|
||||||
db.session.add(paid_provider)
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
return paid_provider
|
|
Loading…
x
Reference in New Issue
Block a user