From 9ffd7ae32147481394716b2038e733ab4eb9c17b Mon Sep 17 00:00:00 2001 From: yungongzi Date: Tue, 28 May 2024 09:09:37 +0800 Subject: [PATCH] Added support for Baichuan LLM (#934) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What problem does this PR solve? - Added support for Baichuan LLM ### Type of change - [x] New Feature (non-breaking change which adds functionality) Co-authored-by: 海贼宅 --- api/db/init_data.py | 50 +++++++++++- rag/llm/__init__.py | 6 +- rag/llm/chat_model.py | 78 +++++++++++++++++++ rag/llm/embedding_model.py | 9 +++ web/src/assets/svg/llm/baichuan.svg | 28 +++++++ .../user-setting/setting-model/index.tsx | 1 + 6 files changed, 169 insertions(+), 3 deletions(-) create mode 100644 web/src/assets/svg/llm/baichuan.svg diff --git a/api/db/init_data.py b/api/db/init_data.py index 3044ea976..1a4706f25 100644 --- a/api/db/init_data.py +++ b/api/db/init_data.py @@ -137,7 +137,12 @@ factory_infos = [{ "logo": "", "tags": "LLM, TEXT EMBEDDING", "status": "1", -} +},{ + "name": "BaiChuan", + "logo": "", + "tags": "LLM,TEXT EMBEDDING", + "status": "1", +}, # { # "name": "文心一言", # "logo": "", @@ -392,6 +397,49 @@ def init_llm_factory(): "max_tokens": 4096, "model_type": LLMType.CHAT.value }, + # ------------------------ BaiChuan ----------------------- + { + "fid": factory_infos[10]["name"], + "llm_name": "Baichuan2-Turbo", + "tags": "LLM,CHAT,32K", + "max_tokens": 32768, + "model_type": LLMType.CHAT.value + }, + { + "fid": factory_infos[10]["name"], + "llm_name": "Baichuan2-Turbo-192k", + "tags": "LLM,CHAT,192K", + "max_tokens": 196608, + "model_type": LLMType.CHAT.value + }, + { + "fid": factory_infos[10]["name"], + "llm_name": "Baichuan3-Turbo", + "tags": "LLM,CHAT,32K", + "max_tokens": 32768, + "model_type": LLMType.CHAT.value + }, + { + "fid": factory_infos[10]["name"], + "llm_name": "Baichuan3-Turbo-128k", + "tags": "LLM,CHAT,128K", + "max_tokens": 131072, + "model_type": LLMType.CHAT.value + }, + { + "fid": factory_infos[10]["name"], + "llm_name": "Baichuan4", + "tags": "LLM,CHAT,128K", + "max_tokens": 131072, + "model_type": LLMType.CHAT.value + }, + { + "fid": factory_infos[10]["name"], + "llm_name": "Baichuan-Text-Embedding", + "tags": "TEXT EMBEDDING", + "max_tokens": 512, + "model_type": LLMType.EMBEDDING.value + }, ] for info in factory_infos: try: diff --git a/rag/llm/__init__.py b/rag/llm/__init__.py index 546a09946..9fc114f06 100644 --- a/rag/llm/__init__.py +++ b/rag/llm/__init__.py @@ -26,7 +26,8 @@ EmbeddingModel = { "ZHIPU-AI": ZhipuEmbed, "FastEmbed": FastEmbed, "Youdao": YoudaoEmbed, - "DeepSeek": DefaultEmbedding + "DeepSeek": DefaultEmbedding, + "BaiChuan": BaiChuanEmbed } @@ -47,6 +48,7 @@ ChatModel = { "Ollama": OllamaChat, "Xinference": XinferenceChat, "Moonshot": MoonshotChat, - "DeepSeek": DeepSeekChat + "DeepSeek": DeepSeekChat, + "BaiChuan": BaiChuanChat } diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index 484b13929..a9530fe61 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -95,6 +95,84 @@ class DeepSeekChat(Base): super().__init__(key, model_name, base_url) +class BaiChuanChat(Base): + def __init__(self, key, model_name="Baichuan3-Turbo", base_url="https://api.baichuan-ai.com/v1"): + if not base_url: + base_url = "https://api.baichuan-ai.com/v1" + super().__init__(key, model_name, base_url) + + @staticmethod + def _format_params(params): + return { + "temperature": params.get("temperature", 0.3), + "max_tokens": params.get("max_tokens", 2048), + "top_p": params.get("top_p", 0.85), + } + + def chat(self, system, history, gen_conf): + if system: + history.insert(0, {"role": "system", "content": system}) + try: + response = self.client.chat.completions.create( + model=self.model_name, + messages=history, + extra_body={ + "tools": [{ + "type": "web_search", + "web_search": { + "enable": True, + "search_mode": "performance_first" + } + }] + }, + **self._format_params(gen_conf)) + ans = response.choices[0].message.content.strip() + if response.choices[0].finish_reason == "length": + ans += "...\nFor the content length reason, it stopped, continue?" if is_english( + [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" + return ans, response.usage.total_tokens + except openai.APIError as e: + return "**ERROR**: " + str(e), 0 + + def chat_streamly(self, system, history, gen_conf): + if system: + history.insert(0, {"role": "system", "content": system}) + ans = "" + total_tokens = 0 + try: + response = self.client.chat.completions.create( + model=self.model_name, + messages=history, + extra_body={ + "tools": [{ + "type": "web_search", + "web_search": { + "enable": True, + "search_mode": "performance_first" + } + }] + }, + stream=True, + **self._format_params(gen_conf)) + for resp in response: + if resp.choices[0].finish_reason == "stop": + if not resp.choices[0].delta.content: + continue + total_tokens = resp.usage.get('total_tokens', 0) + if not resp.choices[0].delta.content: + continue + ans += resp.choices[0].delta.content + if resp.choices[0].finish_reason == "length": + ans += "...\nFor the content length reason, it stopped, continue?" if is_english( + [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" + yield ans + + except Exception as e: + yield ans + "\n**ERROR**: " + str(e) + + yield total_tokens + + class QWenChat(Base): def __init__(self, key, model_name=Generation.Models.qwen_turbo, **kwargs): import dashscope diff --git a/rag/llm/embedding_model.py b/rag/llm/embedding_model.py index 6ef741133..43485d99c 100644 --- a/rag/llm/embedding_model.py +++ b/rag/llm/embedding_model.py @@ -104,6 +104,15 @@ class OpenAIEmbed(Base): return np.array(res.data[0].embedding), res.usage.total_tokens +class BaiChuanEmbed(OpenAIEmbed): + def __init__(self, key, + model_name='Baichuan-Text-Embedding', + base_url='https://api.baichuan-ai.com/v1'): + if not base_url: + base_url = "https://api.baichuan-ai.com/v1" + super().__init__(key, model_name, base_url) + + class QWenEmbed(Base): def __init__(self, key, model_name="text_embedding_v2", **kwargs): dashscope.api_key = key diff --git a/web/src/assets/svg/llm/baichuan.svg b/web/src/assets/svg/llm/baichuan.svg new file mode 100644 index 000000000..6ea61b5b8 --- /dev/null +++ b/web/src/assets/svg/llm/baichuan.svg @@ -0,0 +1,28 @@ + + + + diff --git a/web/src/pages/user-setting/setting-model/index.tsx b/web/src/pages/user-setting/setting-model/index.tsx index 071d9ec2a..bffbc0386 100644 --- a/web/src/pages/user-setting/setting-model/index.tsx +++ b/web/src/pages/user-setting/setting-model/index.tsx @@ -55,6 +55,7 @@ const IconMap = { Xinference: 'xinference', DeepSeek: 'deepseek', VolcEngine: 'volc_engine', + BaiChuan: 'baichuan', }; const LlmIcon = ({ name }: { name: string }) => {