diff --git a/api/db/init_data.py b/api/db/init_data.py index ad3c16d98..ec295d57a 100644 --- a/api/db/init_data.py +++ b/api/db/init_data.py @@ -180,6 +180,12 @@ factory_infos = [{ "logo": "", "tags": "LLM,TEXT EMBEDDING,IMAGE2TEXT", "status": "1", +}, +{ + "name": "Groq", + "logo": "", + "tags": "LLM", + "status": "1", } # { # "name": "文心一言", @@ -933,6 +939,47 @@ def init_llm_factory(): "tags": "TEXT EMBEDDING", "max_tokens": 2048, "model_type": LLMType.EMBEDDING.value + }, + # ------------------------ Groq ----------------------- + { + "fid": factory_infos[18]["name"], + "llm_name": "gemma-7b-it", + "tags": "LLM,CHAT,15k", + "max_tokens": 8192, + + "model_type": LLMType.CHAT.value + }, + { + "fid": factory_infos[18]["name"], + "llm_name": "gemma2-9b-it", + "tags": "LLM,CHAT,15k", + "max_tokens": 8192, + + "model_type": LLMType.CHAT.value + }, + { + "fid": factory_infos[18]["name"], + "llm_name": "llama3-70b-8192", + "tags": "LLM,CHAT,6k", + "max_tokens": 8192, + + "model_type": LLMType.CHAT.value + }, + { + "fid": factory_infos[18]["name"], + "llm_name": "llama3-8b-8192", + "tags": "LLM,CHAT,30k", + "max_tokens": 8192, + + "model_type": LLMType.CHAT.value + }, + { + "fid": factory_infos[18]["name"], + "llm_name": "mixtral-8x7b-32768", + "tags": "LLM,CHAT,5k", + "max_tokens": 32768, + + "model_type": LLMType.CHAT.value } ] for info in factory_infos: diff --git a/rag/llm/__init__.py b/rag/llm/__init__.py index b07d98064..65a76e8f1 100644 --- a/rag/llm/__init__.py +++ b/rag/llm/__init__.py @@ -32,7 +32,8 @@ EmbeddingModel = { "Jina": JinaEmbed, "BAAI": DefaultEmbedding, "Mistral": MistralEmbed, - "Bedrock": BedrockEmbed + "Bedrock": BedrockEmbed, + "Groq": GroqChat } diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index f0fcb39bc..c3701b168 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -23,6 +23,7 @@ from ollama import Client from volcengine.maas.v2 import MaasService from rag.nlp import is_english from rag.utils import num_tokens_from_string +from groq import Groq class Base(ABC): @@ -681,4 +682,63 @@ class GeminiChat(Base): except Exception as e: yield ans + "\n**ERROR**: " + str(e) - yield response._chunks[-1].usage_metadata.total_token_count \ No newline at end of file + yield response._chunks[-1].usage_metadata.total_token_count + + + +class GroqChat: + def __init__(self, key, model_name,base_url=''): + self.client = Groq(api_key=key) + self.model_name = model_name + + def chat(self, system, history, gen_conf): + if system: + history.insert(0, {"role": "system", "content": system}) + for k in list(gen_conf.keys()): + if k not in ["temperature", "top_p", "max_tokens"]: + del gen_conf[k] + + ans = "" + try: + response = self.client.chat.completions.create( + model=self.model_name, + messages=history, + **gen_conf + ) + ans = response.choices[0].message.content + if response.choices[0].finish_reason == "length": + ans += "...\nFor the content length reason, it stopped, continue?" if self.is_english( + [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" + return ans, response.usage.total_tokens + except Exception as e: + return ans + "\n**ERROR**: " + str(e), 0 + + def chat_streamly(self, system, history, gen_conf): + if system: + history.insert(0, {"role": "system", "content": system}) + for k in list(gen_conf.keys()): + if k not in ["temperature", "top_p", "max_tokens"]: + del gen_conf[k] + ans = "" + total_tokens = 0 + try: + response = self.client.chat.completions.create( + model=self.model_name, + messages=history, + stream=True, + **gen_conf + ) + for resp in response: + if not resp.choices or not resp.choices[0].delta.content: + continue + ans += resp.choices[0].delta.content + total_tokens += 1 + if resp.choices[0].finish_reason == "length": + ans += "...\nFor the content length reason, it stopped, continue?" if self.is_english( + [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" + yield ans + + except Exception as e: + yield ans + "\n**ERROR**: " + str(e) + + yield total_tokens \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 497c94683..6e45cb33a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -147,4 +147,5 @@ markdown==3.6 mistralai==0.4.2 boto3==1.34.140 duckduckgo_search==6.1.9 -google-generativeai==0.7.2 \ No newline at end of file +google-generativeai==0.7.2 +groq==0.9.0 \ No newline at end of file diff --git a/requirements_arm.txt b/requirements_arm.txt index a5cbf5d70..abec23a82 100644 --- a/requirements_arm.txt +++ b/requirements_arm.txt @@ -148,4 +148,5 @@ markdown==3.6 mistralai==0.4.2 boto3==1.34.140 duckduckgo_search==6.1.9 -google-generativeai==0.7.2 \ No newline at end of file +google-generativeai==0.7.2 +groq==0.9.0 diff --git a/requirements_dev.txt b/requirements_dev.txt index f6b6799b5..22a422650 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -133,4 +133,5 @@ markdown==3.6 mistralai==0.4.2 boto3==1.34.140 duckduckgo_search==6.1.9 -google-generativeai==0.7.2 \ No newline at end of file +google-generativeai==0.7.2 +groq==0.9.0 \ No newline at end of file diff --git a/web/src/assets/svg/llm/Groq.svg b/web/src/assets/svg/llm/Groq.svg new file mode 100644 index 000000000..5608a42e4 --- /dev/null +++ b/web/src/assets/svg/llm/Groq.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/web/src/pages/user-setting/setting-model/index.tsx b/web/src/pages/user-setting/setting-model/index.tsx index 42f58a61e..6f4319e85 100644 --- a/web/src/pages/user-setting/setting-model/index.tsx +++ b/web/src/pages/user-setting/setting-model/index.tsx @@ -62,6 +62,7 @@ const IconMap = { 'Azure-OpenAI': 'azure', Bedrock: 'bedrock', Gemini:'gemini', + Groq: 'Groq', }; const LlmIcon = ({ name }: { name: string }) => {