From 56e35c6b7f315bbce24197c96c2eb062d2dfd604 Mon Sep 17 00:00:00 2001 From: DanielWalnut <45447813+hetaoBackend@users.noreply.github.com> Date: Wed, 28 May 2025 01:21:40 -0700 Subject: [PATCH] feat: support llm env in env file (#251) --- src/config/loader.py | 2 ++ src/llms/llm.py | 35 +++++++++++++++++++++++++++++------ 2 files changed, 31 insertions(+), 6 deletions(-) diff --git a/src/config/loader.py b/src/config/loader.py index a6edfe5..a1488c8 100644 --- a/src/config/loader.py +++ b/src/config/loader.py @@ -18,6 +18,8 @@ def replace_env_vars(value: str) -> str: def process_dict(config: Dict[str, Any]) -> Dict[str, Any]: """Recursively process dictionary to replace environment variables.""" + if not config: + return {} result = {} for key, value in config.items(): if isinstance(value, dict): diff --git a/src/llms/llm.py b/src/llms/llm.py index 81fb748..3f31189 100644 --- a/src/llms/llm.py +++ b/src/llms/llm.py @@ -3,6 +3,7 @@ from pathlib import Path from typing import Any, Dict +import os from langchain_openai import ChatOpenAI @@ -13,18 +14,40 @@ from src.config.agents import LLMType _llm_cache: dict[LLMType, ChatOpenAI] = {} +def _get_env_llm_conf(llm_type: str) -> Dict[str, Any]: + """ + Get LLM configuration from environment variables. + Environment variables should follow the format: {LLM_TYPE}__{KEY} + e.g., BASIC_MODEL__api_key, BASIC_MODEL__base_url + """ + prefix = f"{llm_type.upper()}_MODEL__" + conf = {} + for key, value in os.environ.items(): + if key.startswith(prefix): + conf_key = key[len(prefix) :].lower() + conf[conf_key] = value + return conf + + def _create_llm_use_conf(llm_type: LLMType, conf: Dict[str, Any]) -> ChatOpenAI: llm_type_map = { - "reasoning": conf.get("REASONING_MODEL"), - "basic": conf.get("BASIC_MODEL"), - "vision": conf.get("VISION_MODEL"), + "reasoning": conf.get("REASONING_MODEL", {}), + "basic": conf.get("BASIC_MODEL", {}), + "vision": conf.get("VISION_MODEL", {}), } llm_conf = llm_type_map.get(llm_type) - if not llm_conf: - raise ValueError(f"Unknown LLM type: {llm_type}") if not isinstance(llm_conf, dict): raise ValueError(f"Invalid LLM Conf: {llm_type}") - return ChatOpenAI(**llm_conf) + # Get configuration from environment variables + env_conf = _get_env_llm_conf(llm_type) + + # Merge configurations, with environment variables taking precedence + merged_conf = {**llm_conf, **env_conf} + + if not merged_conf: + raise ValueError(f"Unknown LLM Conf: {llm_type}") + + return ChatOpenAI(**merged_conf) def get_llm_by_type(