From eb51ad73d6ef87ccc73d80e24f97a6ad2a004ba0 Mon Sep 17 00:00:00 2001 From: yungongzi Date: Thu, 23 May 2024 11:15:29 +0800 Subject: [PATCH] Add support for VolcEngine - the current version supports SDK2 (#885) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - The main idea is to assemble **ak**, **sk**, and **ep_id** into a dictionary and store it in the database **api_key** field - I don’t know much about the front-end, so I learned from Ollama, which may be redundant. ### Configuration method - model name - Format requirements: {"VolcEngine model name":"endpoint_id"} - For example: {"Skylark-pro-32K":"ep-xxxxxxxxx"} - Volcano ACCESS_KEY - Format requirements: VOLC_ACCESSKEY of the volcano engine corresponding to the model - Volcano SECRET_KEY - Format requirements: VOLC_SECRETKEY of the volcano engine corresponding to the model ### What problem does this PR solve? _Briefly describe what this PR aims to solve. Include background context that will help reviewers understand the purpose of the PR._ ### Type of change - [x] New Feature (non-breaking change which adds functionality) --- api/apps/llm_app.py | 27 +++- api/db/init_data.py | 22 +++- rag/llm/chat_model.py | 69 ++++++++++ web/src/assets/svg/llm/volc_engine.svg | 14 +++ web/src/locales/en.ts | 5 + web/src/locales/zh-traditional.ts | 7 +- web/src/locales/zh.ts | 5 + .../pages/user-setting/setting-model/hooks.ts | 35 ++++++ .../user-setting/setting-model/index.tsx | 21 ++++ .../setting-model/volcengine-model/index.tsx | 118 ++++++++++++++++++ 10 files changed, 315 insertions(+), 8 deletions(-) create mode 100644 web/src/assets/svg/llm/volc_engine.svg create mode 100644 web/src/pages/user-setting/setting-model/volcengine-model/index.tsx diff --git a/api/apps/llm_app.py b/api/apps/llm_app.py index 2e878af9e..b95601706 100644 --- a/api/apps/llm_app.py +++ b/api/apps/llm_app.py @@ -96,16 +96,29 @@ def set_api_key(): @validate_request("llm_factory", "llm_name", "model_type") def add_llm(): req = request.json + factory = req["llm_factory"] + # For VolcEngine, due to its special authentication method + # Assemble volc_ak, volc_sk, endpoint_id into api_key + if factory == "VolcEngine": + temp = list(eval(req["llm_name"]).items())[0] + llm_name = temp[0] + endpoint_id = temp[1] + api_key = '{' + f'"volc_ak": "{req.get("volc_ak", "")}", ' \ + f'"volc_sk": "{req.get("volc_sk", "")}", ' \ + f'"ep_id": "{endpoint_id}", ' + '}' + else: + llm_name = req["llm_name"] + api_key = "xxxxxxxxxxxxxxx" + llm = { "tenant_id": current_user.id, - "llm_factory": req["llm_factory"], + "llm_factory": factory, "model_type": req["model_type"], - "llm_name": req["llm_name"], + "llm_name": llm_name, "api_base": req.get("api_base", ""), - "api_key": "xxxxxxxxxxxxxxx" + "api_key": api_key } - factory = req["llm_factory"] msg = "" if llm["model_type"] == LLMType.EMBEDDING.value: mdl = EmbeddingModel[factory]( @@ -118,7 +131,10 @@ def add_llm(): msg += f"\nFail to access embedding model({llm['llm_name']})." + str(e) elif llm["model_type"] == LLMType.CHAT.value: mdl = ChatModel[factory]( - key=None, model_name=llm["llm_name"], base_url=llm["api_base"]) + key=llm['api_key'] if factory == "VolcEngine" else None, + model_name=llm["llm_name"], + base_url=llm["api_base"] + ) try: m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], { "temperature": 0.9}) @@ -134,7 +150,6 @@ def add_llm(): if msg: return get_data_error_result(retmsg=msg) - if not TenantLLMService.filter_update( [TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == factory, TenantLLM.llm_name == llm["llm_name"]], llm): TenantLLMService.save(**llm) diff --git a/api/db/init_data.py b/api/db/init_data.py index 1bf6f01a6..3044ea976 100644 --- a/api/db/init_data.py +++ b/api/db/init_data.py @@ -132,7 +132,12 @@ factory_infos = [{ "logo": "", "tags": "LLM", "status": "1", -}, +},{ + "name": "VolcEngine", + "logo": "", + "tags": "LLM, TEXT EMBEDDING", + "status": "1", +} # { # "name": "文心一言", # "logo": "", @@ -372,6 +377,21 @@ def init_llm_factory(): "max_tokens": 16385, "model_type": LLMType.CHAT.value }, + # ------------------------ VolcEngine ----------------------- + { + "fid": factory_infos[9]["name"], + "llm_name": "Skylark2-pro-32k", + "tags": "LLM,CHAT,32k", + "max_tokens": 32768, + "model_type": LLMType.CHAT.value + }, + { + "fid": factory_infos[9]["name"], + "llm_name": "Skylark2-pro-4k", + "tags": "LLM,CHAT,4k", + "max_tokens": 4096, + "model_type": LLMType.CHAT.value + }, ] for info in factory_infos: try: diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index 9a2eec5d5..e9eb470c2 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -19,6 +19,7 @@ from abc import ABC from openai import OpenAI import openai from ollama import Client +from volcengine.maas.v2 import MaasService from rag.nlp import is_english from rag.utils import num_tokens_from_string @@ -315,3 +316,71 @@ class LocalLLM(Base): yield answer + "\n**ERROR**: " + str(e) yield token_count + + +class VolcEngineChat(Base): + def __init__(self, key, model_name, base_url): + """ + Since do not want to modify the original database fields, and the VolcEngine authentication method is quite special, + Assemble ak, sk, ep_id into api_key, store it as a dictionary type, and parse it for use + model_name is for display only + """ + self.client = MaasService('maas-api.ml-platform-cn-beijing.volces.com', 'cn-beijing') + self.volc_ak = eval(key).get('volc_ak', '') + self.volc_sk = eval(key).get('volc_sk', '') + self.client.set_ak(self.volc_ak) + self.client.set_sk(self.volc_sk) + self.model_name = eval(key).get('ep_id', '') + + def chat(self, system, history, gen_conf): + if system: + history.insert(0, {"role": "system", "content": system}) + try: + req = { + "parameters": { + "min_new_tokens": gen_conf.get("min_new_tokens", 1), + "top_k": gen_conf.get("top_k", 0), + "max_prompt_tokens": gen_conf.get("max_prompt_tokens", 30000), + "temperature": gen_conf.get("temperature", 0.1), + "max_new_tokens": gen_conf.get("max_tokens", 1000), + "top_p": gen_conf.get("top_p", 0.3), + }, + "messages": history + } + response = self.client.chat(self.model_name, req) + ans = response.choices[0].message.content.strip() + if response.choices[0].finish_reason == "length": + ans += "...\nFor the content length reason, it stopped, continue?" if is_english( + [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" + return ans, response.usage.total_tokens + except Exception as e: + return "**ERROR**: " + str(e), 0 + + def chat_streamly(self, system, history, gen_conf): + if system: + history.insert(0, {"role": "system", "content": system}) + ans = "" + try: + req = { + "parameters": { + "min_new_tokens": gen_conf.get("min_new_tokens", 1), + "top_k": gen_conf.get("top_k", 0), + "max_prompt_tokens": gen_conf.get("max_prompt_tokens", 30000), + "temperature": gen_conf.get("temperature", 0.1), + "max_new_tokens": gen_conf.get("max_tokens", 1000), + "top_p": gen_conf.get("top_p", 0.3), + }, + "messages": history + } + stream = self.client.stream_chat(self.model_name, req) + for resp in stream: + if not resp.choices[0].message.content: + continue + ans += resp.choices[0].message.content + yield ans + if resp.choices[0].finish_reason == "stop": + return resp.usage.total_tokens + + except Exception as e: + yield ans + "\n**ERROR**: " + str(e) + yield 0 diff --git a/web/src/assets/svg/llm/volc_engine.svg b/web/src/assets/svg/llm/volc_engine.svg new file mode 100644 index 000000000..2c56cb00b --- /dev/null +++ b/web/src/assets/svg/llm/volc_engine.svg @@ -0,0 +1,14 @@ + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/locales/en.ts b/web/src/locales/en.ts index c5affca35..0e345ea63 100644 --- a/web/src/locales/en.ts +++ b/web/src/locales/en.ts @@ -477,6 +477,11 @@ The above is the content you need to summarize.`, baseUrlNameMessage: 'Please input your base url!', vision: 'Does it support Vision?', ollamaLink: 'How to integrate {{name}}', + volcModelNameMessage: 'Please input your model name! Format: {"ModelName":"EndpointID"}', + addVolcEngineAK: 'VOLC ACCESS_KEY', + volcAKMessage: 'Please input your VOLC_ACCESS_KEY', + addVolcEngineSK: 'VOLC SECRET_KEY', + volcSKMessage: 'Please input your SECRET_KEY', }, message: { registered: 'Registered!', diff --git a/web/src/locales/zh-traditional.ts b/web/src/locales/zh-traditional.ts index 58123c0d0..cff078267 100644 --- a/web/src/locales/zh-traditional.ts +++ b/web/src/locales/zh-traditional.ts @@ -440,7 +440,12 @@ export default { modelNameMessage: '請輸入模型名稱!', modelTypeMessage: '請輸入模型類型!', baseUrlNameMessage: '請輸入基礎 Url!', - ollamaLink: '如何集成Ollama', + ollamaLink: '如何集成 {{name}}', + volcModelNameMessage: '請輸入模型名稱!格式:{"模型名稱":"EndpointID"}', + addVolcEngineAK: '火山 ACCESS_KEY', + volcAKMessage: '請輸入VOLC_ACCESS_KEY', + addVolcEngineSK: '火山 SECRET_KEY', + volcSKMessage: '請輸入VOLC_SECRET_KEY', }, message: { registered: '註冊成功', diff --git a/web/src/locales/zh.ts b/web/src/locales/zh.ts index 9442898f0..effb3bf20 100644 --- a/web/src/locales/zh.ts +++ b/web/src/locales/zh.ts @@ -458,6 +458,11 @@ export default { modelTypeMessage: '请输入模型类型!', baseUrlNameMessage: '请输入基础 Url!', ollamaLink: '如何集成 {{name}}', + volcModelNameMessage: '请输入模型名称!格式:{"模型名称":"EndpointID"}', + addVolcEngineAK: '火山 ACCESS_KEY', + volcAKMessage: '请输入VOLC_ACCESS_KEY', + addVolcEngineSK: '火山 SECRET_KEY', + volcSKMessage: '请输入VOLC_SECRET_KEY', }, message: { registered: '注册成功', diff --git a/web/src/pages/user-setting/setting-model/hooks.ts b/web/src/pages/user-setting/setting-model/hooks.ts index 6a5649225..dfdce7e25 100644 --- a/web/src/pages/user-setting/setting-model/hooks.ts +++ b/web/src/pages/user-setting/setting-model/hooks.ts @@ -166,6 +166,41 @@ export const useSubmitOllama = () => { }; }; +export const useSubmitVolcEngine = () => { + const loading = useOneNamespaceEffectsLoading('settingModel', ['add_llm']); + const [selectedVolcFactory, setSelectedVolcFactory] = useState(''); + const addLlm = useAddLlm(); + const { + visible: volcAddingVisible, + hideModal: hideVolcAddingModal, + showModal: showVolcAddingModal, + } = useSetModalState(); + + const onVolcAddingOk = useCallback( + async (payload: IAddLlmRequestBody) => { + const ret = await addLlm(payload); + if (ret === 0) { + hideVolcAddingModal(); + } + }, + [hideVolcAddingModal, addLlm], + ); + + const handleShowVolcAddingModal = (llmFactory: string) => { + setSelectedVolcFactory(llmFactory); + showVolcAddingModal(); + }; + + return { + volcAddingLoading: loading, + onVolcAddingOk, + volcAddingVisible, + hideVolcAddingModal, + showVolcAddingModal: handleShowVolcAddingModal, + selectedVolcFactory, + }; +}; + export const useHandleDeleteLlm = (llmFactory: string) => { const deleteLlm = useDeleteLlm(); const showDeleteConfirm = useShowDeleteConfirm(); diff --git a/web/src/pages/user-setting/setting-model/index.tsx b/web/src/pages/user-setting/setting-model/index.tsx index 3d39d55d7..69a770108 100644 --- a/web/src/pages/user-setting/setting-model/index.tsx +++ b/web/src/pages/user-setting/setting-model/index.tsx @@ -37,10 +37,12 @@ import { useSelectModelProvidersLoading, useSubmitApiKey, useSubmitOllama, + useSubmitVolcEngine, useSubmitSystemModelSetting, } from './hooks'; import styles from './index.less'; import OllamaModal from './ollama-modal'; +import VolcEngineModal from "./volcengine-model"; import SystemModelSettingModal from './system-model-setting-modal'; const IconMap = { @@ -52,6 +54,7 @@ const IconMap = { Ollama: 'ollama', Xinference: 'xinference', DeepSeek: 'deepseek', + VolcEngine: 'volc_engine', }; const LlmIcon = ({ name }: { name: string }) => { @@ -165,6 +168,15 @@ const UserSettingModel = () => { selectedLlmFactory, } = useSubmitOllama(); + const { + volcAddingVisible, + hideVolcAddingModal, + showVolcAddingModal, + onVolcAddingOk, + volcAddingLoading, + selectedVolcFactory, + } = useSubmitVolcEngine(); + const handleApiKeyClick = useCallback( (llmFactory: string) => { if (isLocalLlmFactory(llmFactory)) { @@ -179,6 +191,8 @@ const UserSettingModel = () => { const handleAddModel = (llmFactory: string) => () => { if (isLocalLlmFactory(llmFactory)) { showLlmAddingModal(llmFactory); + } else if (llmFactory === 'VolcEngine') { + showVolcAddingModal('VolcEngine'); } else { handleApiKeyClick(llmFactory); } @@ -270,6 +284,13 @@ const UserSettingModel = () => { loading={llmAddingLoading} llmFactory={selectedLlmFactory} > + ); }; diff --git a/web/src/pages/user-setting/setting-model/volcengine-model/index.tsx b/web/src/pages/user-setting/setting-model/volcengine-model/index.tsx new file mode 100644 index 000000000..65872067a --- /dev/null +++ b/web/src/pages/user-setting/setting-model/volcengine-model/index.tsx @@ -0,0 +1,118 @@ +import { useTranslate } from '@/hooks/commonHooks'; +import { IModalProps } from '@/interfaces/common'; +import { IAddLlmRequestBody } from '@/interfaces/request/llm'; +import { Flex, Form, Input, Modal, Select, Space, Switch } from 'antd'; +import omit from 'lodash/omit'; + +type FieldType = IAddLlmRequestBody & { vision: boolean }; + +const { Option } = Select; + +const VolcEngineModal = ({ + visible, + hideModal, + onOk, + loading, + llmFactory +}: IModalProps & { llmFactory: string }) => { + const [form] = Form.useForm(); + + const { t } = useTranslate('setting'); + + const handleOk = async () => { + const values = await form.validateFields(); + const modelType = + values.model_type === 'chat' && values.vision + ? 'image2text' + : values.model_type; + + const data = { + ...omit(values, ['vision']), + model_type: modelType, + llm_factory: llmFactory, + }; + console.info(data); + + onOk?.(data); + }; + + return ( + { + return ( + + + {t('ollamaLink', { name: llmFactory })} + + {originNode} + + ); + }} + > +
+ + label={t('modelType')} + name="model_type" + initialValue={'chat'} + rules={[{ required: true, message: t('modelTypeMessage') }]} + > + + + + label={t('modelName')} + name="llm_name" + rules={[{ required: true, message: t('volcModelNameMessage') }]} + > + + + + label={t('addVolcEngineAK')} + name="volc_ak" + rules={[{ required: true, message: t('volcAKMessage') }]} + > + + + + label={t('addVolcEngineSK')} + name="volc_sk" + rules={[{ required: true, message: t('volcAKMessage') }]} + > + + + + {({ getFieldValue }) => + getFieldValue('model_type') === 'chat' && ( + + + + ) + } + + +
+ ); +}; + +export default VolcEngineModal;