diff --git a/api/core/model_providers/models/llm/spark_model.py b/api/core/model_providers/models/llm/spark_model.py index 5d8c97c463..a16318637c 100644 --- a/api/core/model_providers/models/llm/spark_model.py +++ b/api/core/model_providers/models/llm/spark_model.py @@ -1,5 +1,4 @@ import decimal -from functools import wraps from typing import List, Optional, Any from langchain.callbacks.manager import Callbacks @@ -19,6 +18,7 @@ class SparkModel(BaseLLM): def _init_client(self) -> Any: provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs) return ChatSpark( + model_name=self.name, streaming=self.streaming, callbacks=self.callbacks, **self.credentials, diff --git a/api/core/model_providers/providers/spark_provider.py b/api/core/model_providers/providers/spark_provider.py index 4030a577dd..b55ea77f4d 100644 --- a/api/core/model_providers/providers/spark_provider.py +++ b/api/core/model_providers/providers/spark_provider.py @@ -29,7 +29,11 @@ class SparkProvider(BaseModelProvider): return [ { 'id': 'spark', - 'name': '星火认知大模型', + 'name': 'Spark V1.5', + }, + { + 'id': 'spark-v2', + 'name': 'Spark V2.0', } ] else: diff --git a/api/core/third_party/langchain/llms/spark.py b/api/core/third_party/langchain/llms/spark.py index 7bc777c484..56cc828fd8 100644 --- a/api/core/third_party/langchain/llms/spark.py +++ b/api/core/third_party/langchain/llms/spark.py @@ -25,6 +25,7 @@ class ChatSpark(BaseChatModel): .. code-block:: python client = SparkLLMClient( + model_name="", app_id="", api_key="", api_secret="" @@ -32,6 +33,9 @@ class ChatSpark(BaseChatModel): """ client: Any = None #: :meta private: + model_name: str = "spark" + """The Spark model name.""" + max_tokens: int = 256 """Denotes the number of tokens to predict per generation.""" @@ -66,6 +70,7 @@ class ChatSpark(BaseChatModel): ) values["client"] = SparkLLMClient( + model_name=values["model_name"], app_id=values["app_id"], api_key=values["api_key"], api_secret=values["api_secret"], diff --git a/api/core/third_party/spark/spark_llm.py b/api/core/third_party/spark/spark_llm.py index 1cc3b8a486..ae25a2b071 100644 --- a/api/core/third_party/spark/spark_llm.py +++ b/api/core/third_party/spark/spark_llm.py @@ -16,9 +16,13 @@ import websocket class SparkLLMClient: - def __init__(self, app_id: str, api_key: str, api_secret: str, api_domain: Optional[str] = None): + def __init__(self, model_name: str, app_id: str, api_key: str, api_secret: str, api_domain: Optional[str] = None): - self.api_base = "wss://spark-api.xf-yun.com/v1.1/chat" if not api_domain else ('wss://' + api_domain + '/v1.1/chat') + domain = 'spark-api.xf-yun.com' if not api_domain else api_domain + api_version = 'v2.1' if model_name == 'spark-v2' else 'v1.1' + + self.chat_domain = 'generalv2' if model_name == 'spark-v2' else 'general' + self.api_base = f"wss://{domain}/{api_version}/chat" self.app_id = app_id self.ws_url = self.create_url( urlparse(self.api_base).netloc, @@ -76,7 +80,10 @@ class SparkLLMClient: ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE}) def on_error(self, ws, error): - self.queue.put({'error': error}) + self.queue.put({ + 'status_code': error.status_code, + 'error': error.resp_body.decode('utf-8') + }) ws.close() def on_close(self, ws, close_status_code, close_reason): @@ -120,7 +127,7 @@ class SparkLLMClient: }, "parameter": { "chat": { - "domain": "general" + "domain": self.chat_domain } }, "payload": { @@ -139,7 +146,14 @@ class SparkLLMClient: while True: content = self.queue.get() if 'error' in content: - raise SparkError(content['error']) + if content['status_code'] == 401: + raise SparkError('[Spark] The credentials you provided are incorrect. ' + 'Please double-check and fill them in again.') + elif content['status_code'] == 403: + raise SparkError("[Spark] Sorry, the credentials you provided are access denied. " + "Please try again after obtaining the necessary permissions.") + else: + raise SparkError(f"[Spark] code: {content['status_code']}, error: {content['error']}") if 'data' not in content: break diff --git a/api/services/provider_service.py b/api/services/provider_service.py index 8aba64153d..9a4ae5ee73 100644 --- a/api/services/provider_service.py +++ b/api/services/provider_service.py @@ -471,6 +471,7 @@ class ProviderService: for model in model_list: valid_model_dict = { "model_name": model['id'], + "model_display_name": model['name'], "model_type": model_type, "model_provider": { "provider_name": provider.provider_name,