From 9ae57eb37089336b29b71635c7deaa9525537a3e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=BB=84=E8=85=BE?= <101850389+hangters@users.noreply.github.com> Date: Wed, 17 Jul 2024 15:32:51 +0800 Subject: [PATCH] fix MiniMax api error (#1567) ### What problem does this PR solve? #1353 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --------- Co-authored-by: Zhedong Cen --- api/db/db_models.py | 8 +++- conf/llm_factories.json | 12 +++--- rag/llm/chat_model.py | 84 ++++++++++++++++++++++++++++++++++++++--- 3 files changed, 91 insertions(+), 13 deletions(-) diff --git a/api/db/db_models.py b/api/db/db_models.py index 4491526bf..ff2fab916 100644 --- a/api/db/db_models.py +++ b/api/db/db_models.py @@ -558,7 +558,7 @@ class TenantLLM(DataBaseModel): null=True, help_text="LLM name", default="") - api_key = CharField(max_length=255, null=True, help_text="API KEY") + api_key = CharField(max_length=1024, null=True, help_text="API KEY") api_base = CharField(max_length=255, null=True, help_text="API Base") used_tokens = IntegerField(default=0) @@ -885,3 +885,9 @@ def migrate_db(): ) except Exception as e: pass + try: + migrate( + migrator.alter_column_type('tenant_llm', 'api_key', CharField(max_length=1024, null=True, help_text="API KEY")) + ) + except Exception as e: + pass \ No newline at end of file diff --git a/conf/llm_factories.json b/conf/llm_factories.json index 446b1eeb8..0a438cba5 100644 --- a/conf/llm_factories.json +++ b/conf/llm_factories.json @@ -433,37 +433,37 @@ ] }, { - "name": "Minimax", + "name": "MiniMax", "logo": "", "tags": "LLM,TEXT EMBEDDING", "status": "1", "llm": [ { - "llm_name": "abab6.5", + "llm_name": "abab6.5-chat", "tags": "LLM,CHAT,8k", "max_tokens": 8192, "model_type": "chat" }, { - "llm_name": "abab6.5s", + "llm_name": "abab6.5s-chat", "tags": "LLM,CHAT,245k", "max_tokens": 245760, "model_type": "chat" }, { - "llm_name": "abab6.5t", + "llm_name": "abab6.5t-chat", "tags": "LLM,CHAT,8k", "max_tokens": 8192, "model_type": "chat" }, { - "llm_name": "abab6.5g", + "llm_name": "abab6.5g-chat", "tags": "LLM,CHAT,8k", "max_tokens": 8192, "model_type": "chat" }, { - "llm_name": "abab5.5s", + "llm_name": "abab5.5s-chat", "tags": "LLM,CHAT,8k", "max_tokens": 8192, "model_type": "chat" diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index e883869a0..cf00a7fc5 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -24,7 +24,8 @@ from volcengine.maas.v2 import MaasService from rag.nlp import is_english from rag.utils import num_tokens_from_string from groq import Groq - +import json +import requests class Base(ABC): def __init__(self, key, model_name, base_url): @@ -475,11 +476,83 @@ class VolcEngineChat(Base): class MiniMaxChat(Base): - def __init__(self, key, model_name="abab6.5s-chat", - base_url="https://api.minimax.chat/v1/text/chatcompletion_v2"): + def __init__( + self, + key, + model_name, + base_url="https://api.minimax.chat/v1/text/chatcompletion_v2", + ): if not base_url: - base_url="https://api.minimax.chat/v1/text/chatcompletion_v2" - super().__init__(key, model_name, base_url) + base_url = "https://api.minimax.chat/v1/text/chatcompletion_v2" + self.base_url = base_url + self.model_name = model_name + self.api_key = key + + def chat(self, system, history, gen_conf): + if system: + history.insert(0, {"role": "system", "content": system}) + for k in list(gen_conf.keys()): + if k not in ["temperature", "top_p", "max_tokens"]: + del gen_conf[k] + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + } + payload = json.dumps( + {"model": self.model_name, "messages": history, **gen_conf} + ) + try: + response = requests.request( + "POST", url=self.base_url, headers=headers, data=payload + ) + print(response, flush=True) + response = response.json() + ans = response["choices"][0]["message"]["content"].strip() + if response["choices"][0]["finish_reason"] == "length": + ans += "...\nFor the content length reason, it stopped, continue?" if is_english( + [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" + return ans, response["usage"]["total_tokens"] + except Exception as e: + return "**ERROR**: " + str(e), 0 + + def chat_streamly(self, system, history, gen_conf): + if system: + history.insert(0, {"role": "system", "content": system}) + ans = "" + total_tokens = 0 + try: + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + } + payload = json.dumps( + { + "model": self.model_name, + "messages": history, + "stream": True, + **gen_conf, + } + ) + response = requests.request( + "POST", + url=self.base_url, + headers=headers, + data=payload, + ) + for resp in response.text.split("\n\n")[:-1]: + resp = json.loads(resp[6:]) + if "delta" in resp["choices"][0]: + text = resp["choices"][0]["delta"]["content"] + else: + continue + ans += text + total_tokens += num_tokens_from_string(text) + yield ans + + except Exception as e: + yield ans + "\n**ERROR**: " + str(e) + + yield total_tokens class MistralChat(Base): @@ -748,4 +821,3 @@ class OpenRouterChat(Base): self.base_url = "https://openrouter.ai/api/v1" self.client = OpenAI(base_url=self.base_url, api_key=key) self.model_name = model_name -