diff --git a/api/settings.py b/api/settings.py index 3c3d29574..3aa7bf8f3 100644 --- a/api/settings.py +++ b/api/settings.py @@ -66,75 +66,28 @@ def init_settings(): DATABASE_TYPE = os.getenv("DB_TYPE", 'mysql') DATABASE = decrypt_database_config(name=DATABASE_TYPE) LLM = get_base_config("user_default_llm", {}) + LLM_DEFAULT_MODELS = LLM.get("default_models", {}) LLM_FACTORY = LLM.get("factory", "Tongyi-Qianwen") LLM_BASE_URL = LLM.get("base_url") global CHAT_MDL, EMBEDDING_MDL, RERANK_MDL, ASR_MDL, IMAGE2TEXT_MDL if not LIGHTEN: - default_llm = { - "Tongyi-Qianwen": { - "chat_model": "qwen-plus", - "embedding_model": "text-embedding-v2", - "image2text_model": "qwen-vl-max", - "asr_model": "paraformer-realtime-8k-v1", - }, - "OpenAI": { - "chat_model": "gpt-3.5-turbo", - "embedding_model": "text-embedding-ada-002", - "image2text_model": "gpt-4-vision-preview", - "asr_model": "whisper-1", - }, - "Azure-OpenAI": { - "chat_model": "gpt-35-turbo", - "embedding_model": "text-embedding-ada-002", - "image2text_model": "gpt-4-vision-preview", - "asr_model": "whisper-1", - }, - "ZHIPU-AI": { - "chat_model": "glm-3-turbo", - "embedding_model": "embedding-2", - "image2text_model": "glm-4v", - "asr_model": "", - }, - "Ollama": { - "chat_model": "qwen-14B-chat", - "embedding_model": "flag-embedding", - "image2text_model": "", - "asr_model": "", - }, - "Moonshot": { - "chat_model": "moonshot-v1-8k", - "embedding_model": "", - "image2text_model": "", - "asr_model": "", - }, - "DeepSeek": { - "chat_model": "deepseek-chat", - "embedding_model": "", - "image2text_model": "", - "asr_model": "", - }, - "VolcEngine": { - "chat_model": "", - "embedding_model": "", - "image2text_model": "", - "asr_model": "", - }, - "BAAI": { - "chat_model": "", - "embedding_model": "BAAI/bge-large-zh-v1.5", - "image2text_model": "", - "asr_model": "", - "rerank_model": "BAAI/bge-reranker-v2-m3", - } - } + EMBEDDING_MDL = "BAAI/bge-large-zh-v1.5@BAAI" - if LLM_FACTORY: - CHAT_MDL = default_llm[LLM_FACTORY]["chat_model"] + f"@{LLM_FACTORY}" - ASR_MDL = default_llm[LLM_FACTORY]["asr_model"] + f"@{LLM_FACTORY}" - IMAGE2TEXT_MDL = default_llm[LLM_FACTORY]["image2text_model"] + f"@{LLM_FACTORY}" - EMBEDDING_MDL = default_llm["BAAI"]["embedding_model"] + "@BAAI" - RERANK_MDL = default_llm["BAAI"]["rerank_model"] + "@BAAI" + if LLM_DEFAULT_MODELS: + CHAT_MDL = LLM_DEFAULT_MODELS.get("chat_model", CHAT_MDL) + EMBEDDING_MDL = LLM_DEFAULT_MODELS.get("embedding_model", EMBEDDING_MDL) + RERANK_MDL = LLM_DEFAULT_MODELS.get("rerank_model", RERANK_MDL) + ASR_MDL = LLM_DEFAULT_MODELS.get("asr_model", ASR_MDL) + IMAGE2TEXT_MDL = LLM_DEFAULT_MODELS.get("image2text_model", IMAGE2TEXT_MDL) + + # factory can be specified in the config name with "@". LLM_FACTORY will be used if not specified + CHAT_MDL = CHAT_MDL + (f"@{LLM_FACTORY}" if "@" not in CHAT_MDL and CHAT_MDL != "" else "") + EMBEDDING_MDL = EMBEDDING_MDL + (f"@{LLM_FACTORY}" if "@" not in EMBEDDING_MDL and EMBEDDING_MDL != "" else "") + RERANK_MDL = RERANK_MDL + (f"@{LLM_FACTORY}" if "@" not in RERANK_MDL and RERANK_MDL != "" else "") + ASR_MDL = ASR_MDL + (f"@{LLM_FACTORY}" if "@" not in ASR_MDL and ASR_MDL != "" else "") + IMAGE2TEXT_MDL = IMAGE2TEXT_MDL + ( + f"@{LLM_FACTORY}" if "@" not in IMAGE2TEXT_MDL and IMAGE2TEXT_MDL != "" else "") global API_KEY, PARSERS, HOST_IP, HOST_PORT, SECRET_KEY API_KEY = LLM.get("api_key", "") diff --git a/docker/service_conf.yaml.template b/docker/service_conf.yaml.template index f4acd8bc5..c667161f0 100644 --- a/docker/service_conf.yaml.template +++ b/docker/service_conf.yaml.template @@ -52,6 +52,12 @@ redis: # factory: 'Tongyi-Qianwen' # api_key: 'sk-xxxxxxxxxxxxx' # base_url: '' +# default_models: +# chat_model: 'qwen-plus' +# embedding_model: 'BAAI/bge-large-zh-v1.5@BAAI' +# rerank_model: '' +# asr_model: '' +# image2text_model: '' # oauth: # github: # client_id: xxxxxxxxxxxxxxxxxxxxxxxxx