feat: support llm env in env file (#251)

This commit is contained in:
DanielWalnut 2025-05-28 01:21:40 -07:00 committed by GitHub
parent 462752b462
commit 56e35c6b7f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 31 additions and 6 deletions

View File

@ -18,6 +18,8 @@ def replace_env_vars(value: str) -> str:
def process_dict(config: Dict[str, Any]) -> Dict[str, Any]:
"""Recursively process dictionary to replace environment variables."""
if not config:
return {}
result = {}
for key, value in config.items():
if isinstance(value, dict):

View File

@ -3,6 +3,7 @@
from pathlib import Path
from typing import Any, Dict
import os
from langchain_openai import ChatOpenAI
@ -13,18 +14,40 @@ from src.config.agents import LLMType
_llm_cache: dict[LLMType, ChatOpenAI] = {}
def _get_env_llm_conf(llm_type: str) -> Dict[str, Any]:
"""
Get LLM configuration from environment variables.
Environment variables should follow the format: {LLM_TYPE}__{KEY}
e.g., BASIC_MODEL__api_key, BASIC_MODEL__base_url
"""
prefix = f"{llm_type.upper()}_MODEL__"
conf = {}
for key, value in os.environ.items():
if key.startswith(prefix):
conf_key = key[len(prefix) :].lower()
conf[conf_key] = value
return conf
def _create_llm_use_conf(llm_type: LLMType, conf: Dict[str, Any]) -> ChatOpenAI:
llm_type_map = {
"reasoning": conf.get("REASONING_MODEL"),
"basic": conf.get("BASIC_MODEL"),
"vision": conf.get("VISION_MODEL"),
"reasoning": conf.get("REASONING_MODEL", {}),
"basic": conf.get("BASIC_MODEL", {}),
"vision": conf.get("VISION_MODEL", {}),
}
llm_conf = llm_type_map.get(llm_type)
if not llm_conf:
raise ValueError(f"Unknown LLM type: {llm_type}")
if not isinstance(llm_conf, dict):
raise ValueError(f"Invalid LLM Conf: {llm_type}")
return ChatOpenAI(**llm_conf)
# Get configuration from environment variables
env_conf = _get_env_llm_conf(llm_type)
# Merge configurations, with environment variables taking precedence
merged_conf = {**llm_conf, **env_conf}
if not merged_conf:
raise ValueError(f"Unknown LLM Conf: {llm_type}")
return ChatOpenAI(**merged_conf)
def get_llm_by_type(