diff --git a/api/controllers/console/workspace/model_providers.py b/api/controllers/console/workspace/model_providers.py index d32591803a..0fa01d28a3 100644 --- a/api/controllers/console/workspace/model_providers.py +++ b/api/controllers/console/workspace/model_providers.py @@ -270,6 +270,20 @@ class ModelProviderPaymentCheckoutUrlApi(Resource): } +class ModelProviderFreeQuotaSubmitApi(Resource): + @setup_required + @login_required + @account_initialization_required + def post(self, provider_name: str): + provider_service = ProviderService() + result = provider_service.free_quota_submit( + tenant_id=current_user.current_tenant_id, + provider_name=provider_name + ) + + return result + + api.add_resource(ModelProviderListApi, '/workspaces/current/model-providers') api.add_resource(ModelProviderValidateApi, '/workspaces/current/model-providers//validate') api.add_resource(ModelProviderUpdateApi, '/workspaces/current/model-providers/') @@ -283,3 +297,5 @@ api.add_resource(ModelProviderModelParameterRuleApi, '/workspaces/current/model-providers//models/parameter-rules') api.add_resource(ModelProviderPaymentCheckoutUrlApi, '/workspaces/current/model-providers//checkout-url') +api.add_resource(ModelProviderFreeQuotaSubmitApi, + '/workspaces/current/model-providers//free-quota-submit') diff --git a/api/core/model_providers/providers/spark_provider.py b/api/core/model_providers/providers/spark_provider.py index 7bcd060be2..4030a577dd 100644 --- a/api/core/model_providers/providers/spark_provider.py +++ b/api/core/model_providers/providers/spark_provider.py @@ -3,7 +3,6 @@ import logging from json import JSONDecodeError from typing import Type -from flask import current_app from langchain.schema import HumanMessage from core.helper import encrypter diff --git a/api/core/third_party/langchain/llms/spark.py b/api/core/third_party/langchain/llms/spark.py index 23eb7472a9..7bc777c484 100644 --- a/api/core/third_party/langchain/llms/spark.py +++ b/api/core/third_party/langchain/llms/spark.py @@ -50,6 +50,7 @@ class ChatSpark(BaseChatModel): app_id: Optional[str] = None api_key: Optional[str] = None api_secret: Optional[str] = None + api_domain: Optional[str] = None @root_validator() def validate_environment(cls, values: Dict) -> Dict: @@ -68,6 +69,7 @@ class ChatSpark(BaseChatModel): app_id=values["app_id"], api_key=values["api_key"], api_secret=values["api_secret"], + api_domain=values.get('api_domain') ) return values diff --git a/api/core/third_party/spark/spark_llm.py b/api/core/third_party/spark/spark_llm.py index 2b6d9b498c..1cc3b8a486 100644 --- a/api/core/third_party/spark/spark_llm.py +++ b/api/core/third_party/spark/spark_llm.py @@ -16,9 +16,9 @@ import websocket class SparkLLMClient: - def __init__(self, app_id: str, api_key: str, api_secret: str): + def __init__(self, app_id: str, api_key: str, api_secret: str, api_domain: Optional[str] = None): - self.api_base = "ws://spark-api.xf-yun.com/v1.1/chat" + self.api_base = "wss://spark-api.xf-yun.com/v1.1/chat" if not api_domain else ('wss://' + api_domain + '/v1.1/chat') self.app_id = app_id self.ws_url = self.create_url( urlparse(self.api_base).netloc, diff --git a/api/services/provider_service.py b/api/services/provider_service.py index f061e68d92..8aba64153d 100644 --- a/api/services/provider_service.py +++ b/api/services/provider_service.py @@ -1,8 +1,12 @@ import datetime import json +import logging +import os from collections import defaultdict from typing import Optional +import requests + from core.model_providers.model_factory import ModelFactory from extensions.ext_database import db from core.model_providers.model_provider_factory import ModelProviderFactory @@ -509,3 +513,33 @@ class ProviderService: # get model parameter rules return model_provider.get_model_parameter_rules(model_name, ModelType.value_of(model_type)) + def free_quota_submit(self, tenant_id: str, provider_name: str): + api_key = os.environ.get("FREE_QUOTA_APPLY_API_KEY") + api_url = os.environ.get("FREE_QUOTA_APPLY_URL") + + headers = { + 'Content-Type': 'application/json', + 'Authorization': f"Bearer {api_key}" + } + response = requests.post(api_url, headers=headers, json={'workspace_id': tenant_id, 'provider_name': provider_name}) + if not response.ok: + logging.error(f"Request FREE QUOTA APPLY SERVER Error: {response.status_code} ") + raise ValueError(f"Error: {response.status_code} ") + + if response.json()["code"] != 'success': + raise ValueError( + f"error: {response.json()['message']}" + ) + + rst = response.json() + + if rst['type'] == 'redirect': + return { + 'type': rst['type'], + 'redirect_url': rst['redirect_url'] + } + else: + return { + 'type': rst['type'], + 'result': 'success' + }