Merge pull request #2156 from cheahjs/feat/save-config

feat: save UI config changes to config.json
This commit is contained in:
Timothy Jaeryang Baek 2024-05-13 11:45:30 -10:00 committed by GitHub
commit 8b0144cd06
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 632 additions and 369 deletions

View File

@ -45,6 +45,7 @@ from config import (
AUDIO_OPENAI_API_KEY, AUDIO_OPENAI_API_KEY,
AUDIO_OPENAI_API_MODEL, AUDIO_OPENAI_API_MODEL,
AUDIO_OPENAI_API_VOICE, AUDIO_OPENAI_API_VOICE,
AppConfig,
) )
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -59,11 +60,11 @@ app.add_middleware(
allow_headers=["*"], allow_headers=["*"],
) )
app.state.config = AppConfig()
app.state.OPENAI_API_BASE_URL = AUDIO_OPENAI_API_BASE_URL app.state.config.OPENAI_API_BASE_URL = AUDIO_OPENAI_API_BASE_URL
app.state.OPENAI_API_KEY = AUDIO_OPENAI_API_KEY app.state.config.OPENAI_API_KEY = AUDIO_OPENAI_API_KEY
app.state.OPENAI_API_MODEL = AUDIO_OPENAI_API_MODEL app.state.config.OPENAI_API_MODEL = AUDIO_OPENAI_API_MODEL
app.state.OPENAI_API_VOICE = AUDIO_OPENAI_API_VOICE app.state.config.OPENAI_API_VOICE = AUDIO_OPENAI_API_VOICE
# setting device type for whisper model # setting device type for whisper model
whisper_device_type = DEVICE_TYPE if DEVICE_TYPE and DEVICE_TYPE == "cuda" else "cpu" whisper_device_type = DEVICE_TYPE if DEVICE_TYPE and DEVICE_TYPE == "cuda" else "cpu"
@ -83,10 +84,10 @@ class OpenAIConfigUpdateForm(BaseModel):
@app.get("/config") @app.get("/config")
async def get_openai_config(user=Depends(get_admin_user)): async def get_openai_config(user=Depends(get_admin_user)):
return { return {
"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL, "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL,
"OPENAI_API_KEY": app.state.OPENAI_API_KEY, "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY,
"OPENAI_API_MODEL": app.state.OPENAI_API_MODEL, "OPENAI_API_MODEL": app.state.config.OPENAI_API_MODEL,
"OPENAI_API_VOICE": app.state.OPENAI_API_VOICE, "OPENAI_API_VOICE": app.state.config.OPENAI_API_VOICE,
} }
@ -97,17 +98,17 @@ async def update_openai_config(
if form_data.key == "": if form_data.key == "":
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND) raise HTTPException(status_code=400, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
app.state.OPENAI_API_BASE_URL = form_data.url app.state.config.OPENAI_API_BASE_URL = form_data.url
app.state.OPENAI_API_KEY = form_data.key app.state.config.OPENAI_API_KEY = form_data.key
app.state.OPENAI_API_MODEL = form_data.model app.state.config.OPENAI_API_MODEL = form_data.model
app.state.OPENAI_API_VOICE = form_data.speaker app.state.config.OPENAI_API_VOICE = form_data.speaker
return { return {
"status": True, "status": True,
"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL, "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL,
"OPENAI_API_KEY": app.state.OPENAI_API_KEY, "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY,
"OPENAI_API_MODEL": app.state.OPENAI_API_MODEL, "OPENAI_API_MODEL": app.state.config.OPENAI_API_MODEL,
"OPENAI_API_VOICE": app.state.OPENAI_API_VOICE, "OPENAI_API_VOICE": app.state.config.OPENAI_API_VOICE,
} }
@ -124,13 +125,13 @@ async def speech(request: Request, user=Depends(get_verified_user)):
return FileResponse(file_path) return FileResponse(file_path)
headers = {} headers = {}
headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEY}" headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEY}"
headers["Content-Type"] = "application/json" headers["Content-Type"] = "application/json"
r = None r = None
try: try:
r = requests.post( r = requests.post(
url=f"{app.state.OPENAI_API_BASE_URL}/audio/speech", url=f"{app.state.config.OPENAI_API_BASE_URL}/audio/speech",
data=body, data=body,
headers=headers, headers=headers,
stream=True, stream=True,

View File

@ -42,6 +42,7 @@ from config import (
IMAGE_GENERATION_MODEL, IMAGE_GENERATION_MODEL,
IMAGE_SIZE, IMAGE_SIZE,
IMAGE_STEPS, IMAGE_STEPS,
AppConfig,
) )
@ -60,26 +61,31 @@ app.add_middleware(
allow_headers=["*"], allow_headers=["*"],
) )
app.state.ENGINE = IMAGE_GENERATION_ENGINE app.state.config = AppConfig()
app.state.ENABLED = ENABLE_IMAGE_GENERATION
app.state.OPENAI_API_BASE_URL = IMAGES_OPENAI_API_BASE_URL app.state.config.ENGINE = IMAGE_GENERATION_ENGINE
app.state.OPENAI_API_KEY = IMAGES_OPENAI_API_KEY app.state.config.ENABLED = ENABLE_IMAGE_GENERATION
app.state.MODEL = IMAGE_GENERATION_MODEL app.state.config.OPENAI_API_BASE_URL = IMAGES_OPENAI_API_BASE_URL
app.state.config.OPENAI_API_KEY = IMAGES_OPENAI_API_KEY
app.state.config.MODEL = IMAGE_GENERATION_MODEL
app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
app.state.COMFYUI_BASE_URL = COMFYUI_BASE_URL app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL
app.state.IMAGE_SIZE = IMAGE_SIZE app.state.config.IMAGE_SIZE = IMAGE_SIZE
app.state.IMAGE_STEPS = IMAGE_STEPS app.state.config.IMAGE_STEPS = IMAGE_STEPS
@app.get("/config") @app.get("/config")
async def get_config(request: Request, user=Depends(get_admin_user)): async def get_config(request: Request, user=Depends(get_admin_user)):
return {"engine": app.state.ENGINE, "enabled": app.state.ENABLED} return {
"engine": app.state.config.ENGINE,
"enabled": app.state.config.ENABLED,
}
class ConfigUpdateForm(BaseModel): class ConfigUpdateForm(BaseModel):
@ -89,9 +95,12 @@ class ConfigUpdateForm(BaseModel):
@app.post("/config/update") @app.post("/config/update")
async def update_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)): async def update_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
app.state.ENGINE = form_data.engine app.state.config.ENGINE = form_data.engine
app.state.ENABLED = form_data.enabled app.state.config.ENABLED = form_data.enabled
return {"engine": app.state.ENGINE, "enabled": app.state.ENABLED} return {
"engine": app.state.config.ENGINE,
"enabled": app.state.config.ENABLED,
}
class EngineUrlUpdateForm(BaseModel): class EngineUrlUpdateForm(BaseModel):
@ -102,8 +111,8 @@ class EngineUrlUpdateForm(BaseModel):
@app.get("/url") @app.get("/url")
async def get_engine_url(user=Depends(get_admin_user)): async def get_engine_url(user=Depends(get_admin_user)):
return { return {
"AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL, "AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL,
"COMFYUI_BASE_URL": app.state.COMFYUI_BASE_URL, "COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL,
} }
@ -113,29 +122,29 @@ async def update_engine_url(
): ):
if form_data.AUTOMATIC1111_BASE_URL == None: if form_data.AUTOMATIC1111_BASE_URL == None:
app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
else: else:
url = form_data.AUTOMATIC1111_BASE_URL.strip("/") url = form_data.AUTOMATIC1111_BASE_URL.strip("/")
try: try:
r = requests.head(url) r = requests.head(url)
app.state.AUTOMATIC1111_BASE_URL = url app.state.config.AUTOMATIC1111_BASE_URL = url
except Exception as e: except Exception as e:
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
if form_data.COMFYUI_BASE_URL == None: if form_data.COMFYUI_BASE_URL == None:
app.state.COMFYUI_BASE_URL = COMFYUI_BASE_URL app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL
else: else:
url = form_data.COMFYUI_BASE_URL.strip("/") url = form_data.COMFYUI_BASE_URL.strip("/")
try: try:
r = requests.head(url) r = requests.head(url)
app.state.COMFYUI_BASE_URL = url app.state.config.COMFYUI_BASE_URL = url
except Exception as e: except Exception as e:
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
return { return {
"AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL, "AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL,
"COMFYUI_BASE_URL": app.state.COMFYUI_BASE_URL, "COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL,
"status": True, "status": True,
} }
@ -148,8 +157,8 @@ class OpenAIConfigUpdateForm(BaseModel):
@app.get("/openai/config") @app.get("/openai/config")
async def get_openai_config(user=Depends(get_admin_user)): async def get_openai_config(user=Depends(get_admin_user)):
return { return {
"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL, "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL,
"OPENAI_API_KEY": app.state.OPENAI_API_KEY, "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY,
} }
@ -160,13 +169,13 @@ async def update_openai_config(
if form_data.key == "": if form_data.key == "":
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND) raise HTTPException(status_code=400, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
app.state.OPENAI_API_BASE_URL = form_data.url app.state.config.OPENAI_API_BASE_URL = form_data.url
app.state.OPENAI_API_KEY = form_data.key app.state.config.OPENAI_API_KEY = form_data.key
return { return {
"status": True, "status": True,
"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL, "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL,
"OPENAI_API_KEY": app.state.OPENAI_API_KEY, "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY,
} }
@ -176,7 +185,7 @@ class ImageSizeUpdateForm(BaseModel):
@app.get("/size") @app.get("/size")
async def get_image_size(user=Depends(get_admin_user)): async def get_image_size(user=Depends(get_admin_user)):
return {"IMAGE_SIZE": app.state.IMAGE_SIZE} return {"IMAGE_SIZE": app.state.config.IMAGE_SIZE}
@app.post("/size/update") @app.post("/size/update")
@ -185,9 +194,9 @@ async def update_image_size(
): ):
pattern = r"^\d+x\d+$" # Regular expression pattern pattern = r"^\d+x\d+$" # Regular expression pattern
if re.match(pattern, form_data.size): if re.match(pattern, form_data.size):
app.state.IMAGE_SIZE = form_data.size app.state.config.IMAGE_SIZE = form_data.size
return { return {
"IMAGE_SIZE": app.state.IMAGE_SIZE, "IMAGE_SIZE": app.state.config.IMAGE_SIZE,
"status": True, "status": True,
} }
else: else:
@ -203,7 +212,7 @@ class ImageStepsUpdateForm(BaseModel):
@app.get("/steps") @app.get("/steps")
async def get_image_size(user=Depends(get_admin_user)): async def get_image_size(user=Depends(get_admin_user)):
return {"IMAGE_STEPS": app.state.IMAGE_STEPS} return {"IMAGE_STEPS": app.state.config.IMAGE_STEPS}
@app.post("/steps/update") @app.post("/steps/update")
@ -211,9 +220,9 @@ async def update_image_size(
form_data: ImageStepsUpdateForm, user=Depends(get_admin_user) form_data: ImageStepsUpdateForm, user=Depends(get_admin_user)
): ):
if form_data.steps >= 0: if form_data.steps >= 0:
app.state.IMAGE_STEPS = form_data.steps app.state.config.IMAGE_STEPS = form_data.steps
return { return {
"IMAGE_STEPS": app.state.IMAGE_STEPS, "IMAGE_STEPS": app.state.config.IMAGE_STEPS,
"status": True, "status": True,
} }
else: else:
@ -226,14 +235,14 @@ async def update_image_size(
@app.get("/models") @app.get("/models")
def get_models(user=Depends(get_current_user)): def get_models(user=Depends(get_current_user)):
try: try:
if app.state.ENGINE == "openai": if app.state.config.ENGINE == "openai":
return [ return [
{"id": "dall-e-2", "name": "DALL·E 2"}, {"id": "dall-e-2", "name": "DALL·E 2"},
{"id": "dall-e-3", "name": "DALL·E 3"}, {"id": "dall-e-3", "name": "DALL·E 3"},
] ]
elif app.state.ENGINE == "comfyui": elif app.state.config.ENGINE == "comfyui":
r = requests.get(url=f"{app.state.COMFYUI_BASE_URL}/object_info") r = requests.get(url=f"{app.state.config.COMFYUI_BASE_URL}/object_info")
info = r.json() info = r.json()
return list( return list(
@ -245,7 +254,7 @@ def get_models(user=Depends(get_current_user)):
else: else:
r = requests.get( r = requests.get(
url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models" url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models"
) )
models = r.json() models = r.json()
return list( return list(
@ -255,23 +264,29 @@ def get_models(user=Depends(get_current_user)):
) )
) )
except Exception as e: except Exception as e:
app.state.ENABLED = False app.state.config.ENABLED = False
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
@app.get("/models/default") @app.get("/models/default")
async def get_default_model(user=Depends(get_admin_user)): async def get_default_model(user=Depends(get_admin_user)):
try: try:
if app.state.ENGINE == "openai": if app.state.config.ENGINE == "openai":
return {"model": app.state.MODEL if app.state.MODEL else "dall-e-2"} return {
elif app.state.ENGINE == "comfyui": "model": (
return {"model": app.state.MODEL if app.state.MODEL else ""} app.state.config.MODEL if app.state.config.MODEL else "dall-e-2"
)
}
elif app.state.config.ENGINE == "comfyui":
return {"model": (app.state.config.MODEL if app.state.config.MODEL else "")}
else: else:
r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options") r = requests.get(
url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options"
)
options = r.json() options = r.json()
return {"model": options["sd_model_checkpoint"]} return {"model": options["sd_model_checkpoint"]}
except Exception as e: except Exception as e:
app.state.ENABLED = False app.state.config.ENABLED = False
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
@ -280,20 +295,20 @@ class UpdateModelForm(BaseModel):
def set_model_handler(model: str): def set_model_handler(model: str):
if app.state.ENGINE == "openai": if app.state.config.ENGINE in ["openai", "comfyui"]:
app.state.MODEL = model app.state.config.MODEL = model
return app.state.MODEL return app.state.config.MODEL
if app.state.ENGINE == "comfyui":
app.state.MODEL = model
return app.state.MODEL
else: else:
r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options") r = requests.get(
url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options"
)
options = r.json() options = r.json()
if model != options["sd_model_checkpoint"]: if model != options["sd_model_checkpoint"]:
options["sd_model_checkpoint"] = model options["sd_model_checkpoint"] = model
r = requests.post( r = requests.post(
url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", json=options url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
json=options,
) )
return options return options
@ -382,26 +397,32 @@ def generate_image(
user=Depends(get_current_user), user=Depends(get_current_user),
): ):
width, height = tuple(map(int, app.state.IMAGE_SIZE.split("x"))) width, height = tuple(map(int, app.state.config.IMAGE_SIZE).split("x"))
r = None r = None
try: try:
if app.state.ENGINE == "openai": if app.state.config.ENGINE == "openai":
headers = {} headers = {}
headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEY}" headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEY}"
headers["Content-Type"] = "application/json" headers["Content-Type"] = "application/json"
data = { data = {
"model": app.state.MODEL if app.state.MODEL != "" else "dall-e-2", "model": (
app.state.config.MODEL
if app.state.config.MODEL != ""
else "dall-e-2"
),
"prompt": form_data.prompt, "prompt": form_data.prompt,
"n": form_data.n, "n": form_data.n,
"size": form_data.size if form_data.size else app.state.IMAGE_SIZE, "size": (
form_data.size if form_data.size else app.state.config.IMAGE_SIZE
),
"response_format": "b64_json", "response_format": "b64_json",
} }
r = requests.post( r = requests.post(
url=f"{app.state.OPENAI_API_BASE_URL}/images/generations", url=f"{app.state.config.OPENAI_API_BASE_URL}/images/generations",
json=data, json=data,
headers=headers, headers=headers,
) )
@ -421,7 +442,7 @@ def generate_image(
return images return images
elif app.state.ENGINE == "comfyui": elif app.state.config.ENGINE == "comfyui":
data = { data = {
"prompt": form_data.prompt, "prompt": form_data.prompt,
@ -430,19 +451,19 @@ def generate_image(
"n": form_data.n, "n": form_data.n,
} }
if app.state.IMAGE_STEPS != None: if app.state.config.IMAGE_STEPS is not None:
data["steps"] = app.state.IMAGE_STEPS data["steps"] = app.state.config.IMAGE_STEPS
if form_data.negative_prompt != None: if form_data.negative_prompt is not None:
data["negative_prompt"] = form_data.negative_prompt data["negative_prompt"] = form_data.negative_prompt
data = ImageGenerationPayload(**data) data = ImageGenerationPayload(**data)
res = comfyui_generate_image( res = comfyui_generate_image(
app.state.MODEL, app.state.config.MODEL,
data, data,
user.id, user.id,
app.state.COMFYUI_BASE_URL, app.state.config.COMFYUI_BASE_URL,
) )
log.debug(f"res: {res}") log.debug(f"res: {res}")
@ -469,14 +490,14 @@ def generate_image(
"height": height, "height": height,
} }
if app.state.IMAGE_STEPS != None: if app.state.config.IMAGE_STEPS is not None:
data["steps"] = app.state.IMAGE_STEPS data["steps"] = app.state.config.IMAGE_STEPS
if form_data.negative_prompt != None: if form_data.negative_prompt is not None:
data["negative_prompt"] = form_data.negative_prompt data["negative_prompt"] = form_data.negative_prompt
r = requests.post( r = requests.post(
url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img", url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
json=data, json=data,
) )

View File

@ -46,6 +46,7 @@ from config import (
ENABLE_MODEL_FILTER, ENABLE_MODEL_FILTER,
MODEL_FILTER_LIST, MODEL_FILTER_LIST,
UPLOAD_DIR, UPLOAD_DIR,
AppConfig,
) )
from utils.misc import calculate_sha256 from utils.misc import calculate_sha256
@ -61,11 +62,12 @@ app.add_middleware(
allow_headers=["*"], allow_headers=["*"],
) )
app.state.config = AppConfig()
app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
app.state.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS app.state.config.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS
app.state.MODELS = {} app.state.MODELS = {}
@ -96,7 +98,7 @@ async def get_status():
@app.get("/urls") @app.get("/urls")
async def get_ollama_api_urls(user=Depends(get_admin_user)): async def get_ollama_api_urls(user=Depends(get_admin_user)):
return {"OLLAMA_BASE_URLS": app.state.OLLAMA_BASE_URLS} return {"OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS}
class UrlUpdateForm(BaseModel): class UrlUpdateForm(BaseModel):
@ -105,10 +107,10 @@ class UrlUpdateForm(BaseModel):
@app.post("/urls/update") @app.post("/urls/update")
async def update_ollama_api_url(form_data: UrlUpdateForm, user=Depends(get_admin_user)): async def update_ollama_api_url(form_data: UrlUpdateForm, user=Depends(get_admin_user)):
app.state.OLLAMA_BASE_URLS = form_data.urls app.state.config.OLLAMA_BASE_URLS = form_data.urls
log.info(f"app.state.OLLAMA_BASE_URLS: {app.state.OLLAMA_BASE_URLS}") log.info(f"app.state.config.OLLAMA_BASE_URLS: {app.state.config.OLLAMA_BASE_URLS}")
return {"OLLAMA_BASE_URLS": app.state.OLLAMA_BASE_URLS} return {"OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS}
@app.get("/cancel/{request_id}") @app.get("/cancel/{request_id}")
@ -153,7 +155,7 @@ def merge_models_lists(model_lists):
async def get_all_models(): async def get_all_models():
log.info("get_all_models()") log.info("get_all_models()")
tasks = [fetch_url(f"{url}/api/tags") for url in app.state.OLLAMA_BASE_URLS] tasks = [fetch_url(f"{url}/api/tags") for url in app.state.config.OLLAMA_BASE_URLS]
responses = await asyncio.gather(*tasks) responses = await asyncio.gather(*tasks)
models = { models = {
@ -186,7 +188,7 @@ async def get_ollama_tags(
return models return models
return models return models
else: else:
url = app.state.OLLAMA_BASE_URLS[url_idx] url = app.state.config.OLLAMA_BASE_URLS[url_idx]
try: try:
r = requests.request(method="GET", url=f"{url}/api/tags") r = requests.request(method="GET", url=f"{url}/api/tags")
r.raise_for_status() r.raise_for_status()
@ -216,7 +218,9 @@ async def get_ollama_versions(url_idx: Optional[int] = None):
if url_idx == None: if url_idx == None:
# returns lowest version # returns lowest version
tasks = [fetch_url(f"{url}/api/version") for url in app.state.OLLAMA_BASE_URLS] tasks = [
fetch_url(f"{url}/api/version") for url in app.state.config.OLLAMA_BASE_URLS
]
responses = await asyncio.gather(*tasks) responses = await asyncio.gather(*tasks)
responses = list(filter(lambda x: x is not None, responses)) responses = list(filter(lambda x: x is not None, responses))
@ -235,7 +239,7 @@ async def get_ollama_versions(url_idx: Optional[int] = None):
detail=ERROR_MESSAGES.OLLAMA_NOT_FOUND, detail=ERROR_MESSAGES.OLLAMA_NOT_FOUND,
) )
else: else:
url = app.state.OLLAMA_BASE_URLS[url_idx] url = app.state.config.OLLAMA_BASE_URLS[url_idx]
try: try:
r = requests.request(method="GET", url=f"{url}/api/version") r = requests.request(method="GET", url=f"{url}/api/version")
r.raise_for_status() r.raise_for_status()
@ -267,7 +271,7 @@ class ModelNameForm(BaseModel):
async def pull_model( async def pull_model(
form_data: ModelNameForm, url_idx: int = 0, user=Depends(get_admin_user) form_data: ModelNameForm, url_idx: int = 0, user=Depends(get_admin_user)
): ):
url = app.state.OLLAMA_BASE_URLS[url_idx] url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}") log.info(f"url: {url}")
r = None r = None
@ -355,7 +359,7 @@ async def push_model(
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name), detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name),
) )
url = app.state.OLLAMA_BASE_URLS[url_idx] url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.debug(f"url: {url}") log.debug(f"url: {url}")
r = None r = None
@ -417,7 +421,7 @@ async def create_model(
form_data: CreateModelForm, url_idx: int = 0, user=Depends(get_admin_user) form_data: CreateModelForm, url_idx: int = 0, user=Depends(get_admin_user)
): ):
log.debug(f"form_data: {form_data}") log.debug(f"form_data: {form_data}")
url = app.state.OLLAMA_BASE_URLS[url_idx] url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}") log.info(f"url: {url}")
r = None r = None
@ -490,7 +494,7 @@ async def copy_model(
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.source), detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.source),
) )
url = app.state.OLLAMA_BASE_URLS[url_idx] url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}") log.info(f"url: {url}")
try: try:
@ -537,7 +541,7 @@ async def delete_model(
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name), detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name),
) )
url = app.state.OLLAMA_BASE_URLS[url_idx] url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}") log.info(f"url: {url}")
try: try:
@ -577,7 +581,7 @@ async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_us
) )
url_idx = random.choice(app.state.MODELS[form_data.name]["urls"]) url_idx = random.choice(app.state.MODELS[form_data.name]["urls"])
url = app.state.OLLAMA_BASE_URLS[url_idx] url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}") log.info(f"url: {url}")
try: try:
@ -634,7 +638,7 @@ async def generate_embeddings(
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
) )
url = app.state.OLLAMA_BASE_URLS[url_idx] url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}") log.info(f"url: {url}")
try: try:
@ -684,7 +688,7 @@ def generate_ollama_embeddings(
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
) )
url = app.state.OLLAMA_BASE_URLS[url_idx] url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}") log.info(f"url: {url}")
try: try:
@ -753,7 +757,7 @@ async def generate_completion(
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
) )
url = app.state.OLLAMA_BASE_URLS[url_idx] url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}") log.info(f"url: {url}")
r = None r = None
@ -856,7 +860,7 @@ async def generate_chat_completion(
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
) )
url = app.state.OLLAMA_BASE_URLS[url_idx] url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}") log.info(f"url: {url}")
r = None r = None
@ -965,7 +969,7 @@ async def generate_openai_chat_completion(
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
) )
url = app.state.OLLAMA_BASE_URLS[url_idx] url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}") log.info(f"url: {url}")
r = None r = None
@ -1064,7 +1068,7 @@ async def get_openai_models(
} }
else: else:
url = app.state.OLLAMA_BASE_URLS[url_idx] url = app.state.config.OLLAMA_BASE_URLS[url_idx]
try: try:
r = requests.request(method="GET", url=f"{url}/api/tags") r = requests.request(method="GET", url=f"{url}/api/tags")
r.raise_for_status() r.raise_for_status()
@ -1198,7 +1202,7 @@ async def download_model(
if url_idx == None: if url_idx == None:
url_idx = 0 url_idx = 0
url = app.state.OLLAMA_BASE_URLS[url_idx] url = app.state.config.OLLAMA_BASE_URLS[url_idx]
file_name = parse_huggingface_url(form_data.url) file_name = parse_huggingface_url(form_data.url)
@ -1217,7 +1221,7 @@ async def download_model(
def upload_model(file: UploadFile = File(...), url_idx: Optional[int] = None): def upload_model(file: UploadFile = File(...), url_idx: Optional[int] = None):
if url_idx == None: if url_idx == None:
url_idx = 0 url_idx = 0
ollama_url = app.state.OLLAMA_BASE_URLS[url_idx] ollama_url = app.state.config.OLLAMA_BASE_URLS[url_idx]
file_path = f"{UPLOAD_DIR}/{file.filename}" file_path = f"{UPLOAD_DIR}/{file.filename}"
@ -1282,7 +1286,7 @@ def upload_model(file: UploadFile = File(...), url_idx: Optional[int] = None):
# async def upload_model(file: UploadFile = File(), url_idx: Optional[int] = None): # async def upload_model(file: UploadFile = File(), url_idx: Optional[int] = None):
# if url_idx == None: # if url_idx == None:
# url_idx = 0 # url_idx = 0
# url = app.state.OLLAMA_BASE_URLS[url_idx] # url = app.state.config.OLLAMA_BASE_URLS[url_idx]
# file_location = os.path.join(UPLOAD_DIR, file.filename) # file_location = os.path.join(UPLOAD_DIR, file.filename)
# total_size = file.size # total_size = file.size
@ -1319,7 +1323,7 @@ def upload_model(file: UploadFile = File(...), url_idx: Optional[int] = None):
async def deprecated_proxy( async def deprecated_proxy(
path: str, request: Request, user=Depends(get_verified_user) path: str, request: Request, user=Depends(get_verified_user)
): ):
url = app.state.OLLAMA_BASE_URLS[0] url = app.state.config.OLLAMA_BASE_URLS[0]
target_url = f"{url}/{path}" target_url = f"{url}/{path}"
body = await request.body() body = await request.body()

View File

@ -26,6 +26,7 @@ from config import (
CACHE_DIR, CACHE_DIR,
ENABLE_MODEL_FILTER, ENABLE_MODEL_FILTER,
MODEL_FILTER_LIST, MODEL_FILTER_LIST,
AppConfig,
) )
from typing import List, Optional from typing import List, Optional
@ -45,11 +46,13 @@ app.add_middleware(
allow_headers=["*"], allow_headers=["*"],
) )
app.state.config = AppConfig()
app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
app.state.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS app.state.config.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS
app.state.OPENAI_API_KEYS = OPENAI_API_KEYS app.state.config.OPENAI_API_KEYS = OPENAI_API_KEYS
app.state.MODELS = {} app.state.MODELS = {}
@ -75,32 +78,32 @@ class KeysUpdateForm(BaseModel):
@app.get("/urls") @app.get("/urls")
async def get_openai_urls(user=Depends(get_admin_user)): async def get_openai_urls(user=Depends(get_admin_user)):
return {"OPENAI_API_BASE_URLS": app.state.OPENAI_API_BASE_URLS} return {"OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS}
@app.post("/urls/update") @app.post("/urls/update")
async def update_openai_urls(form_data: UrlsUpdateForm, user=Depends(get_admin_user)): async def update_openai_urls(form_data: UrlsUpdateForm, user=Depends(get_admin_user)):
await get_all_models() await get_all_models()
app.state.OPENAI_API_BASE_URLS = form_data.urls app.state.config.OPENAI_API_BASE_URLS = form_data.urls
return {"OPENAI_API_BASE_URLS": app.state.OPENAI_API_BASE_URLS} return {"OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS}
@app.get("/keys") @app.get("/keys")
async def get_openai_keys(user=Depends(get_admin_user)): async def get_openai_keys(user=Depends(get_admin_user)):
return {"OPENAI_API_KEYS": app.state.OPENAI_API_KEYS} return {"OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS}
@app.post("/keys/update") @app.post("/keys/update")
async def update_openai_key(form_data: KeysUpdateForm, user=Depends(get_admin_user)): async def update_openai_key(form_data: KeysUpdateForm, user=Depends(get_admin_user)):
app.state.OPENAI_API_KEYS = form_data.keys app.state.config.OPENAI_API_KEYS = form_data.keys
return {"OPENAI_API_KEYS": app.state.OPENAI_API_KEYS} return {"OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS}
@app.post("/audio/speech") @app.post("/audio/speech")
async def speech(request: Request, user=Depends(get_verified_user)): async def speech(request: Request, user=Depends(get_verified_user)):
idx = None idx = None
try: try:
idx = app.state.OPENAI_API_BASE_URLS.index("https://api.openai.com/v1") idx = app.state.config.OPENAI_API_BASE_URLS.index("https://api.openai.com/v1")
body = await request.body() body = await request.body()
name = hashlib.sha256(body).hexdigest() name = hashlib.sha256(body).hexdigest()
@ -114,7 +117,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
return FileResponse(file_path) return FileResponse(file_path)
headers = {} headers = {}
headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEYS[idx]}" headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEYS[idx]}"
headers["Content-Type"] = "application/json" headers["Content-Type"] = "application/json"
if "openrouter.ai" in app.state.OPENAI_API_BASE_URLS[idx]: if "openrouter.ai" in app.state.OPENAI_API_BASE_URLS[idx]:
headers['HTTP-Referer'] = "https://openwebui.com/" headers['HTTP-Referer'] = "https://openwebui.com/"
@ -122,7 +125,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
r = None r = None
try: try:
r = requests.post( r = requests.post(
url=f"{app.state.OPENAI_API_BASE_URLS[idx]}/audio/speech", url=f"{app.state.config.OPENAI_API_BASE_URLS[idx]}/audio/speech",
data=body, data=body,
headers=headers, headers=headers,
stream=True, stream=True,
@ -182,7 +185,8 @@ def merge_models_lists(model_lists):
[ [
{**model, "urlIdx": idx} {**model, "urlIdx": idx}
for model in models for model in models
if "api.openai.com" not in app.state.OPENAI_API_BASE_URLS[idx] if "api.openai.com"
not in app.state.config.OPENAI_API_BASE_URLS[idx]
or "gpt" in model["id"] or "gpt" in model["id"]
] ]
) )
@ -193,12 +197,15 @@ def merge_models_lists(model_lists):
async def get_all_models(): async def get_all_models():
log.info("get_all_models()") log.info("get_all_models()")
if len(app.state.OPENAI_API_KEYS) == 1 and app.state.OPENAI_API_KEYS[0] == "": if (
len(app.state.config.OPENAI_API_KEYS) == 1
and app.state.config.OPENAI_API_KEYS[0] == ""
):
models = {"data": []} models = {"data": []}
else: else:
tasks = [ tasks = [
fetch_url(f"{url}/models", app.state.OPENAI_API_KEYS[idx]) fetch_url(f"{url}/models", app.state.config.OPENAI_API_KEYS[idx])
for idx, url in enumerate(app.state.OPENAI_API_BASE_URLS) for idx, url in enumerate(app.state.config.OPENAI_API_BASE_URLS)
] ]
responses = await asyncio.gather(*tasks) responses = await asyncio.gather(*tasks)
@ -241,7 +248,7 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_use
return models return models
return models return models
else: else:
url = app.state.OPENAI_API_BASE_URLS[url_idx] url = app.state.config.OPENAI_API_BASE_URLS[url_idx]
r = None r = None
@ -305,8 +312,8 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
log.error("Error loading request body into a dictionary:", e) log.error("Error loading request body into a dictionary:", e)
url = app.state.OPENAI_API_BASE_URLS[idx] url = app.state.config.OPENAI_API_BASE_URLS[idx]
key = app.state.OPENAI_API_KEYS[idx] key = app.state.config.OPENAI_API_KEYS[idx]
target_url = f"{url}/{path}" target_url = f"{url}/{path}"

View File

@ -93,6 +93,7 @@ from config import (
RAG_TEMPLATE, RAG_TEMPLATE,
ENABLE_RAG_LOCAL_WEB_FETCH, ENABLE_RAG_LOCAL_WEB_FETCH,
YOUTUBE_LOADER_LANGUAGE, YOUTUBE_LOADER_LANGUAGE,
AppConfig,
) )
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
@ -102,30 +103,32 @@ log.setLevel(SRC_LOG_LEVELS["RAG"])
app = FastAPI() app = FastAPI()
app.state.TOP_K = RAG_TOP_K app.state.config = AppConfig()
app.state.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD
app.state.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH app.state.config.TOP_K = RAG_TOP_K
app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = ( app.state.config.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD
app.state.config.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH
app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
) )
app.state.CHUNK_SIZE = CHUNK_SIZE app.state.config.CHUNK_SIZE = CHUNK_SIZE
app.state.CHUNK_OVERLAP = CHUNK_OVERLAP app.state.config.CHUNK_OVERLAP = CHUNK_OVERLAP
app.state.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE app.state.config.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL app.state.config.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
app.state.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL app.state.config.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
app.state.RAG_TEMPLATE = RAG_TEMPLATE app.state.config.RAG_TEMPLATE = RAG_TEMPLATE
app.state.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL app.state.config.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
app.state.OPENAI_API_KEY = RAG_OPENAI_API_KEY app.state.config.OPENAI_API_KEY = RAG_OPENAI_API_KEY
app.state.PDF_EXTRACT_IMAGES = PDF_EXTRACT_IMAGES app.state.config.PDF_EXTRACT_IMAGES = PDF_EXTRACT_IMAGES
app.state.YOUTUBE_LOADER_LANGUAGE = YOUTUBE_LOADER_LANGUAGE app.state.config.YOUTUBE_LOADER_LANGUAGE = YOUTUBE_LOADER_LANGUAGE
app.state.YOUTUBE_LOADER_TRANSLATION = None app.state.YOUTUBE_LOADER_TRANSLATION = None
@ -133,7 +136,7 @@ def update_embedding_model(
embedding_model: str, embedding_model: str,
update_model: bool = False, update_model: bool = False,
): ):
if embedding_model and app.state.RAG_EMBEDDING_ENGINE == "": if embedding_model and app.state.config.RAG_EMBEDDING_ENGINE == "":
app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer( app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer(
get_model_path(embedding_model, update_model), get_model_path(embedding_model, update_model),
device=DEVICE_TYPE, device=DEVICE_TYPE,
@ -158,22 +161,22 @@ def update_reranking_model(
update_embedding_model( update_embedding_model(
app.state.RAG_EMBEDDING_MODEL, app.state.config.RAG_EMBEDDING_MODEL,
RAG_EMBEDDING_MODEL_AUTO_UPDATE, RAG_EMBEDDING_MODEL_AUTO_UPDATE,
) )
update_reranking_model( update_reranking_model(
app.state.RAG_RERANKING_MODEL, app.state.config.RAG_RERANKING_MODEL,
RAG_RERANKING_MODEL_AUTO_UPDATE, RAG_RERANKING_MODEL_AUTO_UPDATE,
) )
app.state.EMBEDDING_FUNCTION = get_embedding_function( app.state.EMBEDDING_FUNCTION = get_embedding_function(
app.state.RAG_EMBEDDING_ENGINE, app.state.config.RAG_EMBEDDING_ENGINE,
app.state.RAG_EMBEDDING_MODEL, app.state.config.RAG_EMBEDDING_MODEL,
app.state.sentence_transformer_ef, app.state.sentence_transformer_ef,
app.state.OPENAI_API_KEY, app.state.config.OPENAI_API_KEY,
app.state.OPENAI_API_BASE_URL, app.state.config.OPENAI_API_BASE_URL,
) )
origins = ["*"] origins = ["*"]
@ -200,12 +203,12 @@ class UrlForm(CollectionNameForm):
async def get_status(): async def get_status():
return { return {
"status": True, "status": True,
"chunk_size": app.state.CHUNK_SIZE, "chunk_size": app.state.config.CHUNK_SIZE,
"chunk_overlap": app.state.CHUNK_OVERLAP, "chunk_overlap": app.state.config.CHUNK_OVERLAP,
"template": app.state.RAG_TEMPLATE, "template": app.state.config.RAG_TEMPLATE,
"embedding_engine": app.state.RAG_EMBEDDING_ENGINE, "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
"embedding_model": app.state.RAG_EMBEDDING_MODEL, "embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
"reranking_model": app.state.RAG_RERANKING_MODEL, "reranking_model": app.state.config.RAG_RERANKING_MODEL,
} }
@ -213,18 +216,21 @@ async def get_status():
async def get_embedding_config(user=Depends(get_admin_user)): async def get_embedding_config(user=Depends(get_admin_user)):
return { return {
"status": True, "status": True,
"embedding_engine": app.state.RAG_EMBEDDING_ENGINE, "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
"embedding_model": app.state.RAG_EMBEDDING_MODEL, "embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
"openai_config": { "openai_config": {
"url": app.state.OPENAI_API_BASE_URL, "url": app.state.config.OPENAI_API_BASE_URL,
"key": app.state.OPENAI_API_KEY, "key": app.state.config.OPENAI_API_KEY,
}, },
} }
@app.get("/reranking") @app.get("/reranking")
async def get_reraanking_config(user=Depends(get_admin_user)): async def get_reraanking_config(user=Depends(get_admin_user)):
return {"status": True, "reranking_model": app.state.RAG_RERANKING_MODEL} return {
"status": True,
"reranking_model": app.state.config.RAG_RERANKING_MODEL,
}
class OpenAIConfigForm(BaseModel): class OpenAIConfigForm(BaseModel):
@ -243,34 +249,34 @@ async def update_embedding_config(
form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user) form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
): ):
log.info( log.info(
f"Updating embedding model: {app.state.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}" f"Updating embedding model: {app.state.config.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}"
) )
try: try:
app.state.RAG_EMBEDDING_ENGINE = form_data.embedding_engine app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model
if app.state.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]: if app.state.config.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]:
if form_data.openai_config != None: if form_data.openai_config != None:
app.state.OPENAI_API_BASE_URL = form_data.openai_config.url app.state.config.OPENAI_API_BASE_URL = form_data.openai_config.url
app.state.OPENAI_API_KEY = form_data.openai_config.key app.state.config.OPENAI_API_KEY = form_data.openai_config.key
update_embedding_model(app.state.RAG_EMBEDDING_MODEL, True) update_embedding_model(app.state.config.RAG_EMBEDDING_MODEL), True
app.state.EMBEDDING_FUNCTION = get_embedding_function( app.state.EMBEDDING_FUNCTION = get_embedding_function(
app.state.RAG_EMBEDDING_ENGINE, app.state.config.RAG_EMBEDDING_ENGINE,
app.state.RAG_EMBEDDING_MODEL, app.state.config.RAG_EMBEDDING_MODEL,
app.state.sentence_transformer_ef, app.state.sentence_transformer_ef,
app.state.OPENAI_API_KEY, app.state.config.OPENAI_API_KEY,
app.state.OPENAI_API_BASE_URL, app.state.config.OPENAI_API_BASE_URL,
) )
return { return {
"status": True, "status": True,
"embedding_engine": app.state.RAG_EMBEDDING_ENGINE, "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
"embedding_model": app.state.RAG_EMBEDDING_MODEL, "embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
"openai_config": { "openai_config": {
"url": app.state.OPENAI_API_BASE_URL, "url": app.state.config.OPENAI_API_BASE_URL,
"key": app.state.OPENAI_API_KEY, "key": app.state.config.OPENAI_API_KEY,
}, },
} }
except Exception as e: except Exception as e:
@ -290,16 +296,16 @@ async def update_reranking_config(
form_data: RerankingModelUpdateForm, user=Depends(get_admin_user) form_data: RerankingModelUpdateForm, user=Depends(get_admin_user)
): ):
log.info( log.info(
f"Updating reranking model: {app.state.RAG_RERANKING_MODEL} to {form_data.reranking_model}" f"Updating reranking model: {app.state.config.RAG_RERANKING_MODEL} to {form_data.reranking_model}"
) )
try: try:
app.state.RAG_RERANKING_MODEL = form_data.reranking_model app.state.config.RAG_RERANKING_MODEL = form_data.reranking_model
update_reranking_model(app.state.RAG_RERANKING_MODEL, True) update_reranking_model(app.state.config.RAG_RERANKING_MODEL), True
return { return {
"status": True, "status": True,
"reranking_model": app.state.RAG_RERANKING_MODEL, "reranking_model": app.state.config.RAG_RERANKING_MODEL,
} }
except Exception as e: except Exception as e:
log.exception(f"Problem updating reranking model: {e}") log.exception(f"Problem updating reranking model: {e}")
@ -313,14 +319,14 @@ async def update_reranking_config(
async def get_rag_config(user=Depends(get_admin_user)): async def get_rag_config(user=Depends(get_admin_user)):
return { return {
"status": True, "status": True,
"pdf_extract_images": app.state.PDF_EXTRACT_IMAGES, "pdf_extract_images": app.state.config.PDF_EXTRACT_IMAGES,
"chunk": { "chunk": {
"chunk_size": app.state.CHUNK_SIZE, "chunk_size": app.state.config.CHUNK_SIZE,
"chunk_overlap": app.state.CHUNK_OVERLAP, "chunk_overlap": app.state.config.CHUNK_OVERLAP,
}, },
"web_loader_ssl_verification": app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, "web_loader_ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
"youtube": { "youtube": {
"language": app.state.YOUTUBE_LOADER_LANGUAGE, "language": app.state.config.YOUTUBE_LOADER_LANGUAGE,
"translation": app.state.YOUTUBE_LOADER_TRANSLATION, "translation": app.state.YOUTUBE_LOADER_TRANSLATION,
}, },
} }
@ -345,50 +351,52 @@ class ConfigUpdateForm(BaseModel):
@app.post("/config/update") @app.post("/config/update")
async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)): async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
app.state.PDF_EXTRACT_IMAGES = ( app.state.config.PDF_EXTRACT_IMAGES = (
form_data.pdf_extract_images form_data.pdf_extract_images
if form_data.pdf_extract_images != None if form_data.pdf_extract_images is not None
else app.state.PDF_EXTRACT_IMAGES else app.state.config.PDF_EXTRACT_IMAGES
) )
app.state.CHUNK_SIZE = ( app.state.config.CHUNK_SIZE = (
form_data.chunk.chunk_size if form_data.chunk != None else app.state.CHUNK_SIZE form_data.chunk.chunk_size
if form_data.chunk is not None
else app.state.config.CHUNK_SIZE
) )
app.state.CHUNK_OVERLAP = ( app.state.config.CHUNK_OVERLAP = (
form_data.chunk.chunk_overlap form_data.chunk.chunk_overlap
if form_data.chunk != None if form_data.chunk is not None
else app.state.CHUNK_OVERLAP else app.state.config.CHUNK_OVERLAP
) )
app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = ( app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
form_data.web_loader_ssl_verification form_data.web_loader_ssl_verification
if form_data.web_loader_ssl_verification != None if form_data.web_loader_ssl_verification != None
else app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION else app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
) )
app.state.YOUTUBE_LOADER_LANGUAGE = ( app.state.config.YOUTUBE_LOADER_LANGUAGE = (
form_data.youtube.language form_data.youtube.language
if form_data.youtube != None if form_data.youtube is not None
else app.state.YOUTUBE_LOADER_LANGUAGE else app.state.config.YOUTUBE_LOADER_LANGUAGE
) )
app.state.YOUTUBE_LOADER_TRANSLATION = ( app.state.YOUTUBE_LOADER_TRANSLATION = (
form_data.youtube.translation form_data.youtube.translation
if form_data.youtube != None if form_data.youtube is not None
else app.state.YOUTUBE_LOADER_TRANSLATION else app.state.YOUTUBE_LOADER_TRANSLATION
) )
return { return {
"status": True, "status": True,
"pdf_extract_images": app.state.PDF_EXTRACT_IMAGES, "pdf_extract_images": app.state.config.PDF_EXTRACT_IMAGES,
"chunk": { "chunk": {
"chunk_size": app.state.CHUNK_SIZE, "chunk_size": app.state.config.CHUNK_SIZE,
"chunk_overlap": app.state.CHUNK_OVERLAP, "chunk_overlap": app.state.config.CHUNK_OVERLAP,
}, },
"web_loader_ssl_verification": app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, "web_loader_ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
"youtube": { "youtube": {
"language": app.state.YOUTUBE_LOADER_LANGUAGE, "language": app.state.config.YOUTUBE_LOADER_LANGUAGE,
"translation": app.state.YOUTUBE_LOADER_TRANSLATION, "translation": app.state.YOUTUBE_LOADER_TRANSLATION,
}, },
} }
@ -398,7 +406,7 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_
async def get_rag_template(user=Depends(get_current_user)): async def get_rag_template(user=Depends(get_current_user)):
return { return {
"status": True, "status": True,
"template": app.state.RAG_TEMPLATE, "template": app.state.config.RAG_TEMPLATE,
} }
@ -406,10 +414,10 @@ async def get_rag_template(user=Depends(get_current_user)):
async def get_query_settings(user=Depends(get_admin_user)): async def get_query_settings(user=Depends(get_admin_user)):
return { return {
"status": True, "status": True,
"template": app.state.RAG_TEMPLATE, "template": app.state.config.RAG_TEMPLATE,
"k": app.state.TOP_K, "k": app.state.config.TOP_K,
"r": app.state.RELEVANCE_THRESHOLD, "r": app.state.config.RELEVANCE_THRESHOLD,
"hybrid": app.state.ENABLE_RAG_HYBRID_SEARCH, "hybrid": app.state.config.ENABLE_RAG_HYBRID_SEARCH,
} }
@ -424,16 +432,20 @@ class QuerySettingsForm(BaseModel):
async def update_query_settings( async def update_query_settings(
form_data: QuerySettingsForm, user=Depends(get_admin_user) form_data: QuerySettingsForm, user=Depends(get_admin_user)
): ):
app.state.RAG_TEMPLATE = form_data.template if form_data.template else RAG_TEMPLATE app.state.config.RAG_TEMPLATE = (
app.state.TOP_K = form_data.k if form_data.k else 4 form_data.template if form_data.template else RAG_TEMPLATE,
app.state.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0 )
app.state.ENABLE_RAG_HYBRID_SEARCH = form_data.hybrid if form_data.hybrid else False app.state.config.TOP_K = form_data.k if form_data.k else 4
app.state.config.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0
app.state.config.ENABLE_RAG_HYBRID_SEARCH = (
form_data.hybrid if form_data.hybrid else False,
)
return { return {
"status": True, "status": True,
"template": app.state.RAG_TEMPLATE, "template": app.state.config.RAG_TEMPLATE,
"k": app.state.TOP_K, "k": app.state.config.TOP_K,
"r": app.state.RELEVANCE_THRESHOLD, "r": app.state.config.RELEVANCE_THRESHOLD,
"hybrid": app.state.ENABLE_RAG_HYBRID_SEARCH, "hybrid": app.state.config.ENABLE_RAG_HYBRID_SEARCH,
} }
@ -451,21 +463,23 @@ def query_doc_handler(
user=Depends(get_current_user), user=Depends(get_current_user),
): ):
try: try:
if app.state.ENABLE_RAG_HYBRID_SEARCH: if app.state.config.ENABLE_RAG_HYBRID_SEARCH:
return query_doc_with_hybrid_search( return query_doc_with_hybrid_search(
collection_name=form_data.collection_name, collection_name=form_data.collection_name,
query=form_data.query, query=form_data.query,
embedding_function=app.state.EMBEDDING_FUNCTION, embedding_function=app.state.EMBEDDING_FUNCTION,
k=form_data.k if form_data.k else app.state.TOP_K, k=form_data.k if form_data.k else app.state.config.TOP_K,
reranking_function=app.state.sentence_transformer_rf, reranking_function=app.state.sentence_transformer_rf,
r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD, r=(
form_data.r if form_data.r else app.state.config.RELEVANCE_THRESHOLD
),
) )
else: else:
return query_doc( return query_doc(
collection_name=form_data.collection_name, collection_name=form_data.collection_name,
query=form_data.query, query=form_data.query,
embedding_function=app.state.EMBEDDING_FUNCTION, embedding_function=app.state.EMBEDDING_FUNCTION,
k=form_data.k if form_data.k else app.state.TOP_K, k=form_data.k if form_data.k else app.state.config.TOP_K,
) )
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
@ -489,21 +503,23 @@ def query_collection_handler(
user=Depends(get_current_user), user=Depends(get_current_user),
): ):
try: try:
if app.state.ENABLE_RAG_HYBRID_SEARCH: if app.state.config.ENABLE_RAG_HYBRID_SEARCH:
return query_collection_with_hybrid_search( return query_collection_with_hybrid_search(
collection_names=form_data.collection_names, collection_names=form_data.collection_names,
query=form_data.query, query=form_data.query,
embedding_function=app.state.EMBEDDING_FUNCTION, embedding_function=app.state.EMBEDDING_FUNCTION,
k=form_data.k if form_data.k else app.state.TOP_K, k=form_data.k if form_data.k else app.state.config.TOP_K,
reranking_function=app.state.sentence_transformer_rf, reranking_function=app.state.sentence_transformer_rf,
r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD, r=(
form_data.r if form_data.r else app.state.config.RELEVANCE_THRESHOLD
),
) )
else: else:
return query_collection( return query_collection(
collection_names=form_data.collection_names, collection_names=form_data.collection_names,
query=form_data.query, query=form_data.query,
embedding_function=app.state.EMBEDDING_FUNCTION, embedding_function=app.state.EMBEDDING_FUNCTION,
k=form_data.k if form_data.k else app.state.TOP_K, k=form_data.k if form_data.k else app.state.config.TOP_K,
) )
except Exception as e: except Exception as e:
@ -520,7 +536,7 @@ def store_youtube_video(form_data: UrlForm, user=Depends(get_current_user)):
loader = YoutubeLoader.from_youtube_url( loader = YoutubeLoader.from_youtube_url(
form_data.url, form_data.url,
add_video_info=True, add_video_info=True,
language=app.state.YOUTUBE_LOADER_LANGUAGE, language=app.state.config.YOUTUBE_LOADER_LANGUAGE,
translation=app.state.YOUTUBE_LOADER_TRANSLATION, translation=app.state.YOUTUBE_LOADER_TRANSLATION,
) )
data = loader.load() data = loader.load()
@ -548,7 +564,8 @@ def store_web(form_data: UrlForm, user=Depends(get_current_user)):
# "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm" # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
try: try:
loader = get_web_loader( loader = get_web_loader(
form_data.url, verify_ssl=app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION form_data.url,
verify_ssl=app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
) )
data = loader.load() data = loader.load()
@ -604,8 +621,8 @@ def resolve_hostname(hostname):
def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool: def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool:
text_splitter = RecursiveCharacterTextSplitter( text_splitter = RecursiveCharacterTextSplitter(
chunk_size=app.state.CHUNK_SIZE, chunk_size=app.state.config.CHUNK_SIZE,
chunk_overlap=app.state.CHUNK_OVERLAP, chunk_overlap=app.state.config.CHUNK_OVERLAP,
add_start_index=True, add_start_index=True,
) )
@ -622,8 +639,8 @@ def store_text_in_vector_db(
text, metadata, collection_name, overwrite: bool = False text, metadata, collection_name, overwrite: bool = False
) -> bool: ) -> bool:
text_splitter = RecursiveCharacterTextSplitter( text_splitter = RecursiveCharacterTextSplitter(
chunk_size=app.state.CHUNK_SIZE, chunk_size=app.state.config.CHUNK_SIZE,
chunk_overlap=app.state.CHUNK_OVERLAP, chunk_overlap=app.state.config.CHUNK_OVERLAP,
add_start_index=True, add_start_index=True,
) )
docs = text_splitter.create_documents([text], metadatas=[metadata]) docs = text_splitter.create_documents([text], metadatas=[metadata])
@ -646,11 +663,11 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
collection = CHROMA_CLIENT.create_collection(name=collection_name) collection = CHROMA_CLIENT.create_collection(name=collection_name)
embedding_func = get_embedding_function( embedding_func = get_embedding_function(
app.state.RAG_EMBEDDING_ENGINE, app.state.config.RAG_EMBEDDING_ENGINE,
app.state.RAG_EMBEDDING_MODEL, app.state.config.RAG_EMBEDDING_MODEL,
app.state.sentence_transformer_ef, app.state.sentence_transformer_ef,
app.state.OPENAI_API_KEY, app.state.config.OPENAI_API_KEY,
app.state.OPENAI_API_BASE_URL, app.state.config.OPENAI_API_BASE_URL,
) )
embedding_texts = list(map(lambda x: x.replace("\n", " "), texts)) embedding_texts = list(map(lambda x: x.replace("\n", " "), texts))
@ -724,7 +741,9 @@ def get_loader(filename: str, file_content_type: str, file_path: str):
] ]
if file_ext == "pdf": if file_ext == "pdf":
loader = PyPDFLoader(file_path, extract_images=app.state.PDF_EXTRACT_IMAGES) loader = PyPDFLoader(
file_path, extract_images=app.state.config.PDF_EXTRACT_IMAGES
)
elif file_ext == "csv": elif file_ext == "csv":
loader = CSVLoader(file_path) loader = CSVLoader(file_path)
elif file_ext == "rst": elif file_ext == "rst":

