From 058eb765687a542e4ce542d5fc8aea921ef85035 Mon Sep 17 00:00:00 2001 From: Jun Siang Cheah Date: Fri, 10 May 2024 13:36:10 +0800 Subject: [PATCH 1/6] feat: save UI config changes to config.json --- backend/apps/audio/main.py | 31 ++- backend/apps/images/main.py | 105 ++++---- backend/apps/ollama/main.py | 56 +++-- backend/apps/openai/main.py | 47 ++-- backend/apps/rag/main.py | 237 ++++++++++-------- backend/apps/web/main.py | 8 +- backend/apps/web/routers/auths.py | 38 +-- backend/apps/web/routers/configs.py | 9 +- backend/apps/web/routers/users.py | 8 +- backend/config.py | 366 ++++++++++++++++++++-------- backend/main.py | 42 ++-- 11 files changed, 611 insertions(+), 336 deletions(-) diff --git a/backend/apps/audio/main.py b/backend/apps/audio/main.py index 87732d7bc..c3dc6a2c4 100644 --- a/backend/apps/audio/main.py +++ b/backend/apps/audio/main.py @@ -45,6 +45,8 @@ from config import ( AUDIO_OPENAI_API_KEY, AUDIO_OPENAI_API_MODEL, AUDIO_OPENAI_API_VOICE, + config_get, + config_set, ) log = logging.getLogger(__name__) @@ -83,10 +85,10 @@ class OpenAIConfigUpdateForm(BaseModel): @app.get("/config") async def get_openai_config(user=Depends(get_admin_user)): return { - "OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL, - "OPENAI_API_KEY": app.state.OPENAI_API_KEY, - "OPENAI_API_MODEL": app.state.OPENAI_API_MODEL, - "OPENAI_API_VOICE": app.state.OPENAI_API_VOICE, + "OPENAI_API_BASE_URL": config_get(app.state.OPENAI_API_BASE_URL), + "OPENAI_API_KEY": config_get(app.state.OPENAI_API_KEY), + "OPENAI_API_MODEL": config_get(app.state.OPENAI_API_MODEL), + "OPENAI_API_VOICE": config_get(app.state.OPENAI_API_VOICE), } @@ -97,17 +99,22 @@ async def update_openai_config( if form_data.key == "": raise HTTPException(status_code=400, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND) - app.state.OPENAI_API_BASE_URL = form_data.url - app.state.OPENAI_API_KEY = form_data.key - app.state.OPENAI_API_MODEL = form_data.model - app.state.OPENAI_API_VOICE = form_data.speaker + config_set(app.state.OPENAI_API_BASE_URL, form_data.url) + config_set(app.state.OPENAI_API_KEY, form_data.key) + config_set(app.state.OPENAI_API_MODEL, form_data.model) + config_set(app.state.OPENAI_API_VOICE, form_data.speaker) + + app.state.OPENAI_API_BASE_URL.save() + app.state.OPENAI_API_KEY.save() + app.state.OPENAI_API_MODEL.save() + app.state.OPENAI_API_VOICE.save() return { "status": True, - "OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL, - "OPENAI_API_KEY": app.state.OPENAI_API_KEY, - "OPENAI_API_MODEL": app.state.OPENAI_API_MODEL, - "OPENAI_API_VOICE": app.state.OPENAI_API_VOICE, + "OPENAI_API_BASE_URL": config_get(app.state.OPENAI_API_BASE_URL), + "OPENAI_API_KEY": config_get(app.state.OPENAI_API_KEY), + "OPENAI_API_MODEL": config_get(app.state.OPENAI_API_MODEL), + "OPENAI_API_VOICE": config_get(app.state.OPENAI_API_VOICE), } diff --git a/backend/apps/images/main.py b/backend/apps/images/main.py index f45cf0d12..8ebfb0446 100644 --- a/backend/apps/images/main.py +++ b/backend/apps/images/main.py @@ -42,6 +42,8 @@ from config import ( IMAGE_GENERATION_MODEL, IMAGE_SIZE, IMAGE_STEPS, + config_get, + config_set, ) @@ -79,7 +81,10 @@ app.state.IMAGE_STEPS = IMAGE_STEPS @app.get("/config") async def get_config(request: Request, user=Depends(get_admin_user)): - return {"engine": app.state.ENGINE, "enabled": app.state.ENABLED} + return { + "engine": config_get(app.state.ENGINE), + "enabled": config_get(app.state.ENABLED), + } class ConfigUpdateForm(BaseModel): @@ -89,9 +94,12 @@ class ConfigUpdateForm(BaseModel): @app.post("/config/update") async def update_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)): - app.state.ENGINE = form_data.engine - app.state.ENABLED = form_data.enabled - return {"engine": app.state.ENGINE, "enabled": app.state.ENABLED} + config_set(app.state.ENGINE, form_data.engine) + config_set(app.state.ENABLED, form_data.enabled) + return { + "engine": config_get(app.state.ENGINE), + "enabled": config_get(app.state.ENABLED), + } class EngineUrlUpdateForm(BaseModel): @@ -102,8 +110,8 @@ class EngineUrlUpdateForm(BaseModel): @app.get("/url") async def get_engine_url(user=Depends(get_admin_user)): return { - "AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL, - "COMFYUI_BASE_URL": app.state.COMFYUI_BASE_URL, + "AUTOMATIC1111_BASE_URL": config_get(app.state.AUTOMATIC1111_BASE_URL), + "COMFYUI_BASE_URL": config_get(app.state.COMFYUI_BASE_URL), } @@ -113,29 +121,29 @@ async def update_engine_url( ): if form_data.AUTOMATIC1111_BASE_URL == None: - app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL + config_set(app.state.AUTOMATIC1111_BASE_URL, config_get(AUTOMATIC1111_BASE_URL)) else: url = form_data.AUTOMATIC1111_BASE_URL.strip("/") try: r = requests.head(url) - app.state.AUTOMATIC1111_BASE_URL = url + config_set(app.state.AUTOMATIC1111_BASE_URL, url) except Exception as e: raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) if form_data.COMFYUI_BASE_URL == None: - app.state.COMFYUI_BASE_URL = COMFYUI_BASE_URL + config_set(app.state.COMFYUI_BASE_URL, COMFYUI_BASE_URL) else: url = form_data.COMFYUI_BASE_URL.strip("/") try: r = requests.head(url) - app.state.COMFYUI_BASE_URL = url + config_set(app.state.COMFYUI_BASE_URL, url) except Exception as e: raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) return { - "AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL, - "COMFYUI_BASE_URL": app.state.COMFYUI_BASE_URL, + "AUTOMATIC1111_BASE_URL": config_get(app.state.AUTOMATIC1111_BASE_URL), + "COMFYUI_BASE_URL": config_get(app.state.COMFYUI_BASE_URL), "status": True, } @@ -148,8 +156,8 @@ class OpenAIConfigUpdateForm(BaseModel): @app.get("/openai/config") async def get_openai_config(user=Depends(get_admin_user)): return { - "OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL, - "OPENAI_API_KEY": app.state.OPENAI_API_KEY, + "OPENAI_API_BASE_URL": config_get(app.state.OPENAI_API_BASE_URL), + "OPENAI_API_KEY": config_get(app.state.OPENAI_API_KEY), } @@ -160,13 +168,13 @@ async def update_openai_config( if form_data.key == "": raise HTTPException(status_code=400, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND) - app.state.OPENAI_API_BASE_URL = form_data.url - app.state.OPENAI_API_KEY = form_data.key + config_set(app.state.OPENAI_API_BASE_URL, form_data.url) + config_set(app.state.OPENAI_API_KEY, form_data.key) return { "status": True, - "OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL, - "OPENAI_API_KEY": app.state.OPENAI_API_KEY, + "OPENAI_API_BASE_URL": config_get(app.state.OPENAI_API_BASE_URL), + "OPENAI_API_KEY": config_get(app.state.OPENAI_API_KEY), } @@ -176,7 +184,7 @@ class ImageSizeUpdateForm(BaseModel): @app.get("/size") async def get_image_size(user=Depends(get_admin_user)): - return {"IMAGE_SIZE": app.state.IMAGE_SIZE} + return {"IMAGE_SIZE": config_get(app.state.IMAGE_SIZE)} @app.post("/size/update") @@ -185,9 +193,9 @@ async def update_image_size( ): pattern = r"^\d+x\d+$" # Regular expression pattern if re.match(pattern, form_data.size): - app.state.IMAGE_SIZE = form_data.size + config_set(app.state.IMAGE_SIZE, form_data.size) return { - "IMAGE_SIZE": app.state.IMAGE_SIZE, + "IMAGE_SIZE": config_get(app.state.IMAGE_SIZE), "status": True, } else: @@ -203,7 +211,7 @@ class ImageStepsUpdateForm(BaseModel): @app.get("/steps") async def get_image_size(user=Depends(get_admin_user)): - return {"IMAGE_STEPS": app.state.IMAGE_STEPS} + return {"IMAGE_STEPS": config_get(app.state.IMAGE_STEPS)} @app.post("/steps/update") @@ -211,9 +219,9 @@ async def update_image_size( form_data: ImageStepsUpdateForm, user=Depends(get_admin_user) ): if form_data.steps >= 0: - app.state.IMAGE_STEPS = form_data.steps + config_set(app.state.IMAGE_STEPS, form_data.steps) return { - "IMAGE_STEPS": app.state.IMAGE_STEPS, + "IMAGE_STEPS": config_get(app.state.IMAGE_STEPS), "status": True, } else: @@ -263,15 +271,25 @@ def get_models(user=Depends(get_current_user)): async def get_default_model(user=Depends(get_admin_user)): try: if app.state.ENGINE == "openai": - return {"model": app.state.MODEL if app.state.MODEL else "dall-e-2"} + return { + "model": ( + config_get(app.state.MODEL) + if config_get(app.state.MODEL) + else "dall-e-2" + ) + } elif app.state.ENGINE == "comfyui": - return {"model": app.state.MODEL if app.state.MODEL else ""} + return { + "model": ( + config_get(app.state.MODEL) if config_get(app.state.MODEL) else "" + ) + } else: r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options") options = r.json() return {"model": options["sd_model_checkpoint"]} except Exception as e: - app.state.ENABLED = False + config_set(app.state.ENABLED, False) raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) @@ -280,12 +298,9 @@ class UpdateModelForm(BaseModel): def set_model_handler(model: str): - if app.state.ENGINE == "openai": - app.state.MODEL = model - return app.state.MODEL - if app.state.ENGINE == "comfyui": - app.state.MODEL = model - return app.state.MODEL + if app.state.ENGINE in ["openai", "comfyui"]: + config_set(app.state.MODEL, model) + return config_get(app.state.MODEL) else: r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options") options = r.json() @@ -382,7 +397,7 @@ def generate_image( user=Depends(get_current_user), ): - width, height = tuple(map(int, app.state.IMAGE_SIZE.split("x"))) + width, height = tuple(map(int, config_get(app.state.IMAGE_SIZE).split("x"))) r = None try: @@ -396,7 +411,11 @@ def generate_image( "model": app.state.MODEL if app.state.MODEL != "" else "dall-e-2", "prompt": form_data.prompt, "n": form_data.n, - "size": form_data.size if form_data.size else app.state.IMAGE_SIZE, + "size": ( + form_data.size + if form_data.size + else config_get(app.state.IMAGE_SIZE) + ), "response_format": "b64_json", } @@ -430,19 +449,19 @@ def generate_image( "n": form_data.n, } - if app.state.IMAGE_STEPS != None: - data["steps"] = app.state.IMAGE_STEPS + if config_get(app.state.IMAGE_STEPS) is not None: + data["steps"] = config_get(app.state.IMAGE_STEPS) - if form_data.negative_prompt != None: + if form_data.negative_prompt is not None: data["negative_prompt"] = form_data.negative_prompt data = ImageGenerationPayload(**data) res = comfyui_generate_image( - app.state.MODEL, + config_get(app.state.MODEL), data, user.id, - app.state.COMFYUI_BASE_URL, + config_get(app.state.COMFYUI_BASE_URL), ) log.debug(f"res: {res}") @@ -469,10 +488,10 @@ def generate_image( "height": height, } - if app.state.IMAGE_STEPS != None: - data["steps"] = app.state.IMAGE_STEPS + if config_get(app.state.IMAGE_STEPS) is not None: + data["steps"] = config_get(app.state.IMAGE_STEPS) - if form_data.negative_prompt != None: + if form_data.negative_prompt is not None: data["negative_prompt"] = form_data.negative_prompt r = requests.post( diff --git a/backend/apps/ollama/main.py b/backend/apps/ollama/main.py index 042d0336d..7dfadbb0c 100644 --- a/backend/apps/ollama/main.py +++ b/backend/apps/ollama/main.py @@ -46,6 +46,8 @@ from config import ( ENABLE_MODEL_FILTER, MODEL_FILTER_LIST, UPLOAD_DIR, + config_set, + config_get, ) from utils.misc import calculate_sha256 @@ -96,7 +98,7 @@ async def get_status(): @app.get("/urls") async def get_ollama_api_urls(user=Depends(get_admin_user)): - return {"OLLAMA_BASE_URLS": app.state.OLLAMA_BASE_URLS} + return {"OLLAMA_BASE_URLS": config_get(app.state.OLLAMA_BASE_URLS)} class UrlUpdateForm(BaseModel): @@ -105,10 +107,10 @@ class UrlUpdateForm(BaseModel): @app.post("/urls/update") async def update_ollama_api_url(form_data: UrlUpdateForm, user=Depends(get_admin_user)): - app.state.OLLAMA_BASE_URLS = form_data.urls + config_set(app.state.OLLAMA_BASE_URLS, form_data.urls) log.info(f"app.state.OLLAMA_BASE_URLS: {app.state.OLLAMA_BASE_URLS}") - return {"OLLAMA_BASE_URLS": app.state.OLLAMA_BASE_URLS} + return {"OLLAMA_BASE_URLS": config_get(app.state.OLLAMA_BASE_URLS)} @app.get("/cancel/{request_id}") @@ -153,7 +155,9 @@ def merge_models_lists(model_lists): async def get_all_models(): log.info("get_all_models()") - tasks = [fetch_url(f"{url}/api/tags") for url in app.state.OLLAMA_BASE_URLS] + tasks = [ + fetch_url(f"{url}/api/tags") for url in config_get(app.state.OLLAMA_BASE_URLS) + ] responses = await asyncio.gather(*tasks) models = { @@ -179,14 +183,15 @@ async def get_ollama_tags( if user.role == "user": models["models"] = list( filter( - lambda model: model["name"] in app.state.MODEL_FILTER_LIST, + lambda model: model["name"] + in config_get(app.state.MODEL_FILTER_LIST), models["models"], ) ) return models return models else: - url = app.state.OLLAMA_BASE_URLS[url_idx] + url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx] try: r = requests.request(method="GET", url=f"{url}/api/tags") r.raise_for_status() @@ -216,7 +221,10 @@ async def get_ollama_versions(url_idx: Optional[int] = None): if url_idx == None: # returns lowest version - tasks = [fetch_url(f"{url}/api/version") for url in app.state.OLLAMA_BASE_URLS] + tasks = [ + fetch_url(f"{url}/api/version") + for url in config_get(app.state.OLLAMA_BASE_URLS) + ] responses = await asyncio.gather(*tasks) responses = list(filter(lambda x: x is not None, responses)) @@ -235,7 +243,7 @@ async def get_ollama_versions(url_idx: Optional[int] = None): detail=ERROR_MESSAGES.OLLAMA_NOT_FOUND, ) else: - url = app.state.OLLAMA_BASE_URLS[url_idx] + url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx] try: r = requests.request(method="GET", url=f"{url}/api/version") r.raise_for_status() @@ -267,7 +275,7 @@ class ModelNameForm(BaseModel): async def pull_model( form_data: ModelNameForm, url_idx: int = 0, user=Depends(get_admin_user) ): - url = app.state.OLLAMA_BASE_URLS[url_idx] + url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx] log.info(f"url: {url}") r = None @@ -355,7 +363,7 @@ async def push_model( detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name), ) - url = app.state.OLLAMA_BASE_URLS[url_idx] + url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx] log.debug(f"url: {url}") r = None @@ -417,7 +425,7 @@ async def create_model( form_data: CreateModelForm, url_idx: int = 0, user=Depends(get_admin_user) ): log.debug(f"form_data: {form_data}") - url = app.state.OLLAMA_BASE_URLS[url_idx] + url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx] log.info(f"url: {url}") r = None @@ -490,7 +498,7 @@ async def copy_model( detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.source), ) - url = app.state.OLLAMA_BASE_URLS[url_idx] + url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx] log.info(f"url: {url}") try: @@ -537,7 +545,7 @@ async def delete_model( detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name), ) - url = app.state.OLLAMA_BASE_URLS[url_idx] + url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx] log.info(f"url: {url}") try: @@ -577,7 +585,7 @@ async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_us ) url_idx = random.choice(app.state.MODELS[form_data.name]["urls"]) - url = app.state.OLLAMA_BASE_URLS[url_idx] + url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx] log.info(f"url: {url}") try: @@ -634,7 +642,7 @@ async def generate_embeddings( detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), ) - url = app.state.OLLAMA_BASE_URLS[url_idx] + url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx] log.info(f"url: {url}") try: @@ -684,7 +692,7 @@ def generate_ollama_embeddings( detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), ) - url = app.state.OLLAMA_BASE_URLS[url_idx] + url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx] log.info(f"url: {url}") try: @@ -753,7 +761,7 @@ async def generate_completion( detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), ) - url = app.state.OLLAMA_BASE_URLS[url_idx] + url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx] log.info(f"url: {url}") r = None @@ -856,7 +864,7 @@ async def generate_chat_completion( detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), ) - url = app.state.OLLAMA_BASE_URLS[url_idx] + url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx] log.info(f"url: {url}") r = None @@ -965,7 +973,7 @@ async def generate_openai_chat_completion( detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), ) - url = app.state.OLLAMA_BASE_URLS[url_idx] + url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx] log.info(f"url: {url}") r = None @@ -1064,7 +1072,7 @@ async def get_openai_models( } else: - url = app.state.OLLAMA_BASE_URLS[url_idx] + url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx] try: r = requests.request(method="GET", url=f"{url}/api/tags") r.raise_for_status() @@ -1198,7 +1206,7 @@ async def download_model( if url_idx == None: url_idx = 0 - url = app.state.OLLAMA_BASE_URLS[url_idx] + url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx] file_name = parse_huggingface_url(form_data.url) @@ -1217,7 +1225,7 @@ async def download_model( def upload_model(file: UploadFile = File(...), url_idx: Optional[int] = None): if url_idx == None: url_idx = 0 - ollama_url = app.state.OLLAMA_BASE_URLS[url_idx] + ollama_url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx] file_path = f"{UPLOAD_DIR}/{file.filename}" @@ -1282,7 +1290,7 @@ def upload_model(file: UploadFile = File(...), url_idx: Optional[int] = None): # async def upload_model(file: UploadFile = File(), url_idx: Optional[int] = None): # if url_idx == None: # url_idx = 0 -# url = app.state.OLLAMA_BASE_URLS[url_idx] +# url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx] # file_location = os.path.join(UPLOAD_DIR, file.filename) # total_size = file.size @@ -1319,7 +1327,7 @@ def upload_model(file: UploadFile = File(...), url_idx: Optional[int] = None): async def deprecated_proxy( path: str, request: Request, user=Depends(get_verified_user) ): - url = app.state.OLLAMA_BASE_URLS[0] + url = config_get(app.state.OLLAMA_BASE_URLS)[0] target_url = f"{url}/{path}" body = await request.body() diff --git a/backend/apps/openai/main.py b/backend/apps/openai/main.py index b5d1e68d6..36fed104c 100644 --- a/backend/apps/openai/main.py +++ b/backend/apps/openai/main.py @@ -26,6 +26,8 @@ from config import ( CACHE_DIR, ENABLE_MODEL_FILTER, MODEL_FILTER_LIST, + config_set, + config_get, ) from typing import List, Optional @@ -75,32 +77,34 @@ class KeysUpdateForm(BaseModel): @app.get("/urls") async def get_openai_urls(user=Depends(get_admin_user)): - return {"OPENAI_API_BASE_URLS": app.state.OPENAI_API_BASE_URLS} + return {"OPENAI_API_BASE_URLS": config_get(app.state.OPENAI_API_BASE_URLS)} @app.post("/urls/update") async def update_openai_urls(form_data: UrlsUpdateForm, user=Depends(get_admin_user)): await get_all_models() - app.state.OPENAI_API_BASE_URLS = form_data.urls - return {"OPENAI_API_BASE_URLS": app.state.OPENAI_API_BASE_URLS} + config_set(app.state.OPENAI_API_BASE_URLS, form_data.urls) + return {"OPENAI_API_BASE_URLS": config_get(app.state.OPENAI_API_BASE_URLS)} @app.get("/keys") async def get_openai_keys(user=Depends(get_admin_user)): - return {"OPENAI_API_KEYS": app.state.OPENAI_API_KEYS} + return {"OPENAI_API_KEYS": config_get(app.state.OPENAI_API_KEYS)} @app.post("/keys/update") async def update_openai_key(form_data: KeysUpdateForm, user=Depends(get_admin_user)): - app.state.OPENAI_API_KEYS = form_data.keys - return {"OPENAI_API_KEYS": app.state.OPENAI_API_KEYS} + config_set(app.state.OPENAI_API_KEYS, form_data.keys) + return {"OPENAI_API_KEYS": config_get(app.state.OPENAI_API_KEYS)} @app.post("/audio/speech") async def speech(request: Request, user=Depends(get_verified_user)): idx = None try: - idx = app.state.OPENAI_API_BASE_URLS.index("https://api.openai.com/v1") + idx = config_get(app.state.OPENAI_API_BASE_URLS).index( + "https://api.openai.com/v1" + ) body = await request.body() name = hashlib.sha256(body).hexdigest() @@ -114,13 +118,15 @@ async def speech(request: Request, user=Depends(get_verified_user)): return FileResponse(file_path) headers = {} - headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEYS[idx]}" + headers["Authorization"] = ( + f"Bearer {config_get(app.state.OPENAI_API_KEYS)[idx]}" + ) headers["Content-Type"] = "application/json" r = None try: r = requests.post( - url=f"{app.state.OPENAI_API_BASE_URLS[idx]}/audio/speech", + url=f"{config_get(app.state.OPENAI_API_BASE_URLS)[idx]}/audio/speech", data=body, headers=headers, stream=True, @@ -180,7 +186,8 @@ def merge_models_lists(model_lists): [ {**model, "urlIdx": idx} for model in models - if "api.openai.com" not in app.state.OPENAI_API_BASE_URLS[idx] + if "api.openai.com" + not in config_get(app.state.OPENAI_API_BASE_URLS)[idx] or "gpt" in model["id"] ] ) @@ -191,12 +198,15 @@ def merge_models_lists(model_lists): async def get_all_models(): log.info("get_all_models()") - if len(app.state.OPENAI_API_KEYS) == 1 and app.state.OPENAI_API_KEYS[0] == "": + if ( + len(config_get(app.state.OPENAI_API_KEYS)) == 1 + and config_get(app.state.OPENAI_API_KEYS)[0] == "" + ): models = {"data": []} else: tasks = [ - fetch_url(f"{url}/models", app.state.OPENAI_API_KEYS[idx]) - for idx, url in enumerate(app.state.OPENAI_API_BASE_URLS) + fetch_url(f"{url}/models", config_get(app.state.OPENAI_API_KEYS)[idx]) + for idx, url in enumerate(config_get(app.state.OPENAI_API_BASE_URLS)) ] responses = await asyncio.gather(*tasks) @@ -228,18 +238,19 @@ async def get_all_models(): async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_user)): if url_idx == None: models = await get_all_models() - if app.state.ENABLE_MODEL_FILTER: + if config_get(app.state.ENABLE_MODEL_FILTER): if user.role == "user": models["data"] = list( filter( - lambda model: model["id"] in app.state.MODEL_FILTER_LIST, + lambda model: model["id"] + in config_get(app.state.MODEL_FILTER_LIST), models["data"], ) ) return models return models else: - url = app.state.OPENAI_API_BASE_URLS[url_idx] + url = config_get(app.state.OPENAI_API_BASE_URLS)[url_idx] r = None @@ -303,8 +314,8 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)): except json.JSONDecodeError as e: log.error("Error loading request body into a dictionary:", e) - url = app.state.OPENAI_API_BASE_URLS[idx] - key = app.state.OPENAI_API_KEYS[idx] + url = config_get(app.state.OPENAI_API_BASE_URLS)[idx] + key = config_get(app.state.OPENAI_API_KEYS)[idx] target_url = f"{url}/{path}" diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index 2e2a8e209..f05447a66 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -93,6 +93,8 @@ from config import ( RAG_TEMPLATE, ENABLE_RAG_LOCAL_WEB_FETCH, YOUTUBE_LOADER_LANGUAGE, + config_set, + config_get, ) from constants import ERROR_MESSAGES @@ -133,7 +135,7 @@ def update_embedding_model( embedding_model: str, update_model: bool = False, ): - if embedding_model and app.state.RAG_EMBEDDING_ENGINE == "": + if embedding_model and config_get(app.state.RAG_EMBEDDING_ENGINE) == "": app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer( get_model_path(embedding_model, update_model), device=DEVICE_TYPE, @@ -158,22 +160,22 @@ def update_reranking_model( update_embedding_model( - app.state.RAG_EMBEDDING_MODEL, + config_get(app.state.RAG_EMBEDDING_MODEL), RAG_EMBEDDING_MODEL_AUTO_UPDATE, ) update_reranking_model( - app.state.RAG_RERANKING_MODEL, + config_get(app.state.RAG_RERANKING_MODEL), RAG_RERANKING_MODEL_AUTO_UPDATE, ) app.state.EMBEDDING_FUNCTION = get_embedding_function( - app.state.RAG_EMBEDDING_ENGINE, - app.state.RAG_EMBEDDING_MODEL, + config_get(app.state.RAG_EMBEDDING_ENGINE), + config_get(app.state.RAG_EMBEDDING_MODEL), app.state.sentence_transformer_ef, - app.state.OPENAI_API_KEY, - app.state.OPENAI_API_BASE_URL, + config_get(app.state.OPENAI_API_KEY), + config_get(app.state.OPENAI_API_BASE_URL), ) origins = ["*"] @@ -200,12 +202,12 @@ class UrlForm(CollectionNameForm): async def get_status(): return { "status": True, - "chunk_size": app.state.CHUNK_SIZE, - "chunk_overlap": app.state.CHUNK_OVERLAP, - "template": app.state.RAG_TEMPLATE, - "embedding_engine": app.state.RAG_EMBEDDING_ENGINE, - "embedding_model": app.state.RAG_EMBEDDING_MODEL, - "reranking_model": app.state.RAG_RERANKING_MODEL, + "chunk_size": config_get(app.state.CHUNK_SIZE), + "chunk_overlap": config_get(app.state.CHUNK_OVERLAP), + "template": config_get(app.state.RAG_TEMPLATE), + "embedding_engine": config_get(app.state.RAG_EMBEDDING_ENGINE), + "embedding_model": config_get(app.state.RAG_EMBEDDING_MODEL), + "reranking_model": config_get(app.state.RAG_RERANKING_MODEL), } @@ -213,18 +215,21 @@ async def get_status(): async def get_embedding_config(user=Depends(get_admin_user)): return { "status": True, - "embedding_engine": app.state.RAG_EMBEDDING_ENGINE, - "embedding_model": app.state.RAG_EMBEDDING_MODEL, + "embedding_engine": config_get(app.state.RAG_EMBEDDING_ENGINE), + "embedding_model": config_get(app.state.RAG_EMBEDDING_MODEL), "openai_config": { - "url": app.state.OPENAI_API_BASE_URL, - "key": app.state.OPENAI_API_KEY, + "url": config_get(app.state.OPENAI_API_BASE_URL), + "key": config_get(app.state.OPENAI_API_KEY), }, } @app.get("/reranking") async def get_reraanking_config(user=Depends(get_admin_user)): - return {"status": True, "reranking_model": app.state.RAG_RERANKING_MODEL} + return { + "status": True, + "reranking_model": config_get(app.state.RAG_RERANKING_MODEL), + } class OpenAIConfigForm(BaseModel): @@ -246,31 +251,31 @@ async def update_embedding_config( f"Updating embedding model: {app.state.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}" ) try: - app.state.RAG_EMBEDDING_ENGINE = form_data.embedding_engine - app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model + config_set(app.state.RAG_EMBEDDING_ENGINE, form_data.embedding_engine) + config_set(app.state.RAG_EMBEDDING_MODEL, form_data.embedding_model) - if app.state.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]: + if config_get(app.state.RAG_EMBEDDING_ENGINE) in ["ollama", "openai"]: if form_data.openai_config != None: - app.state.OPENAI_API_BASE_URL = form_data.openai_config.url - app.state.OPENAI_API_KEY = form_data.openai_config.key + config_set(app.state.OPENAI_API_BASE_URL, form_data.openai_config.url) + config_set(app.state.OPENAI_API_KEY, form_data.openai_config.key) - update_embedding_model(app.state.RAG_EMBEDDING_MODEL, True) + update_embedding_model(config_get(app.state.RAG_EMBEDDING_MODEL), True) app.state.EMBEDDING_FUNCTION = get_embedding_function( - app.state.RAG_EMBEDDING_ENGINE, - app.state.RAG_EMBEDDING_MODEL, + config_get(app.state.RAG_EMBEDDING_ENGINE), + config_get(app.state.RAG_EMBEDDING_MODEL), app.state.sentence_transformer_ef, - app.state.OPENAI_API_KEY, - app.state.OPENAI_API_BASE_URL, + config_get(app.state.OPENAI_API_KEY), + config_get(app.state.OPENAI_API_BASE_URL), ) return { "status": True, - "embedding_engine": app.state.RAG_EMBEDDING_ENGINE, - "embedding_model": app.state.RAG_EMBEDDING_MODEL, + "embedding_engine": config_get(app.state.RAG_EMBEDDING_ENGINE), + "embedding_model": config_get(app.state.RAG_EMBEDDING_MODEL), "openai_config": { - "url": app.state.OPENAI_API_BASE_URL, - "key": app.state.OPENAI_API_KEY, + "url": config_get(app.state.OPENAI_API_BASE_URL), + "key": config_get(app.state.OPENAI_API_KEY), }, } except Exception as e: @@ -293,13 +298,13 @@ async def update_reranking_config( f"Updating reranking model: {app.state.RAG_RERANKING_MODEL} to {form_data.reranking_model}" ) try: - app.state.RAG_RERANKING_MODEL = form_data.reranking_model + config_set(app.state.RAG_RERANKING_MODEL, form_data.reranking_model) - update_reranking_model(app.state.RAG_RERANKING_MODEL, True) + update_reranking_model(config_get(app.state.RAG_RERANKING_MODEL), True) return { "status": True, - "reranking_model": app.state.RAG_RERANKING_MODEL, + "reranking_model": config_get(app.state.RAG_RERANKING_MODEL), } except Exception as e: log.exception(f"Problem updating reranking model: {e}") @@ -313,14 +318,16 @@ async def update_reranking_config( async def get_rag_config(user=Depends(get_admin_user)): return { "status": True, - "pdf_extract_images": app.state.PDF_EXTRACT_IMAGES, + "pdf_extract_images": config_get(app.state.PDF_EXTRACT_IMAGES), "chunk": { - "chunk_size": app.state.CHUNK_SIZE, - "chunk_overlap": app.state.CHUNK_OVERLAP, + "chunk_size": config_get(app.state.CHUNK_SIZE), + "chunk_overlap": config_get(app.state.CHUNK_OVERLAP), }, - "web_loader_ssl_verification": app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, + "web_loader_ssl_verification": config_get( + app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION + ), "youtube": { - "language": app.state.YOUTUBE_LOADER_LANGUAGE, + "language": config_get(app.state.YOUTUBE_LOADER_LANGUAGE), "translation": app.state.YOUTUBE_LOADER_TRANSLATION, }, } @@ -345,50 +352,69 @@ class ConfigUpdateForm(BaseModel): @app.post("/config/update") async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)): - app.state.PDF_EXTRACT_IMAGES = ( - form_data.pdf_extract_images - if form_data.pdf_extract_images != None - else app.state.PDF_EXTRACT_IMAGES + config_set( + app.state.PDF_EXTRACT_IMAGES, + ( + form_data.pdf_extract_images + if form_data.pdf_extract_images is not None + else config_get(app.state.PDF_EXTRACT_IMAGES) + ), ) - app.state.CHUNK_SIZE = ( - form_data.chunk.chunk_size if form_data.chunk != None else app.state.CHUNK_SIZE + config_set( + app.state.CHUNK_SIZE, + ( + form_data.chunk.chunk_size + if form_data.chunk is not None + else config_get(app.state.CHUNK_SIZE) + ), ) - app.state.CHUNK_OVERLAP = ( - form_data.chunk.chunk_overlap - if form_data.chunk != None - else app.state.CHUNK_OVERLAP + config_set( + app.state.CHUNK_OVERLAP, + ( + form_data.chunk.chunk_overlap + if form_data.chunk is not None + else config_get(app.state.CHUNK_OVERLAP) + ), ) - app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = ( - form_data.web_loader_ssl_verification - if form_data.web_loader_ssl_verification != None - else app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION + config_set( + app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, + ( + form_data.web_loader_ssl_verification + if form_data.web_loader_ssl_verification != None + else config_get(app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION) + ), ) - app.state.YOUTUBE_LOADER_LANGUAGE = ( - form_data.youtube.language - if form_data.youtube != None - else app.state.YOUTUBE_LOADER_LANGUAGE + config_set( + app.state.YOUTUBE_LOADER_LANGUAGE, + ( + form_data.youtube.language + if form_data.youtube is not None + else config_get(app.state.YOUTUBE_LOADER_LANGUAGE) + ), ) app.state.YOUTUBE_LOADER_TRANSLATION = ( form_data.youtube.translation - if form_data.youtube != None + if form_data.youtube is not None else app.state.YOUTUBE_LOADER_TRANSLATION ) return { "status": True, - "pdf_extract_images": app.state.PDF_EXTRACT_IMAGES, + "pdf_extract_images": config_get(app.state.PDF_EXTRACT_IMAGES), "chunk": { - "chunk_size": app.state.CHUNK_SIZE, - "chunk_overlap": app.state.CHUNK_OVERLAP, + "chunk_size": config_get(app.state.CHUNK_SIZE), + "chunk_overlap": config_get(app.state.CHUNK_OVERLAP), }, - "web_loader_ssl_verification": app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, + "web_loader_ssl_verification": config_get( + app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION + ), "youtube": { - "language": app.state.YOUTUBE_LOADER_LANGUAGE, + "language": config_get(app.state.YOUTUBE_LOADER_LANGUAGE), "translation": app.state.YOUTUBE_LOADER_TRANSLATION, }, } @@ -398,7 +424,7 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_ async def get_rag_template(user=Depends(get_current_user)): return { "status": True, - "template": app.state.RAG_TEMPLATE, + "template": config_get(app.state.RAG_TEMPLATE), } @@ -406,10 +432,10 @@ async def get_rag_template(user=Depends(get_current_user)): async def get_query_settings(user=Depends(get_admin_user)): return { "status": True, - "template": app.state.RAG_TEMPLATE, - "k": app.state.TOP_K, - "r": app.state.RELEVANCE_THRESHOLD, - "hybrid": app.state.ENABLE_RAG_HYBRID_SEARCH, + "template": config_get(app.state.RAG_TEMPLATE), + "k": config_get(app.state.TOP_K), + "r": config_get(app.state.RELEVANCE_THRESHOLD), + "hybrid": config_get(app.state.ENABLE_RAG_HYBRID_SEARCH), } @@ -424,16 +450,22 @@ class QuerySettingsForm(BaseModel): async def update_query_settings( form_data: QuerySettingsForm, user=Depends(get_admin_user) ): - app.state.RAG_TEMPLATE = form_data.template if form_data.template else RAG_TEMPLATE - app.state.TOP_K = form_data.k if form_data.k else 4 - app.state.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0 - app.state.ENABLE_RAG_HYBRID_SEARCH = form_data.hybrid if form_data.hybrid else False + config_set( + app.state.RAG_TEMPLATE, + form_data.template if form_data.template else RAG_TEMPLATE, + ) + config_set(app.state.TOP_K, form_data.k if form_data.k else 4) + config_set(app.state.RELEVANCE_THRESHOLD, form_data.r if form_data.r else 0.0) + config_set( + app.state.ENABLE_RAG_HYBRID_SEARCH, + form_data.hybrid if form_data.hybrid else False, + ) return { "status": True, - "template": app.state.RAG_TEMPLATE, - "k": app.state.TOP_K, - "r": app.state.RELEVANCE_THRESHOLD, - "hybrid": app.state.ENABLE_RAG_HYBRID_SEARCH, + "template": config_get(app.state.RAG_TEMPLATE), + "k": config_get(app.state.TOP_K), + "r": config_get(app.state.RELEVANCE_THRESHOLD), + "hybrid": config_get(app.state.ENABLE_RAG_HYBRID_SEARCH), } @@ -451,21 +483,25 @@ def query_doc_handler( user=Depends(get_current_user), ): try: - if app.state.ENABLE_RAG_HYBRID_SEARCH: + if config_get(app.state.ENABLE_RAG_HYBRID_SEARCH): return query_doc_with_hybrid_search( collection_name=form_data.collection_name, query=form_data.query, embedding_function=app.state.EMBEDDING_FUNCTION, - k=form_data.k if form_data.k else app.state.TOP_K, + k=form_data.k if form_data.k else config_get(app.state.TOP_K), reranking_function=app.state.sentence_transformer_rf, - r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD, + r=( + form_data.r + if form_data.r + else config_get(app.state.RELEVANCE_THRESHOLD) + ), ) else: return query_doc( collection_name=form_data.collection_name, query=form_data.query, embedding_function=app.state.EMBEDDING_FUNCTION, - k=form_data.k if form_data.k else app.state.TOP_K, + k=form_data.k if form_data.k else config_get(app.state.TOP_K), ) except Exception as e: log.exception(e) @@ -489,21 +525,25 @@ def query_collection_handler( user=Depends(get_current_user), ): try: - if app.state.ENABLE_RAG_HYBRID_SEARCH: + if config_get(app.state.ENABLE_RAG_HYBRID_SEARCH): return query_collection_with_hybrid_search( collection_names=form_data.collection_names, query=form_data.query, embedding_function=app.state.EMBEDDING_FUNCTION, - k=form_data.k if form_data.k else app.state.TOP_K, + k=form_data.k if form_data.k else config_get(app.state.TOP_K), reranking_function=app.state.sentence_transformer_rf, - r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD, + r=( + form_data.r + if form_data.r + else config_get(app.state.RELEVANCE_THRESHOLD) + ), ) else: return query_collection( collection_names=form_data.collection_names, query=form_data.query, embedding_function=app.state.EMBEDDING_FUNCTION, - k=form_data.k if form_data.k else app.state.TOP_K, + k=form_data.k if form_data.k else config_get(app.state.TOP_K), ) except Exception as e: @@ -520,8 +560,8 @@ def store_youtube_video(form_data: UrlForm, user=Depends(get_current_user)): loader = YoutubeLoader.from_youtube_url( form_data.url, add_video_info=True, - language=app.state.YOUTUBE_LOADER_LANGUAGE, - translation=app.state.YOUTUBE_LOADER_TRANSLATION, + language=config_get(app.state.YOUTUBE_LOADER_LANGUAGE), + translation=config_get(app.state.YOUTUBE_LOADER_TRANSLATION), ) data = loader.load() @@ -548,7 +588,8 @@ def store_web(form_data: UrlForm, user=Depends(get_current_user)): # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm" try: loader = get_web_loader( - form_data.url, verify_ssl=app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION + form_data.url, + verify_ssl=config_get(app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION), ) data = loader.load() @@ -604,8 +645,8 @@ def resolve_hostname(hostname): def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool: text_splitter = RecursiveCharacterTextSplitter( - chunk_size=app.state.CHUNK_SIZE, - chunk_overlap=app.state.CHUNK_OVERLAP, + chunk_size=config_get(app.state.CHUNK_SIZE), + chunk_overlap=config_get(app.state.CHUNK_OVERLAP), add_start_index=True, ) @@ -622,8 +663,8 @@ def store_text_in_vector_db( text, metadata, collection_name, overwrite: bool = False ) -> bool: text_splitter = RecursiveCharacterTextSplitter( - chunk_size=app.state.CHUNK_SIZE, - chunk_overlap=app.state.CHUNK_OVERLAP, + chunk_size=config_get(app.state.CHUNK_SIZE), + chunk_overlap=config_get(app.state.CHUNK_OVERLAP), add_start_index=True, ) docs = text_splitter.create_documents([text], metadatas=[metadata]) @@ -646,11 +687,11 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b collection = CHROMA_CLIENT.create_collection(name=collection_name) embedding_func = get_embedding_function( - app.state.RAG_EMBEDDING_ENGINE, - app.state.RAG_EMBEDDING_MODEL, + config_get(app.state.RAG_EMBEDDING_ENGINE), + config_get(app.state.RAG_EMBEDDING_MODEL), app.state.sentence_transformer_ef, - app.state.OPENAI_API_KEY, - app.state.OPENAI_API_BASE_URL, + config_get(app.state.OPENAI_API_KEY), + config_get(app.state.OPENAI_API_BASE_URL), ) embedding_texts = list(map(lambda x: x.replace("\n", " "), texts)) @@ -724,7 +765,9 @@ def get_loader(filename: str, file_content_type: str, file_path: str): ] if file_ext == "pdf": - loader = PyPDFLoader(file_path, extract_images=app.state.PDF_EXTRACT_IMAGES) + loader = PyPDFLoader( + file_path, extract_images=config_get(app.state.PDF_EXTRACT_IMAGES) + ) elif file_ext == "csv": loader = CSVLoader(file_path) elif file_ext == "rst": diff --git a/backend/apps/web/main.py b/backend/apps/web/main.py index 66cdfb3d4..2bed33543 100644 --- a/backend/apps/web/main.py +++ b/backend/apps/web/main.py @@ -21,6 +21,8 @@ from config import ( USER_PERMISSIONS, WEBHOOK_URL, WEBUI_AUTH_TRUSTED_EMAIL_HEADER, + JWT_EXPIRES_IN, + config_get, ) app = FastAPI() @@ -28,7 +30,7 @@ app = FastAPI() origins = ["*"] app.state.ENABLE_SIGNUP = ENABLE_SIGNUP -app.state.JWT_EXPIRES_IN = "-1" +app.state.JWT_EXPIRES_IN = JWT_EXPIRES_IN app.state.DEFAULT_MODELS = DEFAULT_MODELS app.state.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS @@ -61,6 +63,6 @@ async def get_status(): return { "status": True, "auth": WEBUI_AUTH, - "default_models": app.state.DEFAULT_MODELS, - "default_prompt_suggestions": app.state.DEFAULT_PROMPT_SUGGESTIONS, + "default_models": config_get(app.state.DEFAULT_MODELS), + "default_prompt_suggestions": config_get(app.state.DEFAULT_PROMPT_SUGGESTIONS), } diff --git a/backend/apps/web/routers/auths.py b/backend/apps/web/routers/auths.py index 9fa962dda..0bc4967f9 100644 --- a/backend/apps/web/routers/auths.py +++ b/backend/apps/web/routers/auths.py @@ -33,7 +33,7 @@ from utils.utils import ( from utils.misc import parse_duration, validate_email_format from utils.webhook import post_webhook from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES -from config import WEBUI_AUTH, WEBUI_AUTH_TRUSTED_EMAIL_HEADER +from config import WEBUI_AUTH, WEBUI_AUTH_TRUSTED_EMAIL_HEADER, config_get, config_set router = APIRouter() @@ -140,7 +140,7 @@ async def signin(request: Request, form_data: SigninForm): if user: token = create_token( data={"id": user.id}, - expires_delta=parse_duration(request.app.state.JWT_EXPIRES_IN), + expires_delta=parse_duration(config_get(request.app.state.JWT_EXPIRES_IN)), ) return { @@ -163,7 +163,7 @@ async def signin(request: Request, form_data: SigninForm): @router.post("/signup", response_model=SigninResponse) async def signup(request: Request, form_data: SignupForm): - if not request.app.state.ENABLE_SIGNUP and WEBUI_AUTH: + if not config_get(request.app.state.ENABLE_SIGNUP) and WEBUI_AUTH: raise HTTPException( status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED ) @@ -180,7 +180,7 @@ async def signup(request: Request, form_data: SignupForm): role = ( "admin" if Users.get_num_users() == 0 - else request.app.state.DEFAULT_USER_ROLE + else config_get(request.app.state.DEFAULT_USER_ROLE) ) hashed = get_password_hash(form_data.password) user = Auths.insert_new_auth( @@ -194,13 +194,15 @@ async def signup(request: Request, form_data: SignupForm): if user: token = create_token( data={"id": user.id}, - expires_delta=parse_duration(request.app.state.JWT_EXPIRES_IN), + expires_delta=parse_duration( + config_get(request.app.state.JWT_EXPIRES_IN) + ), ) # response.set_cookie(key='token', value=token, httponly=True) - if request.app.state.WEBHOOK_URL: + if config_get(request.app.state.WEBHOOK_URL): post_webhook( - request.app.state.WEBHOOK_URL, + config_get(request.app.state.WEBHOOK_URL), WEBHOOK_MESSAGES.USER_SIGNUP(user.name), { "action": "signup", @@ -276,13 +278,15 @@ async def add_user(form_data: AddUserForm, user=Depends(get_admin_user)): @router.get("/signup/enabled", response_model=bool) async def get_sign_up_status(request: Request, user=Depends(get_admin_user)): - return request.app.state.ENABLE_SIGNUP + return config_get(request.app.state.ENABLE_SIGNUP) @router.get("/signup/enabled/toggle", response_model=bool) async def toggle_sign_up(request: Request, user=Depends(get_admin_user)): - request.app.state.ENABLE_SIGNUP = not request.app.state.ENABLE_SIGNUP - return request.app.state.ENABLE_SIGNUP + config_set( + request.app.state.ENABLE_SIGNUP, not config_get(request.app.state.ENABLE_SIGNUP) + ) + return config_get(request.app.state.ENABLE_SIGNUP) ############################ @@ -292,7 +296,7 @@ async def toggle_sign_up(request: Request, user=Depends(get_admin_user)): @router.get("/signup/user/role") async def get_default_user_role(request: Request, user=Depends(get_admin_user)): - return request.app.state.DEFAULT_USER_ROLE + return config_get(request.app.state.DEFAULT_USER_ROLE) class UpdateRoleForm(BaseModel): @@ -304,8 +308,8 @@ async def update_default_user_role( request: Request, form_data: UpdateRoleForm, user=Depends(get_admin_user) ): if form_data.role in ["pending", "user", "admin"]: - request.app.state.DEFAULT_USER_ROLE = form_data.role - return request.app.state.DEFAULT_USER_ROLE + config_set(request.app.state.DEFAULT_USER_ROLE, form_data.role) + return config_get(request.app.state.DEFAULT_USER_ROLE) ############################ @@ -315,7 +319,7 @@ async def update_default_user_role( @router.get("/token/expires") async def get_token_expires_duration(request: Request, user=Depends(get_admin_user)): - return request.app.state.JWT_EXPIRES_IN + return config_get(request.app.state.JWT_EXPIRES_IN) class UpdateJWTExpiresDurationForm(BaseModel): @@ -332,10 +336,10 @@ async def update_token_expires_duration( # Check if the input string matches the pattern if re.match(pattern, form_data.duration): - request.app.state.JWT_EXPIRES_IN = form_data.duration - return request.app.state.JWT_EXPIRES_IN + config_set(request.app.state.JWT_EXPIRES_IN, form_data.duration) + return config_get(request.app.state.JWT_EXPIRES_IN) else: - return request.app.state.JWT_EXPIRES_IN + return config_get(request.app.state.JWT_EXPIRES_IN) ############################ diff --git a/backend/apps/web/routers/configs.py b/backend/apps/web/routers/configs.py index 0bad55a6a..d726cd2dc 100644 --- a/backend/apps/web/routers/configs.py +++ b/backend/apps/web/routers/configs.py @@ -9,6 +9,7 @@ import time import uuid from apps.web.models.users import Users +from config import config_set, config_get from utils.utils import ( get_password_hash, @@ -44,8 +45,8 @@ class SetDefaultSuggestionsForm(BaseModel): async def set_global_default_models( request: Request, form_data: SetDefaultModelsForm, user=Depends(get_admin_user) ): - request.app.state.DEFAULT_MODELS = form_data.models - return request.app.state.DEFAULT_MODELS + config_set(request.app.state.DEFAULT_MODELS, form_data.models) + return config_get(request.app.state.DEFAULT_MODELS) @router.post("/default/suggestions", response_model=List[PromptSuggestion]) @@ -55,5 +56,5 @@ async def set_global_default_suggestions( user=Depends(get_admin_user), ): data = form_data.model_dump() - request.app.state.DEFAULT_PROMPT_SUGGESTIONS = data["suggestions"] - return request.app.state.DEFAULT_PROMPT_SUGGESTIONS + config_set(request.app.state.DEFAULT_PROMPT_SUGGESTIONS, data["suggestions"]) + return config_get(request.app.state.DEFAULT_PROMPT_SUGGESTIONS) diff --git a/backend/apps/web/routers/users.py b/backend/apps/web/routers/users.py index 59f6c21b7..302432540 100644 --- a/backend/apps/web/routers/users.py +++ b/backend/apps/web/routers/users.py @@ -15,7 +15,7 @@ from apps.web.models.auths import Auths from utils.utils import get_current_user, get_password_hash, get_admin_user from constants import ERROR_MESSAGES -from config import SRC_LOG_LEVELS +from config import SRC_LOG_LEVELS, config_set, config_get log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MODELS"]) @@ -39,15 +39,15 @@ async def get_users(skip: int = 0, limit: int = 50, user=Depends(get_admin_user) @router.get("/permissions/user") async def get_user_permissions(request: Request, user=Depends(get_admin_user)): - return request.app.state.USER_PERMISSIONS + return config_get(request.app.state.USER_PERMISSIONS) @router.post("/permissions/user") async def update_user_permissions( request: Request, form_data: dict, user=Depends(get_admin_user) ): - request.app.state.USER_PERMISSIONS = form_data - return request.app.state.USER_PERMISSIONS + config_set(request.app.state.USER_PERMISSIONS, form_data) + return config_get(request.app.state.USER_PERMISSIONS) ############################ diff --git a/backend/config.py b/backend/config.py index 5c6247a9f..028e6caf0 100644 --- a/backend/config.py +++ b/backend/config.py @@ -5,6 +5,7 @@ import chromadb from chromadb import Settings from base64 import b64encode from bs4 import BeautifulSoup +from typing import TypeVar, Generic, Union from pathlib import Path import json @@ -17,7 +18,6 @@ import shutil from secrets import token_bytes from constants import ERROR_MESSAGES - #################################### # Load .env file #################################### @@ -29,7 +29,6 @@ try: except ImportError: print("dotenv not installed, skipping...") - #################################### # LOGGING #################################### @@ -71,7 +70,6 @@ for source in log_sources: log.setLevel(SRC_LOG_LEVELS["CONFIG"]) - WEBUI_NAME = os.environ.get("WEBUI_NAME", "Open WebUI") if WEBUI_NAME != "Open WebUI": WEBUI_NAME += " (Open WebUI)" @@ -80,7 +78,6 @@ WEBUI_URL = os.environ.get("WEBUI_URL", "http://localhost:3000") WEBUI_FAVICON_URL = "https://openwebui.com/favicon.png" - #################################### # ENV (dev,test,prod) #################################### @@ -151,26 +148,14 @@ for version in soup.find_all("h2"): changelog_json[version_number] = version_data - CHANGELOG = changelog_json - #################################### # WEBUI_VERSION #################################### WEBUI_VERSION = os.environ.get("WEBUI_VERSION", "v1.0.0-alpha.100") -#################################### -# WEBUI_AUTH (Required for security) -#################################### - -WEBUI_AUTH = os.environ.get("WEBUI_AUTH", "True").lower() == "true" -WEBUI_AUTH_TRUSTED_EMAIL_HEADER = os.environ.get( - "WEBUI_AUTH_TRUSTED_EMAIL_HEADER", None -) - - #################################### # DATA/FRONTEND BUILD DIR #################################### @@ -184,6 +169,93 @@ try: except: CONFIG_DATA = {} + +#################################### +# Config helpers +#################################### + + +def save_config(): + try: + with open(f"{DATA_DIR}/config.json", "w") as f: + json.dump(CONFIG_DATA, f, indent="\t") + except Exception as e: + log.exception(e) + + +def get_config_value(config_path: str): + path_parts = config_path.split(".") + cur_config = CONFIG_DATA + for key in path_parts: + if key in cur_config: + cur_config = cur_config[key] + else: + return None + return cur_config + + +T = TypeVar("T") + + +class WrappedConfig(Generic[T]): + def __init__(self, env_name: str, config_path: str, env_value: T): + self.env_name = env_name + self.config_path = config_path + self.env_value = env_value + self.config_value = get_config_value(config_path) + if self.config_value is not None: + log.info(f"'{env_name}' loaded from config.json") + self.value = self.config_value + else: + self.value = env_value + + def __str__(self): + return str(self.value) + + def save(self): + # Don't save if the value is the same as the env value and the config value + if self.env_value == self.value: + if self.config_value == self.value: + return + log.info(f"Saving '{self.env_name}' to config.json") + path_parts = self.config_path.split(".") + config = CONFIG_DATA + for key in path_parts[:-1]: + if key not in config: + config[key] = {} + config = config[key] + config[path_parts[-1]] = self.value + save_config() + self.config_value = self.value + + +def config_set(config: Union[WrappedConfig[T], T], value: T, save_config=True): + if isinstance(config, WrappedConfig): + config.value = value + if save_config: + config.save() + else: + config = value + + +def config_get(config: Union[WrappedConfig[T], T]) -> T: + if isinstance(config, WrappedConfig): + return config.value + return config + + +#################################### +# WEBUI_AUTH (Required for security) +#################################### + +WEBUI_AUTH = os.environ.get("WEBUI_AUTH", "True").lower() == "true" +WEBUI_AUTH_TRUSTED_EMAIL_HEADER = os.environ.get( + "WEBUI_AUTH_TRUSTED_EMAIL_HEADER", None +) +JWT_EXPIRES_IN = WrappedConfig( + "JWT_EXPIRES_IN", "auth.jwt_expiry", os.environ.get("JWT_EXPIRES_IN", "-1") +) + #################################### # Static DIR #################################### @@ -225,7 +297,6 @@ if CUSTOM_NAME: log.exception(e) pass - #################################### # File Upload DIR #################################### @@ -233,7 +304,6 @@ if CUSTOM_NAME: UPLOAD_DIR = f"{DATA_DIR}/uploads" Path(UPLOAD_DIR).mkdir(parents=True, exist_ok=True) - #################################### # Cache DIR #################################### @@ -241,7 +311,6 @@ Path(UPLOAD_DIR).mkdir(parents=True, exist_ok=True) CACHE_DIR = f"{DATA_DIR}/cache" Path(CACHE_DIR).mkdir(parents=True, exist_ok=True) - #################################### # Docs DIR #################################### @@ -282,7 +351,6 @@ if not os.path.exists(LITELLM_CONFIG_PATH): create_config_file(LITELLM_CONFIG_PATH) log.info("Config file created successfully.") - #################################### # OLLAMA_BASE_URL #################################### @@ -313,12 +381,13 @@ if ENV == "prod": elif K8S_FLAG: OLLAMA_BASE_URL = "http://ollama-service.open-webui.svc.cluster.local:11434" - OLLAMA_BASE_URLS = os.environ.get("OLLAMA_BASE_URLS", "") OLLAMA_BASE_URLS = OLLAMA_BASE_URLS if OLLAMA_BASE_URLS != "" else OLLAMA_BASE_URL OLLAMA_BASE_URLS = [url.strip() for url in OLLAMA_BASE_URLS.split(";")] - +OLLAMA_BASE_URLS = WrappedConfig( + "OLLAMA_BASE_URLS", "ollama.base_urls", OLLAMA_BASE_URLS +) #################################### # OPENAI_API @@ -327,7 +396,6 @@ OLLAMA_BASE_URLS = [url.strip() for url in OLLAMA_BASE_URLS.split(";")] OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "") OPENAI_API_BASE_URL = os.environ.get("OPENAI_API_BASE_URL", "") - if OPENAI_API_BASE_URL == "": OPENAI_API_BASE_URL = "https://api.openai.com/v1" @@ -335,7 +403,7 @@ OPENAI_API_KEYS = os.environ.get("OPENAI_API_KEYS", "") OPENAI_API_KEYS = OPENAI_API_KEYS if OPENAI_API_KEYS != "" else OPENAI_API_KEY OPENAI_API_KEYS = [url.strip() for url in OPENAI_API_KEYS.split(";")] - +OPENAI_API_KEYS = WrappedConfig("OPENAI_API_KEYS", "openai.api_keys", OPENAI_API_KEYS) OPENAI_API_BASE_URLS = os.environ.get("OPENAI_API_BASE_URLS", "") OPENAI_API_BASE_URLS = ( @@ -346,37 +414,42 @@ OPENAI_API_BASE_URLS = [ url.strip() if url != "" else "https://api.openai.com/v1" for url in OPENAI_API_BASE_URLS.split(";") ] +OPENAI_API_BASE_URLS = WrappedConfig( + "OPENAI_API_BASE_URLS", "openai.api_base_urls", OPENAI_API_BASE_URLS +) OPENAI_API_KEY = "" try: - OPENAI_API_KEY = OPENAI_API_KEYS[ - OPENAI_API_BASE_URLS.index("https://api.openai.com/v1") + OPENAI_API_KEY = OPENAI_API_KEYS.value[ + OPENAI_API_BASE_URLS.value.index("https://api.openai.com/v1") ] except: pass OPENAI_API_BASE_URL = "https://api.openai.com/v1" - #################################### # WEBUI #################################### -ENABLE_SIGNUP = ( - False - if WEBUI_AUTH == False - else os.environ.get("ENABLE_SIGNUP", "True").lower() == "true" +ENABLE_SIGNUP = WrappedConfig( + "ENABLE_SIGNUP", + "ui.enable_signup", + ( + False + if not WEBUI_AUTH + else os.environ.get("ENABLE_SIGNUP", "True").lower() == "true" + ), +) +DEFAULT_MODELS = WrappedConfig( + "DEFAULT_MODELS", "ui.default_models", os.environ.get("DEFAULT_MODELS", None) ) -DEFAULT_MODELS = os.environ.get("DEFAULT_MODELS", None) - -DEFAULT_PROMPT_SUGGESTIONS = ( - CONFIG_DATA["ui"]["prompt_suggestions"] - if "ui" in CONFIG_DATA - and "prompt_suggestions" in CONFIG_DATA["ui"] - and type(CONFIG_DATA["ui"]["prompt_suggestions"]) is list - else [ +DEFAULT_PROMPT_SUGGESTIONS = WrappedConfig( + "DEFAULT_PROMPT_SUGGESTIONS", + "ui.prompt_suggestions", + [ { "title": ["Help me study", "vocabulary for a college entrance exam"], "content": "Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option.", @@ -404,23 +477,42 @@ DEFAULT_PROMPT_SUGGESTIONS = ( "title": ["Overcome procrastination", "give me tips"], "content": "Could you start by asking me about instances when I procrastinate the most and then give me some suggestions to overcome it?", }, - ] + ], ) - -DEFAULT_USER_ROLE = os.getenv("DEFAULT_USER_ROLE", "pending") - -USER_PERMISSIONS_CHAT_DELETION = ( - os.environ.get("USER_PERMISSIONS_CHAT_DELETION", "True").lower() == "true" +DEFAULT_USER_ROLE = WrappedConfig( + "DEFAULT_USER_ROLE", + "ui.default_user_role", + os.getenv("DEFAULT_USER_ROLE", "pending"), ) -USER_PERMISSIONS = {"chat": {"deletion": USER_PERMISSIONS_CHAT_DELETION}} +USER_PERMISSIONS_CHAT_DELETION = WrappedConfig( + "USER_PERMISSIONS_CHAT_DELETION", + "ui.user_permissions.chat.deletion", + os.environ.get("USER_PERMISSIONS_CHAT_DELETION", "True").lower() == "true", +) -ENABLE_MODEL_FILTER = os.environ.get("ENABLE_MODEL_FILTER", "False").lower() == "true" +USER_PERMISSIONS = WrappedConfig( + "USER_PERMISSIONS", + "ui.user_permissions", + {"chat": {"deletion": USER_PERMISSIONS_CHAT_DELETION}}, +) + +ENABLE_MODEL_FILTER = WrappedConfig( + "ENABLE_MODEL_FILTER", + "model_filter.enable", + os.environ.get("ENABLE_MODEL_FILTER", "False").lower() == "true", +) MODEL_FILTER_LIST = os.environ.get("MODEL_FILTER_LIST", "") -MODEL_FILTER_LIST = [model.strip() for model in MODEL_FILTER_LIST.split(";")] +MODEL_FILTER_LIST = WrappedConfig( + "MODEL_FILTER_LIST", + "model_filter.list", + [model.strip() for model in MODEL_FILTER_LIST.split(";")], +) -WEBHOOK_URL = os.environ.get("WEBHOOK_URL", "") +WEBHOOK_URL = WrappedConfig( + "WEBHOOK_URL", "webhook_url", os.environ.get("WEBHOOK_URL", "") +) ENABLE_ADMIN_EXPORT = os.environ.get("ENABLE_ADMIN_EXPORT", "True").lower() == "true" @@ -458,26 +550,45 @@ else: CHROMA_HTTP_SSL = os.environ.get("CHROMA_HTTP_SSL", "false").lower() == "true" # this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (sentence-transformers/all-MiniLM-L6-v2) -RAG_TOP_K = int(os.environ.get("RAG_TOP_K", "5")) -RAG_RELEVANCE_THRESHOLD = float(os.environ.get("RAG_RELEVANCE_THRESHOLD", "0.0")) - -ENABLE_RAG_HYBRID_SEARCH = ( - os.environ.get("ENABLE_RAG_HYBRID_SEARCH", "").lower() == "true" +RAG_TOP_K = WrappedConfig( + "RAG_TOP_K", "rag.top_k", int(os.environ.get("RAG_TOP_K", "5")) +) +RAG_RELEVANCE_THRESHOLD = WrappedConfig( + "RAG_RELEVANCE_THRESHOLD", + "rag.relevance_threshold", + float(os.environ.get("RAG_RELEVANCE_THRESHOLD", "0.0")), ) - -ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = ( - os.environ.get("ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION", "True").lower() == "true" +ENABLE_RAG_HYBRID_SEARCH = WrappedConfig( + "ENABLE_RAG_HYBRID_SEARCH", + "rag.enable_hybrid_search", + os.environ.get("ENABLE_RAG_HYBRID_SEARCH", "").lower() == "true", ) -RAG_EMBEDDING_ENGINE = os.environ.get("RAG_EMBEDDING_ENGINE", "") - -PDF_EXTRACT_IMAGES = os.environ.get("PDF_EXTRACT_IMAGES", "False").lower() == "true" - -RAG_EMBEDDING_MODEL = os.environ.get( - "RAG_EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2" +ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = WrappedConfig( + "ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION", + "rag.enable_web_loader_ssl_verification", + os.environ.get("ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION", "True").lower() == "true", ) -log.info(f"Embedding model set: {RAG_EMBEDDING_MODEL}"), + +RAG_EMBEDDING_ENGINE = WrappedConfig( + "RAG_EMBEDDING_ENGINE", + "rag.embedding_engine", + os.environ.get("RAG_EMBEDDING_ENGINE", ""), +) + +PDF_EXTRACT_IMAGES = WrappedConfig( + "PDF_EXTRACT_IMAGES", + "rag.pdf_extract_images", + os.environ.get("PDF_EXTRACT_IMAGES", "False").lower() == "true", +) + +RAG_EMBEDDING_MODEL = WrappedConfig( + "RAG_EMBEDDING_MODEL", + "rag.embedding_model", + os.environ.get("RAG_EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2"), +) +log.info(f"Embedding model set: {RAG_EMBEDDING_MODEL.value}"), RAG_EMBEDDING_MODEL_AUTO_UPDATE = ( os.environ.get("RAG_EMBEDDING_MODEL_AUTO_UPDATE", "").lower() == "true" @@ -487,9 +598,13 @@ RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE = ( os.environ.get("RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true" ) -RAG_RERANKING_MODEL = os.environ.get("RAG_RERANKING_MODEL", "") -if not RAG_RERANKING_MODEL == "": - log.info(f"Reranking model set: {RAG_RERANKING_MODEL}"), +RAG_RERANKING_MODEL = WrappedConfig( + "RAG_RERANKING_MODEL", + "rag.reranking_model", + os.environ.get("RAG_RERANKING_MODEL", ""), +) +if RAG_RERANKING_MODEL.value != "": + log.info(f"Reranking model set: {RAG_RERANKING_MODEL.value}"), RAG_RERANKING_MODEL_AUTO_UPDATE = ( os.environ.get("RAG_RERANKING_MODEL_AUTO_UPDATE", "").lower() == "true" @@ -499,7 +614,6 @@ RAG_RERANKING_MODEL_TRUST_REMOTE_CODE = ( os.environ.get("RAG_RERANKING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true" ) - if CHROMA_HTTP_HOST != "": CHROMA_CLIENT = chromadb.HttpClient( host=CHROMA_HTTP_HOST, @@ -518,7 +632,6 @@ else: database=CHROMA_DATABASE, ) - # device type embedding models - "cpu" (default), "cuda" (nvidia gpu required) or "mps" (apple silicon) - choosing this right can lead to better performance USE_CUDA = os.environ.get("USE_CUDA_DOCKER", "false") @@ -527,9 +640,14 @@ if USE_CUDA.lower() == "true": else: DEVICE_TYPE = "cpu" - -CHUNK_SIZE = int(os.environ.get("CHUNK_SIZE", "1500")) -CHUNK_OVERLAP = int(os.environ.get("CHUNK_OVERLAP", "100")) +CHUNK_SIZE = WrappedConfig( + "CHUNK_SIZE", "rag.chunk_size", int(os.environ.get("CHUNK_SIZE", "1500")) +) +CHUNK_OVERLAP = WrappedConfig( + "CHUNK_OVERLAP", + "rag.chunk_overlap", + int(os.environ.get("CHUNK_OVERLAP", "100")), +) DEFAULT_RAG_TEMPLATE = """Use the following context as your learned knowledge, inside XML tags. @@ -545,16 +663,32 @@ And answer according to the language of the user's question. Given the context information, answer the query. Query: [query]""" -RAG_TEMPLATE = os.environ.get("RAG_TEMPLATE", DEFAULT_RAG_TEMPLATE) +RAG_TEMPLATE = WrappedConfig( + "RAG_TEMPLATE", + "rag.template", + os.environ.get("RAG_TEMPLATE", DEFAULT_RAG_TEMPLATE), +) -RAG_OPENAI_API_BASE_URL = os.getenv("RAG_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL) -RAG_OPENAI_API_KEY = os.getenv("RAG_OPENAI_API_KEY", OPENAI_API_KEY) +RAG_OPENAI_API_BASE_URL = WrappedConfig( + "RAG_OPENAI_API_BASE_URL", + "rag.openai_api_base_url", + os.getenv("RAG_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL), +) +RAG_OPENAI_API_KEY = WrappedConfig( + "RAG_OPENAI_API_KEY", + "rag.openai_api_key", + os.getenv("RAG_OPENAI_API_KEY", OPENAI_API_KEY), +) ENABLE_RAG_LOCAL_WEB_FETCH = ( os.getenv("ENABLE_RAG_LOCAL_WEB_FETCH", "False").lower() == "true" ) -YOUTUBE_LOADER_LANGUAGE = os.getenv("YOUTUBE_LOADER_LANGUAGE", "en").split(",") +YOUTUBE_LOADER_LANGUAGE = WrappedConfig( + "YOUTUBE_LOADER_LANGUAGE", + "rag.youtube_loader_language", + os.getenv("YOUTUBE_LOADER_LANGUAGE", "en").split(","), +) #################################### # Transcribe @@ -566,39 +700,82 @@ WHISPER_MODEL_AUTO_UPDATE = ( os.environ.get("WHISPER_MODEL_AUTO_UPDATE", "").lower() == "true" ) - #################################### # Images #################################### -IMAGE_GENERATION_ENGINE = os.getenv("IMAGE_GENERATION_ENGINE", "") - -ENABLE_IMAGE_GENERATION = ( - os.environ.get("ENABLE_IMAGE_GENERATION", "").lower() == "true" +IMAGE_GENERATION_ENGINE = WrappedConfig( + "IMAGE_GENERATION_ENGINE", + "image_generation.engine", + os.getenv("IMAGE_GENERATION_ENGINE", ""), ) -AUTOMATIC1111_BASE_URL = os.getenv("AUTOMATIC1111_BASE_URL", "") -COMFYUI_BASE_URL = os.getenv("COMFYUI_BASE_URL", "") - -IMAGES_OPENAI_API_BASE_URL = os.getenv( - "IMAGES_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL +ENABLE_IMAGE_GENERATION = WrappedConfig( + "ENABLE_IMAGE_GENERATION", + "image_generation.enable", + os.environ.get("ENABLE_IMAGE_GENERATION", "").lower() == "true", +) +AUTOMATIC1111_BASE_URL = WrappedConfig( + "AUTOMATIC1111_BASE_URL", + "image_generation.automatic1111.base_url", + os.getenv("AUTOMATIC1111_BASE_URL", ""), ) -IMAGES_OPENAI_API_KEY = os.getenv("IMAGES_OPENAI_API_KEY", OPENAI_API_KEY) -IMAGE_SIZE = os.getenv("IMAGE_SIZE", "512x512") +COMFYUI_BASE_URL = WrappedConfig( + "COMFYUI_BASE_URL", + "image_generation.comfyui.base_url", + os.getenv("COMFYUI_BASE_URL", ""), +) -IMAGE_STEPS = int(os.getenv("IMAGE_STEPS", 50)) +IMAGES_OPENAI_API_BASE_URL = WrappedConfig( + "IMAGES_OPENAI_API_BASE_URL", + "image_generation.openai.api_base_url", + os.getenv("IMAGES_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL), +) +IMAGES_OPENAI_API_KEY = WrappedConfig( + "IMAGES_OPENAI_API_KEY", + "image_generation.openai.api_key", + os.getenv("IMAGES_OPENAI_API_KEY", OPENAI_API_KEY), +) -IMAGE_GENERATION_MODEL = os.getenv("IMAGE_GENERATION_MODEL", "") +IMAGE_SIZE = WrappedConfig( + "IMAGE_SIZE", "image_generation.size", os.getenv("IMAGE_SIZE", "512x512") +) + +IMAGE_STEPS = WrappedConfig( + "IMAGE_STEPS", "image_generation.steps", int(os.getenv("IMAGE_STEPS", 50)) +) + +IMAGE_GENERATION_MODEL = WrappedConfig( + "IMAGE_GENERATION_MODEL", + "image_generation.model", + os.getenv("IMAGE_GENERATION_MODEL", ""), +) #################################### # Audio #################################### -AUDIO_OPENAI_API_BASE_URL = os.getenv("AUDIO_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL) -AUDIO_OPENAI_API_KEY = os.getenv("AUDIO_OPENAI_API_KEY", OPENAI_API_KEY) -AUDIO_OPENAI_API_MODEL = os.getenv("AUDIO_OPENAI_API_MODEL", "tts-1") -AUDIO_OPENAI_API_VOICE = os.getenv("AUDIO_OPENAI_API_VOICE", "alloy") +AUDIO_OPENAI_API_BASE_URL = WrappedConfig( + "AUDIO_OPENAI_API_BASE_URL", + "audio.openai.api_base_url", + os.getenv("AUDIO_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL), +) +AUDIO_OPENAI_API_KEY = WrappedConfig( + "AUDIO_OPENAI_API_KEY", + "audio.openai.api_key", + os.getenv("AUDIO_OPENAI_API_KEY", OPENAI_API_KEY), +) +AUDIO_OPENAI_API_MODEL = WrappedConfig( + "AUDIO_OPENAI_API_MODEL", + "audio.openai.api_model", + os.getenv("AUDIO_OPENAI_API_MODEL", "tts-1"), +) +AUDIO_OPENAI_API_VOICE = WrappedConfig( + "AUDIO_OPENAI_API_VOICE", + "audio.openai.api_voice", + os.getenv("AUDIO_OPENAI_API_VOICE", "alloy"), +) #################################### # LiteLLM @@ -612,7 +789,6 @@ if LITELLM_PROXY_PORT < 0 or LITELLM_PROXY_PORT > 65535: raise ValueError("Invalid port number for LITELLM_PROXY_PORT") LITELLM_PROXY_HOST = os.getenv("LITELLM_PROXY_HOST", "127.0.0.1") - #################################### # Database #################################### diff --git a/backend/main.py b/backend/main.py index 139819f7c..6f94a8dad 100644 --- a/backend/main.py +++ b/backend/main.py @@ -58,6 +58,8 @@ from config import ( SRC_LOG_LEVELS, WEBHOOK_URL, ENABLE_ADMIN_EXPORT, + config_get, + config_set, ) from constants import ERROR_MESSAGES @@ -243,9 +245,11 @@ async def get_app_config(): "version": VERSION, "auth": WEBUI_AUTH, "default_locale": default_locale, - "images": images_app.state.ENABLED, - "default_models": webui_app.state.DEFAULT_MODELS, - "default_prompt_suggestions": webui_app.state.DEFAULT_PROMPT_SUGGESTIONS, + "images": config_get(images_app.state.ENABLED), + "default_models": config_get(webui_app.state.DEFAULT_MODELS), + "default_prompt_suggestions": config_get( + webui_app.state.DEFAULT_PROMPT_SUGGESTIONS + ), "trusted_header_auth": bool(webui_app.state.AUTH_TRUSTED_EMAIL_HEADER), "admin_export_enabled": ENABLE_ADMIN_EXPORT, } @@ -254,8 +258,8 @@ async def get_app_config(): @app.get("/api/config/model/filter") async def get_model_filter_config(user=Depends(get_admin_user)): return { - "enabled": app.state.ENABLE_MODEL_FILTER, - "models": app.state.MODEL_FILTER_LIST, + "enabled": config_get(app.state.ENABLE_MODEL_FILTER), + "models": config_get(app.state.MODEL_FILTER_LIST), } @@ -268,28 +272,28 @@ class ModelFilterConfigForm(BaseModel): async def update_model_filter_config( form_data: ModelFilterConfigForm, user=Depends(get_admin_user) ): - app.state.ENABLE_MODEL_FILTER = form_data.enabled - app.state.MODEL_FILTER_LIST = form_data.models + config_set(app.state.ENABLE_MODEL_FILTER, form_data.enabled) + config_set(app.state.MODEL_FILTER_LIST, form_data.models) - ollama_app.state.ENABLE_MODEL_FILTER = app.state.ENABLE_MODEL_FILTER - ollama_app.state.MODEL_FILTER_LIST = app.state.MODEL_FILTER_LIST + ollama_app.state.ENABLE_MODEL_FILTER = config_get(app.state.ENABLE_MODEL_FILTER) + ollama_app.state.MODEL_FILTER_LIST = config_get(app.state.MODEL_FILTER_LIST) - openai_app.state.ENABLE_MODEL_FILTER = app.state.ENABLE_MODEL_FILTER - openai_app.state.MODEL_FILTER_LIST = app.state.MODEL_FILTER_LIST + openai_app.state.ENABLE_MODEL_FILTER = config_get(app.state.ENABLE_MODEL_FILTER) + openai_app.state.MODEL_FILTER_LIST = config_get(app.state.MODEL_FILTER_LIST) - litellm_app.state.ENABLE_MODEL_FILTER = app.state.ENABLE_MODEL_FILTER - litellm_app.state.MODEL_FILTER_LIST = app.state.MODEL_FILTER_LIST + litellm_app.state.ENABLE_MODEL_FILTER = config_get(app.state.ENABLE_MODEL_FILTER) + litellm_app.state.MODEL_FILTER_LIST = config_get(app.state.MODEL_FILTER_LIST) return { - "enabled": app.state.ENABLE_MODEL_FILTER, - "models": app.state.MODEL_FILTER_LIST, + "enabled": config_get(app.state.ENABLE_MODEL_FILTER), + "models": config_get(app.state.MODEL_FILTER_LIST), } @app.get("/api/webhook") async def get_webhook_url(user=Depends(get_admin_user)): return { - "url": app.state.WEBHOOK_URL, + "url": config_get(app.state.WEBHOOK_URL), } @@ -299,12 +303,12 @@ class UrlForm(BaseModel): @app.post("/api/webhook") async def update_webhook_url(form_data: UrlForm, user=Depends(get_admin_user)): - app.state.WEBHOOK_URL = form_data.url + config_set(app.state.WEBHOOK_URL, form_data.url) - webui_app.state.WEBHOOK_URL = app.state.WEBHOOK_URL + webui_app.state.WEBHOOK_URL = config_get(app.state.WEBHOOK_URL) return { - "url": app.state.WEBHOOK_URL, + "url": config_get(app.state.WEBHOOK_URL), } From f712c900193d90b249f80efc15b22494dd467a30 Mon Sep 17 00:00:00 2001 From: Jun Siang Cheah Date: Fri, 10 May 2024 14:18:39 +0800 Subject: [PATCH 2/6] feat: raise an exception if a WrappedConfig is used as a response --- backend/config.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/backend/config.py b/backend/config.py index 028e6caf0..9e7a9ef90 100644 --- a/backend/config.py +++ b/backend/config.py @@ -29,6 +29,7 @@ try: except ImportError: print("dotenv not installed, skipping...") + #################################### # LOGGING #################################### @@ -78,6 +79,7 @@ WEBUI_URL = os.environ.get("WEBUI_URL", "http://localhost:3000") WEBUI_FAVICON_URL = "https://openwebui.com/favicon.png" + #################################### # ENV (dev,test,prod) #################################### @@ -148,8 +150,10 @@ for version in soup.find_all("h2"): changelog_json[version_number] = version_data + CHANGELOG = changelog_json + #################################### # WEBUI_VERSION #################################### @@ -212,6 +216,19 @@ class WrappedConfig(Generic[T]): def __str__(self): return str(self.value) + @property + def __dict__(self): + raise TypeError( + "WrappedConfig object cannot be converted to dict, use config_get or .value instead." + ) + + def __getattribute__(self, item): + if item == "__dict__": + raise TypeError( + "WrappedConfig object cannot be converted to dict, use config_get or .value instead." + ) + return super().__getattribute__(item) + def save(self): # Don't save if the value is the same as the env value and the config value if self.env_value == self.value: @@ -297,6 +314,7 @@ if CUSTOM_NAME: log.exception(e) pass + #################################### # File Upload DIR #################################### @@ -304,6 +322,7 @@ if CUSTOM_NAME: UPLOAD_DIR = f"{DATA_DIR}/uploads" Path(UPLOAD_DIR).mkdir(parents=True, exist_ok=True) + #################################### # Cache DIR #################################### @@ -311,6 +330,7 @@ Path(UPLOAD_DIR).mkdir(parents=True, exist_ok=True) CACHE_DIR = f"{DATA_DIR}/cache" Path(CACHE_DIR).mkdir(parents=True, exist_ok=True) + #################################### # Docs DIR #################################### @@ -351,6 +371,7 @@ if not os.path.exists(LITELLM_CONFIG_PATH): create_config_file(LITELLM_CONFIG_PATH) log.info("Config file created successfully.") + #################################### # OLLAMA_BASE_URL #################################### @@ -381,6 +402,7 @@ if ENV == "prod": elif K8S_FLAG: OLLAMA_BASE_URL = "http://ollama-service.open-webui.svc.cluster.local:11434" + OLLAMA_BASE_URLS = os.environ.get("OLLAMA_BASE_URLS", "") OLLAMA_BASE_URLS = OLLAMA_BASE_URLS if OLLAMA_BASE_URLS != "" else OLLAMA_BASE_URL @@ -396,6 +418,7 @@ OLLAMA_BASE_URLS = WrappedConfig( OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "") OPENAI_API_BASE_URL = os.environ.get("OPENAI_API_BASE_URL", "") + if OPENAI_API_BASE_URL == "": OPENAI_API_BASE_URL = "https://api.openai.com/v1" @@ -614,6 +637,7 @@ RAG_RERANKING_MODEL_TRUST_REMOTE_CODE = ( os.environ.get("RAG_RERANKING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true" ) + if CHROMA_HTTP_HOST != "": CHROMA_CLIENT = chromadb.HttpClient( host=CHROMA_HTTP_HOST, @@ -632,6 +656,7 @@ else: database=CHROMA_DATABASE, ) + # device type embedding models - "cpu" (default), "cuda" (nvidia gpu required) or "mps" (apple silicon) - choosing this right can lead to better performance USE_CUDA = os.environ.get("USE_CUDA_DOCKER", "false") @@ -700,6 +725,7 @@ WHISPER_MODEL_AUTO_UPDATE = ( os.environ.get("WHISPER_MODEL_AUTO_UPDATE", "").lower() == "true" ) + #################################### # Images #################################### @@ -789,6 +815,7 @@ if LITELLM_PROXY_PORT < 0 or LITELLM_PROXY_PORT > 65535: raise ValueError("Invalid port number for LITELLM_PROXY_PORT") LITELLM_PROXY_HOST = os.getenv("LITELLM_PROXY_HOST", "127.0.0.1") + #################################### # Database #################################### From 298e6848b383bf5b3d4590ffa3a95f8dc8450f7b Mon Sep 17 00:00:00 2001 From: Jun Siang Cheah Date: Fri, 10 May 2024 15:03:24 +0800 Subject: [PATCH 3/6] feat: switch to config proxy, remove config_get/set --- backend/apps/audio/main.py | 46 +++-- backend/apps/images/main.py | 152 ++++++++-------- backend/apps/ollama/main.py | 60 +++---- backend/apps/openai/main.py | 52 +++--- backend/apps/rag/main.py | 258 +++++++++++++--------------- backend/apps/web/main.py | 22 +-- backend/apps/web/routers/auths.py | 38 ++-- backend/apps/web/routers/configs.py | 9 +- backend/apps/web/routers/users.py | 8 +- backend/config.py | 24 +-- backend/main.py | 50 +++--- 11 files changed, 340 insertions(+), 379 deletions(-) diff --git a/backend/apps/audio/main.py b/backend/apps/audio/main.py index c3dc6a2c4..0f65a551e 100644 --- a/backend/apps/audio/main.py +++ b/backend/apps/audio/main.py @@ -45,8 +45,7 @@ from config import ( AUDIO_OPENAI_API_KEY, AUDIO_OPENAI_API_MODEL, AUDIO_OPENAI_API_VOICE, - config_get, - config_set, + AppConfig, ) log = logging.getLogger(__name__) @@ -61,11 +60,11 @@ app.add_middleware( allow_headers=["*"], ) - -app.state.OPENAI_API_BASE_URL = AUDIO_OPENAI_API_BASE_URL -app.state.OPENAI_API_KEY = AUDIO_OPENAI_API_KEY -app.state.OPENAI_API_MODEL = AUDIO_OPENAI_API_MODEL -app.state.OPENAI_API_VOICE = AUDIO_OPENAI_API_VOICE +app.state.config = AppConfig() +app.state.config.OPENAI_API_BASE_URL = AUDIO_OPENAI_API_BASE_URL +app.state.config.OPENAI_API_KEY = AUDIO_OPENAI_API_KEY +app.state.config.OPENAI_API_MODEL = AUDIO_OPENAI_API_MODEL +app.state.config.OPENAI_API_VOICE = AUDIO_OPENAI_API_VOICE # setting device type for whisper model whisper_device_type = DEVICE_TYPE if DEVICE_TYPE and DEVICE_TYPE == "cuda" else "cpu" @@ -85,10 +84,10 @@ class OpenAIConfigUpdateForm(BaseModel): @app.get("/config") async def get_openai_config(user=Depends(get_admin_user)): return { - "OPENAI_API_BASE_URL": config_get(app.state.OPENAI_API_BASE_URL), - "OPENAI_API_KEY": config_get(app.state.OPENAI_API_KEY), - "OPENAI_API_MODEL": config_get(app.state.OPENAI_API_MODEL), - "OPENAI_API_VOICE": config_get(app.state.OPENAI_API_VOICE), + "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL, + "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY, + "OPENAI_API_MODEL": app.state.config.OPENAI_API_MODEL, + "OPENAI_API_VOICE": app.state.config.OPENAI_API_VOICE, } @@ -99,22 +98,17 @@ async def update_openai_config( if form_data.key == "": raise HTTPException(status_code=400, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND) - config_set(app.state.OPENAI_API_BASE_URL, form_data.url) - config_set(app.state.OPENAI_API_KEY, form_data.key) - config_set(app.state.OPENAI_API_MODEL, form_data.model) - config_set(app.state.OPENAI_API_VOICE, form_data.speaker) - - app.state.OPENAI_API_BASE_URL.save() - app.state.OPENAI_API_KEY.save() - app.state.OPENAI_API_MODEL.save() - app.state.OPENAI_API_VOICE.save() + app.state.config.OPENAI_API_BASE_URL = form_data.url + app.state.config.OPENAI_API_KEY = form_data.key + app.state.config.OPENAI_API_MODEL = form_data.model + app.state.config.OPENAI_API_VOICE = form_data.speaker return { "status": True, - "OPENAI_API_BASE_URL": config_get(app.state.OPENAI_API_BASE_URL), - "OPENAI_API_KEY": config_get(app.state.OPENAI_API_KEY), - "OPENAI_API_MODEL": config_get(app.state.OPENAI_API_MODEL), - "OPENAI_API_VOICE": config_get(app.state.OPENAI_API_VOICE), + "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL, + "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY, + "OPENAI_API_MODEL": app.state.config.OPENAI_API_MODEL, + "OPENAI_API_VOICE": app.state.config.OPENAI_API_VOICE, } @@ -131,13 +125,13 @@ async def speech(request: Request, user=Depends(get_verified_user)): return FileResponse(file_path) headers = {} - headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEY}" + headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEY}" headers["Content-Type"] = "application/json" r = None try: r = requests.post( - url=f"{app.state.OPENAI_API_BASE_URL}/audio/speech", + url=f"{app.state.config.OPENAI_API_BASE_URL}/audio/speech", data=body, headers=headers, stream=True, diff --git a/backend/apps/images/main.py b/backend/apps/images/main.py index 8ebfb0446..1c309439d 100644 --- a/backend/apps/images/main.py +++ b/backend/apps/images/main.py @@ -42,8 +42,7 @@ from config import ( IMAGE_GENERATION_MODEL, IMAGE_SIZE, IMAGE_STEPS, - config_get, - config_set, + AppConfig, ) @@ -62,28 +61,30 @@ app.add_middleware( allow_headers=["*"], ) -app.state.ENGINE = IMAGE_GENERATION_ENGINE -app.state.ENABLED = ENABLE_IMAGE_GENERATION +app.state.config = AppConfig() -app.state.OPENAI_API_BASE_URL = IMAGES_OPENAI_API_BASE_URL -app.state.OPENAI_API_KEY = IMAGES_OPENAI_API_KEY +app.state.config.ENGINE = IMAGE_GENERATION_ENGINE +app.state.config.ENABLED = ENABLE_IMAGE_GENERATION -app.state.MODEL = IMAGE_GENERATION_MODEL +app.state.config.OPENAI_API_BASE_URL = IMAGES_OPENAI_API_BASE_URL +app.state.config.OPENAI_API_KEY = IMAGES_OPENAI_API_KEY + +app.state.config.MODEL = IMAGE_GENERATION_MODEL -app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL -app.state.COMFYUI_BASE_URL = COMFYUI_BASE_URL +app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL +app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL -app.state.IMAGE_SIZE = IMAGE_SIZE -app.state.IMAGE_STEPS = IMAGE_STEPS +app.state.config.IMAGE_SIZE = IMAGE_SIZE +app.state.config.IMAGE_STEPS = IMAGE_STEPS @app.get("/config") async def get_config(request: Request, user=Depends(get_admin_user)): return { - "engine": config_get(app.state.ENGINE), - "enabled": config_get(app.state.ENABLED), + "engine": app.state.config.ENGINE, + "enabled": app.state.config.ENABLED, } @@ -94,11 +95,11 @@ class ConfigUpdateForm(BaseModel): @app.post("/config/update") async def update_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)): - config_set(app.state.ENGINE, form_data.engine) - config_set(app.state.ENABLED, form_data.enabled) + app.state.config.ENGINE = form_data.engine + app.state.config.ENABLED = form_data.enabled return { - "engine": config_get(app.state.ENGINE), - "enabled": config_get(app.state.ENABLED), + "engine": app.state.config.ENGINE, + "enabled": app.state.config.ENABLED, } @@ -110,8 +111,8 @@ class EngineUrlUpdateForm(BaseModel): @app.get("/url") async def get_engine_url(user=Depends(get_admin_user)): return { - "AUTOMATIC1111_BASE_URL": config_get(app.state.AUTOMATIC1111_BASE_URL), - "COMFYUI_BASE_URL": config_get(app.state.COMFYUI_BASE_URL), + "AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL, + "COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL, } @@ -121,29 +122,29 @@ async def update_engine_url( ): if form_data.AUTOMATIC1111_BASE_URL == None: - config_set(app.state.AUTOMATIC1111_BASE_URL, config_get(AUTOMATIC1111_BASE_URL)) + app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL else: url = form_data.AUTOMATIC1111_BASE_URL.strip("/") try: r = requests.head(url) - config_set(app.state.AUTOMATIC1111_BASE_URL, url) + app.state.config.AUTOMATIC1111_BASE_URL = url except Exception as e: raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) if form_data.COMFYUI_BASE_URL == None: - config_set(app.state.COMFYUI_BASE_URL, COMFYUI_BASE_URL) + app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL else: url = form_data.COMFYUI_BASE_URL.strip("/") try: r = requests.head(url) - config_set(app.state.COMFYUI_BASE_URL, url) + app.state.config.COMFYUI_BASE_URL = url except Exception as e: raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) return { - "AUTOMATIC1111_BASE_URL": config_get(app.state.AUTOMATIC1111_BASE_URL), - "COMFYUI_BASE_URL": config_get(app.state.COMFYUI_BASE_URL), + "AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL, + "COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL, "status": True, } @@ -156,8 +157,8 @@ class OpenAIConfigUpdateForm(BaseModel): @app.get("/openai/config") async def get_openai_config(user=Depends(get_admin_user)): return { - "OPENAI_API_BASE_URL": config_get(app.state.OPENAI_API_BASE_URL), - "OPENAI_API_KEY": config_get(app.state.OPENAI_API_KEY), + "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL, + "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY, } @@ -168,13 +169,13 @@ async def update_openai_config( if form_data.key == "": raise HTTPException(status_code=400, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND) - config_set(app.state.OPENAI_API_BASE_URL, form_data.url) - config_set(app.state.OPENAI_API_KEY, form_data.key) + app.state.config.OPENAI_API_BASE_URL = form_data.url + app.state.config.OPENAI_API_KEY = form_data.key return { "status": True, - "OPENAI_API_BASE_URL": config_get(app.state.OPENAI_API_BASE_URL), - "OPENAI_API_KEY": config_get(app.state.OPENAI_API_KEY), + "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL, + "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY, } @@ -184,7 +185,7 @@ class ImageSizeUpdateForm(BaseModel): @app.get("/size") async def get_image_size(user=Depends(get_admin_user)): - return {"IMAGE_SIZE": config_get(app.state.IMAGE_SIZE)} + return {"IMAGE_SIZE": app.state.config.IMAGE_SIZE} @app.post("/size/update") @@ -193,9 +194,9 @@ async def update_image_size( ): pattern = r"^\d+x\d+$" # Regular expression pattern if re.match(pattern, form_data.size): - config_set(app.state.IMAGE_SIZE, form_data.size) + app.state.config.IMAGE_SIZE = form_data.size return { - "IMAGE_SIZE": config_get(app.state.IMAGE_SIZE), + "IMAGE_SIZE": app.state.config.IMAGE_SIZE, "status": True, } else: @@ -211,7 +212,7 @@ class ImageStepsUpdateForm(BaseModel): @app.get("/steps") async def get_image_size(user=Depends(get_admin_user)): - return {"IMAGE_STEPS": config_get(app.state.IMAGE_STEPS)} + return {"IMAGE_STEPS": app.state.config.IMAGE_STEPS} @app.post("/steps/update") @@ -219,9 +220,9 @@ async def update_image_size( form_data: ImageStepsUpdateForm, user=Depends(get_admin_user) ): if form_data.steps >= 0: - config_set(app.state.IMAGE_STEPS, form_data.steps) + app.state.config.IMAGE_STEPS = form_data.steps return { - "IMAGE_STEPS": config_get(app.state.IMAGE_STEPS), + "IMAGE_STEPS": app.state.config.IMAGE_STEPS, "status": True, } else: @@ -234,14 +235,14 @@ async def update_image_size( @app.get("/models") def get_models(user=Depends(get_current_user)): try: - if app.state.ENGINE == "openai": + if app.state.config.ENGINE == "openai": return [ {"id": "dall-e-2", "name": "DALL·E 2"}, {"id": "dall-e-3", "name": "DALL·E 3"}, ] - elif app.state.ENGINE == "comfyui": + elif app.state.config.ENGINE == "comfyui": - r = requests.get(url=f"{app.state.COMFYUI_BASE_URL}/object_info") + r = requests.get(url=f"{app.state.config.COMFYUI_BASE_URL}/object_info") info = r.json() return list( @@ -253,7 +254,7 @@ def get_models(user=Depends(get_current_user)): else: r = requests.get( - url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models" + url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models" ) models = r.json() return list( @@ -263,33 +264,29 @@ def get_models(user=Depends(get_current_user)): ) ) except Exception as e: - app.state.ENABLED = False + app.state.config.ENABLED = False raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) @app.get("/models/default") async def get_default_model(user=Depends(get_admin_user)): try: - if app.state.ENGINE == "openai": + if app.state.config.ENGINE == "openai": return { "model": ( - config_get(app.state.MODEL) - if config_get(app.state.MODEL) - else "dall-e-2" - ) - } - elif app.state.ENGINE == "comfyui": - return { - "model": ( - config_get(app.state.MODEL) if config_get(app.state.MODEL) else "" + app.state.config.MODEL if app.state.config.MODEL else "dall-e-2" ) } + elif app.state.config.ENGINE == "comfyui": + return {"model": (app.state.config.MODEL if app.state.config.MODEL else "")} else: - r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options") + r = requests.get( + url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options" + ) options = r.json() return {"model": options["sd_model_checkpoint"]} except Exception as e: - config_set(app.state.ENABLED, False) + app.state.config.ENABLED = False raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) @@ -298,17 +295,20 @@ class UpdateModelForm(BaseModel): def set_model_handler(model: str): - if app.state.ENGINE in ["openai", "comfyui"]: - config_set(app.state.MODEL, model) - return config_get(app.state.MODEL) + if app.state.config.ENGINE in ["openai", "comfyui"]: + app.state.config.MODEL = model + return app.state.config.MODEL else: - r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options") + r = requests.get( + url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options" + ) options = r.json() if model != options["sd_model_checkpoint"]: options["sd_model_checkpoint"] = model r = requests.post( - url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", json=options + url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", + json=options, ) return options @@ -397,30 +397,32 @@ def generate_image( user=Depends(get_current_user), ): - width, height = tuple(map(int, config_get(app.state.IMAGE_SIZE).split("x"))) + width, height = tuple(map(int, app.state.config.IMAGE_SIZE).split("x")) r = None try: - if app.state.ENGINE == "openai": + if app.state.config.ENGINE == "openai": headers = {} - headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEY}" + headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEY}" headers["Content-Type"] = "application/json" data = { - "model": app.state.MODEL if app.state.MODEL != "" else "dall-e-2", + "model": ( + app.state.config.MODEL + if app.state.config.MODEL != "" + else "dall-e-2" + ), "prompt": form_data.prompt, "n": form_data.n, "size": ( - form_data.size - if form_data.size - else config_get(app.state.IMAGE_SIZE) + form_data.size if form_data.size else app.state.config.IMAGE_SIZE ), "response_format": "b64_json", } r = requests.post( - url=f"{app.state.OPENAI_API_BASE_URL}/images/generations", + url=f"{app.state.config.OPENAI_API_BASE_URL}/images/generations", json=data, headers=headers, ) @@ -440,7 +442,7 @@ def generate_image( return images - elif app.state.ENGINE == "comfyui": + elif app.state.config.ENGINE == "comfyui": data = { "prompt": form_data.prompt, @@ -449,8 +451,8 @@ def generate_image( "n": form_data.n, } - if config_get(app.state.IMAGE_STEPS) is not None: - data["steps"] = config_get(app.state.IMAGE_STEPS) + if app.state.config.IMAGE_STEPS is not None: + data["steps"] = app.state.config.IMAGE_STEPS if form_data.negative_prompt is not None: data["negative_prompt"] = form_data.negative_prompt @@ -458,10 +460,10 @@ def generate_image( data = ImageGenerationPayload(**data) res = comfyui_generate_image( - config_get(app.state.MODEL), + app.state.config.MODEL, data, user.id, - config_get(app.state.COMFYUI_BASE_URL), + app.state.config.COMFYUI_BASE_URL, ) log.debug(f"res: {res}") @@ -488,14 +490,14 @@ def generate_image( "height": height, } - if config_get(app.state.IMAGE_STEPS) is not None: - data["steps"] = config_get(app.state.IMAGE_STEPS) + if app.state.config.IMAGE_STEPS is not None: + data["steps"] = app.state.config.IMAGE_STEPS if form_data.negative_prompt is not None: data["negative_prompt"] = form_data.negative_prompt r = requests.post( - url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img", + url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img", json=data, ) diff --git a/backend/apps/ollama/main.py b/backend/apps/ollama/main.py index 7dfadbb0c..cb80eeed2 100644 --- a/backend/apps/ollama/main.py +++ b/backend/apps/ollama/main.py @@ -46,8 +46,7 @@ from config import ( ENABLE_MODEL_FILTER, MODEL_FILTER_LIST, UPLOAD_DIR, - config_set, - config_get, + AppConfig, ) from utils.misc import calculate_sha256 @@ -63,11 +62,12 @@ app.add_middleware( allow_headers=["*"], ) +app.state.config = AppConfig() app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST -app.state.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS +app.state.config.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS app.state.MODELS = {} @@ -98,7 +98,7 @@ async def get_status(): @app.get("/urls") async def get_ollama_api_urls(user=Depends(get_admin_user)): - return {"OLLAMA_BASE_URLS": config_get(app.state.OLLAMA_BASE_URLS)} + return {"OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS} class UrlUpdateForm(BaseModel): @@ -107,10 +107,10 @@ class UrlUpdateForm(BaseModel): @app.post("/urls/update") async def update_ollama_api_url(form_data: UrlUpdateForm, user=Depends(get_admin_user)): - config_set(app.state.OLLAMA_BASE_URLS, form_data.urls) + app.state.config.OLLAMA_BASE_URLS = form_data.urls - log.info(f"app.state.OLLAMA_BASE_URLS: {app.state.OLLAMA_BASE_URLS}") - return {"OLLAMA_BASE_URLS": config_get(app.state.OLLAMA_BASE_URLS)} + log.info(f"app.state.config.OLLAMA_BASE_URLS: {app.state.config.OLLAMA_BASE_URLS}") + return {"OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS} @app.get("/cancel/{request_id}") @@ -155,9 +155,7 @@ def merge_models_lists(model_lists): async def get_all_models(): log.info("get_all_models()") - tasks = [ - fetch_url(f"{url}/api/tags") for url in config_get(app.state.OLLAMA_BASE_URLS) - ] + tasks = [fetch_url(f"{url}/api/tags") for url in app.state.config.OLLAMA_BASE_URLS] responses = await asyncio.gather(*tasks) models = { @@ -183,15 +181,14 @@ async def get_ollama_tags( if user.role == "user": models["models"] = list( filter( - lambda model: model["name"] - in config_get(app.state.MODEL_FILTER_LIST), + lambda model: model["name"] in app.state.MODEL_FILTER_LIST, models["models"], ) ) return models return models else: - url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx] + url = app.state.config.OLLAMA_BASE_URLS[url_idx] try: r = requests.request(method="GET", url=f"{url}/api/tags") r.raise_for_status() @@ -222,8 +219,7 @@ async def get_ollama_versions(url_idx: Optional[int] = None): # returns lowest version tasks = [ - fetch_url(f"{url}/api/version") - for url in config_get(app.state.OLLAMA_BASE_URLS) + fetch_url(f"{url}/api/version") for url in app.state.config.OLLAMA_BASE_URLS ] responses = await asyncio.gather(*tasks) responses = list(filter(lambda x: x is not None, responses)) @@ -243,7 +239,7 @@ async def get_ollama_versions(url_idx: Optional[int] = None): detail=ERROR_MESSAGES.OLLAMA_NOT_FOUND, ) else: - url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx] + url = app.state.config.OLLAMA_BASE_URLS[url_idx] try: r = requests.request(method="GET", url=f"{url}/api/version") r.raise_for_status() @@ -275,7 +271,7 @@ class ModelNameForm(BaseModel): async def pull_model( form_data: ModelNameForm, url_idx: int = 0, user=Depends(get_admin_user) ): - url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx] + url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") r = None @@ -363,7 +359,7 @@ async def push_model( detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name), ) - url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx] + url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.debug(f"url: {url}") r = None @@ -425,7 +421,7 @@ async def create_model( form_data: CreateModelForm, url_idx: int = 0, user=Depends(get_admin_user) ): log.debug(f"form_data: {form_data}") - url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx] + url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") r = None @@ -498,7 +494,7 @@ async def copy_model( detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.source), ) - url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx] + url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") try: @@ -545,7 +541,7 @@ async def delete_model( detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name), ) - url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx] + url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") try: @@ -585,7 +581,7 @@ async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_us ) url_idx = random.choice(app.state.MODELS[form_data.name]["urls"]) - url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx] + url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") try: @@ -642,7 +638,7 @@ async def generate_embeddings( detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), ) - url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx] + url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") try: @@ -692,7 +688,7 @@ def generate_ollama_embeddings( detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), ) - url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx] + url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") try: @@ -761,7 +757,7 @@ async def generate_completion( detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), ) - url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx] + url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") r = None @@ -864,7 +860,7 @@ async def generate_chat_completion( detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), ) - url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx] + url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") r = None @@ -973,7 +969,7 @@ async def generate_openai_chat_completion( detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), ) - url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx] + url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") r = None @@ -1072,7 +1068,7 @@ async def get_openai_models( } else: - url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx] + url = app.state.config.OLLAMA_BASE_URLS[url_idx] try: r = requests.request(method="GET", url=f"{url}/api/tags") r.raise_for_status() @@ -1206,7 +1202,7 @@ async def download_model( if url_idx == None: url_idx = 0 - url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx] + url = app.state.config.OLLAMA_BASE_URLS[url_idx] file_name = parse_huggingface_url(form_data.url) @@ -1225,7 +1221,7 @@ async def download_model( def upload_model(file: UploadFile = File(...), url_idx: Optional[int] = None): if url_idx == None: url_idx = 0 - ollama_url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx] + ollama_url = app.state.config.OLLAMA_BASE_URLS[url_idx] file_path = f"{UPLOAD_DIR}/{file.filename}" @@ -1290,7 +1286,7 @@ def upload_model(file: UploadFile = File(...), url_idx: Optional[int] = None): # async def upload_model(file: UploadFile = File(), url_idx: Optional[int] = None): # if url_idx == None: # url_idx = 0 -# url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx] +# url = app.state.config.OLLAMA_BASE_URLS[url_idx] # file_location = os.path.join(UPLOAD_DIR, file.filename) # total_size = file.size @@ -1327,7 +1323,7 @@ def upload_model(file: UploadFile = File(...), url_idx: Optional[int] = None): async def deprecated_proxy( path: str, request: Request, user=Depends(get_verified_user) ): - url = config_get(app.state.OLLAMA_BASE_URLS)[0] + url = app.state.config.OLLAMA_BASE_URLS[0] target_url = f"{url}/{path}" body = await request.body() diff --git a/backend/apps/openai/main.py b/backend/apps/openai/main.py index 36fed104c..5112ebb62 100644 --- a/backend/apps/openai/main.py +++ b/backend/apps/openai/main.py @@ -26,8 +26,7 @@ from config import ( CACHE_DIR, ENABLE_MODEL_FILTER, MODEL_FILTER_LIST, - config_set, - config_get, + AppConfig, ) from typing import List, Optional @@ -47,11 +46,13 @@ app.add_middleware( allow_headers=["*"], ) +app.state.config = AppConfig() + app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST -app.state.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS -app.state.OPENAI_API_KEYS = OPENAI_API_KEYS +app.state.config.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS +app.state.config.OPENAI_API_KEYS = OPENAI_API_KEYS app.state.MODELS = {} @@ -77,34 +78,32 @@ class KeysUpdateForm(BaseModel): @app.get("/urls") async def get_openai_urls(user=Depends(get_admin_user)): - return {"OPENAI_API_BASE_URLS": config_get(app.state.OPENAI_API_BASE_URLS)} + return {"OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS} @app.post("/urls/update") async def update_openai_urls(form_data: UrlsUpdateForm, user=Depends(get_admin_user)): await get_all_models() - config_set(app.state.OPENAI_API_BASE_URLS, form_data.urls) - return {"OPENAI_API_BASE_URLS": config_get(app.state.OPENAI_API_BASE_URLS)} + app.state.config.OPENAI_API_BASE_URLS = form_data.urls + return {"OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS} @app.get("/keys") async def get_openai_keys(user=Depends(get_admin_user)): - return {"OPENAI_API_KEYS": config_get(app.state.OPENAI_API_KEYS)} + return {"OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS} @app.post("/keys/update") async def update_openai_key(form_data: KeysUpdateForm, user=Depends(get_admin_user)): - config_set(app.state.OPENAI_API_KEYS, form_data.keys) - return {"OPENAI_API_KEYS": config_get(app.state.OPENAI_API_KEYS)} + app.state.config.OPENAI_API_KEYS = form_data.keys + return {"OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS} @app.post("/audio/speech") async def speech(request: Request, user=Depends(get_verified_user)): idx = None try: - idx = config_get(app.state.OPENAI_API_BASE_URLS).index( - "https://api.openai.com/v1" - ) + idx = app.state.config.OPENAI_API_BASE_URLS.index("https://api.openai.com/v1") body = await request.body() name = hashlib.sha256(body).hexdigest() @@ -118,15 +117,13 @@ async def speech(request: Request, user=Depends(get_verified_user)): return FileResponse(file_path) headers = {} - headers["Authorization"] = ( - f"Bearer {config_get(app.state.OPENAI_API_KEYS)[idx]}" - ) + headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEYS[idx]}" headers["Content-Type"] = "application/json" r = None try: r = requests.post( - url=f"{config_get(app.state.OPENAI_API_BASE_URLS)[idx]}/audio/speech", + url=f"{app.state.config.OPENAI_API_BASE_URLS[idx]}/audio/speech", data=body, headers=headers, stream=True, @@ -187,7 +184,7 @@ def merge_models_lists(model_lists): {**model, "urlIdx": idx} for model in models if "api.openai.com" - not in config_get(app.state.OPENAI_API_BASE_URLS)[idx] + not in app.state.config.OPENAI_API_BASE_URLS[idx] or "gpt" in model["id"] ] ) @@ -199,14 +196,14 @@ async def get_all_models(): log.info("get_all_models()") if ( - len(config_get(app.state.OPENAI_API_KEYS)) == 1 - and config_get(app.state.OPENAI_API_KEYS)[0] == "" + len(app.state.config.OPENAI_API_KEYS) == 1 + and app.state.config.OPENAI_API_KEYS[0] == "" ): models = {"data": []} else: tasks = [ - fetch_url(f"{url}/models", config_get(app.state.OPENAI_API_KEYS)[idx]) - for idx, url in enumerate(config_get(app.state.OPENAI_API_BASE_URLS)) + fetch_url(f"{url}/models", app.state.config.OPENAI_API_KEYS[idx]) + for idx, url in enumerate(app.state.config.OPENAI_API_BASE_URLS) ] responses = await asyncio.gather(*tasks) @@ -238,19 +235,18 @@ async def get_all_models(): async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_user)): if url_idx == None: models = await get_all_models() - if config_get(app.state.ENABLE_MODEL_FILTER): + if app.state.ENABLE_MODEL_FILTER: if user.role == "user": models["data"] = list( filter( - lambda model: model["id"] - in config_get(app.state.MODEL_FILTER_LIST), + lambda model: model["id"] in app.state.MODEL_FILTER_LIST, models["data"], ) ) return models return models else: - url = config_get(app.state.OPENAI_API_BASE_URLS)[url_idx] + url = app.state.config.OPENAI_API_BASE_URLS[url_idx] r = None @@ -314,8 +310,8 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)): except json.JSONDecodeError as e: log.error("Error loading request body into a dictionary:", e) - url = config_get(app.state.OPENAI_API_BASE_URLS)[idx] - key = config_get(app.state.OPENAI_API_KEYS)[idx] + url = app.state.config.OPENAI_API_BASE_URLS[idx] + key = app.state.config.OPENAI_API_KEYS[idx] target_url = f"{url}/{path}" diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index f05447a66..d2c3964ae 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -93,8 +93,7 @@ from config import ( RAG_TEMPLATE, ENABLE_RAG_LOCAL_WEB_FETCH, YOUTUBE_LOADER_LANGUAGE, - config_set, - config_get, + AppConfig, ) from constants import ERROR_MESSAGES @@ -104,30 +103,32 @@ log.setLevel(SRC_LOG_LEVELS["RAG"]) app = FastAPI() -app.state.TOP_K = RAG_TOP_K -app.state.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD +app.state.config = AppConfig() -app.state.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH -app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = ( +app.state.config.TOP_K = RAG_TOP_K +app.state.config.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD + +app.state.config.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH +app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = ( ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION ) -app.state.CHUNK_SIZE = CHUNK_SIZE -app.state.CHUNK_OVERLAP = CHUNK_OVERLAP +app.state.config.CHUNK_SIZE = CHUNK_SIZE +app.state.config.CHUNK_OVERLAP = CHUNK_OVERLAP -app.state.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE -app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL -app.state.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL -app.state.RAG_TEMPLATE = RAG_TEMPLATE +app.state.config.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE +app.state.config.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL +app.state.config.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL +app.state.config.RAG_TEMPLATE = RAG_TEMPLATE -app.state.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL -app.state.OPENAI_API_KEY = RAG_OPENAI_API_KEY +app.state.config.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL +app.state.config.OPENAI_API_KEY = RAG_OPENAI_API_KEY -app.state.PDF_EXTRACT_IMAGES = PDF_EXTRACT_IMAGES +app.state.config.PDF_EXTRACT_IMAGES = PDF_EXTRACT_IMAGES -app.state.YOUTUBE_LOADER_LANGUAGE = YOUTUBE_LOADER_LANGUAGE +app.state.config.YOUTUBE_LOADER_LANGUAGE = YOUTUBE_LOADER_LANGUAGE app.state.YOUTUBE_LOADER_TRANSLATION = None @@ -135,7 +136,7 @@ def update_embedding_model( embedding_model: str, update_model: bool = False, ): - if embedding_model and config_get(app.state.RAG_EMBEDDING_ENGINE) == "": + if embedding_model and app.state.config.RAG_EMBEDDING_ENGINE == "": app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer( get_model_path(embedding_model, update_model), device=DEVICE_TYPE, @@ -160,22 +161,22 @@ def update_reranking_model( update_embedding_model( - config_get(app.state.RAG_EMBEDDING_MODEL), + app.state.config.RAG_EMBEDDING_MODEL, RAG_EMBEDDING_MODEL_AUTO_UPDATE, ) update_reranking_model( - config_get(app.state.RAG_RERANKING_MODEL), + app.state.config.RAG_RERANKING_MODEL, RAG_RERANKING_MODEL_AUTO_UPDATE, ) app.state.EMBEDDING_FUNCTION = get_embedding_function( - config_get(app.state.RAG_EMBEDDING_ENGINE), - config_get(app.state.RAG_EMBEDDING_MODEL), + app.state.config.RAG_EMBEDDING_ENGINE, + app.state.config.RAG_EMBEDDING_MODEL, app.state.sentence_transformer_ef, - config_get(app.state.OPENAI_API_KEY), - config_get(app.state.OPENAI_API_BASE_URL), + app.state.config.OPENAI_API_KEY, + app.state.config.OPENAI_API_BASE_URL, ) origins = ["*"] @@ -202,12 +203,12 @@ class UrlForm(CollectionNameForm): async def get_status(): return { "status": True, - "chunk_size": config_get(app.state.CHUNK_SIZE), - "chunk_overlap": config_get(app.state.CHUNK_OVERLAP), - "template": config_get(app.state.RAG_TEMPLATE), - "embedding_engine": config_get(app.state.RAG_EMBEDDING_ENGINE), - "embedding_model": config_get(app.state.RAG_EMBEDDING_MODEL), - "reranking_model": config_get(app.state.RAG_RERANKING_MODEL), + "chunk_size": app.state.config.CHUNK_SIZE, + "chunk_overlap": app.state.config.CHUNK_OVERLAP, + "template": app.state.config.RAG_TEMPLATE, + "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE, + "embedding_model": app.state.config.RAG_EMBEDDING_MODEL, + "reranking_model": app.state.config.RAG_RERANKING_MODEL, } @@ -215,11 +216,11 @@ async def get_status(): async def get_embedding_config(user=Depends(get_admin_user)): return { "status": True, - "embedding_engine": config_get(app.state.RAG_EMBEDDING_ENGINE), - "embedding_model": config_get(app.state.RAG_EMBEDDING_MODEL), + "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE, + "embedding_model": app.state.config.RAG_EMBEDDING_MODEL, "openai_config": { - "url": config_get(app.state.OPENAI_API_BASE_URL), - "key": config_get(app.state.OPENAI_API_KEY), + "url": app.state.config.OPENAI_API_BASE_URL, + "key": app.state.config.OPENAI_API_KEY, }, } @@ -228,7 +229,7 @@ async def get_embedding_config(user=Depends(get_admin_user)): async def get_reraanking_config(user=Depends(get_admin_user)): return { "status": True, - "reranking_model": config_get(app.state.RAG_RERANKING_MODEL), + "reranking_model": app.state.config.RAG_RERANKING_MODEL, } @@ -248,34 +249,34 @@ async def update_embedding_config( form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user) ): log.info( - f"Updating embedding model: {app.state.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}" + f"Updating embedding model: {app.state.config.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}" ) try: - config_set(app.state.RAG_EMBEDDING_ENGINE, form_data.embedding_engine) - config_set(app.state.RAG_EMBEDDING_MODEL, form_data.embedding_model) + app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine + app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model - if config_get(app.state.RAG_EMBEDDING_ENGINE) in ["ollama", "openai"]: + if app.state.config.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]: if form_data.openai_config != None: - config_set(app.state.OPENAI_API_BASE_URL, form_data.openai_config.url) - config_set(app.state.OPENAI_API_KEY, form_data.openai_config.key) + app.state.config.OPENAI_API_BASE_URL = form_data.openai_config.url + app.state.config.OPENAI_API_KEY = form_data.openai_config.key - update_embedding_model(config_get(app.state.RAG_EMBEDDING_MODEL), True) + update_embedding_model(app.state.config.RAG_EMBEDDING_MODEL), True app.state.EMBEDDING_FUNCTION = get_embedding_function( - config_get(app.state.RAG_EMBEDDING_ENGINE), - config_get(app.state.RAG_EMBEDDING_MODEL), + app.state.config.RAG_EMBEDDING_ENGINE, + app.state.config.RAG_EMBEDDING_MODEL, app.state.sentence_transformer_ef, - config_get(app.state.OPENAI_API_KEY), - config_get(app.state.OPENAI_API_BASE_URL), + app.state.config.OPENAI_API_KEY, + app.state.config.OPENAI_API_BASE_URL, ) return { "status": True, - "embedding_engine": config_get(app.state.RAG_EMBEDDING_ENGINE), - "embedding_model": config_get(app.state.RAG_EMBEDDING_MODEL), + "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE, + "embedding_model": app.state.config.RAG_EMBEDDING_MODEL, "openai_config": { - "url": config_get(app.state.OPENAI_API_BASE_URL), - "key": config_get(app.state.OPENAI_API_KEY), + "url": app.state.config.OPENAI_API_BASE_URL, + "key": app.state.config.OPENAI_API_KEY, }, } except Exception as e: @@ -295,16 +296,16 @@ async def update_reranking_config( form_data: RerankingModelUpdateForm, user=Depends(get_admin_user) ): log.info( - f"Updating reranking model: {app.state.RAG_RERANKING_MODEL} to {form_data.reranking_model}" + f"Updating reranking model: {app.state.config.RAG_RERANKING_MODEL} to {form_data.reranking_model}" ) try: - config_set(app.state.RAG_RERANKING_MODEL, form_data.reranking_model) + app.state.config.RAG_RERANKING_MODEL = form_data.reranking_model - update_reranking_model(config_get(app.state.RAG_RERANKING_MODEL), True) + update_reranking_model(app.state.config.RAG_RERANKING_MODEL), True return { "status": True, - "reranking_model": config_get(app.state.RAG_RERANKING_MODEL), + "reranking_model": app.state.config.RAG_RERANKING_MODEL, } except Exception as e: log.exception(f"Problem updating reranking model: {e}") @@ -318,16 +319,14 @@ async def update_reranking_config( async def get_rag_config(user=Depends(get_admin_user)): return { "status": True, - "pdf_extract_images": config_get(app.state.PDF_EXTRACT_IMAGES), + "pdf_extract_images": app.state.config.PDF_EXTRACT_IMAGES, "chunk": { - "chunk_size": config_get(app.state.CHUNK_SIZE), - "chunk_overlap": config_get(app.state.CHUNK_OVERLAP), + "chunk_size": app.state.config.CHUNK_SIZE, + "chunk_overlap": app.state.config.CHUNK_OVERLAP, }, - "web_loader_ssl_verification": config_get( - app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION - ), + "web_loader_ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, "youtube": { - "language": config_get(app.state.YOUTUBE_LOADER_LANGUAGE), + "language": app.state.config.YOUTUBE_LOADER_LANGUAGE, "translation": app.state.YOUTUBE_LOADER_TRANSLATION, }, } @@ -352,49 +351,34 @@ class ConfigUpdateForm(BaseModel): @app.post("/config/update") async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)): - config_set( - app.state.PDF_EXTRACT_IMAGES, - ( - form_data.pdf_extract_images - if form_data.pdf_extract_images is not None - else config_get(app.state.PDF_EXTRACT_IMAGES) - ), + app.state.config.PDF_EXTRACT_IMAGES = ( + form_data.pdf_extract_images + if form_data.pdf_extract_images is not None + else app.state.config.PDF_EXTRACT_IMAGES ) - config_set( - app.state.CHUNK_SIZE, - ( - form_data.chunk.chunk_size - if form_data.chunk is not None - else config_get(app.state.CHUNK_SIZE) - ), + app.state.config.CHUNK_SIZE = ( + form_data.chunk.chunk_size + if form_data.chunk is not None + else app.state.config.CHUNK_SIZE ) - config_set( - app.state.CHUNK_OVERLAP, - ( - form_data.chunk.chunk_overlap - if form_data.chunk is not None - else config_get(app.state.CHUNK_OVERLAP) - ), + app.state.config.CHUNK_OVERLAP = ( + form_data.chunk.chunk_overlap + if form_data.chunk is not None + else app.state.config.CHUNK_OVERLAP ) - config_set( - app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, - ( - form_data.web_loader_ssl_verification - if form_data.web_loader_ssl_verification != None - else config_get(app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION) - ), + app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = ( + form_data.web_loader_ssl_verification + if form_data.web_loader_ssl_verification != None + else app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION ) - config_set( - app.state.YOUTUBE_LOADER_LANGUAGE, - ( - form_data.youtube.language - if form_data.youtube is not None - else config_get(app.state.YOUTUBE_LOADER_LANGUAGE) - ), + app.state.config.YOUTUBE_LOADER_LANGUAGE = ( + form_data.youtube.language + if form_data.youtube is not None + else app.state.config.YOUTUBE_LOADER_LANGUAGE ) app.state.YOUTUBE_LOADER_TRANSLATION = ( @@ -405,16 +389,14 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_ return { "status": True, - "pdf_extract_images": config_get(app.state.PDF_EXTRACT_IMAGES), + "pdf_extract_images": app.state.config.PDF_EXTRACT_IMAGES, "chunk": { - "chunk_size": config_get(app.state.CHUNK_SIZE), - "chunk_overlap": config_get(app.state.CHUNK_OVERLAP), + "chunk_size": app.state.config.CHUNK_SIZE, + "chunk_overlap": app.state.config.CHUNK_OVERLAP, }, - "web_loader_ssl_verification": config_get( - app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION - ), + "web_loader_ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, "youtube": { - "language": config_get(app.state.YOUTUBE_LOADER_LANGUAGE), + "language": app.state.config.YOUTUBE_LOADER_LANGUAGE, "translation": app.state.YOUTUBE_LOADER_TRANSLATION, }, } @@ -424,7 +406,7 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_ async def get_rag_template(user=Depends(get_current_user)): return { "status": True, - "template": config_get(app.state.RAG_TEMPLATE), + "template": app.state.config.RAG_TEMPLATE, } @@ -432,10 +414,10 @@ async def get_rag_template(user=Depends(get_current_user)): async def get_query_settings(user=Depends(get_admin_user)): return { "status": True, - "template": config_get(app.state.RAG_TEMPLATE), - "k": config_get(app.state.TOP_K), - "r": config_get(app.state.RELEVANCE_THRESHOLD), - "hybrid": config_get(app.state.ENABLE_RAG_HYBRID_SEARCH), + "template": app.state.config.RAG_TEMPLATE, + "k": app.state.config.TOP_K, + "r": app.state.config.RELEVANCE_THRESHOLD, + "hybrid": app.state.config.ENABLE_RAG_HYBRID_SEARCH, } @@ -450,22 +432,20 @@ class QuerySettingsForm(BaseModel): async def update_query_settings( form_data: QuerySettingsForm, user=Depends(get_admin_user) ): - config_set( - app.state.RAG_TEMPLATE, + app.state.config.RAG_TEMPLATE = ( form_data.template if form_data.template else RAG_TEMPLATE, ) - config_set(app.state.TOP_K, form_data.k if form_data.k else 4) - config_set(app.state.RELEVANCE_THRESHOLD, form_data.r if form_data.r else 0.0) - config_set( - app.state.ENABLE_RAG_HYBRID_SEARCH, + app.state.config.TOP_K = form_data.k if form_data.k else 4 + app.state.config.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0 + app.state.config.ENABLE_RAG_HYBRID_SEARCH = ( form_data.hybrid if form_data.hybrid else False, ) return { "status": True, - "template": config_get(app.state.RAG_TEMPLATE), - "k": config_get(app.state.TOP_K), - "r": config_get(app.state.RELEVANCE_THRESHOLD), - "hybrid": config_get(app.state.ENABLE_RAG_HYBRID_SEARCH), + "template": app.state.config.RAG_TEMPLATE, + "k": app.state.config.TOP_K, + "r": app.state.config.RELEVANCE_THRESHOLD, + "hybrid": app.state.config.ENABLE_RAG_HYBRID_SEARCH, } @@ -483,17 +463,15 @@ def query_doc_handler( user=Depends(get_current_user), ): try: - if config_get(app.state.ENABLE_RAG_HYBRID_SEARCH): + if app.state.config.ENABLE_RAG_HYBRID_SEARCH: return query_doc_with_hybrid_search( collection_name=form_data.collection_name, query=form_data.query, embedding_function=app.state.EMBEDDING_FUNCTION, - k=form_data.k if form_data.k else config_get(app.state.TOP_K), + k=form_data.k if form_data.k else app.state.config.TOP_K, reranking_function=app.state.sentence_transformer_rf, r=( - form_data.r - if form_data.r - else config_get(app.state.RELEVANCE_THRESHOLD) + form_data.r if form_data.r else app.state.config.RELEVANCE_THRESHOLD ), ) else: @@ -501,7 +479,7 @@ def query_doc_handler( collection_name=form_data.collection_name, query=form_data.query, embedding_function=app.state.EMBEDDING_FUNCTION, - k=form_data.k if form_data.k else config_get(app.state.TOP_K), + k=form_data.k if form_data.k else app.state.config.TOP_K, ) except Exception as e: log.exception(e) @@ -525,17 +503,15 @@ def query_collection_handler( user=Depends(get_current_user), ): try: - if config_get(app.state.ENABLE_RAG_HYBRID_SEARCH): + if app.state.config.ENABLE_RAG_HYBRID_SEARCH: return query_collection_with_hybrid_search( collection_names=form_data.collection_names, query=form_data.query, embedding_function=app.state.EMBEDDING_FUNCTION, - k=form_data.k if form_data.k else config_get(app.state.TOP_K), + k=form_data.k if form_data.k else app.state.config.TOP_K, reranking_function=app.state.sentence_transformer_rf, r=( - form_data.r - if form_data.r - else config_get(app.state.RELEVANCE_THRESHOLD) + form_data.r if form_data.r else app.state.config.RELEVANCE_THRESHOLD ), ) else: @@ -543,7 +519,7 @@ def query_collection_handler( collection_names=form_data.collection_names, query=form_data.query, embedding_function=app.state.EMBEDDING_FUNCTION, - k=form_data.k if form_data.k else config_get(app.state.TOP_K), + k=form_data.k if form_data.k else app.state.config.TOP_K, ) except Exception as e: @@ -560,8 +536,8 @@ def store_youtube_video(form_data: UrlForm, user=Depends(get_current_user)): loader = YoutubeLoader.from_youtube_url( form_data.url, add_video_info=True, - language=config_get(app.state.YOUTUBE_LOADER_LANGUAGE), - translation=config_get(app.state.YOUTUBE_LOADER_TRANSLATION), + language=app.state.config.YOUTUBE_LOADER_LANGUAGE, + translation=app.state.YOUTUBE_LOADER_TRANSLATION, ) data = loader.load() @@ -589,7 +565,7 @@ def store_web(form_data: UrlForm, user=Depends(get_current_user)): try: loader = get_web_loader( form_data.url, - verify_ssl=config_get(app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION), + verify_ssl=app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, ) data = loader.load() @@ -645,8 +621,8 @@ def resolve_hostname(hostname): def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool: text_splitter = RecursiveCharacterTextSplitter( - chunk_size=config_get(app.state.CHUNK_SIZE), - chunk_overlap=config_get(app.state.CHUNK_OVERLAP), + chunk_size=app.state.config.CHUNK_SIZE, + chunk_overlap=app.state.config.CHUNK_OVERLAP, add_start_index=True, ) @@ -663,8 +639,8 @@ def store_text_in_vector_db( text, metadata, collection_name, overwrite: bool = False ) -> bool: text_splitter = RecursiveCharacterTextSplitter( - chunk_size=config_get(app.state.CHUNK_SIZE), - chunk_overlap=config_get(app.state.CHUNK_OVERLAP), + chunk_size=app.state.config.CHUNK_SIZE, + chunk_overlap=app.state.config.CHUNK_OVERLAP, add_start_index=True, ) docs = text_splitter.create_documents([text], metadatas=[metadata]) @@ -687,11 +663,11 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b collection = CHROMA_CLIENT.create_collection(name=collection_name) embedding_func = get_embedding_function( - config_get(app.state.RAG_EMBEDDING_ENGINE), - config_get(app.state.RAG_EMBEDDING_MODEL), + app.state.config.RAG_EMBEDDING_ENGINE, + app.state.config.RAG_EMBEDDING_MODEL, app.state.sentence_transformer_ef, - config_get(app.state.OPENAI_API_KEY), - config_get(app.state.OPENAI_API_BASE_URL), + app.state.config.OPENAI_API_KEY, + app.state.config.OPENAI_API_BASE_URL, ) embedding_texts = list(map(lambda x: x.replace("\n", " "), texts)) @@ -766,7 +742,7 @@ def get_loader(filename: str, file_content_type: str, file_path: str): if file_ext == "pdf": loader = PyPDFLoader( - file_path, extract_images=config_get(app.state.PDF_EXTRACT_IMAGES) + file_path, extract_images=app.state.config.PDF_EXTRACT_IMAGES ) elif file_ext == "csv": loader = CSVLoader(file_path) diff --git a/backend/apps/web/main.py b/backend/apps/web/main.py index 2bed33543..755e3911b 100644 --- a/backend/apps/web/main.py +++ b/backend/apps/web/main.py @@ -22,21 +22,23 @@ from config import ( WEBHOOK_URL, WEBUI_AUTH_TRUSTED_EMAIL_HEADER, JWT_EXPIRES_IN, - config_get, + AppConfig, ) app = FastAPI() origins = ["*"] -app.state.ENABLE_SIGNUP = ENABLE_SIGNUP -app.state.JWT_EXPIRES_IN = JWT_EXPIRES_IN +app.state.config = AppConfig() -app.state.DEFAULT_MODELS = DEFAULT_MODELS -app.state.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS -app.state.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE -app.state.USER_PERMISSIONS = USER_PERMISSIONS -app.state.WEBHOOK_URL = WEBHOOK_URL +app.state.config.ENABLE_SIGNUP = ENABLE_SIGNUP +app.state.config.JWT_EXPIRES_IN = JWT_EXPIRES_IN + +app.state.config.DEFAULT_MODELS = DEFAULT_MODELS +app.state.config.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS +app.state.config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE +app.state.config.USER_PERMISSIONS = USER_PERMISSIONS +app.state.config.WEBHOOK_URL = WEBHOOK_URL app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER app.add_middleware( @@ -63,6 +65,6 @@ async def get_status(): return { "status": True, "auth": WEBUI_AUTH, - "default_models": config_get(app.state.DEFAULT_MODELS), - "default_prompt_suggestions": config_get(app.state.DEFAULT_PROMPT_SUGGESTIONS), + "default_models": app.state.config.DEFAULT_MODELS, + "default_prompt_suggestions": app.state.config.DEFAULT_PROMPT_SUGGESTIONS, } diff --git a/backend/apps/web/routers/auths.py b/backend/apps/web/routers/auths.py index 0bc4967f9..998e74659 100644 --- a/backend/apps/web/routers/auths.py +++ b/backend/apps/web/routers/auths.py @@ -33,7 +33,7 @@ from utils.utils import ( from utils.misc import parse_duration, validate_email_format from utils.webhook import post_webhook from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES -from config import WEBUI_AUTH, WEBUI_AUTH_TRUSTED_EMAIL_HEADER, config_get, config_set +from config import WEBUI_AUTH, WEBUI_AUTH_TRUSTED_EMAIL_HEADER router = APIRouter() @@ -140,7 +140,7 @@ async def signin(request: Request, form_data: SigninForm): if user: token = create_token( data={"id": user.id}, - expires_delta=parse_duration(config_get(request.app.state.JWT_EXPIRES_IN)), + expires_delta=parse_duration(request.app.state.config.JWT_EXPIRES_IN), ) return { @@ -163,7 +163,7 @@ async def signin(request: Request, form_data: SigninForm): @router.post("/signup", response_model=SigninResponse) async def signup(request: Request, form_data: SignupForm): - if not config_get(request.app.state.ENABLE_SIGNUP) and WEBUI_AUTH: + if not request.app.state.config.ENABLE_SIGNUP and WEBUI_AUTH: raise HTTPException( status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED ) @@ -180,7 +180,7 @@ async def signup(request: Request, form_data: SignupForm): role = ( "admin" if Users.get_num_users() == 0 - else config_get(request.app.state.DEFAULT_USER_ROLE) + else request.app.state.config.DEFAULT_USER_ROLE ) hashed = get_password_hash(form_data.password) user = Auths.insert_new_auth( @@ -194,15 +194,13 @@ async def signup(request: Request, form_data: SignupForm): if user: token = create_token( data={"id": user.id}, - expires_delta=parse_duration( - config_get(request.app.state.JWT_EXPIRES_IN) - ), + expires_delta=parse_duration(request.app.state.config.JWT_EXPIRES_IN), ) # response.set_cookie(key='token', value=token, httponly=True) - if config_get(request.app.state.WEBHOOK_URL): + if request.app.state.config.WEBHOOK_URL: post_webhook( - config_get(request.app.state.WEBHOOK_URL), + request.app.state.config.WEBHOOK_URL, WEBHOOK_MESSAGES.USER_SIGNUP(user.name), { "action": "signup", @@ -278,15 +276,13 @@ async def add_user(form_data: AddUserForm, user=Depends(get_admin_user)): @router.get("/signup/enabled", response_model=bool) async def get_sign_up_status(request: Request, user=Depends(get_admin_user)): - return config_get(request.app.state.ENABLE_SIGNUP) + return request.app.state.config.ENABLE_SIGNUP @router.get("/signup/enabled/toggle", response_model=bool) async def toggle_sign_up(request: Request, user=Depends(get_admin_user)): - config_set( - request.app.state.ENABLE_SIGNUP, not config_get(request.app.state.ENABLE_SIGNUP) - ) - return config_get(request.app.state.ENABLE_SIGNUP) + request.app.state.config.ENABLE_SIGNUP = not request.app.state.config.ENABLE_SIGNUP + return request.app.state.config.ENABLE_SIGNUP ############################ @@ -296,7 +292,7 @@ async def toggle_sign_up(request: Request, user=Depends(get_admin_user)): @router.get("/signup/user/role") async def get_default_user_role(request: Request, user=Depends(get_admin_user)): - return config_get(request.app.state.DEFAULT_USER_ROLE) + return request.app.state.config.DEFAULT_USER_ROLE class UpdateRoleForm(BaseModel): @@ -308,8 +304,8 @@ async def update_default_user_role( request: Request, form_data: UpdateRoleForm, user=Depends(get_admin_user) ): if form_data.role in ["pending", "user", "admin"]: - config_set(request.app.state.DEFAULT_USER_ROLE, form_data.role) - return config_get(request.app.state.DEFAULT_USER_ROLE) + request.app.state.config.DEFAULT_USER_ROLE = form_data.role + return request.app.state.config.DEFAULT_USER_ROLE ############################ @@ -319,7 +315,7 @@ async def update_default_user_role( @router.get("/token/expires") async def get_token_expires_duration(request: Request, user=Depends(get_admin_user)): - return config_get(request.app.state.JWT_EXPIRES_IN) + return request.app.state.config.JWT_EXPIRES_IN class UpdateJWTExpiresDurationForm(BaseModel): @@ -336,10 +332,10 @@ async def update_token_expires_duration( # Check if the input string matches the pattern if re.match(pattern, form_data.duration): - config_set(request.app.state.JWT_EXPIRES_IN, form_data.duration) - return config_get(request.app.state.JWT_EXPIRES_IN) + request.app.state.config.JWT_EXPIRES_IN = form_data.duration + return request.app.state.config.JWT_EXPIRES_IN else: - return config_get(request.app.state.JWT_EXPIRES_IN) + return request.app.state.config.JWT_EXPIRES_IN ############################ diff --git a/backend/apps/web/routers/configs.py b/backend/apps/web/routers/configs.py index d726cd2dc..143ed5e0a 100644 --- a/backend/apps/web/routers/configs.py +++ b/backend/apps/web/routers/configs.py @@ -9,7 +9,6 @@ import time import uuid from apps.web.models.users import Users -from config import config_set, config_get from utils.utils import ( get_password_hash, @@ -45,8 +44,8 @@ class SetDefaultSuggestionsForm(BaseModel): async def set_global_default_models( request: Request, form_data: SetDefaultModelsForm, user=Depends(get_admin_user) ): - config_set(request.app.state.DEFAULT_MODELS, form_data.models) - return config_get(request.app.state.DEFAULT_MODELS) + request.app.state.config.DEFAULT_MODELS = form_data.models + return request.app.state.config.DEFAULT_MODELS @router.post("/default/suggestions", response_model=List[PromptSuggestion]) @@ -56,5 +55,5 @@ async def set_global_default_suggestions( user=Depends(get_admin_user), ): data = form_data.model_dump() - config_set(request.app.state.DEFAULT_PROMPT_SUGGESTIONS, data["suggestions"]) - return config_get(request.app.state.DEFAULT_PROMPT_SUGGESTIONS) + request.app.state.config.DEFAULT_PROMPT_SUGGESTIONS = data["suggestions"] + return request.app.state.config.DEFAULT_PROMPT_SUGGESTIONS diff --git a/backend/apps/web/routers/users.py b/backend/apps/web/routers/users.py index 302432540..d87854e89 100644 --- a/backend/apps/web/routers/users.py +++ b/backend/apps/web/routers/users.py @@ -15,7 +15,7 @@ from apps.web.models.auths import Auths from utils.utils import get_current_user, get_password_hash, get_admin_user from constants import ERROR_MESSAGES -from config import SRC_LOG_LEVELS, config_set, config_get +from config import SRC_LOG_LEVELS log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MODELS"]) @@ -39,15 +39,15 @@ async def get_users(skip: int = 0, limit: int = 50, user=Depends(get_admin_user) @router.get("/permissions/user") async def get_user_permissions(request: Request, user=Depends(get_admin_user)): - return config_get(request.app.state.USER_PERMISSIONS) + return request.app.state.config.USER_PERMISSIONS @router.post("/permissions/user") async def update_user_permissions( request: Request, form_data: dict, user=Depends(get_admin_user) ): - config_set(request.app.state.USER_PERMISSIONS, form_data) - return config_get(request.app.state.USER_PERMISSIONS) + request.app.state.config.USER_PERMISSIONS = form_data + return request.app.state.config.USER_PERMISSIONS ############################ diff --git a/backend/config.py b/backend/config.py index 9e7a9ef90..845a812ce 100644 --- a/backend/config.py +++ b/backend/config.py @@ -246,19 +246,21 @@ class WrappedConfig(Generic[T]): self.config_value = self.value -def config_set(config: Union[WrappedConfig[T], T], value: T, save_config=True): - if isinstance(config, WrappedConfig): - config.value = value - if save_config: - config.save() - else: - config = value +class AppConfig: + _state: dict[str, WrappedConfig] + def __init__(self): + super().__setattr__("_state", {}) -def config_get(config: Union[WrappedConfig[T], T]) -> T: - if isinstance(config, WrappedConfig): - return config.value - return config + def __setattr__(self, key, value): + if isinstance(value, WrappedConfig): + self._state[key] = value + else: + self._state[key].value = value + self._state[key].save() + + def __getattr__(self, key): + return self._state[key].value #################################### diff --git a/backend/main.py b/backend/main.py index 6f94a8dad..e2d7e18a3 100644 --- a/backend/main.py +++ b/backend/main.py @@ -58,8 +58,7 @@ from config import ( SRC_LOG_LEVELS, WEBHOOK_URL, ENABLE_ADMIN_EXPORT, - config_get, - config_set, + AppConfig, ) from constants import ERROR_MESSAGES @@ -96,10 +95,11 @@ https://github.com/open-webui/open-webui app = FastAPI(docs_url="/docs" if ENV == "dev" else None, redoc_url=None) -app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER -app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST +app.state.config = AppConfig() +app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER +app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST -app.state.WEBHOOK_URL = WEBHOOK_URL +app.state.config.WEBHOOK_URL = WEBHOOK_URL origins = ["*"] @@ -245,11 +245,9 @@ async def get_app_config(): "version": VERSION, "auth": WEBUI_AUTH, "default_locale": default_locale, - "images": config_get(images_app.state.ENABLED), - "default_models": config_get(webui_app.state.DEFAULT_MODELS), - "default_prompt_suggestions": config_get( - webui_app.state.DEFAULT_PROMPT_SUGGESTIONS - ), + "images": images_app.state.config.ENABLED, + "default_models": webui_app.state.config.DEFAULT_MODELS, + "default_prompt_suggestions": webui_app.state.config.DEFAULT_PROMPT_SUGGESTIONS, "trusted_header_auth": bool(webui_app.state.AUTH_TRUSTED_EMAIL_HEADER), "admin_export_enabled": ENABLE_ADMIN_EXPORT, } @@ -258,8 +256,8 @@ async def get_app_config(): @app.get("/api/config/model/filter") async def get_model_filter_config(user=Depends(get_admin_user)): return { - "enabled": config_get(app.state.ENABLE_MODEL_FILTER), - "models": config_get(app.state.MODEL_FILTER_LIST), + "enabled": app.state.config.ENABLE_MODEL_FILTER, + "models": app.state.config.MODEL_FILTER_LIST, } @@ -272,28 +270,28 @@ class ModelFilterConfigForm(BaseModel): async def update_model_filter_config( form_data: ModelFilterConfigForm, user=Depends(get_admin_user) ): - config_set(app.state.ENABLE_MODEL_FILTER, form_data.enabled) - config_set(app.state.MODEL_FILTER_LIST, form_data.models) + app.state.config.ENABLE_MODEL_FILTER, form_data.enabled + app.state.config.MODEL_FILTER_LIST, form_data.models - ollama_app.state.ENABLE_MODEL_FILTER = config_get(app.state.ENABLE_MODEL_FILTER) - ollama_app.state.MODEL_FILTER_LIST = config_get(app.state.MODEL_FILTER_LIST) + ollama_app.state.ENABLE_MODEL_FILTER = app.state.config.ENABLE_MODEL_FILTER + ollama_app.state.MODEL_FILTER_LIST = app.state.config.MODEL_FILTER_LIST - openai_app.state.ENABLE_MODEL_FILTER = config_get(app.state.ENABLE_MODEL_FILTER) - openai_app.state.MODEL_FILTER_LIST = config_get(app.state.MODEL_FILTER_LIST) + openai_app.state.ENABLE_MODEL_FILTER = app.state.config.ENABLE_MODEL_FILTER + openai_app.state.MODEL_FILTER_LIST = app.state.config.MODEL_FILTER_LIST - litellm_app.state.ENABLE_MODEL_FILTER = config_get(app.state.ENABLE_MODEL_FILTER) - litellm_app.state.MODEL_FILTER_LIST = config_get(app.state.MODEL_FILTER_LIST) + litellm_app.state.ENABLE_MODEL_FILTER = app.state.config.ENABLE_MODEL_FILTER + litellm_app.state.MODEL_FILTER_LIST = app.state.config.MODEL_FILTER_LIST return { - "enabled": config_get(app.state.ENABLE_MODEL_FILTER), - "models": config_get(app.state.MODEL_FILTER_LIST), + "enabled": app.state.config.ENABLE_MODEL_FILTER, + "models": app.state.config.MODEL_FILTER_LIST, } @app.get("/api/webhook") async def get_webhook_url(user=Depends(get_admin_user)): return { - "url": config_get(app.state.WEBHOOK_URL), + "url": app.state.config.WEBHOOK_URL, } @@ -303,12 +301,12 @@ class UrlForm(BaseModel): @app.post("/api/webhook") async def update_webhook_url(form_data: UrlForm, user=Depends(get_admin_user)): - config_set(app.state.WEBHOOK_URL, form_data.url) + app.state.config.WEBHOOK_URL = form_data.url - webui_app.state.WEBHOOK_URL = config_get(app.state.WEBHOOK_URL) + webui_app.state.WEBHOOK_URL = app.state.config.WEBHOOK_URL return { - "url": config_get(app.state.WEBHOOK_URL), + "url": app.state.config.WEBHOOK_URL, } From a0dceb06a5a3c31b9be8617179fe9c1876abc8ca Mon Sep 17 00:00:00 2001 From: Jun Siang Cheah Date: Fri, 10 May 2024 15:20:22 +0800 Subject: [PATCH 4/6] fix: nested WrappedConfig breaks things --- backend/config.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/backend/config.py b/backend/config.py index 845a812ce..a619bb746 100644 --- a/backend/config.py +++ b/backend/config.py @@ -511,10 +511,8 @@ DEFAULT_USER_ROLE = WrappedConfig( os.getenv("DEFAULT_USER_ROLE", "pending"), ) -USER_PERMISSIONS_CHAT_DELETION = WrappedConfig( - "USER_PERMISSIONS_CHAT_DELETION", - "ui.user_permissions.chat.deletion", - os.environ.get("USER_PERMISSIONS_CHAT_DELETION", "True").lower() == "true", +USER_PERMISSIONS_CHAT_DELETION = ( + os.environ.get("USER_PERMISSIONS_CHAT_DELETION", "True").lower() == "true" ) USER_PERMISSIONS = WrappedConfig( From 5d64822c84030c2a2c0631ddfa68362cd55463fa Mon Sep 17 00:00:00 2001 From: Jun Siang Cheah Date: Sun, 12 May 2024 13:28:40 +0800 Subject: [PATCH 5/6] refac: rename WrappedConfig to PersistedConfig --- backend/config.py | 88 +++++++++++++++++++++++------------------------ 1 file changed, 44 insertions(+), 44 deletions(-) diff --git a/backend/config.py b/backend/config.py index a619bb746..106044540 100644 --- a/backend/config.py +++ b/backend/config.py @@ -201,7 +201,7 @@ def get_config_value(config_path: str): T = TypeVar("T") -class WrappedConfig(Generic[T]): +class PersistedConfig(Generic[T]): def __init__(self, env_name: str, config_path: str, env_value: T): self.env_name = env_name self.config_path = config_path @@ -219,13 +219,13 @@ class WrappedConfig(Generic[T]): @property def __dict__(self): raise TypeError( - "WrappedConfig object cannot be converted to dict, use config_get or .value instead." + "PersistedConfig object cannot be converted to dict, use config_get or .value instead." ) def __getattribute__(self, item): if item == "__dict__": raise TypeError( - "WrappedConfig object cannot be converted to dict, use config_get or .value instead." + "PersistedConfig object cannot be converted to dict, use config_get or .value instead." ) return super().__getattribute__(item) @@ -247,13 +247,13 @@ class WrappedConfig(Generic[T]): class AppConfig: - _state: dict[str, WrappedConfig] + _state: dict[str, PersistedConfig] def __init__(self): super().__setattr__("_state", {}) def __setattr__(self, key, value): - if isinstance(value, WrappedConfig): + if isinstance(value, PersistedConfig): self._state[key] = value else: self._state[key].value = value @@ -271,7 +271,7 @@ WEBUI_AUTH = os.environ.get("WEBUI_AUTH", "True").lower() == "true" WEBUI_AUTH_TRUSTED_EMAIL_HEADER = os.environ.get( "WEBUI_AUTH_TRUSTED_EMAIL_HEADER", None ) -JWT_EXPIRES_IN = WrappedConfig( +JWT_EXPIRES_IN = PersistedConfig( "JWT_EXPIRES_IN", "auth.jwt_expiry", os.environ.get("JWT_EXPIRES_IN", "-1") ) @@ -409,7 +409,7 @@ OLLAMA_BASE_URLS = os.environ.get("OLLAMA_BASE_URLS", "") OLLAMA_BASE_URLS = OLLAMA_BASE_URLS if OLLAMA_BASE_URLS != "" else OLLAMA_BASE_URL OLLAMA_BASE_URLS = [url.strip() for url in OLLAMA_BASE_URLS.split(";")] -OLLAMA_BASE_URLS = WrappedConfig( +OLLAMA_BASE_URLS = PersistedConfig( "OLLAMA_BASE_URLS", "ollama.base_urls", OLLAMA_BASE_URLS ) @@ -428,7 +428,7 @@ OPENAI_API_KEYS = os.environ.get("OPENAI_API_KEYS", "") OPENAI_API_KEYS = OPENAI_API_KEYS if OPENAI_API_KEYS != "" else OPENAI_API_KEY OPENAI_API_KEYS = [url.strip() for url in OPENAI_API_KEYS.split(";")] -OPENAI_API_KEYS = WrappedConfig("OPENAI_API_KEYS", "openai.api_keys", OPENAI_API_KEYS) +OPENAI_API_KEYS = PersistedConfig("OPENAI_API_KEYS", "openai.api_keys", OPENAI_API_KEYS) OPENAI_API_BASE_URLS = os.environ.get("OPENAI_API_BASE_URLS", "") OPENAI_API_BASE_URLS = ( @@ -439,7 +439,7 @@ OPENAI_API_BASE_URLS = [ url.strip() if url != "" else "https://api.openai.com/v1" for url in OPENAI_API_BASE_URLS.split(";") ] -OPENAI_API_BASE_URLS = WrappedConfig( +OPENAI_API_BASE_URLS = PersistedConfig( "OPENAI_API_BASE_URLS", "openai.api_base_urls", OPENAI_API_BASE_URLS ) @@ -458,7 +458,7 @@ OPENAI_API_BASE_URL = "https://api.openai.com/v1" # WEBUI #################################### -ENABLE_SIGNUP = WrappedConfig( +ENABLE_SIGNUP = PersistedConfig( "ENABLE_SIGNUP", "ui.enable_signup", ( @@ -467,11 +467,11 @@ ENABLE_SIGNUP = WrappedConfig( else os.environ.get("ENABLE_SIGNUP", "True").lower() == "true" ), ) -DEFAULT_MODELS = WrappedConfig( +DEFAULT_MODELS = PersistedConfig( "DEFAULT_MODELS", "ui.default_models", os.environ.get("DEFAULT_MODELS", None) ) -DEFAULT_PROMPT_SUGGESTIONS = WrappedConfig( +DEFAULT_PROMPT_SUGGESTIONS = PersistedConfig( "DEFAULT_PROMPT_SUGGESTIONS", "ui.prompt_suggestions", [ @@ -505,7 +505,7 @@ DEFAULT_PROMPT_SUGGESTIONS = WrappedConfig( ], ) -DEFAULT_USER_ROLE = WrappedConfig( +DEFAULT_USER_ROLE = PersistedConfig( "DEFAULT_USER_ROLE", "ui.default_user_role", os.getenv("DEFAULT_USER_ROLE", "pending"), @@ -515,25 +515,25 @@ USER_PERMISSIONS_CHAT_DELETION = ( os.environ.get("USER_PERMISSIONS_CHAT_DELETION", "True").lower() == "true" ) -USER_PERMISSIONS = WrappedConfig( +USER_PERMISSIONS = PersistedConfig( "USER_PERMISSIONS", "ui.user_permissions", {"chat": {"deletion": USER_PERMISSIONS_CHAT_DELETION}}, ) -ENABLE_MODEL_FILTER = WrappedConfig( +ENABLE_MODEL_FILTER = PersistedConfig( "ENABLE_MODEL_FILTER", "model_filter.enable", os.environ.get("ENABLE_MODEL_FILTER", "False").lower() == "true", ) MODEL_FILTER_LIST = os.environ.get("MODEL_FILTER_LIST", "") -MODEL_FILTER_LIST = WrappedConfig( +MODEL_FILTER_LIST = PersistedConfig( "MODEL_FILTER_LIST", "model_filter.list", [model.strip() for model in MODEL_FILTER_LIST.split(";")], ) -WEBHOOK_URL = WrappedConfig( +WEBHOOK_URL = PersistedConfig( "WEBHOOK_URL", "webhook_url", os.environ.get("WEBHOOK_URL", "") ) @@ -573,40 +573,40 @@ else: CHROMA_HTTP_SSL = os.environ.get("CHROMA_HTTP_SSL", "false").lower() == "true" # this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (sentence-transformers/all-MiniLM-L6-v2) -RAG_TOP_K = WrappedConfig( +RAG_TOP_K = PersistedConfig( "RAG_TOP_K", "rag.top_k", int(os.environ.get("RAG_TOP_K", "5")) ) -RAG_RELEVANCE_THRESHOLD = WrappedConfig( +RAG_RELEVANCE_THRESHOLD = PersistedConfig( "RAG_RELEVANCE_THRESHOLD", "rag.relevance_threshold", float(os.environ.get("RAG_RELEVANCE_THRESHOLD", "0.0")), ) -ENABLE_RAG_HYBRID_SEARCH = WrappedConfig( +ENABLE_RAG_HYBRID_SEARCH = PersistedConfig( "ENABLE_RAG_HYBRID_SEARCH", "rag.enable_hybrid_search", os.environ.get("ENABLE_RAG_HYBRID_SEARCH", "").lower() == "true", ) -ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = WrappedConfig( +ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = PersistedConfig( "ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION", "rag.enable_web_loader_ssl_verification", os.environ.get("ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION", "True").lower() == "true", ) -RAG_EMBEDDING_ENGINE = WrappedConfig( +RAG_EMBEDDING_ENGINE = PersistedConfig( "RAG_EMBEDDING_ENGINE", "rag.embedding_engine", os.environ.get("RAG_EMBEDDING_ENGINE", ""), ) -PDF_EXTRACT_IMAGES = WrappedConfig( +PDF_EXTRACT_IMAGES = PersistedConfig( "PDF_EXTRACT_IMAGES", "rag.pdf_extract_images", os.environ.get("PDF_EXTRACT_IMAGES", "False").lower() == "true", ) -RAG_EMBEDDING_MODEL = WrappedConfig( +RAG_EMBEDDING_MODEL = PersistedConfig( "RAG_EMBEDDING_MODEL", "rag.embedding_model", os.environ.get("RAG_EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2"), @@ -621,7 +621,7 @@ RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE = ( os.environ.get("RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true" ) -RAG_RERANKING_MODEL = WrappedConfig( +RAG_RERANKING_MODEL = PersistedConfig( "RAG_RERANKING_MODEL", "rag.reranking_model", os.environ.get("RAG_RERANKING_MODEL", ""), @@ -665,10 +665,10 @@ if USE_CUDA.lower() == "true": else: DEVICE_TYPE = "cpu" -CHUNK_SIZE = WrappedConfig( +CHUNK_SIZE = PersistedConfig( "CHUNK_SIZE", "rag.chunk_size", int(os.environ.get("CHUNK_SIZE", "1500")) ) -CHUNK_OVERLAP = WrappedConfig( +CHUNK_OVERLAP = PersistedConfig( "CHUNK_OVERLAP", "rag.chunk_overlap", int(os.environ.get("CHUNK_OVERLAP", "100")), @@ -688,18 +688,18 @@ And answer according to the language of the user's question. Given the context information, answer the query. Query: [query]""" -RAG_TEMPLATE = WrappedConfig( +RAG_TEMPLATE = PersistedConfig( "RAG_TEMPLATE", "rag.template", os.environ.get("RAG_TEMPLATE", DEFAULT_RAG_TEMPLATE), ) -RAG_OPENAI_API_BASE_URL = WrappedConfig( +RAG_OPENAI_API_BASE_URL = PersistedConfig( "RAG_OPENAI_API_BASE_URL", "rag.openai_api_base_url", os.getenv("RAG_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL), ) -RAG_OPENAI_API_KEY = WrappedConfig( +RAG_OPENAI_API_KEY = PersistedConfig( "RAG_OPENAI_API_KEY", "rag.openai_api_key", os.getenv("RAG_OPENAI_API_KEY", OPENAI_API_KEY), @@ -709,7 +709,7 @@ ENABLE_RAG_LOCAL_WEB_FETCH = ( os.getenv("ENABLE_RAG_LOCAL_WEB_FETCH", "False").lower() == "true" ) -YOUTUBE_LOADER_LANGUAGE = WrappedConfig( +YOUTUBE_LOADER_LANGUAGE = PersistedConfig( "YOUTUBE_LOADER_LANGUAGE", "rag.youtube_loader_language", os.getenv("YOUTUBE_LOADER_LANGUAGE", "en").split(","), @@ -730,49 +730,49 @@ WHISPER_MODEL_AUTO_UPDATE = ( # Images #################################### -IMAGE_GENERATION_ENGINE = WrappedConfig( +IMAGE_GENERATION_ENGINE = PersistedConfig( "IMAGE_GENERATION_ENGINE", "image_generation.engine", os.getenv("IMAGE_GENERATION_ENGINE", ""), ) -ENABLE_IMAGE_GENERATION = WrappedConfig( +ENABLE_IMAGE_GENERATION = PersistedConfig( "ENABLE_IMAGE_GENERATION", "image_generation.enable", os.environ.get("ENABLE_IMAGE_GENERATION", "").lower() == "true", ) -AUTOMATIC1111_BASE_URL = WrappedConfig( +AUTOMATIC1111_BASE_URL = PersistedConfig( "AUTOMATIC1111_BASE_URL", "image_generation.automatic1111.base_url", os.getenv("AUTOMATIC1111_BASE_URL", ""), ) -COMFYUI_BASE_URL = WrappedConfig( +COMFYUI_BASE_URL = PersistedConfig( "COMFYUI_BASE_URL", "image_generation.comfyui.base_url", os.getenv("COMFYUI_BASE_URL", ""), ) -IMAGES_OPENAI_API_BASE_URL = WrappedConfig( +IMAGES_OPENAI_API_BASE_URL = PersistedConfig( "IMAGES_OPENAI_API_BASE_URL", "image_generation.openai.api_base_url", os.getenv("IMAGES_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL), ) -IMAGES_OPENAI_API_KEY = WrappedConfig( +IMAGES_OPENAI_API_KEY = PersistedConfig( "IMAGES_OPENAI_API_KEY", "image_generation.openai.api_key", os.getenv("IMAGES_OPENAI_API_KEY", OPENAI_API_KEY), ) -IMAGE_SIZE = WrappedConfig( +IMAGE_SIZE = PersistedConfig( "IMAGE_SIZE", "image_generation.size", os.getenv("IMAGE_SIZE", "512x512") ) -IMAGE_STEPS = WrappedConfig( +IMAGE_STEPS = PersistedConfig( "IMAGE_STEPS", "image_generation.steps", int(os.getenv("IMAGE_STEPS", 50)) ) -IMAGE_GENERATION_MODEL = WrappedConfig( +IMAGE_GENERATION_MODEL = PersistedConfig( "IMAGE_GENERATION_MODEL", "image_generation.model", os.getenv("IMAGE_GENERATION_MODEL", ""), @@ -782,22 +782,22 @@ IMAGE_GENERATION_MODEL = WrappedConfig( # Audio #################################### -AUDIO_OPENAI_API_BASE_URL = WrappedConfig( +AUDIO_OPENAI_API_BASE_URL = PersistedConfig( "AUDIO_OPENAI_API_BASE_URL", "audio.openai.api_base_url", os.getenv("AUDIO_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL), ) -AUDIO_OPENAI_API_KEY = WrappedConfig( +AUDIO_OPENAI_API_KEY = PersistedConfig( "AUDIO_OPENAI_API_KEY", "audio.openai.api_key", os.getenv("AUDIO_OPENAI_API_KEY", OPENAI_API_KEY), ) -AUDIO_OPENAI_API_MODEL = WrappedConfig( +AUDIO_OPENAI_API_MODEL = PersistedConfig( "AUDIO_OPENAI_API_MODEL", "audio.openai.api_model", os.getenv("AUDIO_OPENAI_API_MODEL", "tts-1"), ) -AUDIO_OPENAI_API_VOICE = WrappedConfig( +AUDIO_OPENAI_API_VOICE = PersistedConfig( "AUDIO_OPENAI_API_VOICE", "audio.openai.api_voice", os.getenv("AUDIO_OPENAI_API_VOICE", "alloy"), From 0c033b5b7b394dd525c6183c8962c6e680277534 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Mon, 13 May 2024 11:32:21 -1000 Subject: [PATCH 6/6] refac: rename --- backend/config.py | 90 ++++++++++++++++++++++++----------------------- 1 file changed, 46 insertions(+), 44 deletions(-) diff --git a/backend/config.py b/backend/config.py index 106044540..112edba90 100644 --- a/backend/config.py +++ b/backend/config.py @@ -201,7 +201,7 @@ def get_config_value(config_path: str): T = TypeVar("T") -class PersistedConfig(Generic[T]): +class PersistentConfig(Generic[T]): def __init__(self, env_name: str, config_path: str, env_value: T): self.env_name = env_name self.config_path = config_path @@ -219,13 +219,13 @@ class PersistedConfig(Generic[T]): @property def __dict__(self): raise TypeError( - "PersistedConfig object cannot be converted to dict, use config_get or .value instead." + "PersistentConfig object cannot be converted to dict, use config_get or .value instead." ) def __getattribute__(self, item): if item == "__dict__": raise TypeError( - "PersistedConfig object cannot be converted to dict, use config_get or .value instead." + "PersistentConfig object cannot be converted to dict, use config_get or .value instead." ) return super().__getattribute__(item) @@ -247,13 +247,13 @@ class PersistedConfig(Generic[T]): class AppConfig: - _state: dict[str, PersistedConfig] + _state: dict[str, PersistentConfig] def __init__(self): super().__setattr__("_state", {}) def __setattr__(self, key, value): - if isinstance(value, PersistedConfig): + if isinstance(value, PersistentConfig): self._state[key] = value else: self._state[key].value = value @@ -271,7 +271,7 @@ WEBUI_AUTH = os.environ.get("WEBUI_AUTH", "True").lower() == "true" WEBUI_AUTH_TRUSTED_EMAIL_HEADER = os.environ.get( "WEBUI_AUTH_TRUSTED_EMAIL_HEADER", None ) -JWT_EXPIRES_IN = PersistedConfig( +JWT_EXPIRES_IN = PersistentConfig( "JWT_EXPIRES_IN", "auth.jwt_expiry", os.environ.get("JWT_EXPIRES_IN", "-1") ) @@ -409,7 +409,7 @@ OLLAMA_BASE_URLS = os.environ.get("OLLAMA_BASE_URLS", "") OLLAMA_BASE_URLS = OLLAMA_BASE_URLS if OLLAMA_BASE_URLS != "" else OLLAMA_BASE_URL OLLAMA_BASE_URLS = [url.strip() for url in OLLAMA_BASE_URLS.split(";")] -OLLAMA_BASE_URLS = PersistedConfig( +OLLAMA_BASE_URLS = PersistentConfig( "OLLAMA_BASE_URLS", "ollama.base_urls", OLLAMA_BASE_URLS ) @@ -428,7 +428,9 @@ OPENAI_API_KEYS = os.environ.get("OPENAI_API_KEYS", "") OPENAI_API_KEYS = OPENAI_API_KEYS if OPENAI_API_KEYS != "" else OPENAI_API_KEY OPENAI_API_KEYS = [url.strip() for url in OPENAI_API_KEYS.split(";")] -OPENAI_API_KEYS = PersistedConfig("OPENAI_API_KEYS", "openai.api_keys", OPENAI_API_KEYS) +OPENAI_API_KEYS = PersistentConfig( + "OPENAI_API_KEYS", "openai.api_keys", OPENAI_API_KEYS +) OPENAI_API_BASE_URLS = os.environ.get("OPENAI_API_BASE_URLS", "") OPENAI_API_BASE_URLS = ( @@ -439,7 +441,7 @@ OPENAI_API_BASE_URLS = [ url.strip() if url != "" else "https://api.openai.com/v1" for url in OPENAI_API_BASE_URLS.split(";") ] -OPENAI_API_BASE_URLS = PersistedConfig( +OPENAI_API_BASE_URLS = PersistentConfig( "OPENAI_API_BASE_URLS", "openai.api_base_urls", OPENAI_API_BASE_URLS ) @@ -458,7 +460,7 @@ OPENAI_API_BASE_URL = "https://api.openai.com/v1" # WEBUI #################################### -ENABLE_SIGNUP = PersistedConfig( +ENABLE_SIGNUP = PersistentConfig( "ENABLE_SIGNUP", "ui.enable_signup", ( @@ -467,11 +469,11 @@ ENABLE_SIGNUP = PersistedConfig( else os.environ.get("ENABLE_SIGNUP", "True").lower() == "true" ), ) -DEFAULT_MODELS = PersistedConfig( +DEFAULT_MODELS = PersistentConfig( "DEFAULT_MODELS", "ui.default_models", os.environ.get("DEFAULT_MODELS", None) ) -DEFAULT_PROMPT_SUGGESTIONS = PersistedConfig( +DEFAULT_PROMPT_SUGGESTIONS = PersistentConfig( "DEFAULT_PROMPT_SUGGESTIONS", "ui.prompt_suggestions", [ @@ -505,7 +507,7 @@ DEFAULT_PROMPT_SUGGESTIONS = PersistedConfig( ], ) -DEFAULT_USER_ROLE = PersistedConfig( +DEFAULT_USER_ROLE = PersistentConfig( "DEFAULT_USER_ROLE", "ui.default_user_role", os.getenv("DEFAULT_USER_ROLE", "pending"), @@ -515,25 +517,25 @@ USER_PERMISSIONS_CHAT_DELETION = ( os.environ.get("USER_PERMISSIONS_CHAT_DELETION", "True").lower() == "true" ) -USER_PERMISSIONS = PersistedConfig( +USER_PERMISSIONS = PersistentConfig( "USER_PERMISSIONS", "ui.user_permissions", {"chat": {"deletion": USER_PERMISSIONS_CHAT_DELETION}}, ) -ENABLE_MODEL_FILTER = PersistedConfig( +ENABLE_MODEL_FILTER = PersistentConfig( "ENABLE_MODEL_FILTER", "model_filter.enable", os.environ.get("ENABLE_MODEL_FILTER", "False").lower() == "true", ) MODEL_FILTER_LIST = os.environ.get("MODEL_FILTER_LIST", "") -MODEL_FILTER_LIST = PersistedConfig( +MODEL_FILTER_LIST = PersistentConfig( "MODEL_FILTER_LIST", "model_filter.list", [model.strip() for model in MODEL_FILTER_LIST.split(";")], ) -WEBHOOK_URL = PersistedConfig( +WEBHOOK_URL = PersistentConfig( "WEBHOOK_URL", "webhook_url", os.environ.get("WEBHOOK_URL", "") ) @@ -573,40 +575,40 @@ else: CHROMA_HTTP_SSL = os.environ.get("CHROMA_HTTP_SSL", "false").lower() == "true" # this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (sentence-transformers/all-MiniLM-L6-v2) -RAG_TOP_K = PersistedConfig( +RAG_TOP_K = PersistentConfig( "RAG_TOP_K", "rag.top_k", int(os.environ.get("RAG_TOP_K", "5")) ) -RAG_RELEVANCE_THRESHOLD = PersistedConfig( +RAG_RELEVANCE_THRESHOLD = PersistentConfig( "RAG_RELEVANCE_THRESHOLD", "rag.relevance_threshold", float(os.environ.get("RAG_RELEVANCE_THRESHOLD", "0.0")), ) -ENABLE_RAG_HYBRID_SEARCH = PersistedConfig( +ENABLE_RAG_HYBRID_SEARCH = PersistentConfig( "ENABLE_RAG_HYBRID_SEARCH", "rag.enable_hybrid_search", os.environ.get("ENABLE_RAG_HYBRID_SEARCH", "").lower() == "true", ) -ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = PersistedConfig( +ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = PersistentConfig( "ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION", "rag.enable_web_loader_ssl_verification", os.environ.get("ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION", "True").lower() == "true", ) -RAG_EMBEDDING_ENGINE = PersistedConfig( +RAG_EMBEDDING_ENGINE = PersistentConfig( "RAG_EMBEDDING_ENGINE", "rag.embedding_engine", os.environ.get("RAG_EMBEDDING_ENGINE", ""), ) -PDF_EXTRACT_IMAGES = PersistedConfig( +PDF_EXTRACT_IMAGES = PersistentConfig( "PDF_EXTRACT_IMAGES", "rag.pdf_extract_images", os.environ.get("PDF_EXTRACT_IMAGES", "False").lower() == "true", ) -RAG_EMBEDDING_MODEL = PersistedConfig( +RAG_EMBEDDING_MODEL = PersistentConfig( "RAG_EMBEDDING_MODEL", "rag.embedding_model", os.environ.get("RAG_EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2"), @@ -621,7 +623,7 @@ RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE = ( os.environ.get("RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true" ) -RAG_RERANKING_MODEL = PersistedConfig( +RAG_RERANKING_MODEL = PersistentConfig( "RAG_RERANKING_MODEL", "rag.reranking_model", os.environ.get("RAG_RERANKING_MODEL", ""), @@ -665,10 +667,10 @@ if USE_CUDA.lower() == "true": else: DEVICE_TYPE = "cpu" -CHUNK_SIZE = PersistedConfig( +CHUNK_SIZE = PersistentConfig( "CHUNK_SIZE", "rag.chunk_size", int(os.environ.get("CHUNK_SIZE", "1500")) ) -CHUNK_OVERLAP = PersistedConfig( +CHUNK_OVERLAP = PersistentConfig( "CHUNK_OVERLAP", "rag.chunk_overlap", int(os.environ.get("CHUNK_OVERLAP", "100")), @@ -688,18 +690,18 @@ And answer according to the language of the user's question. Given the context information, answer the query. Query: [query]""" -RAG_TEMPLATE = PersistedConfig( +RAG_TEMPLATE = PersistentConfig( "RAG_TEMPLATE", "rag.template", os.environ.get("RAG_TEMPLATE", DEFAULT_RAG_TEMPLATE), ) -RAG_OPENAI_API_BASE_URL = PersistedConfig( +RAG_OPENAI_API_BASE_URL = PersistentConfig( "RAG_OPENAI_API_BASE_URL", "rag.openai_api_base_url", os.getenv("RAG_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL), ) -RAG_OPENAI_API_KEY = PersistedConfig( +RAG_OPENAI_API_KEY = PersistentConfig( "RAG_OPENAI_API_KEY", "rag.openai_api_key", os.getenv("RAG_OPENAI_API_KEY", OPENAI_API_KEY), @@ -709,7 +711,7 @@ ENABLE_RAG_LOCAL_WEB_FETCH = ( os.getenv("ENABLE_RAG_LOCAL_WEB_FETCH", "False").lower() == "true" ) -YOUTUBE_LOADER_LANGUAGE = PersistedConfig( +YOUTUBE_LOADER_LANGUAGE = PersistentConfig( "YOUTUBE_LOADER_LANGUAGE", "rag.youtube_loader_language", os.getenv("YOUTUBE_LOADER_LANGUAGE", "en").split(","), @@ -730,49 +732,49 @@ WHISPER_MODEL_AUTO_UPDATE = ( # Images #################################### -IMAGE_GENERATION_ENGINE = PersistedConfig( +IMAGE_GENERATION_ENGINE = PersistentConfig( "IMAGE_GENERATION_ENGINE", "image_generation.engine", os.getenv("IMAGE_GENERATION_ENGINE", ""), ) -ENABLE_IMAGE_GENERATION = PersistedConfig( +ENABLE_IMAGE_GENERATION = PersistentConfig( "ENABLE_IMAGE_GENERATION", "image_generation.enable", os.environ.get("ENABLE_IMAGE_GENERATION", "").lower() == "true", ) -AUTOMATIC1111_BASE_URL = PersistedConfig( +AUTOMATIC1111_BASE_URL = PersistentConfig( "AUTOMATIC1111_BASE_URL", "image_generation.automatic1111.base_url", os.getenv("AUTOMATIC1111_BASE_URL", ""), ) -COMFYUI_BASE_URL = PersistedConfig( +COMFYUI_BASE_URL = PersistentConfig( "COMFYUI_BASE_URL", "image_generation.comfyui.base_url", os.getenv("COMFYUI_BASE_URL", ""), ) -IMAGES_OPENAI_API_BASE_URL = PersistedConfig( +IMAGES_OPENAI_API_BASE_URL = PersistentConfig( "IMAGES_OPENAI_API_BASE_URL", "image_generation.openai.api_base_url", os.getenv("IMAGES_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL), ) -IMAGES_OPENAI_API_KEY = PersistedConfig( +IMAGES_OPENAI_API_KEY = PersistentConfig( "IMAGES_OPENAI_API_KEY", "image_generation.openai.api_key", os.getenv("IMAGES_OPENAI_API_KEY", OPENAI_API_KEY), ) -IMAGE_SIZE = PersistedConfig( +IMAGE_SIZE = PersistentConfig( "IMAGE_SIZE", "image_generation.size", os.getenv("IMAGE_SIZE", "512x512") ) -IMAGE_STEPS = PersistedConfig( +IMAGE_STEPS = PersistentConfig( "IMAGE_STEPS", "image_generation.steps", int(os.getenv("IMAGE_STEPS", 50)) ) -IMAGE_GENERATION_MODEL = PersistedConfig( +IMAGE_GENERATION_MODEL = PersistentConfig( "IMAGE_GENERATION_MODEL", "image_generation.model", os.getenv("IMAGE_GENERATION_MODEL", ""), @@ -782,22 +784,22 @@ IMAGE_GENERATION_MODEL = PersistedConfig( # Audio #################################### -AUDIO_OPENAI_API_BASE_URL = PersistedConfig( +AUDIO_OPENAI_API_BASE_URL = PersistentConfig( "AUDIO_OPENAI_API_BASE_URL", "audio.openai.api_base_url", os.getenv("AUDIO_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL), ) -AUDIO_OPENAI_API_KEY = PersistedConfig( +AUDIO_OPENAI_API_KEY = PersistentConfig( "AUDIO_OPENAI_API_KEY", "audio.openai.api_key", os.getenv("AUDIO_OPENAI_API_KEY", OPENAI_API_KEY), ) -AUDIO_OPENAI_API_MODEL = PersistedConfig( +AUDIO_OPENAI_API_MODEL = PersistentConfig( "AUDIO_OPENAI_API_MODEL", "audio.openai.api_model", os.getenv("AUDIO_OPENAI_API_MODEL", "tts-1"), ) -AUDIO_OPENAI_API_VOICE = PersistedConfig( +AUDIO_OPENAI_API_VOICE = PersistentConfig( "AUDIO_OPENAI_API_VOICE", "audio.openai.api_voice", os.getenv("AUDIO_OPENAI_API_VOICE", "alloy"),