From b67484e77dade1dfad889cdf8e4deefea03df8ba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=BB=84=E8=85=BE?= <101850389+hangters@users.noreply.github.com> Date: Tue, 6 Aug 2024 16:20:21 +0800 Subject: [PATCH] add supprot for OpenAI-API-Compatible llm (#1787) ### What problem does this PR solve? #1771 add supprot for OpenAI-API-Compatible ### Type of change - [x] New Feature (non-breaking change which adds functionality) --------- Co-authored-by: Zhedong Cen --- api/apps/llm_app.py | 9 ++++++--- conf/llm_factories.json | 7 +++++++ rag/llm/__init__.py | 12 ++++++++---- rag/llm/chat_model.py | 14 ++++++++++++-- rag/llm/cv_model.py | 11 +++++++++++ rag/llm/embedding_model.py | 10 ++++++++++ rag/llm/rerank_model.py | 8 ++++++++ web/src/assets/svg/llm/openai-api.svg | 1 + web/src/interfaces/request/llm.ts | 1 + web/src/pages/user-setting/constants.tsx | 2 +- .../pages/user-setting/setting-model/constant.ts | 3 ++- .../setting-model/ollama-modal/index.tsx | 7 +++++++ 12 files changed, 74 insertions(+), 11 deletions(-) create mode 100644 web/src/assets/svg/llm/openai-api.svg diff --git a/api/apps/llm_app.py b/api/apps/llm_app.py index d622b24b4..c4657c44c 100644 --- a/api/apps/llm_app.py +++ b/api/apps/llm_app.py @@ -129,6 +129,9 @@ def add_llm(): elif factory == "LocalAI": llm_name = req["llm_name"]+"___LocalAI" api_key = "xxxxxxxxxxxxxxx" + elif factory == "OpenAI-API-Compatible": + llm_name = req["llm_name"]+"___OpenAI-API" + api_key = req["api_key"] else: llm_name = req["llm_name"] api_key = "xxxxxxxxxxxxxxx" @@ -145,7 +148,7 @@ def add_llm(): msg = "" if llm["model_type"] == LLMType.EMBEDDING.value: mdl = EmbeddingModel[factory]( - key=llm['api_key'] if factory in ["VolcEngine", "Bedrock"] else None, + key=llm['api_key'] if factory in ["VolcEngine", "Bedrock","OpenAI-API-Compatible"] else None, model_name=llm["llm_name"], base_url=llm["api_base"]) try: @@ -156,7 +159,7 @@ def add_llm(): msg += f"\nFail to access embedding model({llm['llm_name']})." + str(e) elif llm["model_type"] == LLMType.CHAT.value: mdl = ChatModel[factory]( - key=llm['api_key'] if factory in ["VolcEngine", "Bedrock"] else None, + key=llm['api_key'] if factory in ["VolcEngine", "Bedrock","OpenAI-API-Compatible"] else None, model_name=llm["llm_name"], base_url=llm["api_base"] ) @@ -181,7 +184,7 @@ def add_llm(): e) elif llm["model_type"] == LLMType.IMAGE2TEXT.value: mdl = CvModel[factory]( - key=None, model_name=llm["llm_name"], base_url=llm["api_base"] + key=llm["api_key"] if factory in ["OpenAI-API-Compatible"] else None, model_name=llm["llm_name"], base_url=llm["api_base"] ) try: img_url = ( diff --git a/conf/llm_factories.json b/conf/llm_factories.json index 4b29802c7..3eb23c17e 100644 --- a/conf/llm_factories.json +++ b/conf/llm_factories.json @@ -158,6 +158,13 @@ "status": "1", "llm": [] }, + { + "name": "OpenAI-API-Compatible", + "logo": "", + "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION", + "status": "1", + "llm": [] + }, { "name": "Moonshot", "logo": "", diff --git a/rag/llm/__init__.py b/rag/llm/__init__.py index 3ebb230e2..4c3182cae 100644 --- a/rag/llm/__init__.py +++ b/rag/llm/__init__.py @@ -36,7 +36,8 @@ EmbeddingModel = { "Bedrock": BedrockEmbed, "Gemini": GeminiEmbed, "NVIDIA": NvidiaEmbed, - "LM-Studio": LmStudioEmbed + "LM-Studio": LmStudioEmbed, + "OpenAI-API-Compatible": OpenAI_APIEmbed } @@ -53,7 +54,8 @@ CvModel = { "LocalAI": LocalAICV, "NVIDIA": NvidiaCV, "LM-Studio": LmStudioCV, - "StepFun":StepFunCV + "StepFun":StepFunCV, + "OpenAI-API-Compatible": OpenAI_APICV } @@ -78,7 +80,8 @@ ChatModel = { "OpenRouter": OpenRouterChat, "StepFun": StepFunChat, "NVIDIA": NvidiaChat, - "LM-Studio": LmStudioChat + "LM-Studio": LmStudioChat, + "OpenAI-API-Compatible": OpenAI_APIChat } @@ -88,7 +91,8 @@ RerankModel = { "Youdao": YoudaoRerank, "Xinference": XInferenceRerank, "NVIDIA": NvidiaRerank, - "LM-Studio": LmStudioRerank + "LM-Studio": LmStudioRerank, + "OpenAI-API-Compatible": OpenAI_APIRerank } diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index 303be41c6..40dbff111 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -887,6 +887,16 @@ class LmStudioChat(Base): if not base_url: raise ValueError("Local llm url cannot be None") if base_url.split("/")[-1] != "v1": - self.base_url = os.path.join(base_url, "v1") - self.client = OpenAI(api_key="lm-studio", base_url=self.base_url) + base_url = os.path.join(base_url, "v1") + self.client = OpenAI(api_key="lm-studio", base_url=base_url) self.model_name = model_name + + +class OpenAI_APIChat(Base): + def __init__(self, key, model_name, base_url): + if not base_url: + raise ValueError("url cannot be None") + if base_url.split("/")[-1] != "v1": + base_url = os.path.join(base_url, "v1") + model_name = model_name.split("___")[0] + super().__init__(key, model_name, base_url) diff --git a/rag/llm/cv_model.py b/rag/llm/cv_model.py index 1277791ee..82a773024 100644 --- a/rag/llm/cv_model.py +++ b/rag/llm/cv_model.py @@ -638,3 +638,14 @@ class LmStudioCV(GptV4): self.client = OpenAI(api_key="lm-studio", base_url=base_url) self.model_name = model_name self.lang = lang + + +class OpenAI_APICV(GptV4): + def __init__(self, key, model_name, base_url, lang="Chinese"): + if not base_url: + raise ValueError("url cannot be None") + if base_url.split("/")[-1] != "v1": + base_url = os.path.join(base_url, "v1") + self.client = OpenAI(api_key=key, base_url=base_url) + self.model_name = model_name.split("___")[0] + self.lang = lang diff --git a/rag/llm/embedding_model.py b/rag/llm/embedding_model.py index d9eb484c8..3c3a018d4 100644 --- a/rag/llm/embedding_model.py +++ b/rag/llm/embedding_model.py @@ -513,3 +513,13 @@ class LmStudioEmbed(LocalAIEmbed): self.base_url = os.path.join(base_url, "v1") self.client = OpenAI(api_key="lm-studio", base_url=self.base_url) self.model_name = model_name + + +class OpenAI_APIEmbed(OpenAIEmbed): + def __init__(self, key, model_name, base_url): + if not base_url: + raise ValueError("url cannot be None") + if base_url.split("/")[-1] != "v1": + self.base_url = os.path.join(base_url, "v1") + self.client = OpenAI(api_key=key, base_url=base_url) + self.model_name = model_name.split("___")[0] \ No newline at end of file diff --git a/rag/llm/rerank_model.py b/rag/llm/rerank_model.py index b6d5421ee..f5e89437f 100644 --- a/rag/llm/rerank_model.py +++ b/rag/llm/rerank_model.py @@ -212,3 +212,11 @@ class LmStudioRerank(Base): def similarity(self, query: str, texts: list): raise NotImplementedError("The LmStudioRerank has not been implement") + + +class OpenAI_APIRerank(Base): + def __init__(self, key, model_name, base_url): + pass + + def similarity(self, query: str, texts: list): + raise NotImplementedError("The api has not been implement") diff --git a/web/src/assets/svg/llm/openai-api.svg b/web/src/assets/svg/llm/openai-api.svg new file mode 100644 index 000000000..a0ecf992f --- /dev/null +++ b/web/src/assets/svg/llm/openai-api.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/web/src/interfaces/request/llm.ts b/web/src/interfaces/request/llm.ts index 727af0149..2db66484e 100644 --- a/web/src/interfaces/request/llm.ts +++ b/web/src/interfaces/request/llm.ts @@ -3,6 +3,7 @@ export interface IAddLlmRequestBody { llm_name: string; model_type: string; api_base?: string; // chat|embedding|speech2text|image2text + api_key: string; } export interface IDeleteLlmRequestBody { diff --git a/web/src/pages/user-setting/constants.tsx b/web/src/pages/user-setting/constants.tsx index 97df5415e..cf2bce9ed 100644 --- a/web/src/pages/user-setting/constants.tsx +++ b/web/src/pages/user-setting/constants.tsx @@ -17,4 +17,4 @@ export const UserSettingIconMap = { export * from '@/constants/setting'; -export const LocalLlmFactories = ['Ollama', 'Xinference','LocalAI','LM-Studio']; +export const LocalLlmFactories = ['Ollama', 'Xinference','LocalAI','LM-Studio',"OpenAI-API-Compatible"]; diff --git a/web/src/pages/user-setting/setting-model/constant.ts b/web/src/pages/user-setting/setting-model/constant.ts index 9493835d8..865eb29bf 100644 --- a/web/src/pages/user-setting/setting-model/constant.ts +++ b/web/src/pages/user-setting/setting-model/constant.ts @@ -21,7 +21,8 @@ export const IconMap = { LocalAI: 'local-ai', StepFun: 'stepfun', NVIDIA:'nvidia', - 'LM-Studio':'lm-studio' + 'LM-Studio':'lm-studio', + 'OpenAI-API-Compatible':'openai-api' }; export const BedrockRegionList = [ diff --git a/web/src/pages/user-setting/setting-model/ollama-modal/index.tsx b/web/src/pages/user-setting/setting-model/ollama-modal/index.tsx index ffa71241e..d102d6572 100644 --- a/web/src/pages/user-setting/setting-model/ollama-modal/index.tsx +++ b/web/src/pages/user-setting/setting-model/ollama-modal/index.tsx @@ -92,6 +92,13 @@ const OllamaModal = ({ > + + label={t('apiKey')} + name="api_key" + rules={[{ required: false, message: t('apiKeyMessage') }]} + > + + {({ getFieldValue }) => getFieldValue('model_type') === 'chat' && (