From a0b461a18e19e3d932297114b172393792091fa0 Mon Sep 17 00:00:00 2001 From: Omar Leonardo Sanchez Granados Date: Sun, 23 Feb 2025 21:13:39 -0500 Subject: [PATCH] Add configuration to choose default llm models (#5245) ### What problem does this PR solve? This pull request includes changes to the `api/settings.py` and `docker/service_conf.yaml.template` files to add support for default models in the LLM configuration (specially for LIGHTEN builds). The most important changes include adding default model configurations and updating the initialization settings to use these defaults. For example: With this configuration Bedrock will be enable by default with claude and titan embeddings. ``` user_default_llm: factory: 'Bedrock' api_key: '{}' base_url: '' default_models: chat_model: 'anthropic.claude-3-5-sonnet-20240620-v1:0' embedding_model: 'amazon.titan-embed-text-v2:0' rerank_model: '' asr_model: '' image2text_model: '' ``` ### Type of change - [X] New Feature (non-breaking change which adds functionality) --------- Co-authored-by: Kevin Hu --- api/settings.py | 79 +++++++------------------------ docker/service_conf.yaml.template | 6 +++ 2 files changed, 22 insertions(+), 63 deletions(-) diff --git a/api/settings.py b/api/settings.py index 3c3d29574..3aa7bf8f3 100644 --- a/api/settings.py +++ b/api/settings.py @@ -66,75 +66,28 @@ def init_settings(): DATABASE_TYPE = os.getenv("DB_TYPE", 'mysql') DATABASE = decrypt_database_config(name=DATABASE_TYPE) LLM = get_base_config("user_default_llm", {}) + LLM_DEFAULT_MODELS = LLM.get("default_models", {}) LLM_FACTORY = LLM.get("factory", "Tongyi-Qianwen") LLM_BASE_URL = LLM.get("base_url") global CHAT_MDL, EMBEDDING_MDL, RERANK_MDL, ASR_MDL, IMAGE2TEXT_MDL if not LIGHTEN: - default_llm = { - "Tongyi-Qianwen": { - "chat_model": "qwen-plus", - "embedding_model": "text-embedding-v2", - "image2text_model": "qwen-vl-max", - "asr_model": "paraformer-realtime-8k-v1", - }, - "OpenAI": { - "chat_model": "gpt-3.5-turbo", - "embedding_model": "text-embedding-ada-002", - "image2text_model": "gpt-4-vision-preview", - "asr_model": "whisper-1", - }, - "Azure-OpenAI": { - "chat_model": "gpt-35-turbo", - "embedding_model": "text-embedding-ada-002", - "image2text_model": "gpt-4-vision-preview", - "asr_model": "whisper-1", - }, - "ZHIPU-AI": { - "chat_model": "glm-3-turbo", - "embedding_model": "embedding-2", - "image2text_model": "glm-4v", - "asr_model": "", - }, - "Ollama": { - "chat_model": "qwen-14B-chat", - "embedding_model": "flag-embedding", - "image2text_model": "", - "asr_model": "", - }, - "Moonshot": { - "chat_model": "moonshot-v1-8k", - "embedding_model": "", - "image2text_model": "", - "asr_model": "", - }, - "DeepSeek": { - "chat_model": "deepseek-chat", - "embedding_model": "", - "image2text_model": "", - "asr_model": "", - }, - "VolcEngine": { - "chat_model": "", - "embedding_model": "", - "image2text_model": "", - "asr_model": "", - }, - "BAAI": { - "chat_model": "", - "embedding_model": "BAAI/bge-large-zh-v1.5", - "image2text_model": "", - "asr_model": "", - "rerank_model": "BAAI/bge-reranker-v2-m3", - } - } + EMBEDDING_MDL = "BAAI/bge-large-zh-v1.5@BAAI" - if LLM_FACTORY: - CHAT_MDL = default_llm[LLM_FACTORY]["chat_model"] + f"@{LLM_FACTORY}" - ASR_MDL = default_llm[LLM_FACTORY]["asr_model"] + f"@{LLM_FACTORY}" - IMAGE2TEXT_MDL = default_llm[LLM_FACTORY]["image2text_model"] + f"@{LLM_FACTORY}" - EMBEDDING_MDL = default_llm["BAAI"]["embedding_model"] + "@BAAI" - RERANK_MDL = default_llm["BAAI"]["rerank_model"] + "@BAAI" + if LLM_DEFAULT_MODELS: + CHAT_MDL = LLM_DEFAULT_MODELS.get("chat_model", CHAT_MDL) + EMBEDDING_MDL = LLM_DEFAULT_MODELS.get("embedding_model", EMBEDDING_MDL) + RERANK_MDL = LLM_DEFAULT_MODELS.get("rerank_model", RERANK_MDL) + ASR_MDL = LLM_DEFAULT_MODELS.get("asr_model", ASR_MDL) + IMAGE2TEXT_MDL = LLM_DEFAULT_MODELS.get("image2text_model", IMAGE2TEXT_MDL) + + # factory can be specified in the config name with "@". LLM_FACTORY will be used if not specified + CHAT_MDL = CHAT_MDL + (f"@{LLM_FACTORY}" if "@" not in CHAT_MDL and CHAT_MDL != "" else "") + EMBEDDING_MDL = EMBEDDING_MDL + (f"@{LLM_FACTORY}" if "@" not in EMBEDDING_MDL and EMBEDDING_MDL != "" else "") + RERANK_MDL = RERANK_MDL + (f"@{LLM_FACTORY}" if "@" not in RERANK_MDL and RERANK_MDL != "" else "") + ASR_MDL = ASR_MDL + (f"@{LLM_FACTORY}" if "@" not in ASR_MDL and ASR_MDL != "" else "") + IMAGE2TEXT_MDL = IMAGE2TEXT_MDL + ( + f"@{LLM_FACTORY}" if "@" not in IMAGE2TEXT_MDL and IMAGE2TEXT_MDL != "" else "") global API_KEY, PARSERS, HOST_IP, HOST_PORT, SECRET_KEY API_KEY = LLM.get("api_key", "") diff --git a/docker/service_conf.yaml.template b/docker/service_conf.yaml.template index f4acd8bc5..c667161f0 100644 --- a/docker/service_conf.yaml.template +++ b/docker/service_conf.yaml.template @@ -52,6 +52,12 @@ redis: # factory: 'Tongyi-Qianwen' # api_key: 'sk-xxxxxxxxxxxxx' # base_url: '' +# default_models: +# chat_model: 'qwen-plus' +# embedding_model: 'BAAI/bge-large-zh-v1.5@BAAI' +# rerank_model: '' +# asr_model: '' +# image2text_model: '' # oauth: # github: # client_id: xxxxxxxxxxxxxxxxxxxxxxxxx