From 94cb66ba806138d0bf13c2dc687121553ee98efd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=BB=84=E8=85=BE?= <101850389+hangters@users.noreply.github.com> Date: Mon, 12 Aug 2024 10:15:21 +0800 Subject: [PATCH] add support for TogetherAI (#1890) ### What problem does this PR solve? #1853 add support for TogetherAI ### Type of change - [x] New Feature (non-breaking change which adds functionality) --------- Co-authored-by: Zhedong Cen Co-authored-by: Kevin Hu --- api/apps/llm_app.py | 2 +- conf/llm_factories.json | 9 ++++++++- rag/llm/__init__.py | 8 ++++++-- rag/llm/chat_model.py | 7 +++++++ rag/llm/cv_model.py | 7 +++++++ rag/llm/embedding_model.py | 8 ++++++++ rag/llm/rerank_model.py | 8 ++++++++ web/src/assets/svg/llm/together-ai.svg | 15 +++++++++++++++ web/src/pages/user-setting/constants.tsx | 2 +- .../pages/user-setting/setting-model/constant.ts | 1 + 10 files changed, 62 insertions(+), 5 deletions(-) create mode 100644 web/src/assets/svg/llm/together-ai.svg diff --git a/api/apps/llm_app.py b/api/apps/llm_app.py index 1d54550ff..26609f635 100644 --- a/api/apps/llm_app.py +++ b/api/apps/llm_app.py @@ -135,7 +135,7 @@ def add_llm(): api_key = req.get("api_key","xxxxxxxxxxxxxxx") else: llm_name = req["llm_name"] - api_key = "xxxxxxxxxxxxxxx" + api_key = req.get("api_key","xxxxxxxxxxxxxxx") llm = { "tenant_id": current_user.id, diff --git a/conf/llm_factories.json b/conf/llm_factories.json index c4566bbda..cb8020c14 100644 --- a/conf/llm_factories.json +++ b/conf/llm_factories.json @@ -2443,6 +2443,13 @@ } ] }, + { + "name": "TogetherAI", + "logo": "", + "tags": "LLM,TEXT EMBEDDING,IMAGE2TEXT", + "status": "1", + "llm": [] + }, { "name": "PerfXCloud", "logo": "", @@ -2594,6 +2601,6 @@ "model_type": "embedding" } ] - } + } ] } diff --git a/rag/llm/__init__.py b/rag/llm/__init__.py index dcc9f5523..a6e2aa417 100644 --- a/rag/llm/__init__.py +++ b/rag/llm/__init__.py @@ -39,6 +39,7 @@ EmbeddingModel = { "LM-Studio": LmStudioEmbed, "OpenAI-API-Compatible": OpenAI_APIEmbed, "cohere": CoHereEmbed, + "TogetherAI": TogetherAIEmbed, "PerfXCloud": PerfXCloudEmbed, } @@ -57,7 +58,8 @@ CvModel = { "NVIDIA": NvidiaCV, "LM-Studio": LmStudioCV, "StepFun":StepFunCV, - "OpenAI-API-Compatible": OpenAI_APICV + "OpenAI-API-Compatible": OpenAI_APICV, + "TogetherAI": TogetherAICV } @@ -86,6 +88,7 @@ ChatModel = { "OpenAI-API-Compatible": OpenAI_APIChat, "cohere": CoHereChat, "LeptonAI": LeptonAIChat, + "TogetherAI": TogetherAIChat, "PerfXCloud": PerfXCloudChat } @@ -98,7 +101,8 @@ RerankModel = { "NVIDIA": NvidiaRerank, "LM-Studio": LmStudioRerank, "OpenAI-API-Compatible": OpenAI_APIRerank, - "cohere": CoHereRerank + "cohere": CoHereRerank, + "TogetherAI": TogetherAIRerank } diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index dc908c40d..256e0f0ba 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -990,6 +990,13 @@ class LeptonAIChat(Base): super().__init__(key, model_name, base_url) +class TogetherAIChat(Base): + def __init__(self, key, model_name, base_url="https://api.together.xyz/v1"): + if not base_url: + base_url = "https://api.together.xyz/v1" + super().__init__(key, model_name, base_url) + + class PerfXCloudChat(Base): def __init__(self, key, model_name, base_url="https://cloud.perfxlab.cn/v1"): if not base_url: diff --git a/rag/llm/cv_model.py b/rag/llm/cv_model.py index 82a773024..ee3de6bbd 100644 --- a/rag/llm/cv_model.py +++ b/rag/llm/cv_model.py @@ -649,3 +649,10 @@ class OpenAI_APICV(GptV4): self.client = OpenAI(api_key=key, base_url=base_url) self.model_name = model_name.split("___")[0] self.lang = lang + + +class TogetherAICV(GptV4): + def __init__(self, key, model_name, base_url="https://api.together.xyz/v1"): + if not base_url: + base_url = "https://api.together.xyz/v1" + super().__init__(key, model_name, base_url) \ No newline at end of file diff --git a/rag/llm/embedding_model.py b/rag/llm/embedding_model.py index 4ceef04f6..c663a762e 100644 --- a/rag/llm/embedding_model.py +++ b/rag/llm/embedding_model.py @@ -555,8 +555,16 @@ class CoHereEmbed(Base): ) +class TogetherAIEmbed(OllamaEmbed): + def __init__(self, key, model_name, base_url="https://api.together.xyz/v1"): + if not base_url: + base_url = "https://api.together.xyz/v1" + super().__init__(key, model_name, base_url) + + class PerfXCloudEmbed(OpenAIEmbed): def __init__(self, key, model_name, base_url="https://cloud.perfxlab.cn/v1"): if not base_url: base_url = "https://cloud.perfxlab.cn/v1" super().__init__(key, model_name, base_url) + \ No newline at end of file diff --git a/rag/llm/rerank_model.py b/rag/llm/rerank_model.py index 2f142ef0e..7f39f7a6a 100644 --- a/rag/llm/rerank_model.py +++ b/rag/llm/rerank_model.py @@ -245,3 +245,11 @@ class CoHereRerank(Base): rank = np.array([d.relevance_score for d in res.results]) indexs = [d.index for d in res.results] return rank[indexs], token_count + + +class TogetherAIRerank(Base): + def __init__(self, key, model_name, base_url): + pass + + def similarity(self, query: str, texts: list): + raise NotImplementedError("The api has not been implement") \ No newline at end of file diff --git a/web/src/assets/svg/llm/together-ai.svg b/web/src/assets/svg/llm/together-ai.svg new file mode 100644 index 000000000..5f39ecce1 --- /dev/null +++ b/web/src/assets/svg/llm/together-ai.svg @@ -0,0 +1,15 @@ + + + + + + + + + + + + + + + diff --git a/web/src/pages/user-setting/constants.tsx b/web/src/pages/user-setting/constants.tsx index cf2bce9ed..98b6e42be 100644 --- a/web/src/pages/user-setting/constants.tsx +++ b/web/src/pages/user-setting/constants.tsx @@ -17,4 +17,4 @@ export const UserSettingIconMap = { export * from '@/constants/setting'; -export const LocalLlmFactories = ['Ollama', 'Xinference','LocalAI','LM-Studio',"OpenAI-API-Compatible"]; +export const LocalLlmFactories = ['Ollama', 'Xinference','LocalAI','LM-Studio',"OpenAI-API-Compatible",'TogetherAI']; diff --git a/web/src/pages/user-setting/setting-model/constant.ts b/web/src/pages/user-setting/setting-model/constant.ts index db5b03077..2fa10d554 100644 --- a/web/src/pages/user-setting/setting-model/constant.ts +++ b/web/src/pages/user-setting/setting-model/constant.ts @@ -25,6 +25,7 @@ export const IconMap = { 'OpenAI-API-Compatible': 'openai-api', cohere: 'cohere', Lepton: 'lepton', + TogetherAI:'together-ai', PerfXCould: 'perfx-could' };