feat: add spark v2 support (#885)

This commit is contained in:
takatost 2023-08-17 15:08:57 +08:00 committed by GitHub
parent c4d759dfba
commit f42e7d1a61
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 31 additions and 7 deletions

View File

@ -1,5 +1,4 @@
import decimal
from functools import wraps
from typing import List, Optional, Any
from langchain.callbacks.manager import Callbacks
@ -19,6 +18,7 @@ class SparkModel(BaseLLM):
def _init_client(self) -> Any:
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
return ChatSpark(
model_name=self.name,
streaming=self.streaming,
callbacks=self.callbacks,
**self.credentials,

View File

@ -29,7 +29,11 @@ class SparkProvider(BaseModelProvider):
return [
{
'id': 'spark',
'name': '星火认知大模型',
'name': 'Spark V1.5',
},
{
'id': 'spark-v2',
'name': 'Spark V2.0',
}
]
else:

View File

@ -25,6 +25,7 @@ class ChatSpark(BaseChatModel):
.. code-block:: python
client = SparkLLMClient(
model_name="<model_name>",
app_id="<app_id>",
api_key="<api_key>",
api_secret="<api_secret>"
@ -32,6 +33,9 @@ class ChatSpark(BaseChatModel):
"""
client: Any = None #: :meta private:
model_name: str = "spark"
"""The Spark model name."""
max_tokens: int = 256
"""Denotes the number of tokens to predict per generation."""
@ -66,6 +70,7 @@ class ChatSpark(BaseChatModel):
)
values["client"] = SparkLLMClient(
model_name=values["model_name"],
app_id=values["app_id"],
api_key=values["api_key"],
api_secret=values["api_secret"],

View File

@ -16,9 +16,13 @@ import websocket
class SparkLLMClient:
def __init__(self, app_id: str, api_key: str, api_secret: str, api_domain: Optional[str] = None):
def __init__(self, model_name: str, app_id: str, api_key: str, api_secret: str, api_domain: Optional[str] = None):
self.api_base = "wss://spark-api.xf-yun.com/v1.1/chat" if not api_domain else ('wss://' + api_domain + '/v1.1/chat')
domain = 'spark-api.xf-yun.com' if not api_domain else api_domain
api_version = 'v2.1' if model_name == 'spark-v2' else 'v1.1'
self.chat_domain = 'generalv2' if model_name == 'spark-v2' else 'general'
self.api_base = f"wss://{domain}/{api_version}/chat"
self.app_id = app_id
self.ws_url = self.create_url(
urlparse(self.api_base).netloc,
@ -76,7 +80,10 @@ class SparkLLMClient:
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
def on_error(self, ws, error):
self.queue.put({'error': error})
self.queue.put({
'status_code': error.status_code,
'error': error.resp_body.decode('utf-8')
})
ws.close()
def on_close(self, ws, close_status_code, close_reason):
@ -120,7 +127,7 @@ class SparkLLMClient:
},
"parameter": {
"chat": {
"domain": "general"
"domain": self.chat_domain
}
},
"payload": {
@ -139,7 +146,14 @@ class SparkLLMClient:
while True:
content = self.queue.get()
if 'error' in content:
raise SparkError(content['error'])
if content['status_code'] == 401:
raise SparkError('[Spark] The credentials you provided are incorrect. '
'Please double-check and fill them in again.')
elif content['status_code'] == 403:
raise SparkError("[Spark] Sorry, the credentials you provided are access denied. "
"Please try again after obtaining the necessary permissions.")
else:
raise SparkError(f"[Spark] code: {content['status_code']}, error: {content['error']}")
if 'data' not in content:
break

View File

@ -471,6 +471,7 @@ class ProviderService:
for model in model_list:
valid_model_dict = {
"model_name": model['id'],
"model_display_name": model['name'],
"model_type": model_type,
"model_provider": {
"provider_name": provider.provider_name,