Added support for Baichuan LLM (#934)

### What problem does this PR solve?

- Added support for Baichuan LLM

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

Co-authored-by: 海贼宅 <stu_xyx@163.com>
This commit is contained in:
yungongzi 2024-05-28 09:09:37 +08:00 committed by GitHub
parent ec6ae744a1
commit 9ffd7ae321
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 169 additions and 3 deletions

View File

@ -137,7 +137,12 @@ factory_infos = [{
"logo": "",
"tags": "LLM, TEXT EMBEDDING",
"status": "1",
}
},{
"name": "BaiChuan",
"logo": "",
"tags": "LLM,TEXT EMBEDDING",
"status": "1",
},
# {
# "name": "文心一言",
# "logo": "",
@ -392,6 +397,49 @@ def init_llm_factory():
"max_tokens": 4096,
"model_type": LLMType.CHAT.value
},
# ------------------------ BaiChuan -----------------------
{
"fid": factory_infos[10]["name"],
"llm_name": "Baichuan2-Turbo",
"tags": "LLM,CHAT,32K",
"max_tokens": 32768,
"model_type": LLMType.CHAT.value
},
{
"fid": factory_infos[10]["name"],
"llm_name": "Baichuan2-Turbo-192k",
"tags": "LLM,CHAT,192K",
"max_tokens": 196608,
"model_type": LLMType.CHAT.value
},
{
"fid": factory_infos[10]["name"],
"llm_name": "Baichuan3-Turbo",
"tags": "LLM,CHAT,32K",
"max_tokens": 32768,
"model_type": LLMType.CHAT.value
},
{
"fid": factory_infos[10]["name"],
"llm_name": "Baichuan3-Turbo-128k",
"tags": "LLM,CHAT,128K",
"max_tokens": 131072,
"model_type": LLMType.CHAT.value
},
{
"fid": factory_infos[10]["name"],
"llm_name": "Baichuan4",
"tags": "LLM,CHAT,128K",
"max_tokens": 131072,
"model_type": LLMType.CHAT.value
},
{
"fid": factory_infos[10]["name"],
"llm_name": "Baichuan-Text-Embedding",
"tags": "TEXT EMBEDDING",
"max_tokens": 512,
"model_type": LLMType.EMBEDDING.value
},
]
for info in factory_infos:
try:

View File

@ -26,7 +26,8 @@ EmbeddingModel = {
"ZHIPU-AI": ZhipuEmbed,
"FastEmbed": FastEmbed,
"Youdao": YoudaoEmbed,
"DeepSeek": DefaultEmbedding
"DeepSeek": DefaultEmbedding,
"BaiChuan": BaiChuanEmbed
}
@ -47,6 +48,7 @@ ChatModel = {
"Ollama": OllamaChat,
"Xinference": XinferenceChat,
"Moonshot": MoonshotChat,
"DeepSeek": DeepSeekChat
"DeepSeek": DeepSeekChat,
"BaiChuan": BaiChuanChat
}

View File

@ -95,6 +95,84 @@ class DeepSeekChat(Base):
super().__init__(key, model_name, base_url)
class BaiChuanChat(Base):
def __init__(self, key, model_name="Baichuan3-Turbo", base_url="https://api.baichuan-ai.com/v1"):
if not base_url:
base_url = "https://api.baichuan-ai.com/v1"
super().__init__(key, model_name, base_url)
@staticmethod
def _format_params(params):
return {
"temperature": params.get("temperature", 0.3),
"max_tokens": params.get("max_tokens", 2048),
"top_p": params.get("top_p", 0.85),
}
def chat(self, system, history, gen_conf):
if system:
history.insert(0, {"role": "system", "content": system})
try:
response = self.client.chat.completions.create(
model=self.model_name,
messages=history,
extra_body={
"tools": [{
"type": "web_search",
"web_search": {
"enable": True,
"search_mode": "performance_first"
}
}]
},
**self._format_params(gen_conf))
ans = response.choices[0].message.content.strip()
if response.choices[0].finish_reason == "length":
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
return ans, response.usage.total_tokens
except openai.APIError as e:
return "**ERROR**: " + str(e), 0
def chat_streamly(self, system, history, gen_conf):
if system:
history.insert(0, {"role": "system", "content": system})
ans = ""
total_tokens = 0
try:
response = self.client.chat.completions.create(
model=self.model_name,
messages=history,
extra_body={
"tools": [{
"type": "web_search",
"web_search": {
"enable": True,
"search_mode": "performance_first"
}
}]
},
stream=True,
**self._format_params(gen_conf))
for resp in response:
if resp.choices[0].finish_reason == "stop":
if not resp.choices[0].delta.content:
continue
total_tokens = resp.usage.get('total_tokens', 0)
if not resp.choices[0].delta.content:
continue
ans += resp.choices[0].delta.content
if resp.choices[0].finish_reason == "length":
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
yield ans
except Exception as e:
yield ans + "\n**ERROR**: " + str(e)
yield total_tokens
class QWenChat(Base):
def __init__(self, key, model_name=Generation.Models.qwen_turbo, **kwargs):
import dashscope

View File

