mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-06-30 08:15:12 +08:00
feat: add spark v2 support (#885)
This commit is contained in:
parent
c4d759dfba
commit
f42e7d1a61
@ -1,5 +1,4 @@
|
||||
import decimal
|
||||
from functools import wraps
|
||||
from typing import List, Optional, Any
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
@ -19,6 +18,7 @@ class SparkModel(BaseLLM):
|
||||
def _init_client(self) -> Any:
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
|
||||
return ChatSpark(
|
||||
model_name=self.name,
|
||||
streaming=self.streaming,
|
||||
callbacks=self.callbacks,
|
||||
**self.credentials,
|
||||
|
@ -29,7 +29,11 @@ class SparkProvider(BaseModelProvider):
|
||||
return [
|
||||
{
|
||||
'id': 'spark',
|
||||
'name': '星火认知大模型',
|
||||
'name': 'Spark V1.5',
|
||||
},
|
||||
{
|
||||
'id': 'spark-v2',
|
||||
'name': 'Spark V2.0',
|
||||
}
|
||||
]
|
||||
else:
|
||||
|
5
api/core/third_party/langchain/llms/spark.py
vendored
5
api/core/third_party/langchain/llms/spark.py
vendored
@ -25,6 +25,7 @@ class ChatSpark(BaseChatModel):
|
||||
.. code-block:: python
|
||||
|
||||
client = SparkLLMClient(
|
||||
model_name="<model_name>",
|
||||
app_id="<app_id>",
|
||||
api_key="<api_key>",
|
||||
api_secret="<api_secret>"
|
||||
@ -32,6 +33,9 @@ class ChatSpark(BaseChatModel):
|
||||
"""
|
||||
client: Any = None #: :meta private:
|
||||
|
||||
model_name: str = "spark"
|
||||
"""The Spark model name."""
|
||||
|
||||
max_tokens: int = 256
|
||||
"""Denotes the number of tokens to predict per generation."""
|
||||
|
||||
@ -66,6 +70,7 @@ class ChatSpark(BaseChatModel):
|
||||
)
|
||||
|
||||
values["client"] = SparkLLMClient(
|
||||
model_name=values["model_name"],
|
||||
app_id=values["app_id"],
|
||||
api_key=values["api_key"],
|
||||
api_secret=values["api_secret"],
|
||||
|
24
api/core/third_party/spark/spark_llm.py
vendored
24
api/core/third_party/spark/spark_llm.py
vendored
@ -16,9 +16,13 @@ import websocket
|
||||
|
||||
|
||||
class SparkLLMClient:
|
||||
def __init__(self, app_id: str, api_key: str, api_secret: str, api_domain: Optional[str] = None):
|
||||
def __init__(self, model_name: str, app_id: str, api_key: str, api_secret: str, api_domain: Optional[str] = None):
|
||||
|
||||
self.api_base = "wss://spark-api.xf-yun.com/v1.1/chat" if not api_domain else ('wss://' + api_domain + '/v1.1/chat')
|
||||
domain = 'spark-api.xf-yun.com' if not api_domain else api_domain
|
||||
api_version = 'v2.1' if model_name == 'spark-v2' else 'v1.1'
|
||||
|
||||
self.chat_domain = 'generalv2' if model_name == 'spark-v2' else 'general'
|
||||
self.api_base = f"wss://{domain}/{api_version}/chat"
|
||||
self.app_id = app_id
|
||||
self.ws_url = self.create_url(
|
||||
urlparse(self.api_base).netloc,
|
||||
@ -76,7 +80,10 @@ class SparkLLMClient:
|
||||
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
|
||||
|
||||
def on_error(self, ws, error):
|
||||
self.queue.put({'error': error})
|
||||
self.queue.put({
|
||||
'status_code': error.status_code,
|
||||
'error': error.resp_body.decode('utf-8')
|
||||
})
|
||||
ws.close()
|
||||
|
||||
def on_close(self, ws, close_status_code, close_reason):
|
||||
@ -120,7 +127,7 @@ class SparkLLMClient:
|
||||
},
|
||||
"parameter": {
|
||||
"chat": {
|
||||
"domain": "general"
|
||||
"domain": self.chat_domain
|
||||
}
|
||||
},
|
||||
"payload": {
|
||||
@ -139,7 +146,14 @@ class SparkLLMClient:
|
||||
while True:
|
||||
content = self.queue.get()
|
||||
if 'error' in content:
|
||||
raise SparkError(content['error'])
|
||||
if content['status_code'] == 401:
|
||||
raise SparkError('[Spark] The credentials you provided are incorrect. '
|
||||
'Please double-check and fill them in again.')
|
||||
elif content['status_code'] == 403:
|
||||
raise SparkError("[Spark] Sorry, the credentials you provided are access denied. "
|
||||
"Please try again after obtaining the necessary permissions.")
|
||||
else:
|
||||
raise SparkError(f"[Spark] code: {content['status_code']}, error: {content['error']}")
|
||||
|
||||
if 'data' not in content:
|
||||
break
|
||||
|
@ -471,6 +471,7 @@ class ProviderService:
|
||||
for model in model_list:
|
||||
valid_model_dict = {
|
||||
"model_name": model['id'],
|
||||
"model_display_name": model['name'],
|
||||
"model_type": model_type,
|
||||
"model_provider": {
|
||||
"provider_name": provider.provider_name,
|
||||
|
Loading…
x
Reference in New Issue
Block a user