From 5decdde18255e96e937c03e2f816ac9ba8e686a1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=BB=84=E8=85=BE?= <101850389+hangters@users.noreply.github.com> Date: Mon, 2 Sep 2024 12:06:41 +0800 Subject: [PATCH] add support for Google Cloud (#2175) ### What problem does this PR solve? #1853 add support for Google Cloud ### Type of change - [x] New Feature (non-breaking change which adds functionality) --------- Co-authored-by: Zhedong Cen --- api/apps/llm_app.py | 8 + conf/llm_factories.json | 7 + rag/llm/__init__.py | 1 + rag/llm/chat_model.py | 161 +++++++++++++++++- requirements.txt | 1 + requirements_arm.txt | 1 + web/src/assets/svg/llm/google-cloud.svg | 1 + web/src/locales/en.ts | 10 ++ web/src/locales/zh-traditional.ts | 10 ++ web/src/locales/zh.ts | 10 ++ .../user-setting/setting-model/constant.ts | 1 + .../setting-model/google-modal/index.tsx | 95 +++++++++++ .../pages/user-setting/setting-model/hooks.ts | 27 +++ .../user-setting/setting-model/index.tsx | 22 ++- 14 files changed, 352 insertions(+), 3 deletions(-) create mode 100644 web/src/assets/svg/llm/google-cloud.svg create mode 100644 web/src/pages/user-setting/setting-model/google-modal/index.tsx diff --git a/api/apps/llm_app.py b/api/apps/llm_app.py index 26cbd4576..c50c0a572 100644 --- a/api/apps/llm_app.py +++ b/api/apps/llm_app.py @@ -150,6 +150,14 @@ def add_llm(): llm_name = req["llm_name"] api_key = '{' + f'"fish_audio_ak": "{req.get("fish_audio_ak", "")}", ' \ f'"fish_audio_refid": "{req.get("fish_audio_refid", "59cb5986671546eaa6ca8ae6f29f6d22")}"' + '}' + elif factory == "Google Cloud": + llm_name = req["llm_name"] + api_key = ( + "{" + f'"google_project_id": "{req.get("google_project_id", "")}", ' + f'"google_region": "{req.get("google_region", "")}", ' + f'"google_service_account_key": "{req.get("google_service_account_key", "")}"' + + "}" + ) else: llm_name = req["llm_name"] api_key = req.get("api_key","xxxxxxxxxxxxxxx") diff --git a/conf/llm_factories.json b/conf/llm_factories.json index 698443efd..891eccf52 100644 --- a/conf/llm_factories.json +++ b/conf/llm_factories.json @@ -3352,6 +3352,13 @@ "model_type": "rerank" } ] + }, + { + "name": "Google Cloud", + "logo": "", + "tags": "LLM", + "status": "1", + "llm": [] } ] } diff --git a/rag/llm/__init__.py b/rag/llm/__init__.py index 54bfc67ff..c3b30b1c9 100644 --- a/rag/llm/__init__.py +++ b/rag/llm/__init__.py @@ -107,6 +107,7 @@ ChatModel = { "XunFei Spark": SparkChat, "BaiduYiyan": BaiduYiyanChat, "Anthropic": AnthropicChat, + "Google Cloud": GoogleChat, } diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index 9569daa24..500b08b87 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -701,9 +701,13 @@ class GeminiChat(Base): self.model = GenerativeModel(model_name=self.model_name) self.model._client = _client + def chat(self,system,history,gen_conf): + from google.generativeai.types import content_types + if system: - history.insert(0, {"role": "user", "parts": system}) + self.model._system_instruction = content_types.to_content(system) + if 'max_tokens' in gen_conf: gen_conf['max_output_tokens'] = gen_conf['max_tokens'] for k in list(gen_conf.keys()): @@ -725,8 +729,10 @@ class GeminiChat(Base): return "**ERROR**: " + str(e), 0 def chat_streamly(self, system, history, gen_conf): + from google.generativeai.types import content_types + if system: - history.insert(0, {"role": "user", "parts": system}) + self.model._system_instruction = content_types.to_content(system) if 'max_tokens' in gen_conf: gen_conf['max_output_tokens'] = gen_conf['max_tokens'] for k in list(gen_conf.keys()): @@ -1257,3 +1263,154 @@ class AnthropicChat(Base): yield ans + "\n**ERROR**: " + str(e) yield total_tokens + + +class GoogleChat(Base): + def __init__(self, key, model_name, base_url=None): + from google.oauth2 import service_account + import base64 + + key = json.load(key) + access_token = json.loads( + base64.b64decode(key.get("google_service_account_key", "")) + ) + project_id = key.get("google_project_id", "") + region = key.get("google_region", "") + + scopes = ["https://www.googleapis.com/auth/cloud-platform"] + self.model_name = model_name + self.system = "" + + if "claude" in self.model_name: + from anthropic import AnthropicVertex + from google.auth.transport.requests import Request + + if access_token: + credits = service_account.Credentials.from_service_account_info( + access_token, scopes=scopes + ) + request = Request() + credits.refresh(request) + token = credits.token + self.client = AnthropicVertex( + region=region, project_id=project_id, access_token=token + ) + else: + self.client = AnthropicVertex(region=region, project_id=project_id) + else: + from google.cloud import aiplatform + import vertexai.generative_models as glm + + if access_token: + credits = service_account.Credentials.from_service_account_info( + access_token + ) + aiplatform.init( + credentials=credits, project=project_id, location=region + ) + else: + aiplatform.init(project=project_id, location=region) + self.client = glm.GenerativeModel(model_name=self.model_name) + + def chat(self, system, history, gen_conf): + if system: + self.system = system + + if "claude" in self.model_name: + if "max_tokens" not in gen_conf: + gen_conf["max_tokens"] = 4096 + try: + response = self.client.messages.create( + model=self.model_name, + messages=history, + system=self.system, + stream=False, + **gen_conf, + ).json() + ans = response["content"][0]["text"] + if response["stop_reason"] == "max_tokens": + ans += ( + "...\nFor the content length reason, it stopped, continue?" + if is_english([ans]) + else "······\n由于长度的原因,回答被截断了,要继续吗?" + ) + return ( + ans, + response["usage"]["input_tokens"] + + response["usage"]["output_tokens"], + ) + except Exception as e: + return ans + "\n**ERROR**: " + str(e), 0 + else: + self.client._system_instruction = self.system + if "max_tokens" in gen_conf: + gen_conf["max_output_tokens"] = gen_conf["max_tokens"] + for k in list(gen_conf.keys()): + if k not in ["temperature", "top_p", "max_output_tokens"]: + del gen_conf[k] + for item in history: + if "role" in item and item["role"] == "assistant": + item["role"] = "model" + if "content" in item: + item["parts"] = item.pop("content") + try: + response = self.client.generate_content( + history, generation_config=gen_conf + ) + ans = response.text + return ans, response.usage_metadata.total_token_count + except Exception as e: + return "**ERROR**: " + str(e), 0 + + def chat_streamly(self, system, history, gen_conf): + if system: + self.system = system + + if "claude" in self.model_name: + if "max_tokens" not in gen_conf: + gen_conf["max_tokens"] = 4096 + ans = "" + total_tokens = 0 + try: + response = self.client.messages.create( + model=self.model_name, + messages=history, + system=self.system, + stream=True, + **gen_conf, + ) + for res in response.iter_lines(): + res = res.decode("utf-8") + if "content_block_delta" in res and "data" in res: + text = json.loads(res[6:])["delta"]["text"] + ans += text + total_tokens += num_tokens_from_string(text) + except Exception as e: + yield ans + "\n**ERROR**: " + str(e) + + yield total_tokens + else: + self.client._system_instruction = self.system + if "max_tokens" in gen_conf: + gen_conf["max_output_tokens"] = gen_conf["max_tokens"] + for k in list(gen_conf.keys()): + if k not in ["temperature", "top_p", "max_output_tokens"]: + del gen_conf[k] + for item in history: + if "role" in item and item["role"] == "assistant": + item["role"] = "model" + if "content" in item: + item["parts"] = item.pop("content") + ans = "" + try: + response = self.model.generate_content( + history, generation_config=gen_conf, stream=True + ) + for resp in response: + ans += resp.text + yield ans + + except Exception as e: + yield ans + "\n**ERROR**: " + str(e) + + yield response._chunks[-1].usage_metadata.total_token_count diff --git a/requirements.txt b/requirements.txt index 0f2136606..2d0892441 100644 --- a/requirements.txt +++ b/requirements.txt @@ -85,6 +85,7 @@ tiktoken==0.6.0 torch==2.3.0 transformers==4.38.1 umap==0.1.1 +vertexai==1.64.0 volcengine==1.0.146 voyageai==0.2.3 webdriver_manager==4.0.1 diff --git a/requirements_arm.txt b/requirements_arm.txt index 01b9ce0e3..cf0b4c504 100644 --- a/requirements_arm.txt +++ b/requirements_arm.txt @@ -167,3 +167,4 @@ scholarly==1.7.11 deepl==1.18.0 psycopg2-binary==2.9.9 tabulate==0.9.0 +vertexai==1.64.0 \ No newline at end of file diff --git a/web/src/assets/svg/llm/google-cloud.svg b/web/src/assets/svg/llm/google-cloud.svg new file mode 100644 index 000000000..2f7870552 --- /dev/null +++ b/web/src/assets/svg/llm/google-cloud.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/web/src/locales/en.ts b/web/src/locales/en.ts index b6563203d..ad49c53a3 100644 --- a/web/src/locales/en.ts +++ b/web/src/locales/en.ts @@ -499,6 +499,7 @@ The above is the content you need to summarize.`, upgrade: 'Upgrade', addLlmTitle: 'Add LLM', modelName: 'Model name', + modelID: 'Model ID', modelUid: 'Model UID', modelNameMessage: 'Please input your model name!', modelType: 'Model type', @@ -551,6 +552,15 @@ The above is the content you need to summarize.`, addFishAudioRefID: 'FishAudio Refrence ID', addFishAudioRefIDMessage: 'Please input the Reference ID (leave blank to use the default model).', + GoogleModelIDMessage: 'Please input your model ID!', + addGoogleProjectID: 'Project ID', + GoogleProjectIDMessage: 'Please input your Project ID', + addGoogleServiceAccountKey: + 'Service Account Key(Leave blank if you use Application Default Credentials)', + GoogleServiceAccountKeyMessage: + 'Please input Google Cloud Service Account Key in base64 format', + addGoogleRegion: 'Google Cloud Region', + GoogleRegionMessage: 'Please input Google Cloud Region', }, message: { registered: 'Registered!', diff --git a/web/src/locales/zh-traditional.ts b/web/src/locales/zh-traditional.ts index 01eebf424..dc793d905 100644 --- a/web/src/locales/zh-traditional.ts +++ b/web/src/locales/zh-traditional.ts @@ -461,6 +461,7 @@ export default { upgrade: '升級', addLlmTitle: '添加Llm', modelName: '模型名稱', + modelID: '模型ID', modelUid: '模型uid', modelType: '模型類型', addLlmBaseUrl: '基礎 Url', @@ -511,6 +512,15 @@ export default { addFishAudioAKMessage: '請輸入 API KEY', addFishAudioRefID: 'FishAudio Refrence ID', addFishAudioRefIDMessage: '請輸入引用模型的ID(留空表示使用默認模型)', + GoogleModelIDMessage: '請輸入 model ID!', + addGoogleProjectID: 'Project ID', + GoogleProjectIDMessage: '請輸入 Project ID', + addGoogleServiceAccountKey: + 'Service Account Key(Leave blank if you use Application Default Credentials)', + GoogleServiceAccountKeyMessage: + '請輸入 Google Cloud Service Account Key in base64 format', + addGoogleRegion: 'Google Cloud 區域', + GoogleRegionMessage: '請輸入 Google Cloud 區域', }, message: { registered: '註冊成功', diff --git a/web/src/locales/zh.ts b/web/src/locales/zh.ts index c82623082..c7b6ea2ce 100644 --- a/web/src/locales/zh.ts +++ b/web/src/locales/zh.ts @@ -478,6 +478,7 @@ export default { upgrade: '升级', addLlmTitle: '添加 LLM', modelName: '模型名称', + modelID: '模型ID', modelUid: '模型UID', modelType: '模型类型', addLlmBaseUrl: '基础 Url', @@ -528,6 +529,15 @@ export default { FishAudioAKMessage: '请输入 API KEY', addFishAudioRefID: 'FishAudio Refrence ID', FishAudioRefIDMessage: '请输入引用模型的ID(留空表示使用默认模型)', + GoogleModelIDMessage: '请输入 model ID!', + addGoogleProjectID: 'Project ID', + GoogleProjectIDMessage: '请输入 Project ID', + addGoogleServiceAccountKey: + 'Service Account Key(Leave blank if you use Application Default Credentials)', + GoogleServiceAccountKeyMessage: + '请输入 Google Cloud Service Account Key in base64 format', + addGoogleRegion: 'Google Cloud 区域', + GoogleRegionMessage: '请输入 Google Cloud 区域', }, message: { registered: '注册成功', diff --git a/web/src/pages/user-setting/setting-model/constant.ts b/web/src/pages/user-setting/setting-model/constant.ts index a644c21c8..8ac93e782 100644 --- a/web/src/pages/user-setting/setting-model/constant.ts +++ b/web/src/pages/user-setting/setting-model/constant.ts @@ -39,6 +39,7 @@ export const IconMap = { 'Tencent Cloud': 'tencent-cloud', Anthropic: 'anthropic', 'Voyage AI': 'voyage', + 'Google Cloud': 'google-cloud', }; export const BedrockRegionList = [ diff --git a/web/src/pages/user-setting/setting-model/google-modal/index.tsx b/web/src/pages/user-setting/setting-model/google-modal/index.tsx new file mode 100644 index 000000000..d1d17c581 --- /dev/null +++ b/web/src/pages/user-setting/setting-model/google-modal/index.tsx @@ -0,0 +1,95 @@ +import { useTranslate } from '@/hooks/common-hooks'; +import { IModalProps } from '@/interfaces/common'; +import { IAddLlmRequestBody } from '@/interfaces/request/llm'; +import { Form, Input, Modal, Select } from 'antd'; + +type FieldType = IAddLlmRequestBody & { + google_project_id: string; + google_region: string; + google_service_account_key: string; +}; + +const { Option } = Select; + +const GoogleModal = ({ + visible, + hideModal, + onOk, + loading, + llmFactory, +}: IModalProps & { llmFactory: string }) => { + const [form] = Form.useForm(); + + const { t } = useTranslate('setting'); + const handleOk = async () => { + const values = await form.validateFields(); + + const data = { + ...values, + llm_factory: llmFactory, + }; + + onOk?.(data); + }; + + return ( + +
+ + label={t('modelType')} + name="model_type" + initialValue={'chat'} + rules={[{ required: true, message: t('modelTypeMessage') }]} + > + + + + label={t('modelID')} + name="llm_name" + rules={[{ required: true, message: t('GoogleModelIDMessage') }]} + > + + + + label={t('addGoogleProjectID')} + name="google_project_id" + rules={[{ required: true, message: t('GoogleProjectIDMessage') }]} + > + + + + label={t('addGoogleRegion')} + name="google_region" + rules={[{ required: true, message: t('GoogleRegionMessage') }]} + > + + + + label={t('addGoogleServiceAccountKey')} + name="google_service_account_key" + rules={[ + { required: true, message: t('GoogleServiceAccountKeyMessage') }, + ]} + > + + + +
+ ); +}; + +export default GoogleModal; diff --git a/web/src/pages/user-setting/setting-model/hooks.ts b/web/src/pages/user-setting/setting-model/hooks.ts index 1c60bbcef..f076a27d1 100644 --- a/web/src/pages/user-setting/setting-model/hooks.ts +++ b/web/src/pages/user-setting/setting-model/hooks.ts @@ -298,6 +298,33 @@ export const useSubmitFishAudio = () => { }; }; +export const useSubmitGoogle = () => { + const { addLlm, loading } = useAddLlm(); + const { + visible: GoogleAddingVisible, + hideModal: hideGoogleAddingModal, + showModal: showGoogleAddingModal, + } = useSetModalState(); + + const onGoogleAddingOk = useCallback( + async (payload: IAddLlmRequestBody) => { + const ret = await addLlm(payload); + if (ret === 0) { + hideGoogleAddingModal(); + } + }, + [hideGoogleAddingModal, addLlm], + ); + + return { + GoogleAddingLoading: loading, + onGoogleAddingOk, + GoogleAddingVisible, + hideGoogleAddingModal, + showGoogleAddingModal, + }; +}; + export const useSubmitBedrock = () => { const { addLlm, loading } = useAddLlm(); const { diff --git a/web/src/pages/user-setting/setting-model/index.tsx b/web/src/pages/user-setting/setting-model/index.tsx index 235ee66e3..c4a2a1cef 100644 --- a/web/src/pages/user-setting/setting-model/index.tsx +++ b/web/src/pages/user-setting/setting-model/index.tsx @@ -32,11 +32,13 @@ import ApiKeyModal from './api-key-modal'; import BedrockModal from './bedrock-modal'; import { IconMap } from './constant'; import FishAudioModal from './fish-audio-modal'; +import GoogleModal from './google-modal'; import { useHandleDeleteLlm, useSubmitApiKey, useSubmitBedrock, useSubmitFishAudio, + useSubmitGoogle, useSubmitHunyuan, useSubmitOllama, useSubmitSpark, @@ -104,7 +106,8 @@ const ModelCard = ({ item, clickApiKey }: IModelCardProps) => { item.name === 'XunFei Spark' || item.name === 'BaiduYiyan' || item.name === 'Fish Audio' || - item.name === 'Tencent Cloud' + item.name === 'Tencent Cloud' || + item.name === 'Google Cloud' ? t('addTheModel') : 'API-Key'} @@ -186,6 +189,14 @@ const UserSettingModel = () => { HunyuanAddingLoading, } = useSubmitHunyuan(); + const { + GoogleAddingVisible, + hideGoogleAddingModal, + showGoogleAddingModal, + onGoogleAddingOk, + GoogleAddingLoading, + } = useSubmitGoogle(); + const { TencentCloudAddingVisible, hideTencentCloudAddingModal, @@ -235,6 +246,7 @@ const UserSettingModel = () => { BaiduYiyan: showyiyanAddingModal, 'Fish Audio': showFishAudioAddingModal, 'Tencent Cloud': showTencentCloudAddingModal, + 'Google Cloud': showGoogleAddingModal, }), [ showBedrockAddingModal, @@ -244,6 +256,7 @@ const UserSettingModel = () => { showSparkAddingModal, showyiyanAddingModal, showFishAudioAddingModal, + showGoogleAddingModal, ], ); @@ -364,6 +377,13 @@ const UserSettingModel = () => { loading={HunyuanAddingLoading} llmFactory={'Tencent Hunyuan'} > +