diff --git a/api/core/third_party/spark/spark_llm.py b/api/core/third_party/spark/spark_llm.py index 828fea75e4..b11e7c20b5 100644 --- a/api/core/third_party/spark/spark_llm.py +++ b/api/core/third_party/spark/spark_llm.py @@ -17,8 +17,12 @@ import websocket class SparkLLMClient: def __init__(self, model_name: str, app_id: str, api_key: str, api_secret: str, api_domain: Optional[str] = None): - - domain = 'spark-api.xf-yun.com' if not api_domain else api_domain + domain = 'spark-api.xf-yun.com' + endpoint = 'chat' + if api_domain: + domain = api_domain + if model_name == 'spark-v3': + endpoint = 'multimodal' model_api_configs = { 'spark': { @@ -38,7 +42,7 @@ class SparkLLMClient: api_version = model_api_configs[model_name]['version'] self.chat_domain = model_api_configs[model_name]['chat_domain'] - self.api_base = f"wss://{domain}/{api_version}/chat" + self.api_base = f"wss://{domain}/{api_version}/{endpoint}" self.app_id = app_id self.ws_url = self.create_url( urlparse(self.api_base).netloc,