diff --git a/api/db/init_data.py b/api/db/init_data.py index 3017160cc..0854602a0 100644 --- a/api/db/init_data.py +++ b/api/db/init_data.py @@ -157,6 +157,11 @@ factory_infos = [{ "logo": "", "tags": "LLM,TEXT EMBEDDING", "status": "1", +},{ + "name": "Mistral", + "logo": "", + "tags": "LLM,TEXT EMBEDDING", + "status": "1", } # { # "name": "文心一言", @@ -584,6 +589,63 @@ def init_llm_factory(): "max_tokens": 8192, "model_type": LLMType.CHAT.value }, + # ------------------------ Mistral ----------------------- + { + "fid": factory_infos[14]["name"], + "llm_name": "open-mixtral-8x22b", + "tags": "LLM,CHAT,64k", + "max_tokens": 64000, + "model_type": LLMType.CHAT.value + }, + { + "fid": factory_infos[14]["name"], + "llm_name": "open-mixtral-8x7b", + "tags": "LLM,CHAT,32k", + "max_tokens": 32000, + "model_type": LLMType.CHAT.value + }, + { + "fid": factory_infos[14]["name"], + "llm_name": "open-mistral-7b", + "tags": "LLM,CHAT,32k", + "max_tokens": 32000, + "model_type": LLMType.CHAT.value + }, + { + "fid": factory_infos[14]["name"], + "llm_name": "mistral-large-latest", + "tags": "LLM,CHAT,32k", + "max_tokens": 32000, + "model_type": LLMType.CHAT.value + }, + { + "fid": factory_infos[14]["name"], + "llm_name": "mistral-small-latest", + "tags": "LLM,CHAT,32k", + "max_tokens": 32000, + "model_type": LLMType.CHAT.value + }, + { + "fid": factory_infos[14]["name"], + "llm_name": "mistral-medium-latest", + "tags": "LLM,CHAT,32k", + "max_tokens": 32000, + "model_type": LLMType.CHAT.value + }, + { + "fid": factory_infos[14]["name"], + "llm_name": "codestral-latest", + "tags": "LLM,CHAT,32k", + "max_tokens": 32000, + "model_type": LLMType.CHAT.value + }, + { + "fid": factory_infos[14]["name"], + "llm_name": "mistral-embed", + "tags": "LLM,CHAT,8k", + "max_tokens": 8192, + "model_type": LLMType.EMBEDDING + }, ] for info in factory_infos: try: diff --git a/rag/llm/__init__.py b/rag/llm/__init__.py index c3d58f871..d70136a57 100644 --- a/rag/llm/__init__.py +++ b/rag/llm/__init__.py @@ -29,7 +29,8 @@ EmbeddingModel = { "Youdao": YoudaoEmbed, "BaiChuan": BaiChuanEmbed, "Jina": JinaEmbed, - "BAAI": DefaultEmbedding + "BAAI": DefaultEmbedding, + "Mistral": MistralEmbed } @@ -52,7 +53,8 @@ ChatModel = { "Moonshot": MoonshotChat, "DeepSeek": DeepSeekChat, "BaiChuan": BaiChuanChat, - "MiniMax": MiniMaxChat + "MiniMax": MiniMaxChat, + "Mistral": MistralChat } diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index 2dc63fd99..f6c0666fe 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -472,3 +472,57 @@ class MiniMaxChat(Base): if not base_url: base_url="https://api.minimax.chat/v1/text/chatcompletion_v2" super().__init__(key, model_name, base_url) + + +class MistralChat(Base): + + def __init__(self, key, model_name, base_url=None): + from mistralai.client import MistralClient + self.client = MistralClient(api_key=key) + self.model_name = model_name + + def chat(self, system, history, gen_conf): + if system: + history.insert(0, {"role": "system", "content": system}) + for k in list(gen_conf.keys()): + if k not in ["temperature", "top_p", "max_tokens"]: + del gen_conf[k] + try: + response = self.client.chat( + model=self.model_name, + messages=history, + **gen_conf) + ans = response.choices[0].message.content + if response.choices[0].finish_reason == "length": + ans += "...\nFor the content length reason, it stopped, continue?" if is_english( + [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" + return ans, response.usage.total_tokens + except openai.APIError as e: + return "**ERROR**: " + str(e), 0 + + def chat_streamly(self, system, history, gen_conf): + if system: + history.insert(0, {"role": "system", "content": system}) + for k in list(gen_conf.keys()): + if k not in ["temperature", "top_p", "max_tokens"]: + del gen_conf[k] + ans = "" + total_tokens = 0 + try: + response = self.client.chat_stream( + model=self.model_name, + messages=history, + **gen_conf) + for resp in response: + if not resp.choices or not resp.choices[0].delta.content:continue + ans += resp.choices[0].delta.content + total_tokens += 1 + if resp.choices[0].finish_reason == "length": + ans += "...\nFor the content length reason, it stopped, continue?" if is_english( + [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" + yield ans + + except openai.APIError as e: + yield ans + "\n**ERROR**: " + str(e) + + yield total_tokens diff --git a/rag/llm/embedding_model.py b/rag/llm/embedding_model.py index b0e1bff96..6b45b6035 100644 --- a/rag/llm/embedding_model.py +++ b/rag/llm/embedding_model.py @@ -343,4 +343,24 @@ class InfinityEmbed(Base): def encode_queries(self, text: str) -> tuple[np.ndarray, int]: # Using the internal tokenizer to encode the texts and get the total # number of tokens - return self.encode([text]) \ No newline at end of file + return self.encode([text]) + + +class MistralEmbed(Base): + def __init__(self, key, model_name="mistral-embed", + base_url=None): + from mistralai.client import MistralClient + self.client = MistralClient(api_key=key) + self.model_name = model_name + + def encode(self, texts: list, batch_size=32): + texts = [truncate(t, 8196) for t in texts] + res = self.client.embeddings(input=texts, + model=self.model_name) + return np.array([d.embedding for d in res.data] + ), res.usage.total_tokens + + def encode_queries(self, text): + res = self.client.embeddings(input=[truncate(text, 8196)], + model=self.model_name) + return np.array(res.data[0].embedding), res.usage.total_tokens