add support for mistral (#1153)

### What problem does this PR solve?

#433 

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
KevinHuSh 2024-06-14 11:32:58 +08:00 committed by GitHub
parent a25d32496c
commit 7dc39cbfa6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 141 additions and 3 deletions

View File

@ -157,6 +157,11 @@ factory_infos = [{
"logo": "",
"tags": "LLM,TEXT EMBEDDING",
"status": "1",
},{
"name": "Mistral",
"logo": "",
"tags": "LLM,TEXT EMBEDDING",
"status": "1",
}
# {
# "name": "文心一言",
@ -584,6 +589,63 @@ def init_llm_factory():
"max_tokens": 8192,
"model_type": LLMType.CHAT.value
},
# ------------------------ Mistral -----------------------
{
"fid": factory_infos[14]["name"],
"llm_name": "open-mixtral-8x22b",
"tags": "LLM,CHAT,64k",
"max_tokens": 64000,
"model_type": LLMType.CHAT.value
},
{
"fid": factory_infos[14]["name"],
"llm_name": "open-mixtral-8x7b",
"tags": "LLM,CHAT,32k",
"max_tokens": 32000,
"model_type": LLMType.CHAT.value
},
{
"fid": factory_infos[14]["name"],
"llm_name": "open-mistral-7b",
"tags": "LLM,CHAT,32k",
"max_tokens": 32000,
"model_type": LLMType.CHAT.value
},
{
"fid": factory_infos[14]["name"],
"llm_name": "mistral-large-latest",
"tags": "LLM,CHAT,32k",
"max_tokens": 32000,
"model_type": LLMType.CHAT.value
},
{
"fid": factory_infos[14]["name"],
"llm_name": "mistral-small-latest",
"tags": "LLM,CHAT,32k",
"max_tokens": 32000,
"model_type": LLMType.CHAT.value
},
{
"fid": factory_infos[14]["name"],
"llm_name": "mistral-medium-latest",
"tags": "LLM,CHAT,32k",
"max_tokens": 32000,
"model_type": LLMType.CHAT.value
},
{
"fid": factory_infos[14]["name"],
"llm_name": "codestral-latest",
"tags": "LLM,CHAT,32k",
"max_tokens": 32000,
"model_type": LLMType.CHAT.value
},
{
"fid": factory_infos[14]["name"],
"llm_name": "mistral-embed",
"tags": "LLM,CHAT,8k",
"max_tokens": 8192,
"model_type": LLMType.EMBEDDING
},
]
for info in factory_infos:
try:

View File

@ -29,7 +29,8 @@ EmbeddingModel = {
"Youdao": YoudaoEmbed,
"BaiChuan": BaiChuanEmbed,
"Jina": JinaEmbed,
"BAAI": DefaultEmbedding
"BAAI": DefaultEmbedding,
"Mistral": MistralEmbed
}
@ -52,7 +53,8 @@ ChatModel = {
"Moonshot": MoonshotChat,
"DeepSeek": DeepSeekChat,
"BaiChuan": BaiChuanChat,
"MiniMax": MiniMaxChat
"MiniMax": MiniMaxChat,
"Mistral": MistralChat
}

View File

@ -472,3 +472,57 @@ class MiniMaxChat(Base):
if not base_url:
base_url="https://api.minimax.chat/v1/text/chatcompletion_v2"
super().__init__(key, model_name, base_url)
class MistralChat(Base):
def __init__(self, key, model_name, base_url=None):
from mistralai.client import MistralClient
self.client = MistralClient(api_key=key)
self.model_name = model_name
def chat(self, system, history, gen_conf):
if system:
history.insert(0, {"role": "system", "content": system})
for k in list(gen_conf.keys()):
if k not in ["temperature", "top_p", "max_tokens"]:
del gen_conf[k]
try:
response = self.client.chat(
model=self.model_name,
messages=history,
**gen_conf)
ans = response.choices[0].message.content
if response.choices[0].finish_reason == "length":
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
return ans, response.usage.total_tokens
except openai.APIError as e:
return "**ERROR**: " + str(e), 0
def chat_streamly(self, system, history, gen_conf):
if system:
history.insert(0, {"role": "system", "content": system})
for k in list(gen_conf.keys()):
if k not in ["temperature", "top_p", "max_tokens"]:
del gen_conf[k]
ans = ""
total_tokens = 0
try:
response = self.client.chat_stream(
model=self.model_name,
messages=history,
**gen_conf)
for resp in response:
if not resp.choices or not resp.choices[0].delta.content:continue
ans += resp.choices[0].delta.content
total_tokens += 1
if resp.choices[0].finish_reason == "length":
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
yield ans
except openai.APIError as e:
yield ans + "\n**ERROR**: " + str(e)
yield total_tokens

View File

@ -343,4 +343,24 @@ class InfinityEmbed(Base):
def encode_queries(self, text: str) -> tuple[np.ndarray, int]:
# Using the internal tokenizer to encode the texts and get the total
# number of tokens
return self.encode([text])
return self.encode([text])
class MistralEmbed(Base):
def __init__(self, key, model_name="mistral-embed",
base_url=None):
from mistralai.client import MistralClient
self.client = MistralClient(api_key=key)
self.model_name = model_name
def encode(self, texts: list, batch_size=32):
texts = [truncate(t, 8196) for t in texts]
res = self.client.embeddings(input=texts,
model=self.model_name)
return np.array([d.embedding for d in res.data]
), res.usage.total_tokens
def encode_queries(self, text):
res = self.client.embeddings(input=[truncate(text, 8196)],
model=self.model_name)
return np.array(res.data[0].embedding), res.usage.total_tokens