mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-07-03 13:25:13 +08:00
feat: add spark v2 support (#885)
This commit is contained in:
parent
c4d759dfba
commit
f42e7d1a61
@ -1,5 +1,4 @@
|
|||||||
import decimal
|
import decimal
|
||||||
from functools import wraps
|
|
||||||
from typing import List, Optional, Any
|
from typing import List, Optional, Any
|
||||||
|
|
||||||
from langchain.callbacks.manager import Callbacks
|
from langchain.callbacks.manager import Callbacks
|
||||||
@ -19,6 +18,7 @@ class SparkModel(BaseLLM):
|
|||||||
def _init_client(self) -> Any:
|
def _init_client(self) -> Any:
|
||||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
|
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
|
||||||
return ChatSpark(
|
return ChatSpark(
|
||||||
|
model_name=self.name,
|
||||||
streaming=self.streaming,
|
streaming=self.streaming,
|
||||||
callbacks=self.callbacks,
|
callbacks=self.callbacks,
|
||||||
**self.credentials,
|
**self.credentials,
|
||||||
|
@ -29,7 +29,11 @@ class SparkProvider(BaseModelProvider):
|
|||||||
return [
|
return [
|
||||||
{
|
{
|
||||||
'id': 'spark',
|
'id': 'spark',
|
||||||
'name': '星火认知大模型',
|
'name': 'Spark V1.5',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'id': 'spark-v2',
|
||||||
|
'name': 'Spark V2.0',
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
|
5
api/core/third_party/langchain/llms/spark.py
vendored
5
api/core/third_party/langchain/llms/spark.py
vendored
@ -25,6 +25,7 @@ class ChatSpark(BaseChatModel):
|
|||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
client = SparkLLMClient(
|
client = SparkLLMClient(
|
||||||
|
model_name="<model_name>",
|
||||||
app_id="<app_id>",
|
app_id="<app_id>",
|
||||||
api_key="<api_key>",
|
api_key="<api_key>",
|
||||||
api_secret="<api_secret>"
|
api_secret="<api_secret>"
|
||||||
@ -32,6 +33,9 @@ class ChatSpark(BaseChatModel):
|
|||||||
"""
|
"""
|
||||||
client: Any = None #: :meta private:
|
client: Any = None #: :meta private:
|
||||||
|
|
||||||
|
model_name: str = "spark"
|
||||||
|
"""The Spark model name."""
|
||||||
|
|
||||||
max_tokens: int = 256
|
max_tokens: int = 256
|
||||||
"""Denotes the number of tokens to predict per generation."""
|
"""Denotes the number of tokens to predict per generation."""
|
||||||
|
|
||||||
@ -66,6 +70,7 @@ class ChatSpark(BaseChatModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
values["client"] = SparkLLMClient(
|
values["client"] = SparkLLMClient(
|
||||||
|
model_name=values["model_name"],
|
||||||
app_id=values["app_id"],
|
app_id=values["app_id"],
|
||||||
api_key=values["api_key"],
|
api_key=values["api_key"],
|
||||||
api_secret=values["api_secret"],
|
api_secret=values["api_secret"],
|
||||||
|
24
api/core/third_party/spark/spark_llm.py
vendored
24
api/core/third_party/spark/spark_llm.py
vendored
@ -16,9 +16,13 @@ import websocket
|
|||||||
|
|
||||||
|
|
||||||
class SparkLLMClient:
|
class SparkLLMClient:
|
||||||
def __init__(self, app_id: str, api_key: str, api_secret: str, api_domain: Optional[str] = None):
|
def __init__(self, model_name: str, app_id: str, api_key: str, api_secret: str, api_domain: Optional[str] = None):
|
||||||
|
|
||||||
self.api_base = "wss://spark-api.xf-yun.com/v1.1/chat" if not api_domain else ('wss://' + api_domain + '/v1.1/chat')
|
domain = 'spark-api.xf-yun.com' if not api_domain else api_domain
|
||||||
|
api_version = 'v2.1' if model_name == 'spark-v2' else 'v1.1'
|
||||||
|
|
||||||
|
self.chat_domain = 'generalv2' if model_name == 'spark-v2' else 'general'
|
||||||
|
self.api_base = f"wss://{domain}/{api_version}/chat"
|
||||||
self.app_id = app_id
|
self.app_id = app_id
|
||||||
self.ws_url = self.create_url(
|
self.ws_url = self.create_url(
|
||||||
urlparse(self.api_base).netloc,
|
urlparse(self.api_base).netloc,
|
||||||
@ -76,7 +80,10 @@ class SparkLLMClient:
|
|||||||
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
|
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
|
||||||
|
|
||||||
def on_error(self, ws, error):
|
def on_error(self, ws, error):
|
||||||
self.queue.put({'error': error})
|
self.queue.put({
|
||||||
|
'status_code': error.status_code,
|
||||||
|
'error': error.resp_body.decode('utf-8')
|
||||||
|
})
|
||||||
ws.close()
|
ws.close()
|
||||||
|
|
||||||
def on_close(self, ws, close_status_code, close_reason):
|
def on_close(self, ws, close_status_code, close_reason):
|
||||||
@ -120,7 +127,7 @@ class SparkLLMClient:
|
|||||||
},
|
},
|
||||||
"parameter": {
|
"parameter": {
|
||||||
"chat": {
|
"chat": {
|
||||||
"domain": "general"
|
"domain": self.chat_domain
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"payload": {
|
"payload": {
|
||||||
@ -139,7 +146,14 @@ class SparkLLMClient:
|
|||||||
while True:
|
while True:
|
||||||
content = self.queue.get()
|
content = self.queue.get()
|
||||||
if 'error' in content:
|
if 'error' in content:
|
||||||
raise SparkError(content['error'])
|
if content['status_code'] == 401:
|
||||||
|
raise SparkError('[Spark] The credentials you provided are incorrect. '
|
||||||
|
'Please double-check and fill them in again.')
|
||||||
|
elif content['status_code'] == 403:
|
||||||
|
raise SparkError("[Spark] Sorry, the credentials you provided are access denied. "
|
||||||
|
"Please try again after obtaining the necessary permissions.")
|
||||||
|
else:
|
||||||
|
raise SparkError(f"[Spark] code: {content['status_code']}, error: {content['error']}")
|
||||||
|
|
||||||
if 'data' not in content:
|
if 'data' not in content:
|
||||||
break
|
break
|
||||||
|
@ -471,6 +471,7 @@ class ProviderService:
|
|||||||
for model in model_list:
|
for model in model_list:
|
||||||
valid_model_dict = {
|
valid_model_dict = {
|
||||||
"model_name": model['id'],
|
"model_name": model['id'],
|
||||||
|
"model_display_name": model['name'],
|
||||||
"model_type": model_type,
|
"model_type": model_type,
|
||||||
"model_provider": {
|
"model_provider": {
|
||||||
"provider_name": provider.provider_name,
|
"provider_name": provider.provider_name,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user