diff --git a/api/core/model_providers/providers/spark_provider.py b/api/core/model_providers/providers/spark_provider.py index 4174c01163..bd95db4890 100644 --- a/api/core/model_providers/providers/spark_provider.py +++ b/api/core/model_providers/providers/spark_provider.py @@ -28,14 +28,19 @@ class SparkProvider(BaseModelProvider): if model_type == ModelType.TEXT_GENERATION: return [ { - 'id': 'spark', - 'name': 'Spark V1.5', + 'id': 'spark-v3', + 'name': 'Spark V3.0', 'mode': ModelMode.CHAT.value, }, { 'id': 'spark-v2', 'name': 'Spark V2.0', 'mode': ModelMode.CHAT.value, + }, + { + 'id': 'spark', + 'name': 'Spark V1.5', + 'mode': ModelMode.CHAT.value, } ] else: @@ -96,7 +101,7 @@ class SparkProvider(BaseModelProvider): try: chat_llm = ChatSpark( - model_name='spark-v2', + model_name='spark-v3', max_tokens=10, temperature=0.01, **credential_kwargs @@ -110,10 +115,10 @@ class SparkProvider(BaseModelProvider): chat_llm(messages) except SparkError as ex: - # try spark v1.5 if v2.1 failed + # try spark v2.1 if v3.1 failed try: chat_llm = ChatSpark( - model_name='spark', + model_name='spark-v2', max_tokens=10, temperature=0.01, **credential_kwargs @@ -127,10 +132,27 @@ class SparkProvider(BaseModelProvider): chat_llm(messages) except SparkError as ex: - raise CredentialsValidateFailedError(str(ex)) - except Exception as ex: - logging.exception('Spark config validation failed') - raise ex + # try spark v1.5 if v2.1 failed + try: + chat_llm = ChatSpark( + model_name='spark', + max_tokens=10, + temperature=0.01, + **credential_kwargs + ) + + messages = [ + HumanMessage( + content="ping" + ) + ] + + chat_llm(messages) + except SparkError as ex: + raise CredentialsValidateFailedError(str(ex)) + except Exception as ex: + logging.exception('Spark config validation failed') + raise ex except Exception as ex: logging.exception('Spark config validation failed') raise ex diff --git a/api/core/model_providers/rules/spark.json b/api/core/model_providers/rules/spark.json index a3a01ae4a5..24133107f8 100644 --- a/api/core/model_providers/rules/spark.json +++ b/api/core/model_providers/rules/spark.json @@ -22,6 +22,12 @@ "completion": "0.36", "unit": "0.0001", "currency": "RMB" + }, + "spark-v3": { + "prompt": "0.36", + "completion": "0.36", + "unit": "0.0001", + "currency": "RMB" } } } \ No newline at end of file diff --git a/api/core/third_party/spark/spark_llm.py b/api/core/third_party/spark/spark_llm.py index 542565e732..828fea75e4 100644 --- a/api/core/third_party/spark/spark_llm.py +++ b/api/core/third_party/spark/spark_llm.py @@ -19,9 +19,25 @@ class SparkLLMClient: def __init__(self, model_name: str, app_id: str, api_key: str, api_secret: str, api_domain: Optional[str] = None): domain = 'spark-api.xf-yun.com' if not api_domain else api_domain - api_version = 'v2.1' if model_name == 'spark-v2' else 'v1.1' - self.chat_domain = 'generalv2' if model_name == 'spark-v2' else 'general' + model_api_configs = { + 'spark': { + 'version': 'v1.1', + 'chat_domain': 'general' + }, + 'spark-v2': { + 'version': 'v2.1', + 'chat_domain': 'generalv2' + }, + 'spark-v3': { + 'version': 'v3.1', + 'chat_domain': 'generalv3' + } + } + + api_version = model_api_configs[model_name]['version'] + + self.chat_domain = model_api_configs[model_name]['chat_domain'] self.api_base = f"wss://{domain}/{api_version}/chat" self.app_id = app_id self.ws_url = self.create_url(