From be431449bdbaaa8c4c26691f51f42951e32406be Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=BB=84=E8=85=BE?= <101850389+hangters@users.noreply.github.com> Date: Tue, 20 Aug 2024 16:56:42 +0800 Subject: [PATCH] add support for XunFei Spark (#2017) ### What problem does this PR solve? #1853 add support for XunFei Spark ### Type of change - [x] New Feature (non-breaking change which adds functionality) Co-authored-by: Zhedong Cen --- api/apps/llm_app.py | 5 +- conf/llm_factories.json | 7 ++ rag/llm/__init__.py | 3 +- rag/llm/chat_model.py | 21 ++++- web/src/assets/svg/llm/spark.svg | 1 + web/src/locales/en.ts | 3 + web/src/locales/zh-traditional.ts | 3 + web/src/locales/zh.ts | 3 + .../user-setting/setting-model/constant.ts | 1 + .../pages/user-setting/setting-model/hooks.ts | 27 ++++++ .../user-setting/setting-model/index.tsx | 28 +++++- .../setting-model/spark-modal/index.tsx | 94 +++++++++++++++++++ 12 files changed, 190 insertions(+), 6 deletions(-) create mode 100644 web/src/assets/svg/llm/spark.svg create mode 100644 web/src/pages/user-setting/setting-model/spark-modal/index.tsx diff --git a/api/apps/llm_app.py b/api/apps/llm_app.py index e85a6679d..2e67934d3 100644 --- a/api/apps/llm_app.py +++ b/api/apps/llm_app.py @@ -138,6 +138,9 @@ def add_llm(): elif factory == "OpenAI-API-Compatible": llm_name = req["llm_name"]+"___OpenAI-API" api_key = req.get("api_key","xxxxxxxxxxxxxxx") + elif factory =="XunFei Spark": + llm_name = req["llm_name"] + api_key = req.get("spark_api_password","") else: llm_name = req["llm_name"] api_key = req.get("api_key","xxxxxxxxxxxxxxx") @@ -165,7 +168,7 @@ def add_llm(): msg += f"\nFail to access embedding model({llm['llm_name']})." + str(e) elif llm["model_type"] == LLMType.CHAT.value: mdl = ChatModel[factory]( - key=llm['api_key'] if factory in ["VolcEngine", "Bedrock","OpenAI-API-Compatible","Replicate"] else None, + key=llm['api_key'] if factory in ["VolcEngine", "Bedrock","OpenAI-API-Compatible","Replicate","XunFei Spark"] else None, model_name=llm["llm_name"], base_url=llm["api_base"] ) diff --git a/conf/llm_factories.json b/conf/llm_factories.json index a6e422518..62afcad83 100644 --- a/conf/llm_factories.json +++ b/conf/llm_factories.json @@ -3194,6 +3194,13 @@ "model_type": "image2text" } ] + }, + { + "name": "XunFei Spark", + "logo": "", + "tags": "LLM", + "status": "1", + "llm": [] } ] } diff --git a/rag/llm/__init__.py b/rag/llm/__init__.py index 6bf9f96ec..b0c1b7a56 100644 --- a/rag/llm/__init__.py +++ b/rag/llm/__init__.py @@ -100,7 +100,8 @@ ChatModel = { "SILICONFLOW": SILICONFLOWChat, "01.AI": YiChat, "Replicate": ReplicateChat, - "Tencent Hunyuan": HunyuanChat + "Tencent Hunyuan": HunyuanChat, + "XunFei Spark": SparkChat } diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index 75832f7c8..b6d035859 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -1133,12 +1133,12 @@ class HunyuanChat(Base): from tencentcloud.common.exception.tencent_cloud_sdk_exception import ( TencentCloudSDKException, ) - + _gen_conf = {} _history = [{k.capitalize(): v for k, v in item.items() } for item in history] if system: _history.insert(0, {"Role": "system", "Content": system}) - + if "temperature" in gen_conf: _gen_conf["Temperature"] = gen_conf["temperature"] if "top_p" in gen_conf: @@ -1168,3 +1168,20 @@ class HunyuanChat(Base): yield ans + "\n**ERROR**: " + str(e) yield total_tokens + + +class SparkChat(Base): + def __init__( + self, key, model_name, base_url="https://spark-api-open.xf-yun.com/v1" + ): + if not base_url: + base_url = "https://spark-api-open.xf-yun.com/v1" + model2version = { + "Spark-Max": "generalv3.5", + "Spark-Lite": "general", + "Spark-Pro": "generalv3", + "Spark-Pro-128K": "pro-128k", + "Spark-4.0-Ultra": "4.0Ultra", + } + model_version = model2version[model_name] + super().__init__(key, model_version, base_url) diff --git a/web/src/assets/svg/llm/spark.svg b/web/src/assets/svg/llm/spark.svg new file mode 100644 index 000000000..30f6040f2 --- /dev/null +++ b/web/src/assets/svg/llm/spark.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/web/src/locales/en.ts b/web/src/locales/en.ts index 45ca820fe..d705b22e8 100644 --- a/web/src/locales/en.ts +++ b/web/src/locales/en.ts @@ -525,6 +525,9 @@ The above is the content you need to summarize.`, HunyuanSIDMessage: 'Please input your Secret ID', addHunyuanSK: 'Hunyuan Secret Key', HunyuanSKMessage: 'Please input your Secret Key', + SparkModelNameMessage: 'Please select Spark model', + addSparkAPIPassword: 'Spark APIPassword', + SparkAPIPasswordMessage: 'please input your APIPassword', }, message: { registered: 'Registered!', diff --git a/web/src/locales/zh-traditional.ts b/web/src/locales/zh-traditional.ts index 237b5edea..5c6d9ccf5 100644 --- a/web/src/locales/zh-traditional.ts +++ b/web/src/locales/zh-traditional.ts @@ -488,6 +488,9 @@ export default { HunyuanSIDMessage: '請輸入 Secret ID', addHunyuanSK: '混元 Secret Key', HunyuanSKMessage: '請輸入 Secret Key', + SparkModelNameMessage: '請選擇星火模型!', + addSparkAPIPassword: '星火 APIPassword', + SparkAPIPasswordMessage: '請輸入 APIPassword', }, message: { registered: '註冊成功', diff --git a/web/src/locales/zh.ts b/web/src/locales/zh.ts index 265475cd5..61d52c6e1 100644 --- a/web/src/locales/zh.ts +++ b/web/src/locales/zh.ts @@ -505,6 +505,9 @@ export default { HunyuanSIDMessage: '请输入 Secret ID', addHunyuanSK: '混元 Secret Key', HunyuanSKMessage: '请输入 Secret Key', + SparkModelNameMessage: '请选择星火模型!', + addSparkAPIPassword: '星火 APIPassword', + SparkAPIPasswordMessage: '请输入 APIPassword', }, message: { registered: '注册成功', diff --git a/web/src/pages/user-setting/setting-model/constant.ts b/web/src/pages/user-setting/setting-model/constant.ts index 503634519..de44f45ae 100644 --- a/web/src/pages/user-setting/setting-model/constant.ts +++ b/web/src/pages/user-setting/setting-model/constant.ts @@ -33,6 +33,7 @@ export const IconMap = { '01.AI': 'yi', Replicate: 'replicate', 'Tencent Hunyuan': 'hunyuan', + 'XunFei Spark': 'spark', }; export const BedrockRegionList = [ diff --git a/web/src/pages/user-setting/setting-model/hooks.ts b/web/src/pages/user-setting/setting-model/hooks.ts index 96c96a5dd..dcf33bbd0 100644 --- a/web/src/pages/user-setting/setting-model/hooks.ts +++ b/web/src/pages/user-setting/setting-model/hooks.ts @@ -190,6 +190,33 @@ export const useSubmitHunyuan = () => { }; }; +export const useSubmitSpark = () => { + const { addLlm, loading } = useAddLlm(); + const { + visible: SparkAddingVisible, + hideModal: hideSparkAddingModal, + showModal: showSparkAddingModal, + } = useSetModalState(); + + const onSparkAddingOk = useCallback( + async (payload: IAddLlmRequestBody) => { + const ret = await addLlm(payload); + if (ret === 0) { + hideSparkAddingModal(); + } + }, + [hideSparkAddingModal, addLlm], + ); + + return { + SparkAddingLoading: loading, + onSparkAddingOk, + SparkAddingVisible, + hideSparkAddingModal, + showSparkAddingModal, + }; +}; + export const useSubmitBedrock = () => { const { addLlm, loading } = useAddLlm(); const { diff --git a/web/src/pages/user-setting/setting-model/index.tsx b/web/src/pages/user-setting/setting-model/index.tsx index ecdc62ad9..7e0500df0 100644 --- a/web/src/pages/user-setting/setting-model/index.tsx +++ b/web/src/pages/user-setting/setting-model/index.tsx @@ -36,12 +36,14 @@ import { useSubmitBedrock, useSubmitHunyuan, useSubmitOllama, + useSubmitSpark, useSubmitSystemModelSetting, useSubmitVolcEngine, } from './hooks'; import HunyuanModal from './hunyuan-modal'; import styles from './index.less'; import OllamaModal from './ollama-modal'; +import SparkModal from './spark-modal'; import SystemModelSettingModal from './system-model-setting-modal'; import VolcEngineModal from './volcengine-modal'; @@ -92,7 +94,8 @@ const ModelCard = ({ item, clickApiKey }: IModelCardProps) => {