diff --git a/conf/llm_factories.json b/conf/llm_factories.json index 2720e4978..a7718f074 100644 --- a/conf/llm_factories.json +++ b/conf/llm_factories.json @@ -134,6 +134,18 @@ "max_tokens": 32768, "model_type": "chat" }, + { + "llm_name": "qwq-32b", + "tags": "LLM,CHAT,128k", + "max_tokens": 131072, + "model_type": "chat" + }, + { + "llm_name": "qwq-plus", + "tags": "LLM,CHAT,128k", + "max_tokens": 131072, + "model_type": "chat" + }, { "llm_name": "qwen-long", "tags": "LLM,CHAT,10000K", @@ -3259,7 +3271,7 @@ "tags": "TEXT EMBEDDING,32000", "max_tokens": 32000, "model_type": "embedding" - }, + }, { "llm_name": "rerank-1", "tags": "RE-RANK, 8000", diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index 2d7865bcd..ce6457bf7 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -268,13 +268,13 @@ class QWenChat(Base): import dashscope dashscope.api_key = key self.model_name = model_name - if model_name.lower().find("deepseek") >= 0: + if self.is_reasoning_model(self.model_name): super().__init__(key, model_name, "https://dashscope.aliyuncs.com/compatible-mode/v1") def chat(self, system, history, gen_conf): if "max_tokens" in gen_conf: del gen_conf["max_tokens"] - if self.model_name.lower().find("deepseek") >= 0: + if self.is_reasoning_model(self.model_name): return super().chat(system, history, gen_conf) stream_flag = str(os.environ.get('QWEN_CHAT_BY_STREAM', 'true')).lower() == 'true' @@ -348,11 +348,19 @@ class QWenChat(Base): def chat_streamly(self, system, history, gen_conf): if "max_tokens" in gen_conf: del gen_conf["max_tokens"] - if self.model_name.lower().find("deepseek") >= 0: + if self.is_reasoning_model(self.model_name): return super().chat_streamly(system, history, gen_conf) return self._chat_streamly(system, history, gen_conf) + @staticmethod + def is_reasoning_model(model_name: str) -> bool: + return any([ + model_name.lower().find("deepseek") >= 0, + model_name.lower().find("qwq") >= 0 and model_name.lower() != 'qwq-32b-preview', + ]) + + class ZhipuChat(Base): def __init__(self, key, model_name="glm-3-turbo", **kwargs): @@ -740,7 +748,7 @@ class BedrockChat(Base): self.bedrock_sk = json.loads(key).get('bedrock_sk', '') self.bedrock_region = json.loads(key).get('bedrock_region', '') self.model_name = model_name - + if self.bedrock_ak == '' or self.bedrock_sk == '' or self.bedrock_region == '': # Try to create a client using the default credentials (AWS_PROFILE, AWS_DEFAULT_REGION, etc.) self.client = boto3.client('bedrock-runtime')