From 83e5db7be74879cdc97460cb7ef2763543e744d9 Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Wed, 12 Feb 2025 23:26:47 -0800 Subject: [PATCH] refac --- backend/open_webui/routers/tasks.py | 14 +++++++------- backend/open_webui/utils/chat.py | 6 +++--- backend/open_webui/utils/middleware.py | 2 +- src/routes/+layout.svelte | 5 +++++ 4 files changed, 16 insertions(+), 11 deletions(-) diff --git a/backend/open_webui/routers/tasks.py b/backend/open_webui/routers/tasks.py index c885d764b..3fcca0e07 100644 --- a/backend/open_webui/routers/tasks.py +++ b/backend/open_webui/routers/tasks.py @@ -139,7 +139,7 @@ async def update_task_config( async def generate_title( request: Request, form_data: dict, user=Depends(get_verified_user) ): - if request.state.direct and request.state.model: + if getattr(request.state, "direct", False) and hasattr(request.state, "model"): models = { request.state.model["id"]: request.state.model, } @@ -231,7 +231,7 @@ async def generate_chat_tags( content={"detail": "Tags generation is disabled"}, ) - if request.state.direct and request.state.model: + if getattr(request.state, "direct", False) and hasattr(request.state, "model"): models = { request.state.model["id"]: request.state.model, } @@ -293,7 +293,7 @@ async def generate_chat_tags( async def generate_image_prompt( request: Request, form_data: dict, user=Depends(get_verified_user) ): - if request.state.direct and request.state.model: + if getattr(request.state, "direct", False) and hasattr(request.state, "model"): models = { request.state.model["id"]: request.state.model, } @@ -374,7 +374,7 @@ async def generate_queries( detail=f"Query generation is disabled", ) - if request.state.direct and request.state.model: + if getattr(request.state, "direct", False) and hasattr(request.state, "model"): models = { request.state.model["id"]: request.state.model, } @@ -455,7 +455,7 @@ async def generate_autocompletion( detail=f"Input prompt exceeds maximum length of {request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH}", ) - if request.state.direct and request.state.model: + if getattr(request.state, "direct", False) and hasattr(request.state, "model"): models = { request.state.model["id"]: request.state.model, } @@ -518,7 +518,7 @@ async def generate_emoji( request: Request, form_data: dict, user=Depends(get_verified_user) ): - if request.state.direct and request.state.model: + if getattr(request.state, "direct", False) and hasattr(request.state, "model"): models = { request.state.model["id"]: request.state.model, } @@ -587,7 +587,7 @@ async def generate_moa_response( request: Request, form_data: dict, user=Depends(get_verified_user) ): - if request.state.direct and request.state.model: + if getattr(request.state, "direct", False) and hasattr(request.state, "model"): models = { request.state.model["id"]: request.state.model, } diff --git a/backend/open_webui/utils/chat.py b/backend/open_webui/utils/chat.py index 72f8eafe3..03b5d589c 100644 --- a/backend/open_webui/utils/chat.py +++ b/backend/open_webui/utils/chat.py @@ -165,7 +165,7 @@ async def generate_chat_completion( if BYPASS_MODEL_ACCESS_CONTROL: bypass_filter = True - if request.state.direct and request.state.model: + if getattr(request.state, "direct", False) and hasattr(request.state, "model"): models = { request.state.model["id"]: request.state.model, } @@ -284,7 +284,7 @@ async def chat_completed(request: Request, form_data: dict, user: Any): if not request.app.state.MODELS: await get_all_models(request) - if request.state.direct and request.state.model: + if getattr(request.state, "direct", False) and hasattr(request.state, "model"): models = { request.state.model["id"]: request.state.model, } @@ -350,7 +350,7 @@ async def chat_action(request: Request, action_id: str, form_data: dict, user: A if not request.app.state.MODELS: await get_all_models(request) - if request.state.direct and request.state.model: + if getattr(request.state, "direct", False) and hasattr(request.state, "model"): models = { request.state.model["id"]: request.state.model, } diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index e630af687..719599001 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -622,7 +622,7 @@ async def process_chat_payload(request, form_data, metadata, user, model): # Initialize events to store additional event to be sent to the client # Initialize contexts and citation - if request.state.direct and request.state.model: + if getattr(request.state, "direct", False) and hasattr(request.state, "model"): models = { request.state.model["id"]: request.state.model, } diff --git a/src/routes/+layout.svelte b/src/routes/+layout.svelte index bef581e74..cae8fba3a 100644 --- a/src/routes/+layout.svelte +++ b/src/routes/+layout.svelte @@ -273,6 +273,11 @@ const API_CONFIG = directConnections.OPENAI_API_CONFIGS[urlIdx]; try { + if (API_CONFIG?.prefix_id) { + const prefixId = API_CONFIG.prefix_id; + form_data['model'] = form_data['model'].replace(`${prefixId}.`, ``); + } + const [res, controller] = await chatCompletion( OPENAI_API_KEY, form_data,