@ -104,6 +104,15 @@ class OpenAIEmbed(Base):
return np.array(res.data[0].embedding), res.usage.total_tokens
class BaiChuanEmbed(OpenAIEmbed):
def __init__(self, key,
model_name='Baichuan-Text-Embedding',
base_url='https://api.baichuan-ai.com/v1'):
if not base_url:
base_url = "https://api.baichuan-ai.com/v1"
super().__init__(key, model_name, base_url)
class QWenEmbed(Base):
def __init__(self, key, model_name="text_embedding_v2", **kwargs):
dashscope.api_key = key

View File

@ -0,0 +1,28 @@
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN" "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">
<svg version="1.1" id="Layer_1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" x="0px" y="0px" width="256px" height="256px" viewBox="0 0 256 256" enable-background="new 0 0 256 256" xml:space="preserve"> <image id="image0" width="256" height="256" x="0" y="0"
xlink:href="
AAB1MAAA6mAAADqYAAAXb5JfxUYAAACuUExURf////349fzy7/77+f349fzy7//+/fXOv/bTxf//
/+2njfjg1u+xmfG6pumXd/jg1uaDXv349fLArfPIt++xmfvv6u6tlP77+frp4uqbfeaIZPPEsvC2
n/LArf////LCr/////G6pvLCr++xmfC2n/zy7/fXyu2njeiPbf318fnl3eudfvG+quykif/////+
/eBnOeN2TON3TuJ0SuFvROFqPeN5UOV+V+R8VP///3/zEGYAAAAwdFJOUwA+RTlbYVqZk1i/Y7yt
4YH5Yquhv27DXnbb86W1plCpJZiOq7lojcvrZnrZrc8yNOA7KFcAAAABYktHRACIBR1IAAAACXBI
WXMAAAsTAAALEwEAmpwYAAAAB3RJTUUH6AUbBTg4QG2qLwAAAyhJREFUeNrt3dlS2zAUgGHaUjvq
ljht4pB0oy3ZoDTgLH3/JytThinjWNJxhwmco/+/RkL6RngSc6GjIyIiIiIiIiIiIiIiIiIiIiIi
IiIiIiIiIvq/nj0X9qI28PilsOPH2FaWe+q42k++Wgl7XRv4Rjrw7WMAdHyruUoDwL1LHKDrXU0a
AFkvbQCX+1eTBEBxnThAP7CaFABcL20A9z60mgQAiiptAPchuBr7AIN12gBuGF6NeYCyCq/GOkDk
CWAfYLRJG8CdxFZjHKAbXY1tgPgBMA4wiD0BjAP4X4QlAtARrMYyQPBrcAoAuWQ1hgHcOHGA+GcA
2wBlL22A8IuwBACKSrYaswDCA2AWIPDPsCQA3ES6mocD+FgJ+3QAgHJ7eIDP4g4AIH0CPCTAk+oL
AAAAAAAAAAAAwAEBvoozCnC6FfbNKMB36cBTAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
AAAAAAAAAAAAAAAAgP3cDzHAziTA2Vq6mFXfIkCLA3Bd1gdbAMjEtwusJvUDYAFAesHMTbtsb7QB
gLMrMcB07wAYAHAz8f7n+/s3ACC4ZfGuvGG4eoDYTcv3GmcN49UDLOSfAZZN47UDRG9a/te5a5pA
O8BS/ARYDxonUA7g5uIDMGw8ANoBZLcs/m3UPINuAHcu3v+s+QAoBxiJ918VnilUAwiu2r7rwnMA
VAO4qXj/u9I3iWaA7LcYYOI7AKoBJFdt3zYvvJMoBmjxGWDqn0UxQIsD4Pyz6AUQ3bV+WzcwjVoA
1xfv/2fgAOgFkH8NXhehebQCtDgAs9ABUAtQ7qS/fzMKTqQUwF2ID8AweAC0AkjvWr95AizCM+kE
aPEqfBI+AEoB5F+De5H96wRo8Sa0H5tLJcBC/ATYFrG5NAK0eALksb8AnQDiV+Hj6P51Akh/9WYZ
n8w0wEn8AOgEWApbCCa7/CXssjbwSd07TEREREREREREREREREREREREREREevoDtKBqvEP0IEYA
AAAldEVYdGRhdGU6Y3JlYXRlADIwMjQtMDUtMjdUMDU6NTY6NTYrMDA6MDANJESyAAAAJXRFWHRk
YXRlOm1vZGlmeQAyMDI0LTA1LTI3VDA1OjU2OjU2KzAwOjAwfHn8DgAAACh0RVh0ZGF0ZTp0aW1l
c3RhbXAAMjAyNC0wNS0yN1QwNTo1Njo1NiswMDowMCts3dEAAAAASUVORK5CYII=" />
</svg>

After

Width:  |  Height:  |  Size: 2.3 KiB

View File

@ -55,6 +55,7 @@ const IconMap = {
Xinference: 'xinference',
DeepSeek: 'deepseek',
VolcEngine: 'volc_engine',
BaiChuan: 'baichuan',
};
const LlmIcon = ({ name }: { name: string }) => {