mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-14 03:45:55 +08:00
feat: add spark v3.0 llm support (#1434)
This commit is contained in:
parent
518083dfe0
commit
076f3289d2
@ -28,14 +28,19 @@ class SparkProvider(BaseModelProvider):
|
|||||||
if model_type == ModelType.TEXT_GENERATION:
|
if model_type == ModelType.TEXT_GENERATION:
|
||||||
return [
|
return [
|
||||||
{
|
{
|
||||||
'id': 'spark',
|
'id': 'spark-v3',
|
||||||
'name': 'Spark V1.5',
|
'name': 'Spark V3.0',
|
||||||
'mode': ModelMode.CHAT.value,
|
'mode': ModelMode.CHAT.value,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
'id': 'spark-v2',
|
'id': 'spark-v2',
|
||||||
'name': 'Spark V2.0',
|
'name': 'Spark V2.0',
|
||||||
'mode': ModelMode.CHAT.value,
|
'mode': ModelMode.CHAT.value,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'id': 'spark',
|
||||||
|
'name': 'Spark V1.5',
|
||||||
|
'mode': ModelMode.CHAT.value,
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
@ -96,7 +101,7 @@ class SparkProvider(BaseModelProvider):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
chat_llm = ChatSpark(
|
chat_llm = ChatSpark(
|
||||||
model_name='spark-v2',
|
model_name='spark-v3',
|
||||||
max_tokens=10,
|
max_tokens=10,
|
||||||
temperature=0.01,
|
temperature=0.01,
|
||||||
**credential_kwargs
|
**credential_kwargs
|
||||||
@ -110,10 +115,10 @@ class SparkProvider(BaseModelProvider):
|
|||||||
|
|
||||||
chat_llm(messages)
|
chat_llm(messages)
|
||||||
except SparkError as ex:
|
except SparkError as ex:
|
||||||
# try spark v1.5 if v2.1 failed
|
# try spark v2.1 if v3.1 failed
|
||||||
try:
|
try:
|
||||||
chat_llm = ChatSpark(
|
chat_llm = ChatSpark(
|
||||||
model_name='spark',
|
model_name='spark-v2',
|
||||||
max_tokens=10,
|
max_tokens=10,
|
||||||
temperature=0.01,
|
temperature=0.01,
|
||||||
**credential_kwargs
|
**credential_kwargs
|
||||||
@ -127,10 +132,27 @@ class SparkProvider(BaseModelProvider):
|
|||||||
|
|
||||||
chat_llm(messages)
|
chat_llm(messages)
|
||||||
except SparkError as ex:
|
except SparkError as ex:
|
||||||
raise CredentialsValidateFailedError(str(ex))
|
# try spark v1.5 if v2.1 failed
|
||||||
except Exception as ex:
|
try:
|
||||||
logging.exception('Spark config validation failed')
|
chat_llm = ChatSpark(
|
||||||
raise ex
|
model_name='spark',
|
||||||
|
max_tokens=10,
|
||||||
|
temperature=0.01,
|
||||||
|
**credential_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
HumanMessage(
|
||||||
|
content="ping"
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
chat_llm(messages)
|
||||||
|
except SparkError as ex:
|
||||||
|
raise CredentialsValidateFailedError(str(ex))
|
||||||
|
except Exception as ex:
|
||||||
|
logging.exception('Spark config validation failed')
|
||||||
|
raise ex
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
logging.exception('Spark config validation failed')
|
logging.exception('Spark config validation failed')
|
||||||
raise ex
|
raise ex
|
||||||
|
@ -22,6 +22,12 @@
|
|||||||
"completion": "0.36",
|
"completion": "0.36",
|
||||||
"unit": "0.0001",
|
"unit": "0.0001",
|
||||||
"currency": "RMB"
|
"currency": "RMB"
|
||||||
|
},
|
||||||
|
"spark-v3": {
|
||||||
|
"prompt": "0.36",
|
||||||
|
"completion": "0.36",
|
||||||
|
"unit": "0.0001",
|
||||||
|
"currency": "RMB"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
20
api/core/third_party/spark/spark_llm.py
vendored
20
api/core/third_party/spark/spark_llm.py
vendored
@ -19,9 +19,25 @@ class SparkLLMClient:
|
|||||||
def __init__(self, model_name: str, app_id: str, api_key: str, api_secret: str, api_domain: Optional[str] = None):
|
def __init__(self, model_name: str, app_id: str, api_key: str, api_secret: str, api_domain: Optional[str] = None):
|
||||||
|
|
||||||
domain = 'spark-api.xf-yun.com' if not api_domain else api_domain
|
domain = 'spark-api.xf-yun.com' if not api_domain else api_domain
|
||||||
api_version = 'v2.1' if model_name == 'spark-v2' else 'v1.1'
|
|
||||||
|
|
||||||
self.chat_domain = 'generalv2' if model_name == 'spark-v2' else 'general'
|
model_api_configs = {
|
||||||
|
'spark': {
|
||||||
|
'version': 'v1.1',
|
||||||
|
'chat_domain': 'general'
|
||||||
|
},
|
||||||
|
'spark-v2': {
|
||||||
|
'version': 'v2.1',
|
||||||
|
'chat_domain': 'generalv2'
|
||||||
|
},
|
||||||
|
'spark-v3': {
|
||||||
|
'version': 'v3.1',
|
||||||
|
'chat_domain': 'generalv3'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
api_version = model_api_configs[model_name]['version']
|
||||||
|
|
||||||
|
self.chat_domain = model_api_configs[model_name]['chat_domain']
|
||||||
self.api_base = f"wss://{domain}/{api_version}/chat"
|
self.api_base = f"wss://{domain}/{api_version}/chat"
|
||||||
self.app_id = app_id
|
self.app_id = app_id
|
||||||
self.ws_url = self.create_url(
|
self.ws_url = self.create_url(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user