View File

@ -21,20 +21,24 @@ from config import (
USER_PERMISSIONS, USER_PERMISSIONS,
WEBHOOK_URL, WEBHOOK_URL,
WEBUI_AUTH_TRUSTED_EMAIL_HEADER, WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
JWT_EXPIRES_IN,
AppConfig,
) )
app = FastAPI() app = FastAPI()
origins = ["*"] origins = ["*"]
app.state.ENABLE_SIGNUP = ENABLE_SIGNUP app.state.config = AppConfig()
app.state.JWT_EXPIRES_IN = "-1"
app.state.DEFAULT_MODELS = DEFAULT_MODELS app.state.config.ENABLE_SIGNUP = ENABLE_SIGNUP
app.state.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS app.state.config.JWT_EXPIRES_IN = JWT_EXPIRES_IN
app.state.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
app.state.USER_PERMISSIONS = USER_PERMISSIONS app.state.config.DEFAULT_MODELS = DEFAULT_MODELS
app.state.WEBHOOK_URL = WEBHOOK_URL app.state.config.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS
app.state.config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
app.state.config.USER_PERMISSIONS = USER_PERMISSIONS
app.state.config.WEBHOOK_URL = WEBHOOK_URL
app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER
app.add_middleware( app.add_middleware(
@ -61,6 +65,6 @@ async def get_status():
return { return {
"status": True, "status": True,
"auth": WEBUI_AUTH, "auth": WEBUI_AUTH,
"default_models": app.state.DEFAULT_MODELS, "default_models": app.state.config.DEFAULT_MODELS,
"default_prompt_suggestions": app.state.DEFAULT_PROMPT_SUGGESTIONS, "default_prompt_suggestions": app.state.config.DEFAULT_PROMPT_SUGGESTIONS,
} }

View File

@ -140,7 +140,7 @@ async def signin(request: Request, form_data: SigninForm):
if user: if user:
token = create_token( token = create_token(
data={"id": user.id}, data={"id": user.id},
expires_delta=parse_duration(request.app.state.JWT_EXPIRES_IN), expires_delta=parse_duration(request.app.state.config.JWT_EXPIRES_IN),
) )
return { return {
@ -163,7 +163,7 @@ async def signin(request: Request, form_data: SigninForm):
@router.post("/signup", response_model=SigninResponse) @router.post("/signup", response_model=SigninResponse)
async def signup(request: Request, form_data: SignupForm): async def signup(request: Request, form_data: SignupForm):
if not request.app.state.ENABLE_SIGNUP and WEBUI_AUTH: if not request.app.state.config.ENABLE_SIGNUP and WEBUI_AUTH:
raise HTTPException( raise HTTPException(
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
) )
@ -180,7 +180,7 @@ async def signup(request: Request, form_data: SignupForm):
role = ( role = (
"admin" "admin"
if Users.get_num_users() == 0 if Users.get_num_users() == 0
else request.app.state.DEFAULT_USER_ROLE else request.app.state.config.DEFAULT_USER_ROLE
) )
hashed = get_password_hash(form_data.password) hashed = get_password_hash(form_data.password)
user = Auths.insert_new_auth( user = Auths.insert_new_auth(
@ -194,13 +194,13 @@ async def signup(request: Request, form_data: SignupForm):
if user: if user:
token = create_token( token = create_token(
data={"id": user.id}, data={"id": user.id},
expires_delta=parse_duration(request.app.state.JWT_EXPIRES_IN), expires_delta=parse_duration(request.app.state.config.JWT_EXPIRES_IN),
) )
# response.set_cookie(key='token', value=token, httponly=True) # response.set_cookie(key='token', value=token, httponly=True)
if request.app.state.WEBHOOK_URL: if request.app.state.config.WEBHOOK_URL:
post_webhook( post_webhook(
request.app.state.WEBHOOK_URL, request.app.state.config.WEBHOOK_URL,
WEBHOOK_MESSAGES.USER_SIGNUP(user.name), WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
{ {
"action": "signup", "action": "signup",
@ -276,13 +276,13 @@ async def add_user(form_data: AddUserForm, user=Depends(get_admin_user)):
@router.get("/signup/enabled", response_model=bool) @router.get("/signup/enabled", response_model=bool)
async def get_sign_up_status(request: Request, user=Depends(get_admin_user)): async def get_sign_up_status(request: Request, user=Depends(get_admin_user)):
return request.app.state.ENABLE_SIGNUP return request.app.state.config.ENABLE_SIGNUP
@router.get("/signup/enabled/toggle", response_model=bool) @router.get("/signup/enabled/toggle", response_model=bool)
async def toggle_sign_up(request: Request, user=Depends(get_admin_user)): async def toggle_sign_up(request: Request, user=Depends(get_admin_user)):
request.app.state.ENABLE_SIGNUP = not request.app.state.ENABLE_SIGNUP request.app.state.config.ENABLE_SIGNUP = not request.app.state.config.ENABLE_SIGNUP
return request.app.state.ENABLE_SIGNUP return request.app.state.config.ENABLE_SIGNUP
############################ ############################
@ -292,7 +292,7 @@ async def toggle_sign_up(request: Request, user=Depends(get_admin_user)):
@router.get("/signup/user/role") @router.get("/signup/user/role")
async def get_default_user_role(request: Request, user=Depends(get_admin_user)): async def get_default_user_role(request: Request, user=Depends(get_admin_user)):
return request.app.state.DEFAULT_USER_ROLE return request.app.state.config.DEFAULT_USER_ROLE
class UpdateRoleForm(BaseModel): class UpdateRoleForm(BaseModel):
@ -304,8 +304,8 @@ async def update_default_user_role(
request: Request, form_data: UpdateRoleForm, user=Depends(get_admin_user) request: Request, form_data: UpdateRoleForm, user=Depends(get_admin_user)
): ):
if form_data.role in ["pending", "user", "admin"]: if form_data.role in ["pending", "user", "admin"]:
request.app.state.DEFAULT_USER_ROLE = form_data.role request.app.state.config.DEFAULT_USER_ROLE = form_data.role
return request.app.state.DEFAULT_USER_ROLE return request.app.state.config.DEFAULT_USER_ROLE
############################ ############################
@ -315,7 +315,7 @@ async def update_default_user_role(
@router.get("/token/expires") @router.get("/token/expires")
async def get_token_expires_duration(request: Request, user=Depends(get_admin_user)): async def get_token_expires_duration(request: Request, user=Depends(get_admin_user)):
return request.app.state.JWT_EXPIRES_IN return request.app.state.config.JWT_EXPIRES_IN
class UpdateJWTExpiresDurationForm(BaseModel): class UpdateJWTExpiresDurationForm(BaseModel):
@ -332,10 +332,10 @@ async def update_token_expires_duration(
# Check if the input string matches the pattern # Check if the input string matches the pattern
if re.match(pattern, form_data.duration): if re.match(pattern, form_data.duration):
request.app.state.JWT_EXPIRES_IN = form_data.duration request.app.state.config.JWT_EXPIRES_IN = form_data.duration
return request.app.state.JWT_EXPIRES_IN return request.app.state.config.JWT_EXPIRES_IN
else: else:
return request.app.state.JWT_EXPIRES_IN return request.app.state.config.JWT_EXPIRES_IN
############################ ############################

View File

@ -44,8 +44,8 @@ class SetDefaultSuggestionsForm(BaseModel):
async def set_global_default_models( async def set_global_default_models(
request: Request, form_data: SetDefaultModelsForm, user=Depends(get_admin_user) request: Request, form_data: SetDefaultModelsForm, user=Depends(get_admin_user)
): ):
request.app.state.DEFAULT_MODELS = form_data.models request.app.state.config.DEFAULT_MODELS = form_data.models
return request.app.state.DEFAULT_MODELS return request.app.state.config.DEFAULT_MODELS
@router.post("/default/suggestions", response_model=List[PromptSuggestion]) @router.post("/default/suggestions", response_model=List[PromptSuggestion])
@ -55,5 +55,5 @@ async def set_global_default_suggestions(
user=Depends(get_admin_user), user=Depends(get_admin_user),
): ):
data = form_data.model_dump() data = form_data.model_dump()
request.app.state.DEFAULT_PROMPT_SUGGESTIONS = data["suggestions"] request.app.state.config.DEFAULT_PROMPT_SUGGESTIONS = data["suggestions"]
return request.app.state.DEFAULT_PROMPT_SUGGESTIONS return request.app.state.config.DEFAULT_PROMPT_SUGGESTIONS

View File

@ -39,15 +39,15 @@ async def get_users(skip: int = 0, limit: int = 50, user=Depends(get_admin_user)
@router.get("/permissions/user") @router.get("/permissions/user")
async def get_user_permissions(request: Request, user=Depends(get_admin_user)): async def get_user_permissions(request: Request, user=Depends(get_admin_user)):
return request.app.state.USER_PERMISSIONS return request.app.state.config.USER_PERMISSIONS
@router.post("/permissions/user") @router.post("/permissions/user")
async def update_user_permissions( async def update_user_permissions(
request: Request, form_data: dict, user=Depends(get_admin_user) request: Request, form_data: dict, user=Depends(get_admin_user)
): ):
request.app.state.USER_PERMISSIONS = form_data request.app.state.config.USER_PERMISSIONS = form_data
return request.app.state.USER_PERMISSIONS return request.app.state.config.USER_PERMISSIONS
############################ ############################

View File

@ -5,6 +5,7 @@ import chromadb
from chromadb import Settings from chromadb import Settings
from base64 import b64encode from base64 import b64encode
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
from typing import TypeVar, Generic, Union
from pathlib import Path from pathlib import Path
import json import json
@ -17,7 +18,6 @@ import shutil
from secrets import token_bytes from secrets import token_bytes
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
#################################### ####################################
# Load .env file # Load .env file
#################################### ####################################
@ -71,7 +71,6 @@ for source in log_sources:
log.setLevel(SRC_LOG_LEVELS["CONFIG"]) log.setLevel(SRC_LOG_LEVELS["CONFIG"])
WEBUI_NAME = os.environ.get("WEBUI_NAME", "Open WebUI") WEBUI_NAME = os.environ.get("WEBUI_NAME", "Open WebUI")
if WEBUI_NAME != "Open WebUI": if WEBUI_NAME != "Open WebUI":
WEBUI_NAME += " (Open WebUI)" WEBUI_NAME += " (Open WebUI)"
@ -161,16 +160,6 @@ CHANGELOG = changelog_json
WEBUI_VERSION = os.environ.get("WEBUI_VERSION", "v1.0.0-alpha.100") WEBUI_VERSION = os.environ.get("WEBUI_VERSION", "v1.0.0-alpha.100")
####################################
# WEBUI_AUTH (Required for security)
####################################
WEBUI_AUTH = os.environ.get("WEBUI_AUTH", "True").lower() == "true"
WEBUI_AUTH_TRUSTED_EMAIL_HEADER = os.environ.get(
"WEBUI_AUTH_TRUSTED_EMAIL_HEADER", None
)
#################################### ####################################
# DATA/FRONTEND BUILD DIR # DATA/FRONTEND BUILD DIR
#################################### ####################################
@ -184,6 +173,108 @@ try:
except: except:
CONFIG_DATA = {} CONFIG_DATA = {}
####################################
# Config helpers
####################################
def save_config():
try:
with open(f"{DATA_DIR}/config.json", "w") as f:
json.dump(CONFIG_DATA, f, indent="\t")
except Exception as e:
log.exception(e)
def get_config_value(config_path: str):
path_parts = config_path.split(".")
cur_config = CONFIG_DATA
for key in path_parts:
if key in cur_config:
cur_config = cur_config[key]
else:
return None
return cur_config
T = TypeVar("T")
class PersistentConfig(Generic[T]):
def __init__(self, env_name: str, config_path: str, env_value: T):
self.env_name = env_name
self.config_path = config_path
self.env_value = env_value
self.config_value = get_config_value(config_path)
if self.config_value is not None:
log.info(f"'{env_name}' loaded from config.json")
self.value = self.config_value
else:
self.value = env_value
def __str__(self):
return str(self.value)
@property
def __dict__(self):
raise TypeError(
"PersistentConfig object cannot be converted to dict, use config_get or .value instead."
)
def __getattribute__(self, item):
if item == "__dict__":
raise TypeError(
"PersistentConfig object cannot be converted to dict, use config_get or .value instead."
)
return super().__getattribute__(item)
def save(self):
# Don't save if the value is the same as the env value and the config value
if self.env_value == self.value:
if self.config_value == self.value:
return
log.info(f"Saving '{self.env_name}' to config.json")
path_parts = self.config_path.split(".")
config = CONFIG_DATA
for key in path_parts[:-1]:
if key not in config:
config[key] = {}
config = config[key]
config[path_parts[-1]] = self.value
save_config()
self.config_value = self.value
class AppConfig:
_state: dict[str, PersistentConfig]
def __init__(self):
super().__setattr__("_state", {})
def __setattr__(self, key, value):
if isinstance(value, PersistentConfig):
self._state[key] = value
else:
self._state[key].value = value
self._state[key].save()
def __getattr__(self, key):
return self._state[key].value
####################################
# WEBUI_AUTH (Required for security)
####################################
WEBUI_AUTH = os.environ.get("WEBUI_AUTH", "True").lower() == "true"
WEBUI_AUTH_TRUSTED_EMAIL_HEADER = os.environ.get(
"WEBUI_AUTH_TRUSTED_EMAIL_HEADER", None
)
JWT_EXPIRES_IN = PersistentConfig(
"JWT_EXPIRES_IN", "auth.jwt_expiry", os.environ.get("JWT_EXPIRES_IN", "-1")
)
#################################### ####################################
# Static DIR # Static DIR
#################################### ####################################
@ -318,7 +409,9 @@ OLLAMA_BASE_URLS = os.environ.get("OLLAMA_BASE_URLS", "")
OLLAMA_BASE_URLS = OLLAMA_BASE_URLS if OLLAMA_BASE_URLS != "" else OLLAMA_BASE_URL OLLAMA_BASE_URLS = OLLAMA_BASE_URLS if OLLAMA_BASE_URLS != "" else OLLAMA_BASE_URL
OLLAMA_BASE_URLS = [url.strip() for url in OLLAMA_BASE_URLS.split(";")] OLLAMA_BASE_URLS = [url.strip() for url in OLLAMA_BASE_URLS.split(";")]
OLLAMA_BASE_URLS = PersistentConfig(
"OLLAMA_BASE_URLS", "ollama.base_urls", OLLAMA_BASE_URLS
)
#################################### ####################################
# OPENAI_API # OPENAI_API
@ -335,7 +428,9 @@ OPENAI_API_KEYS = os.environ.get("OPENAI_API_KEYS", "")
OPENAI_API_KEYS = OPENAI_API_KEYS if OPENAI_API_KEYS != "" else OPENAI_API_KEY OPENAI_API_KEYS = OPENAI_API_KEYS if OPENAI_API_KEYS != "" else OPENAI_API_KEY
OPENAI_API_KEYS = [url.strip() for url in OPENAI_API_KEYS.split(";")] OPENAI_API_KEYS = [url.strip() for url in OPENAI_API_KEYS.split(";")]
OPENAI_API_KEYS = PersistentConfig(
"OPENAI_API_KEYS", "openai.api_keys", OPENAI_API_KEYS
)
OPENAI_API_BASE_URLS = os.environ.get("OPENAI_API_BASE_URLS", "") OPENAI_API_BASE_URLS = os.environ.get("OPENAI_API_BASE_URLS", "")
OPENAI_API_BASE_URLS = ( OPENAI_API_BASE_URLS = (
@ -346,37 +441,42 @@ OPENAI_API_BASE_URLS = [
url.strip() if url != "" else "https://api.openai.com/v1" url.strip() if url != "" else "https://api.openai.com/v1"
for url in OPENAI_API_BASE_URLS.split(";") for url in OPENAI_API_BASE_URLS.split(";")
] ]
OPENAI_API_BASE_URLS = PersistentConfig(
"OPENAI_API_BASE_URLS", "openai.api_base_urls", OPENAI_API_BASE_URLS
)
OPENAI_API_KEY = "" OPENAI_API_KEY = ""
try: try:
OPENAI_API_KEY = OPENAI_API_KEYS[ OPENAI_API_KEY = OPENAI_API_KEYS.value[
OPENAI_API_BASE_URLS.index("https://api.openai.com/v1") OPENAI_API_BASE_URLS.value.index("https://api.openai.com/v1")
] ]
except: except:
pass pass
OPENAI_API_BASE_URL = "https://api.openai.com/v1" OPENAI_API_BASE_URL = "https://api.openai.com/v1"
#################################### ####################################
# WEBUI # WEBUI
#################################### ####################################
ENABLE_SIGNUP = ( ENABLE_SIGNUP = PersistentConfig(
"ENABLE_SIGNUP",
"ui.enable_signup",
(
False False
if WEBUI_AUTH == False if not WEBUI_AUTH
else os.environ.get("ENABLE_SIGNUP", "True").lower() == "true" else os.environ.get("ENABLE_SIGNUP", "True").lower() == "true"
),
)
DEFAULT_MODELS = PersistentConfig(
"DEFAULT_MODELS", "ui.default_models", os.environ.get("DEFAULT_MODELS", None)
) )
DEFAULT_MODELS = os.environ.get("DEFAULT_MODELS", None)
DEFAULT_PROMPT_SUGGESTIONS = PersistentConfig(
DEFAULT_PROMPT_SUGGESTIONS = ( "DEFAULT_PROMPT_SUGGESTIONS",
CONFIG_DATA["ui"]["prompt_suggestions"] "ui.prompt_suggestions",
if "ui" in CONFIG_DATA [
and "prompt_suggestions" in CONFIG_DATA["ui"]
and type(CONFIG_DATA["ui"]["prompt_suggestions"]) is list
else [
{ {
"title": ["Help me study", "vocabulary for a college entrance exam"], "title": ["Help me study", "vocabulary for a college entrance exam"],
"content": "Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option.", "content": "Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option.",
@ -404,23 +504,40 @@ DEFAULT_PROMPT_SUGGESTIONS = (
"title": ["Overcome procrastination", "give me tips"], "title": ["Overcome procrastination", "give me tips"],
"content": "Could you start by asking me about instances when I procrastinate the most and then give me some suggestions to overcome it?", "content": "Could you start by asking me about instances when I procrastinate the most and then give me some suggestions to overcome it?",
}, },
] ],
) )
DEFAULT_USER_ROLE = PersistentConfig(
DEFAULT_USER_ROLE = os.getenv("DEFAULT_USER_ROLE", "pending") "DEFAULT_USER_ROLE",
"ui.default_user_role",
os.getenv("DEFAULT_USER_ROLE", "pending"),
)
USER_PERMISSIONS_CHAT_DELETION = ( USER_PERMISSIONS_CHAT_DELETION = (
os.environ.get("USER_PERMISSIONS_CHAT_DELETION", "True").lower() == "true" os.environ.get("USER_PERMISSIONS_CHAT_DELETION", "True").lower() == "true"
) )
USER_PERMISSIONS = {"chat": {"deletion": USER_PERMISSIONS_CHAT_DELETION}} USER_PERMISSIONS = PersistentConfig(
"USER_PERMISSIONS",
"ui.user_permissions",
{"chat": {"deletion": USER_PERMISSIONS_CHAT_DELETION}},
)
ENABLE_MODEL_FILTER = os.environ.get("ENABLE_MODEL_FILTER", "False").lower() == "true" ENABLE_MODEL_FILTER = PersistentConfig(
"ENABLE_MODEL_FILTER",
"model_filter.enable",
os.environ.get("ENABLE_MODEL_FILTER", "False").lower() == "true",
)
MODEL_FILTER_LIST = os.environ.get("MODEL_FILTER_LIST", "") MODEL_FILTER_LIST = os.environ.get("MODEL_FILTER_LIST", "")
MODEL_FILTER_LIST = [model.strip() for model in MODEL_FILTER_LIST.split(";")] MODEL_FILTER_LIST = PersistentConfig(
"MODEL_FILTER_LIST",
"model_filter.list",
[model.strip() for model in MODEL_FILTER_LIST.split(";")],
)
WEBHOOK_URL = os.environ.get("WEBHOOK_URL", "") WEBHOOK_URL = PersistentConfig(
"WEBHOOK_URL", "webhook_url", os.environ.get("WEBHOOK_URL", "")
)
ENABLE_ADMIN_EXPORT = os.environ.get("ENABLE_ADMIN_EXPORT", "True").lower() == "true" ENABLE_ADMIN_EXPORT = os.environ.get("ENABLE_ADMIN_EXPORT", "True").lower() == "true"
@ -458,26 +575,45 @@ else:
CHROMA_HTTP_SSL = os.environ.get("CHROMA_HTTP_SSL", "false").lower() == "true" CHROMA_HTTP_SSL = os.environ.get("CHROMA_HTTP_SSL", "false").lower() == "true"
# this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (sentence-transformers/all-MiniLM-L6-v2) # this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (sentence-transformers/all-MiniLM-L6-v2)
RAG_TOP_K = int(os.environ.get("RAG_TOP_K", "5")) RAG_TOP_K = PersistentConfig(
RAG_RELEVANCE_THRESHOLD = float(os.environ.get("RAG_RELEVANCE_THRESHOLD", "0.0")) "RAG_TOP_K", "rag.top_k", int(os.environ.get("RAG_TOP_K", "5"))
)
ENABLE_RAG_HYBRID_SEARCH = ( RAG_RELEVANCE_THRESHOLD = PersistentConfig(
os.environ.get("ENABLE_RAG_HYBRID_SEARCH", "").lower() == "true" "RAG_RELEVANCE_THRESHOLD",
"rag.relevance_threshold",
float(os.environ.get("RAG_RELEVANCE_THRESHOLD", "0.0")),
) )
ENABLE_RAG_HYBRID_SEARCH = PersistentConfig(
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = ( "ENABLE_RAG_HYBRID_SEARCH",
os.environ.get("ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION", "True").lower() == "true" "rag.enable_hybrid_search",
os.environ.get("ENABLE_RAG_HYBRID_SEARCH", "").lower() == "true",
) )
RAG_EMBEDDING_ENGINE = os.environ.get("RAG_EMBEDDING_ENGINE", "") ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = PersistentConfig(
"ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION",
PDF_EXTRACT_IMAGES = os.environ.get("PDF_EXTRACT_IMAGES", "False").lower() == "true" "rag.enable_web_loader_ssl_verification",
os.environ.get("ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION", "True").lower() == "true",
RAG_EMBEDDING_MODEL = os.environ.get(
"RAG_EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2"
) )
log.info(f"Embedding model set: {RAG_EMBEDDING_MODEL}"),
RAG_EMBEDDING_ENGINE = PersistentConfig(
"RAG_EMBEDDING_ENGINE",
"rag.embedding_engine",
os.environ.get("RAG_EMBEDDING_ENGINE", ""),
)
PDF_EXTRACT_IMAGES = PersistentConfig(
"PDF_EXTRACT_IMAGES",
"rag.pdf_extract_images",
os.environ.get("PDF_EXTRACT_IMAGES", "False").lower() == "true",
)
RAG_EMBEDDING_MODEL = PersistentConfig(
"RAG_EMBEDDING_MODEL",
"rag.embedding_model",
os.environ.get("RAG_EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2"),
)
log.info(f"Embedding model set: {RAG_EMBEDDING_MODEL.value}"),
RAG_EMBEDDING_MODEL_AUTO_UPDATE = ( RAG_EMBEDDING_MODEL_AUTO_UPDATE = (
os.environ.get("RAG_EMBEDDING_MODEL_AUTO_UPDATE", "").lower() == "true" os.environ.get("RAG_EMBEDDING_MODEL_AUTO_UPDATE", "").lower() == "true"
@ -487,9 +623,13 @@ RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE = (
os.environ.get("RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true" os.environ.get("RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true"
) )
RAG_RERANKING_MODEL = os.environ.get("RAG_RERANKING_MODEL", "") RAG_RERANKING_MODEL = PersistentConfig(
if not RAG_RERANKING_MODEL == "": "RAG_RERANKING_MODEL",
log.info(f"Reranking model set: {RAG_RERANKING_MODEL}"), "rag.reranking_model",
os.environ.get("RAG_RERANKING_MODEL", ""),
)
if RAG_RERANKING_MODEL.value != "":
log.info(f"Reranking model set: {RAG_RERANKING_MODEL.value}"),
RAG_RERANKING_MODEL_AUTO_UPDATE = ( RAG_RERANKING_MODEL_AUTO_UPDATE = (
os.environ.get("RAG_RERANKING_MODEL_AUTO_UPDATE", "").lower() == "true" os.environ.get("RAG_RERANKING_MODEL_AUTO_UPDATE", "").lower() == "true"
@ -527,9 +667,14 @@ if USE_CUDA.lower() == "true":
else: else:
DEVICE_TYPE = "cpu" DEVICE_TYPE = "cpu"
CHUNK_SIZE = PersistentConfig(
CHUNK_SIZE = int(os.environ.get("CHUNK_SIZE", "1500")) "CHUNK_SIZE", "rag.chunk_size", int(os.environ.get("CHUNK_SIZE", "1500"))
CHUNK_OVERLAP = int(os.environ.get("CHUNK_OVERLAP", "100")) )
CHUNK_OVERLAP = PersistentConfig(
"CHUNK_OVERLAP",
"rag.chunk_overlap",
int(os.environ.get("CHUNK_OVERLAP", "100")),
)
DEFAULT_RAG_TEMPLATE = """Use the following context as your learned knowledge, inside <context></context> XML tags. DEFAULT_RAG_TEMPLATE = """Use the following context as your learned knowledge, inside <context></context> XML tags.
<context> <context>
@ -545,16 +690,32 @@ And answer according to the language of the user's question.
Given the context information, answer the query. Given the context information, answer the query.
Query: [query]""" Query: [query]"""
RAG_TEMPLATE = os.environ.get("RAG_TEMPLATE", DEFAULT_RAG_TEMPLATE) RAG_TEMPLATE = PersistentConfig(
"RAG_TEMPLATE",
"rag.template",
os.environ.get("RAG_TEMPLATE", DEFAULT_RAG_TEMPLATE),
)
RAG_OPENAI_API_BASE_URL = os.getenv("RAG_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL) RAG_OPENAI_API_BASE_URL = PersistentConfig(
RAG_OPENAI_API_KEY = os.getenv("RAG_OPENAI_API_KEY", OPENAI_API_KEY) "RAG_OPENAI_API_BASE_URL",
"rag.openai_api_base_url",
os.getenv("RAG_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL),
)
RAG_OPENAI_API_KEY = PersistentConfig(
"RAG_OPENAI_API_KEY",
"rag.openai_api_key",
os.getenv("RAG_OPENAI_API_KEY", OPENAI_API_KEY),
)
ENABLE_RAG_LOCAL_WEB_FETCH = ( ENABLE_RAG_LOCAL_WEB_FETCH = (
os.getenv("ENABLE_RAG_LOCAL_WEB_FETCH", "False").lower() == "true" os.getenv("ENABLE_RAG_LOCAL_WEB_FETCH", "False").lower() == "true"
) )
YOUTUBE_LOADER_LANGUAGE = os.getenv("YOUTUBE_LOADER_LANGUAGE", "en").split(",") YOUTUBE_LOADER_LANGUAGE = PersistentConfig(
"YOUTUBE_LOADER_LANGUAGE",
"rag.youtube_loader_language",
os.getenv("YOUTUBE_LOADER_LANGUAGE", "en").split(","),
)
#################################### ####################################
# Transcribe # Transcribe
@ -571,34 +732,78 @@ WHISPER_MODEL_AUTO_UPDATE = (
# Images # Images
#################################### ####################################
IMAGE_GENERATION_ENGINE = os.getenv("IMAGE_GENERATION_ENGINE", "") IMAGE_GENERATION_ENGINE = PersistentConfig(
"IMAGE_GENERATION_ENGINE",
ENABLE_IMAGE_GENERATION = ( "image_generation.engine",
os.environ.get("ENABLE_IMAGE_GENERATION", "").lower() == "true" os.getenv("IMAGE_GENERATION_ENGINE", ""),
) )
AUTOMATIC1111_BASE_URL = os.getenv("AUTOMATIC1111_BASE_URL", "")
COMFYUI_BASE_URL = os.getenv("COMFYUI_BASE_URL", "") ENABLE_IMAGE_GENERATION = PersistentConfig(
"ENABLE_IMAGE_GENERATION",
IMAGES_OPENAI_API_BASE_URL = os.getenv( "image_generation.enable",
"IMAGES_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL os.environ.get("ENABLE_IMAGE_GENERATION", "").lower() == "true",
)
AUTOMATIC1111_BASE_URL = PersistentConfig(
"AUTOMATIC1111_BASE_URL",
"image_generation.automatic1111.base_url",
os.getenv("AUTOMATIC1111_BASE_URL", ""),
) )
IMAGES_OPENAI_API_KEY = os.getenv("IMAGES_OPENAI_API_KEY", OPENAI_API_KEY)
IMAGE_SIZE = os.getenv("IMAGE_SIZE", "512x512") COMFYUI_BASE_URL = PersistentConfig(
"COMFYUI_BASE_URL",
"image_generation.comfyui.base_url",
os.getenv("COMFYUI_BASE_URL", ""),
)
IMAGE_STEPS = int(os.getenv("IMAGE_STEPS", 50)) IMAGES_OPENAI_API_BASE_URL = PersistentConfig(
"IMAGES_OPENAI_API_BASE_URL",
"image_generation.openai.api_base_url",
os.getenv("IMAGES_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL),
)
IMAGES_OPENAI_API_KEY = PersistentConfig(
"IMAGES_OPENAI_API_KEY",
"image_generation.openai.api_key",
os.getenv("IMAGES_OPENAI_API_KEY", OPENAI_API_KEY),
)
IMAGE_GENERATION_MODEL = os.getenv("IMAGE_GENERATION_MODEL", "") IMAGE_SIZE = PersistentConfig(
"IMAGE_SIZE", "image_generation.size", os.getenv("IMAGE_SIZE", "512x512")
)
IMAGE_STEPS = PersistentConfig(
"IMAGE_STEPS", "image_generation.steps", int(os.getenv("IMAGE_STEPS", 50))
)
IMAGE_GENERATION_MODEL = PersistentConfig(
"IMAGE_GENERATION_MODEL",
"image_generation.model",
os.getenv("IMAGE_GENERATION_MODEL", ""),
)
#################################### ####################################
# Audio # Audio
#################################### ####################################
AUDIO_OPENAI_API_BASE_URL = os.getenv("AUDIO_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL) AUDIO_OPENAI_API_BASE_URL = PersistentConfig(
AUDIO_OPENAI_API_KEY = os.getenv("AUDIO_OPENAI_API_KEY", OPENAI_API_KEY) "AUDIO_OPENAI_API_BASE_URL",
AUDIO_OPENAI_API_MODEL = os.getenv("AUDIO_OPENAI_API_MODEL", "tts-1") "audio.openai.api_base_url",
AUDIO_OPENAI_API_VOICE = os.getenv("AUDIO_OPENAI_API_VOICE", "alloy") os.getenv("AUDIO_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL),
)
AUDIO_OPENAI_API_KEY = PersistentConfig(
"AUDIO_OPENAI_API_KEY",
"audio.openai.api_key",
os.getenv("AUDIO_OPENAI_API_KEY", OPENAI_API_KEY),
)
AUDIO_OPENAI_API_MODEL = PersistentConfig(
"AUDIO_OPENAI_API_MODEL",
"audio.openai.api_model",
os.getenv("AUDIO_OPENAI_API_MODEL", "tts-1"),
)
AUDIO_OPENAI_API_VOICE = PersistentConfig(
"AUDIO_OPENAI_API_VOICE",
"audio.openai.api_voice",
os.getenv("AUDIO_OPENAI_API_VOICE", "alloy"),
)
#################################### ####################################
# LiteLLM # LiteLLM

View File

@ -59,6 +59,7 @@ from config import (
SRC_LOG_LEVELS, SRC_LOG_LEVELS,
WEBHOOK_URL, WEBHOOK_URL,
ENABLE_ADMIN_EXPORT, ENABLE_ADMIN_EXPORT,
AppConfig,
) )
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
@ -107,10 +108,11 @@ app = FastAPI(
docs_url="/docs" if ENV == "dev" else None, redoc_url=None, lifespan=lifespan docs_url="/docs" if ENV == "dev" else None, redoc_url=None, lifespan=lifespan
) )
app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER app.state.config = AppConfig()
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
app.state.WEBHOOK_URL = WEBHOOK_URL app.state.config.WEBHOOK_URL = WEBHOOK_URL
origins = ["*"] origins = ["*"]
@ -250,9 +252,9 @@ async def get_app_config():
"version": VERSION, "version": VERSION,
"auth": WEBUI_AUTH, "auth": WEBUI_AUTH,
"default_locale": default_locale, "default_locale": default_locale,
"images": images_app.state.ENABLED, "images": images_app.state.config.ENABLED,
"default_models": webui_app.state.DEFAULT_MODELS, "default_models": webui_app.state.config.DEFAULT_MODELS,
"default_prompt_suggestions": webui_app.state.DEFAULT_PROMPT_SUGGESTIONS, "default_prompt_suggestions": webui_app.state.config.DEFAULT_PROMPT_SUGGESTIONS,
"trusted_header_auth": bool(webui_app.state.AUTH_TRUSTED_EMAIL_HEADER), "trusted_header_auth": bool(webui_app.state.AUTH_TRUSTED_EMAIL_HEADER),
"admin_export_enabled": ENABLE_ADMIN_EXPORT, "admin_export_enabled": ENABLE_ADMIN_EXPORT,
} }
@ -261,8 +263,8 @@ async def get_app_config():
@app.get("/api/config/model/filter") @app.get("/api/config/model/filter")
async def get_model_filter_config(user=Depends(get_admin_user)): async def get_model_filter_config(user=Depends(get_admin_user)):
return { return {
"enabled": app.state.ENABLE_MODEL_FILTER, "enabled": app.state.config.ENABLE_MODEL_FILTER,
"models": app.state.MODEL_FILTER_LIST, "models": app.state.config.MODEL_FILTER_LIST,
} }
@ -275,28 +277,28 @@ class ModelFilterConfigForm(BaseModel):
async def update_model_filter_config( async def update_model_filter_config(
form_data: ModelFilterConfigForm, user=Depends(get_admin_user) form_data: ModelFilterConfigForm, user=Depends(get_admin_user)
): ):
app.state.ENABLE_MODEL_FILTER = form_data.enabled app.state.config.ENABLE_MODEL_FILTER, form_data.enabled
app.state.MODEL_FILTER_LIST = form_data.models app.state.config.MODEL_FILTER_LIST, form_data.models
ollama_app.state.ENABLE_MODEL_FILTER = app.state.ENABLE_MODEL_FILTER ollama_app.state.ENABLE_MODEL_FILTER = app.state.config.ENABLE_MODEL_FILTER
ollama_app.state.MODEL_FILTER_LIST = app.state.MODEL_FILTER_LIST ollama_app.state.MODEL_FILTER_LIST = app.state.config.MODEL_FILTER_LIST
openai_app.state.ENABLE_MODEL_FILTER = app.state.ENABLE_MODEL_FILTER openai_app.state.ENABLE_MODEL_FILTER = app.state.config.ENABLE_MODEL_FILTER
openai_app.state.MODEL_FILTER_LIST = app.state.MODEL_FILTER_LIST openai_app.state.MODEL_FILTER_LIST = app.state.config.MODEL_FILTER_LIST
litellm_app.state.ENABLE_MODEL_FILTER = app.state.ENABLE_MODEL_FILTER litellm_app.state.ENABLE_MODEL_FILTER = app.state.config.ENABLE_MODEL_FILTER
litellm_app.state.MODEL_FILTER_LIST = app.state.MODEL_FILTER_LIST litellm_app.state.MODEL_FILTER_LIST = app.state.config.MODEL_FILTER_LIST
return { return {
"enabled": app.state.ENABLE_MODEL_FILTER, "enabled": app.state.config.ENABLE_MODEL_FILTER,
"models": app.state.MODEL_FILTER_LIST, "models": app.state.config.MODEL_FILTER_LIST,
} }
@app.get("/api/webhook") @app.get("/api/webhook")
async def get_webhook_url(user=Depends(get_admin_user)): async def get_webhook_url(user=Depends(get_admin_user)):
return { return {
"url": app.state.WEBHOOK_URL, "url": app.state.config.WEBHOOK_URL,
} }
@ -306,12 +308,12 @@ class UrlForm(BaseModel):
@app.post("/api/webhook") @app.post("/api/webhook")
async def update_webhook_url(form_data: UrlForm, user=Depends(get_admin_user)): async def update_webhook_url(form_data: UrlForm, user=Depends(get_admin_user)):
app.state.WEBHOOK_URL = form_data.url app.state.config.WEBHOOK_URL = form_data.url
webui_app.state.WEBHOOK_URL = app.state.WEBHOOK_URL webui_app.state.WEBHOOK_URL = app.state.config.WEBHOOK_URL
return { return {
"url": app.state.WEBHOOK_URL, "url": app.state.config.WEBHOOK_URL,
} }