Feat/add free provider apply (#829)

This commit is contained in:
takatost 2023-08-14 12:44:35 +08:00 committed by GitHub
parent 42a417167f
commit cc52cdc2a9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 54 additions and 3 deletions

View File

@ -270,6 +270,20 @@ class ModelProviderPaymentCheckoutUrlApi(Resource):
} }
class ModelProviderFreeQuotaSubmitApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider_name: str):
provider_service = ProviderService()
result = provider_service.free_quota_submit(
tenant_id=current_user.current_tenant_id,
provider_name=provider_name
)
return result
api.add_resource(ModelProviderListApi, '/workspaces/current/model-providers') api.add_resource(ModelProviderListApi, '/workspaces/current/model-providers')
api.add_resource(ModelProviderValidateApi, '/workspaces/current/model-providers/<string:provider_name>/validate') api.add_resource(ModelProviderValidateApi, '/workspaces/current/model-providers/<string:provider_name>/validate')
api.add_resource(ModelProviderUpdateApi, '/workspaces/current/model-providers/<string:provider_name>') api.add_resource(ModelProviderUpdateApi, '/workspaces/current/model-providers/<string:provider_name>')
@ -283,3 +297,5 @@ api.add_resource(ModelProviderModelParameterRuleApi,
'/workspaces/current/model-providers/<string:provider_name>/models/parameter-rules') '/workspaces/current/model-providers/<string:provider_name>/models/parameter-rules')
api.add_resource(ModelProviderPaymentCheckoutUrlApi, api.add_resource(ModelProviderPaymentCheckoutUrlApi,
'/workspaces/current/model-providers/<string:provider_name>/checkout-url') '/workspaces/current/model-providers/<string:provider_name>/checkout-url')
api.add_resource(ModelProviderFreeQuotaSubmitApi,
'/workspaces/current/model-providers/<string:provider_name>/free-quota-submit')

View File

@ -3,7 +3,6 @@ import logging
from json import JSONDecodeError from json import JSONDecodeError
from typing import Type from typing import Type
from flask import current_app
from langchain.schema import HumanMessage from langchain.schema import HumanMessage
from core.helper import encrypter from core.helper import encrypter

View File

@ -50,6 +50,7 @@ class ChatSpark(BaseChatModel):
app_id: Optional[str] = None app_id: Optional[str] = None
api_key: Optional[str] = None api_key: Optional[str] = None
api_secret: Optional[str] = None api_secret: Optional[str] = None
api_domain: Optional[str] = None
@root_validator() @root_validator()
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
@ -68,6 +69,7 @@ class ChatSpark(BaseChatModel):
app_id=values["app_id"], app_id=values["app_id"],
api_key=values["api_key"], api_key=values["api_key"],
api_secret=values["api_secret"], api_secret=values["api_secret"],
api_domain=values.get('api_domain')
) )
return values return values

View File

@ -16,9 +16,9 @@ import websocket
class SparkLLMClient: class SparkLLMClient:
def __init__(self, app_id: str, api_key: str, api_secret: str): def __init__(self, app_id: str, api_key: str, api_secret: str, api_domain: Optional[str] = None):
self.api_base = "ws://spark-api.xf-yun.com/v1.1/chat" self.api_base = "wss://spark-api.xf-yun.com/v1.1/chat" if not api_domain else ('wss://' + api_domain + '/v1.1/chat')
self.app_id = app_id self.app_id = app_id
self.ws_url = self.create_url( self.ws_url = self.create_url(
urlparse(self.api_base).netloc, urlparse(self.api_base).netloc,

View File

@ -1,8 +1,12 @@
import datetime import datetime
import json import json
import logging
import os
from collections import defaultdict from collections import defaultdict
from typing import Optional from typing import Optional
import requests
from core.model_providers.model_factory import ModelFactory from core.model_providers.model_factory import ModelFactory
from extensions.ext_database import db from extensions.ext_database import db
from core.model_providers.model_provider_factory import ModelProviderFactory from core.model_providers.model_provider_factory import ModelProviderFactory
@ -509,3 +513,33 @@ class ProviderService:
# get model parameter rules # get model parameter rules
return model_provider.get_model_parameter_rules(model_name, ModelType.value_of(model_type)) return model_provider.get_model_parameter_rules(model_name, ModelType.value_of(model_type))
def free_quota_submit(self, tenant_id: str, provider_name: str):
api_key = os.environ.get("FREE_QUOTA_APPLY_API_KEY")
api_url = os.environ.get("FREE_QUOTA_APPLY_URL")
headers = {
'Content-Type': 'application/json',
'Authorization': f"Bearer {api_key}"
}
response = requests.post(api_url, headers=headers, json={'workspace_id': tenant_id, 'provider_name': provider_name})
if not response.ok:
logging.error(f"Request FREE QUOTA APPLY SERVER Error: {response.status_code} ")
raise ValueError(f"Error: {response.status_code} ")
if response.json()["code"] != 'success':
raise ValueError(
f"error: {response.json()['message']}"
)
rst = response.json()
if rst['type'] == 'redirect':
return {
'type': rst['type'],
'redirect_url': rst['redirect_url']
}
else:
return {
'type': rst['type'],
'result': 'success'
}