diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index 57cf53178..a4dcf2db3 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -14,6 +14,7 @@ # limitations under the License. # import re +import random from openai.lib.azure import AzureOpenAI from zhipuai import ZhipuAI @@ -28,6 +29,23 @@ import os import json import requests import asyncio +import logging +import time + + +# Error message constants +ERROR_PREFIX = "**ERROR**" +ERROR_RATE_LIMIT = "RATE_LIMIT_EXCEEDED" +ERROR_AUTHENTICATION = "AUTH_ERROR" +ERROR_INVALID_REQUEST = "INVALID_REQUEST" +ERROR_SERVER = "SERVER_ERROR" +ERROR_TIMEOUT = "TIMEOUT" +ERROR_CONNECTION = "CONNECTION_ERROR" +ERROR_MODEL = "MODEL_ERROR" +ERROR_CONTENT_FILTER = "CONTENT_FILTERED" +ERROR_QUOTA = "QUOTA_EXCEEDED" +ERROR_MAX_RETRIES = "MAX_RETRIES_EXCEEDED" +ERROR_GENERIC = "GENERIC_ERROR" LENGTH_NOTIFICATION_CN = "······\n由于大模型的上下文窗口大小限制,回答已经被大模型截断。" LENGTH_NOTIFICATION_EN = "...\nThe answer is truncated by your chosen LLM due to its limitation on context length." @@ -38,28 +56,78 @@ class Base(ABC): timeout = int(os.environ.get('LM_TIMEOUT_SECONDS', 600)) self.client = OpenAI(api_key=key, base_url=base_url, timeout=timeout) self.model_name = model_name + # Configure retry parameters + self.max_retries = int(os.environ.get('LLM_MAX_RETRIES', 5)) + self.base_delay = float(os.environ.get('LLM_BASE_DELAY', 2.0)) + + def _get_delay(self, attempt): + """Calculate retry delay time""" + return self.base_delay * (2 ** attempt) + random.uniform(0, 0.5) + + def _classify_error(self, error): + """Classify error based on error message content""" + error_str = str(error).lower() + + if "rate limit" in error_str or "429" in error_str or "tpm limit" in error_str or "too many requests" in error_str or "requests per minute" in error_str: + return ERROR_RATE_LIMIT + elif "auth" in error_str or "key" in error_str or "apikey" in error_str or "401" in error_str or "forbidden" in error_str or "permission" in error_str: + return ERROR_AUTHENTICATION + elif "invalid" in error_str or "bad request" in error_str or "400" in error_str or "format" in error_str or "malformed" in error_str or "parameter" in error_str: + return ERROR_INVALID_REQUEST + elif "server" in error_str or "502" in error_str or "503" in error_str or "504" in error_str or "500" in error_str or "unavailable" in error_str: + return ERROR_SERVER + elif "timeout" in error_str or "timed out" in error_str: + return ERROR_TIMEOUT + elif "connect" in error_str or "network" in error_str or "unreachable" in error_str or "dns" in error_str: + return ERROR_CONNECTION + elif "quota" in error_str or "capacity" in error_str or "credit" in error_str or "billing" in error_str or "limit" in error_str and "rate" not in error_str: + return ERROR_QUOTA + elif "filter" in error_str or "content" in error_str or "policy" in error_str or "blocked" in error_str or "safety" in error_str: + return ERROR_CONTENT_FILTER + elif "model" in error_str or "not found" in error_str or "does not exist" in error_str or "not available" in error_str: + return ERROR_MODEL + else: + return ERROR_GENERIC def chat(self, system, history, gen_conf): if system: history.insert(0, {"role": "system", "content": system}) if "max_tokens" in gen_conf: del gen_conf["max_tokens"] - try: - response = self.client.chat.completions.create( - model=self.model_name, - messages=history, - **gen_conf) - if any([not response.choices, not response.choices[0].message, not response.choices[0].message.content]): - return "", 0 - ans = response.choices[0].message.content.strip() - if response.choices[0].finish_reason == "length": - if is_chinese(ans): - ans += LENGTH_NOTIFICATION_CN + + # Implement exponential backoff retry strategy + for attempt in range(self.max_retries): + try: + response = self.client.chat.completions.create( + model=self.model_name, + messages=history, + **gen_conf) + + if any([not response.choices, not response.choices[0].message, not response.choices[0].message.content]): + return "", 0 + ans = response.choices[0].message.content.strip() + if response.choices[0].finish_reason == "length": + if is_chinese(ans): + ans += LENGTH_NOTIFICATION_CN + else: + ans += LENGTH_NOTIFICATION_EN + return ans, self.total_token_count(response) + except Exception as e: + # Classify the error + error_code = self._classify_error(e) + + # Check if it's a rate limit error or server error and not the last attempt + should_retry = (error_code == ERROR_RATE_LIMIT or error_code == ERROR_SERVER) and attempt < self.max_retries - 1 + + if should_retry: + delay = self._get_delay(attempt) + logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt+1}/{self.max_retries})") + time.sleep(delay) else: - ans += LENGTH_NOTIFICATION_EN - return ans, self.total_token_count(response) - except openai.APIError as e: - return "**ERROR**: " + str(e), 0 + # For non-rate limit errors or the last attempt, return an error message + if attempt == self.max_retries - 1: + error_code = ERROR_MAX_RETRIES + return f"{ERROR_PREFIX}: {error_code} - {str(e)}", 0 def chat_streamly(self, system, history, gen_conf): if system: