diff --git a/api/apps/llm_app.py b/api/apps/llm_app.py
index e85a6679d..2e67934d3 100644
--- a/api/apps/llm_app.py
+++ b/api/apps/llm_app.py
@@ -138,6 +138,9 @@ def add_llm():
elif factory == "OpenAI-API-Compatible":
llm_name = req["llm_name"]+"___OpenAI-API"
api_key = req.get("api_key","xxxxxxxxxxxxxxx")
+ elif factory =="XunFei Spark":
+ llm_name = req["llm_name"]
+ api_key = req.get("spark_api_password","")
else:
llm_name = req["llm_name"]
api_key = req.get("api_key","xxxxxxxxxxxxxxx")
@@ -165,7 +168,7 @@ def add_llm():
msg += f"\nFail to access embedding model({llm['llm_name']})." + str(e)
elif llm["model_type"] == LLMType.CHAT.value:
mdl = ChatModel[factory](
- key=llm['api_key'] if factory in ["VolcEngine", "Bedrock","OpenAI-API-Compatible","Replicate"] else None,
+ key=llm['api_key'] if factory in ["VolcEngine", "Bedrock","OpenAI-API-Compatible","Replicate","XunFei Spark"] else None,
model_name=llm["llm_name"],
base_url=llm["api_base"]
)
diff --git a/conf/llm_factories.json b/conf/llm_factories.json
index a6e422518..62afcad83 100644
--- a/conf/llm_factories.json
+++ b/conf/llm_factories.json
@@ -3194,6 +3194,13 @@
"model_type": "image2text"
}
]
+ },
+ {
+ "name": "XunFei Spark",
+ "logo": "",
+ "tags": "LLM",
+ "status": "1",
+ "llm": []
}
]
}
diff --git a/rag/llm/__init__.py b/rag/llm/__init__.py
index 6bf9f96ec..b0c1b7a56 100644
--- a/rag/llm/__init__.py
+++ b/rag/llm/__init__.py
@@ -100,7 +100,8 @@ ChatModel = {
"SILICONFLOW": SILICONFLOWChat,
"01.AI": YiChat,
"Replicate": ReplicateChat,
- "Tencent Hunyuan": HunyuanChat
+ "Tencent Hunyuan": HunyuanChat,
+ "XunFei Spark": SparkChat
}
diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py
index 75832f7c8..b6d035859 100644
--- a/rag/llm/chat_model.py
+++ b/rag/llm/chat_model.py
@@ -1133,12 +1133,12 @@ class HunyuanChat(Base):
from tencentcloud.common.exception.tencent_cloud_sdk_exception import (
TencentCloudSDKException,
)
-
+
_gen_conf = {}
_history = [{k.capitalize(): v for k, v in item.items() } for item in history]
if system:
_history.insert(0, {"Role": "system", "Content": system})
-
+
if "temperature" in gen_conf:
_gen_conf["Temperature"] = gen_conf["temperature"]
if "top_p" in gen_conf:
@@ -1168,3 +1168,20 @@ class HunyuanChat(Base):
yield ans + "\n**ERROR**: " + str(e)
yield total_tokens
+
+
+class SparkChat(Base):
+ def __init__(
+ self, key, model_name, base_url="https://spark-api-open.xf-yun.com/v1"
+ ):
+ if not base_url:
+ base_url = "https://spark-api-open.xf-yun.com/v1"
+ model2version = {
+ "Spark-Max": "generalv3.5",
+ "Spark-Lite": "general",
+ "Spark-Pro": "generalv3",
+ "Spark-Pro-128K": "pro-128k",
+ "Spark-4.0-Ultra": "4.0Ultra",
+ }
+ model_version = model2version[model_name]
+ super().__init__(key, model_version, base_url)
diff --git a/web/src/assets/svg/llm/spark.svg b/web/src/assets/svg/llm/spark.svg
new file mode 100644
index 000000000..30f6040f2
--- /dev/null
+++ b/web/src/assets/svg/llm/spark.svg
@@ -0,0 +1 @@
+
\ No newline at end of file
diff --git a/web/src/locales/en.ts b/web/src/locales/en.ts
index 45ca820fe..d705b22e8 100644
--- a/web/src/locales/en.ts
+++ b/web/src/locales/en.ts
@@ -525,6 +525,9 @@ The above is the content you need to summarize.`,
HunyuanSIDMessage: 'Please input your Secret ID',
addHunyuanSK: 'Hunyuan Secret Key',
HunyuanSKMessage: 'Please input your Secret Key',
+ SparkModelNameMessage: 'Please select Spark model',
+ addSparkAPIPassword: 'Spark APIPassword',
+ SparkAPIPasswordMessage: 'please input your APIPassword',
},
message: {
registered: 'Registered!',
diff --git a/web/src/locales/zh-traditional.ts b/web/src/locales/zh-traditional.ts
index 237b5edea..5c6d9ccf5 100644
--- a/web/src/locales/zh-traditional.ts
+++ b/web/src/locales/zh-traditional.ts
@@ -488,6 +488,9 @@ export default {
HunyuanSIDMessage: '請輸入 Secret ID',
addHunyuanSK: '混元 Secret Key',
HunyuanSKMessage: '請輸入 Secret Key',
+ SparkModelNameMessage: '請選擇星火模型!',
+ addSparkAPIPassword: '星火 APIPassword',
+ SparkAPIPasswordMessage: '請輸入 APIPassword',
},
message: {
registered: '註冊成功',
diff --git a/web/src/locales/zh.ts b/web/src/locales/zh.ts
index 265475cd5..61d52c6e1 100644
--- a/web/src/locales/zh.ts
+++ b/web/src/locales/zh.ts
@@ -505,6 +505,9 @@ export default {
HunyuanSIDMessage: '请输入 Secret ID',
addHunyuanSK: '混元 Secret Key',
HunyuanSKMessage: '请输入 Secret Key',
+ SparkModelNameMessage: '请选择星火模型!',
+ addSparkAPIPassword: '星火 APIPassword',
+ SparkAPIPasswordMessage: '请输入 APIPassword',
},
message: {
registered: '注册成功',
diff --git a/web/src/pages/user-setting/setting-model/constant.ts b/web/src/pages/user-setting/setting-model/constant.ts
index 503634519..de44f45ae 100644
--- a/web/src/pages/user-setting/setting-model/constant.ts
+++ b/web/src/pages/user-setting/setting-model/constant.ts
@@ -33,6 +33,7 @@ export const IconMap = {
'01.AI': 'yi',
Replicate: 'replicate',
'Tencent Hunyuan': 'hunyuan',
+ 'XunFei Spark': 'spark',
};
export const BedrockRegionList = [
diff --git a/web/src/pages/user-setting/setting-model/hooks.ts b/web/src/pages/user-setting/setting-model/hooks.ts
index 96c96a5dd..dcf33bbd0 100644
--- a/web/src/pages/user-setting/setting-model/hooks.ts
+++ b/web/src/pages/user-setting/setting-model/hooks.ts
@@ -190,6 +190,33 @@ export const useSubmitHunyuan = () => {
};
};
+export const useSubmitSpark = () => {
+ const { addLlm, loading } = useAddLlm();
+ const {
+ visible: SparkAddingVisible,
+ hideModal: hideSparkAddingModal,
+ showModal: showSparkAddingModal,
+ } = useSetModalState();
+
+ const onSparkAddingOk = useCallback(
+ async (payload: IAddLlmRequestBody) => {
+ const ret = await addLlm(payload);
+ if (ret === 0) {
+ hideSparkAddingModal();
+ }
+ },
+ [hideSparkAddingModal, addLlm],
+ );
+
+ return {
+ SparkAddingLoading: loading,
+ onSparkAddingOk,
+ SparkAddingVisible,
+ hideSparkAddingModal,
+ showSparkAddingModal,
+ };
+};
+
export const useSubmitBedrock = () => {
const { addLlm, loading } = useAddLlm();
const {
diff --git a/web/src/pages/user-setting/setting-model/index.tsx b/web/src/pages/user-setting/setting-model/index.tsx
index ecdc62ad9..7e0500df0 100644
--- a/web/src/pages/user-setting/setting-model/index.tsx
+++ b/web/src/pages/user-setting/setting-model/index.tsx
@@ -36,12 +36,14 @@ import {
useSubmitBedrock,
useSubmitHunyuan,
useSubmitOllama,
+ useSubmitSpark,
useSubmitSystemModelSetting,
useSubmitVolcEngine,
} from './hooks';
import HunyuanModal from './hunyuan-modal';
import styles from './index.less';
import OllamaModal from './ollama-modal';
+import SparkModal from './spark-modal';
import SystemModelSettingModal from './system-model-setting-modal';
import VolcEngineModal from './volcengine-modal';
@@ -92,7 +94,8 @@ const ModelCard = ({ item, clickApiKey }: IModelCardProps) => {