From f6bec8d9f3c0c503c0c0d67ac5f12ca70edc1856 Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Tue, 10 Dec 2024 00:00:01 -0800 Subject: [PATCH 01/26] general refac --- backend/open_webui/main.py | 581 +---------------------- backend/open_webui/routers/chat.py | 0 backend/open_webui/routers/pipelines.py | 99 ++++ backend/open_webui/routers/tasks.py | 596 ++++++++++++++++++++++++ backend/open_webui/utils/logo.png | Bin 6161 -> 0 bytes 5 files changed, 700 insertions(+), 576 deletions(-) create mode 100644 backend/open_webui/routers/chat.py create mode 100644 backend/open_webui/routers/pipelines.py create mode 100644 backend/open_webui/routers/tasks.py delete mode 100644 backend/open_webui/utils/logo.png diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 253a7a165..5ab820981 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -130,12 +130,6 @@ from open_webui.utils.response import ( from open_webui.utils.security_headers import SecurityHeadersMiddleware from open_webui.utils.task import ( rag_template, - title_generation_template, - query_generation_template, - autocomplete_generation_template, - tags_generation_template, - emoji_generation_template, - moa_response_generation_template, tools_function_calling_generation_template, ) from open_webui.utils.tools import get_tools @@ -1263,12 +1257,15 @@ async def get_base_models(user=Depends(get_admin_user)): @app.post("/api/chat/completions") async def generate_chat_completions( - form_data: dict, user=Depends(get_verified_user), bypass_filter: bool = False + request: Request, + form_data: dict, + user=Depends(get_verified_user), + bypass_filter: bool = False, ): if BYPASS_MODEL_ACCESS_CONTROL: bypass_filter = True - model_list = await get_all_models() + model_list = request.state.models models = {model["id"]: model for model in model_list} model_id = form_data["model"] @@ -1665,574 +1662,6 @@ async def chat_action(action_id: str, form_data: dict, user=Depends(get_verified return data -################################## -# -# Task Endpoints -# -################################## - - -# TODO: Refactor task API endpoints below into a separate file - - -@app.get("/api/task/config") -async def get_task_config(user=Depends(get_verified_user)): - return { - "TASK_MODEL": app.state.config.TASK_MODEL, - "TASK_MODEL_EXTERNAL": app.state.config.TASK_MODEL_EXTERNAL, - "TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE, - "ENABLE_AUTOCOMPLETE_GENERATION": app.state.config.ENABLE_AUTOCOMPLETE_GENERATION, - "AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH, - "TAGS_GENERATION_PROMPT_TEMPLATE": app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE, - "ENABLE_TAGS_GENERATION": app.state.config.ENABLE_TAGS_GENERATION, - "ENABLE_SEARCH_QUERY_GENERATION": app.state.config.ENABLE_SEARCH_QUERY_GENERATION, - "ENABLE_RETRIEVAL_QUERY_GENERATION": app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION, - "QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE, - "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, - } - - -class TaskConfigForm(BaseModel): - TASK_MODEL: Optional[str] - TASK_MODEL_EXTERNAL: Optional[str] - TITLE_GENERATION_PROMPT_TEMPLATE: str - ENABLE_AUTOCOMPLETE_GENERATION: bool - AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH: int - TAGS_GENERATION_PROMPT_TEMPLATE: str - ENABLE_TAGS_GENERATION: bool - ENABLE_SEARCH_QUERY_GENERATION: bool - ENABLE_RETRIEVAL_QUERY_GENERATION: bool - QUERY_GENERATION_PROMPT_TEMPLATE: str - TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE: str - - -@app.post("/api/task/config/update") -async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_user)): - app.state.config.TASK_MODEL = form_data.TASK_MODEL - app.state.config.TASK_MODEL_EXTERNAL = form_data.TASK_MODEL_EXTERNAL - app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = ( - form_data.TITLE_GENERATION_PROMPT_TEMPLATE - ) - - app.state.config.ENABLE_AUTOCOMPLETE_GENERATION = ( - form_data.ENABLE_AUTOCOMPLETE_GENERATION - ) - app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH = ( - form_data.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH - ) - - app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = ( - form_data.TAGS_GENERATION_PROMPT_TEMPLATE - ) - app.state.config.ENABLE_TAGS_GENERATION = form_data.ENABLE_TAGS_GENERATION - app.state.config.ENABLE_SEARCH_QUERY_GENERATION = ( - form_data.ENABLE_SEARCH_QUERY_GENERATION - ) - app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION = ( - form_data.ENABLE_RETRIEVAL_QUERY_GENERATION - ) - - app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE = ( - form_data.QUERY_GENERATION_PROMPT_TEMPLATE - ) - app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = ( - form_data.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE - ) - - return { - "TASK_MODEL": app.state.config.TASK_MODEL, - "TASK_MODEL_EXTERNAL": app.state.config.TASK_MODEL_EXTERNAL, - "TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE, - "ENABLE_AUTOCOMPLETE_GENERATION": app.state.config.ENABLE_AUTOCOMPLETE_GENERATION, - "AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH, - "TAGS_GENERATION_PROMPT_TEMPLATE": app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE, - "ENABLE_TAGS_GENERATION": app.state.config.ENABLE_TAGS_GENERATION, - "ENABLE_SEARCH_QUERY_GENERATION": app.state.config.ENABLE_SEARCH_QUERY_GENERATION, - "ENABLE_RETRIEVAL_QUERY_GENERATION": app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION, - "QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE, - "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, - } - - -@app.post("/api/task/title/completions") -async def generate_title(form_data: dict, user=Depends(get_verified_user)): - - model_list = await get_all_models() - models = {model["id"]: model for model in model_list} - - model_id = form_data["model"] - if model_id not in models: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Model not found", - ) - - # Check if the user has a custom task model - # If the user has a custom task model, use that model - task_model_id = get_task_model_id( - model_id, - app.state.config.TASK_MODEL, - app.state.config.TASK_MODEL_EXTERNAL, - models, - ) - - log.debug( - f"generating chat title using model {task_model_id} for user {user.email} " - ) - - if app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE != "": - template = app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE - else: - template = """Create a concise, 3-5 word title with an emoji as a title for the chat history, in the given language. Suitable Emojis for the summary can be used to enhance understanding but avoid quotation marks or special formatting. RESPOND ONLY WITH THE TITLE TEXT. - -Examples of titles: -📉 Stock Market Trends -🍪 Perfect Chocolate Chip Recipe -Evolution of Music Streaming -Remote Work Productivity Tips -Artificial Intelligence in Healthcare -🎮 Video Game Development Insights - - -{{MESSAGES:END:2}} -""" - - content = title_generation_template( - template, - form_data["messages"], - { - "name": user.name, - "location": user.info.get("location") if user.info else None, - }, - ) - - payload = { - "model": task_model_id, - "messages": [{"role": "user", "content": content}], - "stream": False, - **( - {"max_tokens": 50} - if models[task_model_id]["owned_by"] == "ollama" - else { - "max_completion_tokens": 50, - } - ), - "metadata": { - "task": str(TASKS.TITLE_GENERATION), - "task_body": form_data, - "chat_id": form_data.get("chat_id", None), - }, - } - - # Handle pipeline filters - try: - payload = filter_pipeline(payload, user, models) - except Exception as e: - if len(e.args) > 1: - return JSONResponse( - status_code=e.args[0], - content={"detail": e.args[1]}, - ) - else: - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content={"detail": str(e)}, - ) - if "chat_id" in payload: - del payload["chat_id"] - - return await generate_chat_completions(form_data=payload, user=user) - - -@app.post("/api/task/tags/completions") -async def generate_chat_tags(form_data: dict, user=Depends(get_verified_user)): - - if not app.state.config.ENABLE_TAGS_GENERATION: - return JSONResponse( - status_code=status.HTTP_200_OK, - content={"detail": "Tags generation is disabled"}, - ) - - model_list = await get_all_models() - models = {model["id"]: model for model in model_list} - - model_id = form_data["model"] - if model_id not in models: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Model not found", - ) - - # Check if the user has a custom task model - # If the user has a custom task model, use that model - task_model_id = get_task_model_id( - model_id, - app.state.config.TASK_MODEL, - app.state.config.TASK_MODEL_EXTERNAL, - models, - ) - - log.debug( - f"generating chat tags using model {task_model_id} for user {user.email} " - ) - - if app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE != "": - template = app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE - else: - template = """### Task: -Generate 1-3 broad tags categorizing the main themes of the chat history, along with 1-3 more specific subtopic tags. - -### Guidelines: -- Start with high-level domains (e.g. Science, Technology, Philosophy, Arts, Politics, Business, Health, Sports, Entertainment, Education) -- Consider including relevant subfields/subdomains if they are strongly represented throughout the conversation -- If content is too short (less than 3 messages) or too diverse, use only ["General"] -- Use the chat's primary language; default to English if multilingual -- Prioritize accuracy over specificity - -### Output: -JSON format: { "tags": ["tag1", "tag2", "tag3"] } - -### Chat History: - -{{MESSAGES:END:6}} -""" - - content = tags_generation_template( - template, form_data["messages"], {"name": user.name} - ) - - payload = { - "model": task_model_id, - "messages": [{"role": "user", "content": content}], - "stream": False, - "metadata": { - "task": str(TASKS.TAGS_GENERATION), - "task_body": form_data, - "chat_id": form_data.get("chat_id", None), - }, - } - - # Handle pipeline filters - try: - payload = filter_pipeline(payload, user, models) - except Exception as e: - if len(e.args) > 1: - return JSONResponse( - status_code=e.args[0], - content={"detail": e.args[1]}, - ) - else: - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content={"detail": str(e)}, - ) - if "chat_id" in payload: - del payload["chat_id"] - - return await generate_chat_completions(form_data=payload, user=user) - - -@app.post("/api/task/queries/completions") -async def generate_queries(form_data: dict, user=Depends(get_verified_user)): - - type = form_data.get("type") - if type == "web_search": - if not app.state.config.ENABLE_SEARCH_QUERY_GENERATION: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Search query generation is disabled", - ) - elif type == "retrieval": - if not app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Query generation is disabled", - ) - - model_list = await get_all_models() - models = {model["id"]: model for model in model_list} - - model_id = form_data["model"] - if model_id not in models: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Model not found", - ) - - # Check if the user has a custom task model - # If the user has a custom task model, use that model - task_model_id = get_task_model_id( - model_id, - app.state.config.TASK_MODEL, - app.state.config.TASK_MODEL_EXTERNAL, - models, - ) - - log.debug( - f"generating {type} queries using model {task_model_id} for user {user.email}" - ) - - if (app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE).strip() != "": - template = app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE - else: - template = DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE - - content = query_generation_template( - template, form_data["messages"], {"name": user.name} - ) - - payload = { - "model": task_model_id, - "messages": [{"role": "user", "content": content}], - "stream": False, - "metadata": { - "task": str(TASKS.QUERY_GENERATION), - "task_body": form_data, - "chat_id": form_data.get("chat_id", None), - }, - } - - # Handle pipeline filters - try: - payload = filter_pipeline(payload, user, models) - except Exception as e: - if len(e.args) > 1: - return JSONResponse( - status_code=e.args[0], - content={"detail": e.args[1]}, - ) - else: - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content={"detail": str(e)}, - ) - if "chat_id" in payload: - del payload["chat_id"] - - return await generate_chat_completions(form_data=payload, user=user) - - -@app.post("/api/task/auto/completions") -async def generate_autocompletion(form_data: dict, user=Depends(get_verified_user)): - if not app.state.config.ENABLE_AUTOCOMPLETE_GENERATION: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Autocompletion generation is disabled", - ) - - type = form_data.get("type") - prompt = form_data.get("prompt") - messages = form_data.get("messages") - - if app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH > 0: - if len(prompt) > app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Input prompt exceeds maximum length of {app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH}", - ) - - model_list = await get_all_models() - models = {model["id"]: model for model in model_list} - - model_id = form_data["model"] - if model_id not in models: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Model not found", - ) - - # Check if the user has a custom task model - # If the user has a custom task model, use that model - task_model_id = get_task_model_id( - model_id, - app.state.config.TASK_MODEL, - app.state.config.TASK_MODEL_EXTERNAL, - models, - ) - - log.debug( - f"generating autocompletion using model {task_model_id} for user {user.email}" - ) - - if (app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE).strip() != "": - template = app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE - else: - template = DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE - - content = autocomplete_generation_template( - template, prompt, messages, type, {"name": user.name} - ) - - payload = { - "model": task_model_id, - "messages": [{"role": "user", "content": content}], - "stream": False, - "metadata": { - "task": str(TASKS.AUTOCOMPLETE_GENERATION), - "task_body": form_data, - "chat_id": form_data.get("chat_id", None), - }, - } - - # Handle pipeline filters - try: - payload = filter_pipeline(payload, user, models) - except Exception as e: - if len(e.args) > 1: - return JSONResponse( - status_code=e.args[0], - content={"detail": e.args[1]}, - ) - else: - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content={"detail": str(e)}, - ) - if "chat_id" in payload: - del payload["chat_id"] - - return await generate_chat_completions(form_data=payload, user=user) - - -@app.post("/api/task/emoji/completions") -async def generate_emoji(form_data: dict, user=Depends(get_verified_user)): - - model_list = await get_all_models() - models = {model["id"]: model for model in model_list} - - model_id = form_data["model"] - if model_id not in models: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Model not found", - ) - - # Check if the user has a custom task model - # If the user has a custom task model, use that model - task_model_id = get_task_model_id( - model_id, - app.state.config.TASK_MODEL, - app.state.config.TASK_MODEL_EXTERNAL, - models, - ) - - log.debug(f"generating emoji using model {task_model_id} for user {user.email} ") - - template = ''' -Your task is to reflect the speaker's likely facial expression through a fitting emoji. Interpret emotions from the message and reflect their facial expression using fitting, diverse emojis (e.g., 😊, 😢, 😡, 😱). - -Message: """{{prompt}}""" -''' - content = emoji_generation_template( - template, - form_data["prompt"], - { - "name": user.name, - "location": user.info.get("location") if user.info else None, - }, - ) - - payload = { - "model": task_model_id, - "messages": [{"role": "user", "content": content}], - "stream": False, - **( - {"max_tokens": 4} - if models[task_model_id]["owned_by"] == "ollama" - else { - "max_completion_tokens": 4, - } - ), - "chat_id": form_data.get("chat_id", None), - "metadata": {"task": str(TASKS.EMOJI_GENERATION), "task_body": form_data}, - } - - # Handle pipeline filters - try: - payload = filter_pipeline(payload, user, models) - except Exception as e: - if len(e.args) > 1: - return JSONResponse( - status_code=e.args[0], - content={"detail": e.args[1]}, - ) - else: - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content={"detail": str(e)}, - ) - if "chat_id" in payload: - del payload["chat_id"] - - return await generate_chat_completions(form_data=payload, user=user) - - -@app.post("/api/task/moa/completions") -async def generate_moa_response(form_data: dict, user=Depends(get_verified_user)): - - model_list = await get_all_models() - models = {model["id"]: model for model in model_list} - - model_id = form_data["model"] - if model_id not in models: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Model not found", - ) - - # Check if the user has a custom task model - # If the user has a custom task model, use that model - task_model_id = get_task_model_id( - model_id, - app.state.config.TASK_MODEL, - app.state.config.TASK_MODEL_EXTERNAL, - models, - ) - - log.debug(f"generating MOA model {task_model_id} for user {user.email} ") - - template = """You have been provided with a set of responses from various models to the latest user query: "{{prompt}}" - -Your task is to synthesize these responses into a single, high-quality response. It is crucial to critically evaluate the information provided in these responses, recognizing that some of it may be biased or incorrect. Your response should not simply replicate the given answers but should offer a refined, accurate, and comprehensive reply to the instruction. Ensure your response is well-structured, coherent, and adheres to the highest standards of accuracy and reliability. - -Responses from models: {{responses}}""" - - content = moa_response_generation_template( - template, - form_data["prompt"], - form_data["responses"], - ) - - payload = { - "model": task_model_id, - "messages": [{"role": "user", "content": content}], - "stream": form_data.get("stream", False), - "chat_id": form_data.get("chat_id", None), - "metadata": { - "task": str(TASKS.MOA_RESPONSE_GENERATION), - "task_body": form_data, - }, - } - - try: - payload = filter_pipeline(payload, user, models) - except Exception as e: - if len(e.args) > 1: - return JSONResponse( - status_code=e.args[0], - content={"detail": e.args[1]}, - ) - else: - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content={"detail": str(e)}, - ) - if "chat_id" in payload: - del payload["chat_id"] - - return await generate_chat_completions(form_data=payload, user=user) - - ################################## # # Pipelines Endpoints diff --git a/backend/open_webui/routers/chat.py b/backend/open_webui/routers/chat.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/open_webui/routers/pipelines.py b/backend/open_webui/routers/pipelines.py new file mode 100644 index 000000000..0d9a32c83 --- /dev/null +++ b/backend/open_webui/routers/pipelines.py @@ -0,0 +1,99 @@ +from fastapi import APIRouter, Depends, HTTPException, Response, status +from pydantic import BaseModel +from starlette.responses import FileResponse + + +from open_webui.apps.webui.models.chats import ChatTitleMessagesForm +from open_webui.config import DATA_DIR, ENABLE_ADMIN_EXPORT +from open_webui.constants import ERROR_MESSAGES + +from open_webui.utils.misc import get_gravatar_url +from open_webui.utils.pdf_generator import PDFGenerator +from open_webui.utils.auth import get_admin_user + +router = APIRouter() + + +@router.get("/gravatar") +async def get_gravatar( + email: str, +): + return get_gravatar_url(email) + + +class CodeFormatRequest(BaseModel): + code: str + + +@router.post("/code/format") +async def format_code(request: CodeFormatRequest): + try: + formatted_code = black.format_str(request.code, mode=black.Mode()) + return {"code": formatted_code} + except black.NothingChanged: + return {"code": request.code} + except Exception as e: + raise HTTPException(status_code=400, detail=str(e)) + + +class MarkdownForm(BaseModel): + md: str + + +@router.post("/markdown") +async def get_html_from_markdown( + form_data: MarkdownForm, +): + return {"html": markdown.markdown(form_data.md)} + + +class ChatForm(BaseModel): + title: str + messages: list[dict] + + +@router.post("/pdf") +async def download_chat_as_pdf( + form_data: ChatTitleMessagesForm, +): + try: + pdf_bytes = PDFGenerator(form_data).generate_chat_pdf() + + return Response( + content=pdf_bytes, + media_type="application/pdf", + headers={"Content-Disposition": "attachment;filename=chat.pdf"}, + ) + except Exception as e: + print(e) + raise HTTPException(status_code=400, detail=str(e)) + + +@router.get("/db/download") +async def download_db(user=Depends(get_admin_user)): + if not ENABLE_ADMIN_EXPORT: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) + from open_webui.apps.webui.internal.db import engine + + if engine.name != "sqlite": + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DB_NOT_SQLITE, + ) + return FileResponse( + engine.url.database, + media_type="application/octet-stream", + filename="webui.db", + ) + + +@router.get("/litellm/config") +async def download_litellm_config_yaml(user=Depends(get_admin_user)): + return FileResponse( + f"{DATA_DIR}/litellm/config.yaml", + media_type="application/octet-stream", + filename="config.yaml", + ) diff --git a/backend/open_webui/routers/tasks.py b/backend/open_webui/routers/tasks.py new file mode 100644 index 000000000..c4ccc4700 --- /dev/null +++ b/backend/open_webui/routers/tasks.py @@ -0,0 +1,596 @@ +from fastapi import APIRouter, Depends, HTTPException, Response, status, Request +from pydantic import BaseModel +from starlette.responses import FileResponse +from typing import Optional + +from open_webui.utils.task import ( + title_generation_template, + query_generation_template, + autocomplete_generation_template, + tags_generation_template, + emoji_generation_template, + moa_response_generation_template, +) +from open_webui.utils.auth import get_admin_user, get_verified_user + +router = APIRouter() + + +################################## +# +# Task Endpoints +# +################################## + + +@router.get("/config") +async def get_task_config(request: Request, user=Depends(get_verified_user)): + return { + "TASK_MODEL": request.app.state.config.TASK_MODEL, + "TASK_MODEL_EXTERNAL": request.app.state.config.TASK_MODEL_EXTERNAL, + "TITLE_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE, + "ENABLE_AUTOCOMPLETE_GENERATION": request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION, + "AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH, + "TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE, + "ENABLE_TAGS_GENERATION": request.app.state.config.ENABLE_TAGS_GENERATION, + "ENABLE_SEARCH_QUERY_GENERATION": request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION, + "ENABLE_RETRIEVAL_QUERY_GENERATION": request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION, + "QUERY_GENERATION_PROMPT_TEMPLATE": request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE, + "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, + } + + +class TaskConfigForm(BaseModel): + TASK_MODEL: Optional[str] + TASK_MODEL_EXTERNAL: Optional[str] + TITLE_GENERATION_PROMPT_TEMPLATE: str + ENABLE_AUTOCOMPLETE_GENERATION: bool + AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH: int + TAGS_GENERATION_PROMPT_TEMPLATE: str + ENABLE_TAGS_GENERATION: bool + ENABLE_SEARCH_QUERY_GENERATION: bool + ENABLE_RETRIEVAL_QUERY_GENERATION: bool + QUERY_GENERATION_PROMPT_TEMPLATE: str + TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE: str + + +@router.post("/config/update") +async def update_task_config( + request: Request, form_data: TaskConfigForm, user=Depends(get_admin_user) +): + request.app.state.config.TASK_MODEL = form_data.TASK_MODEL + request.app.state.config.TASK_MODEL_EXTERNAL = form_data.TASK_MODEL_EXTERNAL + request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = ( + form_data.TITLE_GENERATION_PROMPT_TEMPLATE + ) + + request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION = ( + form_data.ENABLE_AUTOCOMPLETE_GENERATION + ) + request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH = ( + form_data.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH + ) + + request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = ( + form_data.TAGS_GENERATION_PROMPT_TEMPLATE + ) + request.app.state.config.ENABLE_TAGS_GENERATION = form_data.ENABLE_TAGS_GENERATION + request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION = ( + form_data.ENABLE_SEARCH_QUERY_GENERATION + ) + request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION = ( + form_data.ENABLE_RETRIEVAL_QUERY_GENERATION + ) + + request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE = ( + form_data.QUERY_GENERATION_PROMPT_TEMPLATE + ) + request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = ( + form_data.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE + ) + + return { + "TASK_MODEL": request.app.state.config.TASK_MODEL, + "TASK_MODEL_EXTERNAL": request.app.state.config.TASK_MODEL_EXTERNAL, + "TITLE_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE, + "ENABLE_AUTOCOMPLETE_GENERATION": request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION, + "AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH, + "TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE, + "ENABLE_TAGS_GENERATION": request.app.state.config.ENABLE_TAGS_GENERATION, + "ENABLE_SEARCH_QUERY_GENERATION": request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION, + "ENABLE_RETRIEVAL_QUERY_GENERATION": request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION, + "QUERY_GENERATION_PROMPT_TEMPLATE": request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE, + "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, + } + + +@router.post("/title/completions") +async def generate_title( + request: Request, form_data: dict, user=Depends(get_verified_user) +): + + model_list = await get_all_models() + models = {model["id"]: model for model in model_list} + + model_id = form_data["model"] + if model_id not in models: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Model not found", + ) + + # Check if the user has a custom task model + # If the user has a custom task model, use that model + task_model_id = get_task_model_id( + model_id, + request.app.state.config.TASK_MODEL, + request.app.state.config.TASK_MODEL_EXTERNAL, + models, + ) + + log.debug( + f"generating chat title using model {task_model_id} for user {user.email} " + ) + + if request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE != "": + template = request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE + else: + template = """Create a concise, 3-5 word title with an emoji as a title for the chat history, in the given language. Suitable Emojis for the summary can be used to enhance understanding but avoid quotation marks or special formatting. RESPOND ONLY WITH THE TITLE TEXT. + +Examples of titles: +📉 Stock Market Trends +🍪 Perfect Chocolate Chip Recipe +Evolution of Music Streaming +Remote Work Productivity Tips +Artificial Intelligence in Healthcare +🎮 Video Game Development Insights + + +{{MESSAGES:END:2}} +""" + + content = title_generation_template( + template, + form_data["messages"], + { + "name": user.name, + "location": user.info.get("location") if user.info else None, + }, + ) + + payload = { + "model": task_model_id, + "messages": [{"role": "user", "content": content}], + "stream": False, + **( + {"max_tokens": 50} + if models[task_model_id]["owned_by"] == "ollama" + else { + "max_completion_tokens": 50, + } + ), + "metadata": { + "task": str(TASKS.TITLE_GENERATION), + "task_body": form_data, + "chat_id": form_data.get("chat_id", None), + }, + } + + # Handle pipeline filters + try: + payload = filter_pipeline(payload, user, models) + except Exception as e: + if len(e.args) > 1: + return JSONResponse( + status_code=e.args[0], + content={"detail": e.args[1]}, + ) + else: + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"detail": str(e)}, + ) + if "chat_id" in payload: + del payload["chat_id"] + + return await generate_chat_completions(form_data=payload, user=user) + + +@router.post("/tags/completions") +async def generate_chat_tags(form_data: dict, user=Depends(get_verified_user)): + + if not request.app.state.config.ENABLE_TAGS_GENERATION: + return JSONResponse( + status_code=status.HTTP_200_OK, + content={"detail": "Tags generation is disabled"}, + ) + + model_list = await get_all_models() + models = {model["id"]: model for model in model_list} + + model_id = form_data["model"] + if model_id not in models: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Model not found", + ) + + # Check if the user has a custom task model + # If the user has a custom task model, use that model + task_model_id = get_task_model_id( + model_id, + request.app.state.config.TASK_MODEL, + request.app.state.config.TASK_MODEL_EXTERNAL, + models, + ) + + log.debug( + f"generating chat tags using model {task_model_id} for user {user.email} " + ) + + if request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE != "": + template = request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE + else: + template = """### Task: +Generate 1-3 broad tags categorizing the main themes of the chat history, along with 1-3 more specific subtopic tags. + +### Guidelines: +- Start with high-level domains (e.g. Science, Technology, Philosophy, Arts, Politics, Business, Health, Sports, Entertainment, Education) +- Consider including relevant subfields/subdomains if they are strongly represented throughout the conversation +- If content is too short (less than 3 messages) or too diverse, use only ["General"] +- Use the chat's primary language; default to English if multilingual +- Prioritize accuracy over specificity + +### Output: +JSON format: { "tags": ["tag1", "tag2", "tag3"] } + +### Chat History: + +{{MESSAGES:END:6}} +""" + + content = tags_generation_template( + template, form_data["messages"], {"name": user.name} + ) + + payload = { + "model": task_model_id, + "messages": [{"role": "user", "content": content}], + "stream": False, + "metadata": { + "task": str(TASKS.TAGS_GENERATION), + "task_body": form_data, + "chat_id": form_data.get("chat_id", None), + }, + } + + # Handle pipeline filters + try: + payload = filter_pipeline(payload, user, models) + except Exception as e: + if len(e.args) > 1: + return JSONResponse( + status_code=e.args[0], + content={"detail": e.args[1]}, + ) + else: + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"detail": str(e)}, + ) + if "chat_id" in payload: + del payload["chat_id"] + + return await generate_chat_completions(form_data=payload, user=user) + + +@router.post("/queries/completions") +async def generate_queries( + request: Request, form_data: dict, user=Depends(get_verified_user) +): + + type = form_data.get("type") + if type == "web_search": + if not request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Search query generation is disabled", + ) + elif type == "retrieval": + if not request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Query generation is disabled", + ) + + model_list = await get_all_models() + models = {model["id"]: model for model in model_list} + + model_id = form_data["model"] + if model_id not in models: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Model not found", + ) + + # Check if the user has a custom task model + # If the user has a custom task model, use that model + task_model_id = get_task_model_id( + model_id, + request.app.state.config.TASK_MODEL, + request.app.state.config.TASK_MODEL_EXTERNAL, + models, + ) + + log.debug( + f"generating {type} queries using model {task_model_id} for user {user.email}" + ) + + if (request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE).strip() != "": + template = request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE + else: + template = DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE + + content = query_generation_template( + template, form_data["messages"], {"name": user.name} + ) + + payload = { + "model": task_model_id, + "messages": [{"role": "user", "content": content}], + "stream": False, + "metadata": { + "task": str(TASKS.QUERY_GENERATION), + "task_body": form_data, + "chat_id": form_data.get("chat_id", None), + }, + } + + # Handle pipeline filters + try: + payload = filter_pipeline(payload, user, models) + except Exception as e: + if len(e.args) > 1: + return JSONResponse( + status_code=e.args[0], + content={"detail": e.args[1]}, + ) + else: + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"detail": str(e)}, + ) + if "chat_id" in payload: + del payload["chat_id"] + + return await generate_chat_completions(form_data=payload, user=user) + + +@router.post("/auto/completions") +async def generate_autocompletion( + request: Request, form_data: dict, user=Depends(get_verified_user) +): + if not request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Autocompletion generation is disabled", + ) + + type = form_data.get("type") + prompt = form_data.get("prompt") + messages = form_data.get("messages") + + if request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH > 0: + if ( + len(prompt) + > request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH + ): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Input prompt exceeds maximum length of {request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH}", + ) + + model_list = await get_all_models() + models = {model["id"]: model for model in model_list} + + model_id = form_data["model"] + if model_id not in models: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Model not found", + ) + + # Check if the user has a custom task model + # If the user has a custom task model, use that model + task_model_id = get_task_model_id( + model_id, + request.app.state.config.TASK_MODEL, + request.app.state.config.TASK_MODEL_EXTERNAL, + models, + ) + + log.debug( + f"generating autocompletion using model {task_model_id} for user {user.email}" + ) + + if (request.app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE).strip() != "": + template = request.app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE + else: + template = DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE + + content = autocomplete_generation_template( + template, prompt, messages, type, {"name": user.name} + ) + + payload = { + "model": task_model_id, + "messages": [{"role": "user", "content": content}], + "stream": False, + "metadata": { + "task": str(TASKS.AUTOCOMPLETE_GENERATION), + "task_body": form_data, + "chat_id": form_data.get("chat_id", None), + }, + } + + # Handle pipeline filters + try: + payload = filter_pipeline(payload, user, models) + except Exception as e: + if len(e.args) > 1: + return JSONResponse( + status_code=e.args[0], + content={"detail": e.args[1]}, + ) + else: + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"detail": str(e)}, + ) + if "chat_id" in payload: + del payload["chat_id"] + + return await generate_chat_completions(form_data=payload, user=user) + + +@router.post("/emoji/completions") +async def generate_emoji( + request: Request, form_data: dict, user=Depends(get_verified_user) +): + + model_list = await get_all_models() + models = {model["id"]: model for model in model_list} + + model_id = form_data["model"] + if model_id not in models: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Model not found", + ) + + # Check if the user has a custom task model + # If the user has a custom task model, use that model + task_model_id = get_task_model_id( + model_id, + request.app.state.config.TASK_MODEL, + request.app.state.config.TASK_MODEL_EXTERNAL, + models, + ) + + log.debug(f"generating emoji using model {task_model_id} for user {user.email} ") + + template = ''' +Your task is to reflect the speaker's likely facial expression through a fitting emoji. Interpret emotions from the message and reflect their facial expression using fitting, diverse emojis (e.g., 😊, 😢, 😡, 😱). + +Message: """{{prompt}}""" +''' + content = emoji_generation_template( + template, + form_data["prompt"], + { + "name": user.name, + "location": user.info.get("location") if user.info else None, + }, + ) + + payload = { + "model": task_model_id, + "messages": [{"role": "user", "content": content}], + "stream": False, + **( + {"max_tokens": 4} + if models[task_model_id]["owned_by"] == "ollama" + else { + "max_completion_tokens": 4, + } + ), + "chat_id": form_data.get("chat_id", None), + "metadata": {"task": str(TASKS.EMOJI_GENERATION), "task_body": form_data}, + } + + # Handle pipeline filters + try: + payload = filter_pipeline(payload, user, models) + except Exception as e: + if len(e.args) > 1: + return JSONResponse( + status_code=e.args[0], + content={"detail": e.args[1]}, + ) + else: + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"detail": str(e)}, + ) + if "chat_id" in payload: + del payload["chat_id"] + + return await generate_chat_completions(form_data=payload, user=user) + + +@router.post("/moa/completions") +async def generate_moa_response( + request: Request, form_data: dict, user=Depends(get_verified_user) +): + + model_list = await get_all_models() + models = {model["id"]: model for model in model_list} + + model_id = form_data["model"] + if model_id not in models: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Model not found", + ) + + # Check if the user has a custom task model + # If the user has a custom task model, use that model + task_model_id = get_task_model_id( + model_id, + request.app.state.config.TASK_MODEL, + request.app.state.config.TASK_MODEL_EXTERNAL, + models, + ) + + log.debug(f"generating MOA model {task_model_id} for user {user.email} ") + + template = """You have been provided with a set of responses from various models to the latest user query: "{{prompt}}" + +Your task is to synthesize these responses into a single, high-quality response. It is crucial to critically evaluate the information provided in these responses, recognizing that some of it may be biased or incorrect. Your response should not simply replicate the given answers but should offer a refined, accurate, and comprehensive reply to the instruction. Ensure your response is well-structured, coherent, and adheres to the highest standards of accuracy and reliability. + +Responses from models: {{responses}}""" + + content = moa_response_generation_template( + template, + form_data["prompt"], + form_data["responses"], + ) + + payload = { + "model": task_model_id, + "messages": [{"role": "user", "content": content}], + "stream": form_data.get("stream", False), + "chat_id": form_data.get("chat_id", None), + "metadata": { + "task": str(TASKS.MOA_RESPONSE_GENERATION), + "task_body": form_data, + }, + } + + try: + payload = filter_pipeline(payload, user, models) + except Exception as e: + if len(e.args) > 1: + return JSONResponse( + status_code=e.args[0], + content={"detail": e.args[1]}, + ) + else: + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"detail": str(e)}, + ) + if "chat_id" in payload: + del payload["chat_id"] + + return await generate_chat_completions(form_data=payload, user=user) diff --git a/backend/open_webui/utils/logo.png b/backend/open_webui/utils/logo.png deleted file mode 100644 index 519af1db620dbf4de3694660dae7abd7392f0b3c..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6161 zcmeHLc~p|yyGEmwIyrGYmN}H;sZ7c`4u~3N4yjnFDP@|b$1>*`2Q01QPt8nCjTEFT zNh?jw6cvZG98kn5D#x5aQBc4EM7`L$cinr}UH6as|M_FDZ|(Pe_q+FgzWsjBv%i#! z2&?^j754%Ffc-Ys7MB44G2Tv-+#|A(zmBJh9-v@rWEcRj54_XF0M7~zh>T)km#xeJ zWfY}3Q3vP?w}%4&6&d@4p1T16g<=~ExKp&);^<ANhuv)Xl$tys&ZRsJ`1`r9ys$S(fVdZeKcl;htZ>{_w)oWN@&7#z6T_fA9GXXheZuzUT6kO> z3|mmpuubP@Czx(-Z)lFkuZ_-v$3@PKa&7{e`Ssn}c0*Ow>6UD5hhhn76}y~<_>IUH zQgFhnwsk(2KJdn+Tt5NPOuyxkJg|M;_d!DL>)sFt2783a-7?!%E0u^q5vH)^+(D9- z(er-^*~S4X*^_D1$<`V!#)5hQIS9 zcYS%vK3a~O5i~XFr}xU&uPJ_G^)Ekm!#$bx$M`CFZe@%)m(izAZZ*(3(mSZTa8e3R zx5P>jeO7!VjUU3PL|OAj(@#<)5R|4%>v#}Hy007~7hYzsU)!_O-IQDPx9h&J_hKPF z%m>4f`#a7qu6mlDlk=qDU&PVOZRCSggd~*ibxd^t z_HKK9Dg(XxJ~WYnnN?mW=2XmJd3=)&mm63h7Sh!+EEV~e5t00o)fnNb4X0OCw&9X$ zlo#T3qV#Kp;!Jf$++nerHVUnqU8xY^Dh;QTP%c%Pk(5^5XjcWe@6ATdp0>~DyXz;e zyClzmp!uNVWI|b97?N0qi&tn41;z(s-s*X_IQxu0NPqHP39jnuh3ni~MviN$vo&}h zpJ@8J!P&N(2x@O{e<8I`h9{;ui+c0>$P+#&>4ns=IwMIo?Bn@C&w60|VK2(Z6sUg& zURtf@-%^ZhWg@;w!Bx*jR>eL!Zu`t*VKP_0)+lkT;M#y|8fwir_%x7I0TV-XBeN|` zVvX)HwHfv^-xqLzc10gZvTsxC3YhozEHmf*6{HbIG}JwDth=G;+pPLm-Q(8CM2WIw zyg@cIJ5c&X>RSzd`~&6@M*4J}cNzAaysmW1iY!m7vajP&G`lyoDp?Kp6=#H6>Z4F7;laTt#q!h{ z`=j5^@-_NXEKQ?WBj4k&>S(OoiV)^bim4_7vl2eqEHf)@0;8hVsit->H2<4ed*Xrc?j&mmx-Ce~))>xw{ z2|bR{xq+f`RM@vp)0lTCG#!nruH!tnn{V1&9Ue@_hYbKpukE5!hzA<7L2*GrL3Ih+ zTYl`+!{;&tf3hk@>kod)lBoPioRCXz0<+yM0aa*(;p}Nkci#@ zyP^leo1oZ6X`_mRHo6dGOfpxNXKopz=vLf^s>-+bVo%@K($>d(t6v0|*=|wvJ6dR< z7+KK1>NJcF50tRuM5c#Fi)Nd#YSaN`e)feV}`WMqspcZRF_g_xM};B6;6H*`5=FD z?m&biG!BO^cWR$=>#l)ZV%P2(dm>dR$IF@5$ENW^D|eZe=PtJWGReVW zT8CTrXTbH!4Uhgm8SFKLHu4>OzAMo? zLw0f;CO;U;8$MA@`|3G2T4`J)Lh*H@E$wn_doPf{n&~6 z)1p~iS4J15RZTw(j#;l{pC+>&+EEp*@BvRA$11u(Fx2A?b1HhR9l&_coIFV`6fuA@ z%OUXIMxu5dDw5;vo@sVm0l9#P>A@z>h5e$8$KW21=L)JzVXuY{w7&+%!*jGj1E`=J zLO9&lDX|e>N?uGM&NVcZI>rQga4Ube)L`r@X+X8d({Fa%cVC$cfw(t3y>TY7@v%#} zyS@C(;UK80)zGk#J@RpIgq)I-1Z)FYVTYc*Qm!Av~lplUw8RR9u*m+rxaEV5XLp=g_ZMY6ozPB5wy3?CMM%+y<} ze&jqhQ#8Pi;pF{Y`Kvv(=IvzL;aQ8%<4M`s*&xd!jB8g;F^r9k!6U>l-_kEx#_F=X z6K@Pp)uItjunKUvrJ5*W888+3r&|7G>87qE^jE-o+-+M{_kZ;x@1DYjT^(90YLx`W zlY!VSJH~r=5xW7g16$m}ZM&i9^kYRbqRf%Ovf+p~>l=p|56!7H-7OEg3MN`&Kv3UM z00AmO>j6FpS)FA2`A+eOl6Lg393#cRi%IxhXLlt^0rpEm0cQ>~0MxtquwsLtF+0H9 zLVOVB`jpdbi?t}a(Pt=;1)$uPPf#teIK3B67fFoU@`ZI;zovWpqvi9>(?*Na)83_) z)SbC7no-A2KW~iz338ode_vf*UIteA!4<-Gw*6NpwhO5I)piFnG0Q=W^Tze1NW&dZ zq)4@uD3X5x@RWKmMYQJ6DL+4_DS%-KTXCgd?o+HqC>&F{p+f|hyOW-k%Kj&gjv;ze z0%xH~Z5W*b6ali^DJ_O|-r$vE2;{*X#zPCL!h4YxbPfv-KXaGRHRq$*qt3AOGX%`nvkLq(84@bl9XP{C79 zSa~h!Lib4L!Q;~~>yq*%v4G5!9J4%wn-B`F1c%z=)2 z`cN~97R;p z$W!6yljif4MdfNob-@dsY>^>N{<(f7i|wij1;@xuR?Y^PPZ6T;kfEXB)6yz<*+-pf z7qg>PBNPTOes?LP-YU?uis$ZVj&siK;;))<=NT{j<2YzO8g<)w;>eeF~S_77>q4H#d%_>z{QkqphJ; zs}|D}ymSQa@jO{t^SgXlB!E*k5lvLgzBLR+fK>nX=_xs~}KPU*2z?2lmB z0C9A2XHdzb0V^#FGJAOF)T}2o*+H~9AMv?$|9St8N@1fiQg-e@*nrUv|GCnRYLMAc z?6K;N%{>vrfT(A4N;`Dv@9Rrg)W{lTr_^l>mb)ObD#iQ%oGlJp*MoZyDFfDCxa8hx z%%t0>aI<@_wab4v>@=_}+wE^kM~o-RwY1M8mpi(wL_WiOyu#y>mmBrCDvO0n6?xgl z5*!kPpr@Nd~K*T%r9lu>Mk9BC;RG)ocHRPbF<9b5_TJ(O?wN4peTltbP zr_>#Ekbn1nz`c(uyi9*gXB+o(1h&}d~g8$|_H|5xGBpa)6{-B_Ylv_9JGHUkO|dpSbx*sI+Fs;X@NBd&x@l zxE}iS^z>#SIC#0Swr!*=vQ~7O?WeB_Bzt+qgaQ}|y%{s#on^cua<$MKi#5-2YF4V; zh_DGiASu($A!;+zPdN5zMTTcN%bY73FGMI40|R|Pzpf4Y-I3_zYVz@WT-R$8F_JOo zz8_uK@B9{tfCWt6XKFI+eK;r{qYrrJw1=X2SsvctFS_xa0lrP}`!nz$-Rz6G2A;RS zzdbq|90nv2l`Oy!5|?r{W`jlQ5s+(y>{kjqem=Mg7%%gJa=9|^T=&mekl?HZHlWKJ zM`%X<-0Po=m$n@I{x=G9z1uSu7=QK!#Xhu0Dfb8?=|Zh;`7IE%#`Od_P+?%bXdTmSCze;x6Sn za$iDdmY5zI`Q1dbKoaWPmEEY+jLG|U9g@5v7vAXSo)~p!_-W;bz!c@L4pgI>ksrD5 z^S`z~E`jA;_s_bo2*mC?^)AB`*A?*aJ`{+xJ@szwNr~`#m2YwAAJ$aj2h`PGM6|@V zpfbtCtao#B6ZG1ON~~PCzs6$LRW_KiZ$zc34Qevn((D+R!P7rMvHQ8heVmFU<#}*; z*^p8s?FQ}67zvjRIg)G(7LSligB3qr9FORgRa8!GV<5?$CDsOa`(?lGMOq;)>m zhWj3$O%K!W4RAo&%l{aNiJaXhX@Sm)tTx))%R~9+{Z)l&~7TCFJvXa|1oJrY*spNh~nd*Vh+n zX!s4*XWx7egRi=P)WhYagI9`*ih5{MQ&YFkq(n}VYHDiuxPKWg86{HYv5#VlR|f8F ziwKbrID0fsbaPBsa$1#=%f}Ceceb zN-b3UWW%p^w&7mWtG;yj7M@MSmWMS5h$|)STo18FG#0q@o_quLb^-Vs_6*oW=M9_h|Bv3-FSgyFU+y_*M3wVTvcc&+w&&qd-&MQ7K4S6BZ_ nZvA8UpYrn`JpNBSmbPOCTRwQJXvm2E-~(*VBP`0yZzTRZFN9iO From d3d161f723e8667e236510c198cf6b194c04e118 Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Tue, 10 Dec 2024 00:54:13 -0800 Subject: [PATCH 02/26] wip --- backend/open_webui/config.py | 7 +- .../{apps/webui => }/internal/db.py | 2 +- .../internal/migrations/001_initial_schema.py | 0 .../migrations/002_add_local_sharing.py | 0 .../migrations/003_add_auth_api_key.py | 0 .../internal/migrations/004_add_archived.py | 0 .../internal/migrations/005_add_updated_at.py | 0 .../006_migrate_timestamps_and_charfields.py | 0 .../migrations/007_add_user_last_active_at.py | 0 .../internal/migrations/008_add_memory.py | 0 .../internal/migrations/009_add_models.py | 0 .../010_migrate_modelfiles_to_models.py | 0 .../migrations/011_add_user_settings.py | 0 .../internal/migrations/012_add_tools.py | 0 .../internal/migrations/013_add_user_info.py | 0 .../internal/migrations/014_add_files.py | 0 .../internal/migrations/015_add_functions.py | 0 .../016_add_valves_and_is_active.py | 0 .../migrations/017_add_user_oauth_sub.py | 0 .../migrations/018_add_function_is_global.py | 0 .../{apps/webui => }/internal/wrappers.py | 0 backend/open_webui/main.py | 1199 +++++------------ backend/open_webui/migrations/env.py | 2 +- backend/open_webui/migrations/script.py.mako | 2 +- .../migrations/versions/7e5b5dc7342b_init.py | 4 +- .../{apps/webui => }/models/auths.py | 4 +- .../{apps/webui => }/models/chats.py | 4 +- .../{apps/webui => }/models/feedbacks.py | 4 +- .../{apps/webui => }/models/files.py | 2 +- .../{apps/webui => }/models/folders.py | 4 +- .../{apps/webui => }/models/functions.py | 4 +- .../{apps/webui => }/models/groups.py | 4 +- .../{apps/webui => }/models/knowledge.py | 6 +- .../{apps/webui => }/models/memories.py | 2 +- .../{apps/webui => }/models/models.py | 4 +- .../{apps/webui => }/models/prompts.py | 4 +- .../{apps/webui => }/models/tags.py | 2 +- .../{apps/webui => }/models/tools.py | 4 +- .../{apps/webui => }/models/users.py | 4 +- .../{apps => }/retrieval/loaders/main.py | 0 .../{apps => }/retrieval/loaders/youtube.py | 0 .../{apps => }/retrieval/models/colbert.py | 0 .../open_webui/{apps => }/retrieval/utils.py | 0 .../{apps => }/retrieval/vector/connector.py | 0 .../{apps => }/retrieval/vector/dbs/chroma.py | 0 .../{apps => }/retrieval/vector/dbs/milvus.py | 0 .../retrieval/vector/dbs/opensearch.py | 0 .../retrieval/vector/dbs/pgvector.py | 2 +- .../{apps => }/retrieval/vector/dbs/qdrant.py | 0 .../{apps => }/retrieval/vector/main.py | 0 .../{apps => }/retrieval/web/bing.py | 0 .../{apps => }/retrieval/web/brave.py | 0 .../{apps => }/retrieval/web/duckduckgo.py | 0 .../{apps => }/retrieval/web/google_pse.py | 0 .../{apps => }/retrieval/web/jina_search.py | 0 .../{apps => }/retrieval/web/kagi.py | 0 .../{apps => }/retrieval/web/main.py | 0 .../{apps => }/retrieval/web/mojeek.py | 0 .../{apps => }/retrieval/web/searchapi.py | 0 .../{apps => }/retrieval/web/searxng.py | 0 .../{apps => }/retrieval/web/serper.py | 0 .../{apps => }/retrieval/web/serply.py | 0 .../{apps => }/retrieval/web/serpstack.py | 0 .../{apps => }/retrieval/web/tavily.py | 0 .../retrieval/web/testdata/bing.json | 0 .../retrieval/web/testdata/brave.json | 0 .../retrieval/web/testdata/google_pse.json | 0 .../retrieval/web/testdata/searchapi.json | 0 .../retrieval/web/testdata/searxng.json | 0 .../retrieval/web/testdata/serper.json | 0 .../retrieval/web/testdata/serply.json | 0 .../retrieval/web/testdata/serpstack.json | 0 .../{apps => }/retrieval/web/utils.py | 0 .../{apps/audio/main.py => routers/audio.py} | 41 +- .../{apps/webui => }/routers/auths.py | 4 +- backend/open_webui/routers/chat.py | 411 ++++++ .../{apps/webui => }/routers/chats.py | 6 +- .../{apps/webui => }/routers/configs.py | 0 .../{apps/webui => }/routers/evaluations.py | 4 +- .../{apps/webui => }/routers/files.py | 4 +- .../{apps/webui => }/routers/folders.py | 4 +- .../{apps/webui => }/routers/functions.py | 4 +- .../{apps/webui => }/routers/groups.py | 2 +- .../images/main.py => routers/images.py} | 55 +- .../{apps/webui => }/routers/knowledge.py | 6 +- .../{apps/webui => }/routers/memories.py | 2 +- .../{apps/webui => }/routers/models.py | 2 +- .../ollama/main.py => routers/ollama.py} | 45 +- .../openai/main.py => routers/openai.py} | 26 +- backend/open_webui/routers/pipelines.py | 398 ++++-- .../{apps/webui => }/routers/prompts.py | 2 +- .../main.py => routers/retrieval.py} | 6 +- backend/open_webui/routers/tasks.py | 17 +- .../{apps/webui => }/routers/tools.py | 4 +- .../{apps/webui => }/routers/users.py | 6 +- .../{apps/webui => }/routers/utils.py | 4 +- .../{apps/webui/main.py => routers/webui.py} | 8 +- backend/open_webui/{apps => }/socket/main.py | 2 +- backend/open_webui/{apps => }/socket/utils.py | 0 .../test/apps/webui/routers/test_auths.py | 4 +- .../test/apps/webui/routers/test_chats.py | 4 +- .../test/apps/webui/routers/test_models.py | 2 +- .../test/apps/webui/routers/test_users.py | 2 +- .../test/util/abstract_integration_test.py | 4 +- backend/open_webui/test/util/mock_user.py | 4 +- backend/open_webui/utils/access_control.py | 2 +- backend/open_webui/utils/auth.py | 2 +- .../images/utils => utils/images}/comfyui.py | 0 backend/open_webui/utils/oauth.py | 13 +- backend/open_webui/utils/pdf_generator.py | 2 +- .../{apps/webui/utils.py => utils/plugin.py} | 4 +- backend/open_webui/utils/tools.py | 6 +- 112 files changed, 1217 insertions(+), 1165 deletions(-) rename backend/open_webui/{apps/webui => }/internal/db.py (97%) rename backend/open_webui/{apps/webui => }/internal/migrations/001_initial_schema.py (100%) rename backend/open_webui/{apps/webui => }/internal/migrations/002_add_local_sharing.py (100%) rename backend/open_webui/{apps/webui => }/internal/migrations/003_add_auth_api_key.py (100%) rename backend/open_webui/{apps/webui => }/internal/migrations/004_add_archived.py (100%) rename backend/open_webui/{apps/webui => }/internal/migrations/005_add_updated_at.py (100%) rename backend/open_webui/{apps/webui => }/internal/migrations/006_migrate_timestamps_and_charfields.py (100%) rename backend/open_webui/{apps/webui => }/internal/migrations/007_add_user_last_active_at.py (100%) rename backend/open_webui/{apps/webui => }/internal/migrations/008_add_memory.py (100%) rename backend/open_webui/{apps/webui => }/internal/migrations/009_add_models.py (100%) rename backend/open_webui/{apps/webui => }/internal/migrations/010_migrate_modelfiles_to_models.py (100%) rename backend/open_webui/{apps/webui => }/internal/migrations/011_add_user_settings.py (100%) rename backend/open_webui/{apps/webui => }/internal/migrations/012_add_tools.py (100%) rename backend/open_webui/{apps/webui => }/internal/migrations/013_add_user_info.py (100%) rename backend/open_webui/{apps/webui => }/internal/migrations/014_add_files.py (100%) rename backend/open_webui/{apps/webui => }/internal/migrations/015_add_functions.py (100%) rename backend/open_webui/{apps/webui => }/internal/migrations/016_add_valves_and_is_active.py (100%) rename backend/open_webui/{apps/webui => }/internal/migrations/017_add_user_oauth_sub.py (100%) rename backend/open_webui/{apps/webui => }/internal/migrations/018_add_function_is_global.py (100%) rename backend/open_webui/{apps/webui => }/internal/wrappers.py (100%) rename backend/open_webui/{apps/webui => }/models/auths.py (97%) rename backend/open_webui/{apps/webui => }/models/chats.py (99%) rename backend/open_webui/{apps/webui => }/models/feedbacks.py (98%) rename backend/open_webui/{apps/webui => }/models/files.py (98%) rename backend/open_webui/{apps/webui => }/models/folders.py (98%) rename backend/open_webui/{apps/webui => }/models/functions.py (98%) rename backend/open_webui/{apps/webui => }/models/groups.py (97%) rename backend/open_webui/{apps/webui => }/models/knowledge.py (97%) rename backend/open_webui/{apps/webui => }/models/memories.py (98%) rename backend/open_webui/{apps/webui => }/models/models.py (98%) rename backend/open_webui/{apps/webui => }/models/prompts.py (97%) rename backend/open_webui/{apps/webui => }/models/tags.py (98%) rename backend/open_webui/{apps/webui => }/models/tools.py (98%) rename backend/open_webui/{apps/webui => }/models/users.py (98%) rename backend/open_webui/{apps => }/retrieval/loaders/main.py (100%) rename backend/open_webui/{apps => }/retrieval/loaders/youtube.py (100%) rename backend/open_webui/{apps => }/retrieval/models/colbert.py (100%) rename backend/open_webui/{apps => }/retrieval/utils.py (100%) rename backend/open_webui/{apps => }/retrieval/vector/connector.py (100%) rename backend/open_webui/{apps => }/retrieval/vector/dbs/chroma.py (100%) rename backend/open_webui/{apps => }/retrieval/vector/dbs/milvus.py (100%) rename backend/open_webui/{apps => }/retrieval/vector/dbs/opensearch.py (100%) rename backend/open_webui/{apps => }/retrieval/vector/dbs/pgvector.py (99%) rename backend/open_webui/{apps => }/retrieval/vector/dbs/qdrant.py (100%) rename backend/open_webui/{apps => }/retrieval/vector/main.py (100%) rename backend/open_webui/{apps => }/retrieval/web/bing.py (100%) rename backend/open_webui/{apps => }/retrieval/web/brave.py (100%) rename backend/open_webui/{apps => }/retrieval/web/duckduckgo.py (100%) rename backend/open_webui/{apps => }/retrieval/web/google_pse.py (100%) rename backend/open_webui/{apps => }/retrieval/web/jina_search.py (100%) rename backend/open_webui/{apps => }/retrieval/web/kagi.py (100%) rename backend/open_webui/{apps => }/retrieval/web/main.py (100%) rename backend/open_webui/{apps => }/retrieval/web/mojeek.py (100%) rename backend/open_webui/{apps => }/retrieval/web/searchapi.py (100%) rename backend/open_webui/{apps => }/retrieval/web/searxng.py (100%) rename backend/open_webui/{apps => }/retrieval/web/serper.py (100%) rename backend/open_webui/{apps => }/retrieval/web/serply.py (100%) rename backend/open_webui/{apps => }/retrieval/web/serpstack.py (100%) rename backend/open_webui/{apps => }/retrieval/web/tavily.py (100%) rename backend/open_webui/{apps => }/retrieval/web/testdata/bing.json (100%) rename backend/open_webui/{apps => }/retrieval/web/testdata/brave.json (100%) rename backend/open_webui/{apps => }/retrieval/web/testdata/google_pse.json (100%) rename backend/open_webui/{apps => }/retrieval/web/testdata/searchapi.json (100%) rename backend/open_webui/{apps => }/retrieval/web/testdata/searxng.json (100%) rename backend/open_webui/{apps => }/retrieval/web/testdata/serper.json (100%) rename backend/open_webui/{apps => }/retrieval/web/testdata/serply.json (100%) rename backend/open_webui/{apps => }/retrieval/web/testdata/serpstack.json (100%) rename backend/open_webui/{apps => }/retrieval/web/utils.py (100%) rename backend/open_webui/{apps/audio/main.py => routers/audio.py} (94%) rename backend/open_webui/{apps/webui => }/routers/auths.py (99%) rename backend/open_webui/{apps/webui => }/routers/chats.py (99%) rename backend/open_webui/{apps/webui => }/routers/configs.py (100%) rename backend/open_webui/{apps/webui => }/routers/evaluations.py (97%) rename backend/open_webui/{apps/webui => }/routers/files.py (98%) rename backend/open_webui/{apps/webui => }/routers/folders.py (98%) rename backend/open_webui/{apps/webui => }/routers/functions.py (98%) rename backend/open_webui/{apps/webui => }/routers/groups.py (98%) rename backend/open_webui/{apps/images/main.py => routers/images.py} (92%) rename backend/open_webui/{apps/webui => }/routers/knowledge.py (98%) rename backend/open_webui/{apps/webui => }/routers/memories.py (98%) rename backend/open_webui/{apps/webui => }/routers/models.py (99%) rename backend/open_webui/{apps/ollama/main.py => routers/ollama.py} (98%) rename backend/open_webui/{apps/openai/main.py => routers/openai.py} (97%) rename backend/open_webui/{apps/webui => }/routers/prompts.py (98%) rename backend/open_webui/{apps/retrieval/main.py => routers/retrieval.py} (99%) rename backend/open_webui/{apps/webui => }/routers/tools.py (98%) rename backend/open_webui/{apps/webui => }/routers/users.py (98%) rename backend/open_webui/{apps/webui => }/routers/utils.py (95%) rename backend/open_webui/{apps/webui/main.py => routers/webui.py} (98%) rename backend/open_webui/{apps => }/socket/main.py (99%) rename backend/open_webui/{apps => }/socket/utils.py (100%) rename backend/open_webui/{apps/images/utils => utils/images}/comfyui.py (100%) rename backend/open_webui/{apps/webui/utils.py => utils/plugin.py} (98%) diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index 905d2472a..955b3423e 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -10,7 +10,7 @@ from urllib.parse import urlparse import chromadb import requests import yaml -from open_webui.apps.webui.internal.db import Base, get_db +from open_webui.internal.db import Base, get_db from open_webui.env import ( OPEN_WEBUI_DIR, DATA_DIR, @@ -432,7 +432,10 @@ OAUTH_ADMIN_ROLES = PersistentConfig( OAUTH_ALLOWED_DOMAINS = PersistentConfig( "OAUTH_ALLOWED_DOMAINS", "oauth.allowed_domains", - [domain.strip() for domain in os.environ.get("OAUTH_ALLOWED_DOMAINS", "*").split(",")], + [ + domain.strip() + for domain in os.environ.get("OAUTH_ALLOWED_DOMAINS", "*").split(",") + ], ) diff --git a/backend/open_webui/apps/webui/internal/db.py b/backend/open_webui/internal/db.py similarity index 97% rename from backend/open_webui/apps/webui/internal/db.py rename to backend/open_webui/internal/db.py index bcf913e6f..ba078822e 100644 --- a/backend/open_webui/apps/webui/internal/db.py +++ b/backend/open_webui/internal/db.py @@ -3,7 +3,7 @@ import logging from contextlib import contextmanager from typing import Any, Optional -from open_webui.apps.webui.internal.wrappers import register_connection +from open_webui.internal.wrappers import register_connection from open_webui.env import ( OPEN_WEBUI_DIR, DATABASE_URL, diff --git a/backend/open_webui/apps/webui/internal/migrations/001_initial_schema.py b/backend/open_webui/internal/migrations/001_initial_schema.py similarity index 100% rename from backend/open_webui/apps/webui/internal/migrations/001_initial_schema.py rename to backend/open_webui/internal/migrations/001_initial_schema.py diff --git a/backend/open_webui/apps/webui/internal/migrations/002_add_local_sharing.py b/backend/open_webui/internal/migrations/002_add_local_sharing.py similarity index 100% rename from backend/open_webui/apps/webui/internal/migrations/002_add_local_sharing.py rename to backend/open_webui/internal/migrations/002_add_local_sharing.py diff --git a/backend/open_webui/apps/webui/internal/migrations/003_add_auth_api_key.py b/backend/open_webui/internal/migrations/003_add_auth_api_key.py similarity index 100% rename from backend/open_webui/apps/webui/internal/migrations/003_add_auth_api_key.py rename to backend/open_webui/internal/migrations/003_add_auth_api_key.py diff --git a/backend/open_webui/apps/webui/internal/migrations/004_add_archived.py b/backend/open_webui/internal/migrations/004_add_archived.py similarity index 100% rename from backend/open_webui/apps/webui/internal/migrations/004_add_archived.py rename to backend/open_webui/internal/migrations/004_add_archived.py diff --git a/backend/open_webui/apps/webui/internal/migrations/005_add_updated_at.py b/backend/open_webui/internal/migrations/005_add_updated_at.py similarity index 100% rename from backend/open_webui/apps/webui/internal/migrations/005_add_updated_at.py rename to backend/open_webui/internal/migrations/005_add_updated_at.py diff --git a/backend/open_webui/apps/webui/internal/migrations/006_migrate_timestamps_and_charfields.py b/backend/open_webui/internal/migrations/006_migrate_timestamps_and_charfields.py similarity index 100% rename from backend/open_webui/apps/webui/internal/migrations/006_migrate_timestamps_and_charfields.py rename to backend/open_webui/internal/migrations/006_migrate_timestamps_and_charfields.py diff --git a/backend/open_webui/apps/webui/internal/migrations/007_add_user_last_active_at.py b/backend/open_webui/internal/migrations/007_add_user_last_active_at.py similarity index 100% rename from backend/open_webui/apps/webui/internal/migrations/007_add_user_last_active_at.py rename to backend/open_webui/internal/migrations/007_add_user_last_active_at.py diff --git a/backend/open_webui/apps/webui/internal/migrations/008_add_memory.py b/backend/open_webui/internal/migrations/008_add_memory.py similarity index 100% rename from backend/open_webui/apps/webui/internal/migrations/008_add_memory.py rename to backend/open_webui/internal/migrations/008_add_memory.py diff --git a/backend/open_webui/apps/webui/internal/migrations/009_add_models.py b/backend/open_webui/internal/migrations/009_add_models.py similarity index 100% rename from backend/open_webui/apps/webui/internal/migrations/009_add_models.py rename to backend/open_webui/internal/migrations/009_add_models.py diff --git a/backend/open_webui/apps/webui/internal/migrations/010_migrate_modelfiles_to_models.py b/backend/open_webui/internal/migrations/010_migrate_modelfiles_to_models.py similarity index 100% rename from backend/open_webui/apps/webui/internal/migrations/010_migrate_modelfiles_to_models.py rename to backend/open_webui/internal/migrations/010_migrate_modelfiles_to_models.py diff --git a/backend/open_webui/apps/webui/internal/migrations/011_add_user_settings.py b/backend/open_webui/internal/migrations/011_add_user_settings.py similarity index 100% rename from backend/open_webui/apps/webui/internal/migrations/011_add_user_settings.py rename to backend/open_webui/internal/migrations/011_add_user_settings.py diff --git a/backend/open_webui/apps/webui/internal/migrations/012_add_tools.py b/backend/open_webui/internal/migrations/012_add_tools.py similarity index 100% rename from backend/open_webui/apps/webui/internal/migrations/012_add_tools.py rename to backend/open_webui/internal/migrations/012_add_tools.py diff --git a/backend/open_webui/apps/webui/internal/migrations/013_add_user_info.py b/backend/open_webui/internal/migrations/013_add_user_info.py similarity index 100% rename from backend/open_webui/apps/webui/internal/migrations/013_add_user_info.py rename to backend/open_webui/internal/migrations/013_add_user_info.py diff --git a/backend/open_webui/apps/webui/internal/migrations/014_add_files.py b/backend/open_webui/internal/migrations/014_add_files.py similarity index 100% rename from backend/open_webui/apps/webui/internal/migrations/014_add_files.py rename to backend/open_webui/internal/migrations/014_add_files.py diff --git a/backend/open_webui/apps/webui/internal/migrations/015_add_functions.py b/backend/open_webui/internal/migrations/015_add_functions.py similarity index 100% rename from backend/open_webui/apps/webui/internal/migrations/015_add_functions.py rename to backend/open_webui/internal/migrations/015_add_functions.py diff --git a/backend/open_webui/apps/webui/internal/migrations/016_add_valves_and_is_active.py b/backend/open_webui/internal/migrations/016_add_valves_and_is_active.py similarity index 100% rename from backend/open_webui/apps/webui/internal/migrations/016_add_valves_and_is_active.py rename to backend/open_webui/internal/migrations/016_add_valves_and_is_active.py diff --git a/backend/open_webui/apps/webui/internal/migrations/017_add_user_oauth_sub.py b/backend/open_webui/internal/migrations/017_add_user_oauth_sub.py similarity index 100% rename from backend/open_webui/apps/webui/internal/migrations/017_add_user_oauth_sub.py rename to backend/open_webui/internal/migrations/017_add_user_oauth_sub.py diff --git a/backend/open_webui/apps/webui/internal/migrations/018_add_function_is_global.py b/backend/open_webui/internal/migrations/018_add_function_is_global.py similarity index 100% rename from backend/open_webui/apps/webui/internal/migrations/018_add_function_is_global.py rename to backend/open_webui/internal/migrations/018_add_function_is_global.py diff --git a/backend/open_webui/apps/webui/internal/wrappers.py b/backend/open_webui/internal/wrappers.py similarity index 100% rename from backend/open_webui/apps/webui/internal/wrappers.py rename to backend/open_webui/internal/wrappers.py diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 5ab820981..ab43ef8b4 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -9,8 +9,11 @@ import sys import time import random from contextlib import asynccontextmanager -from typing import Optional +from urllib.parse import urlencode, parse_qs, urlparse +from pydantic import BaseModel +from sqlalchemy import text +from typing import Optional from aiocache import cached import aiohttp import requests @@ -27,112 +30,201 @@ from fastapi import ( from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, RedirectResponse from fastapi.staticfiles import StaticFiles -from pydantic import BaseModel -from sqlalchemy import text + from starlette.exceptions import HTTPException as StarletteHTTPException from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.sessions import SessionMiddleware from starlette.responses import Response, StreamingResponse -from open_webui.apps.audio.main import app as audio_app -from open_webui.apps.images.main import app as images_app -from open_webui.apps.ollama.main import ( - app as ollama_app, - get_all_models as get_ollama_models, - generate_chat_completion as generate_ollama_chat_completion, - GenerateChatCompletionForm, + +from open_webui.routers import ( + audio, + chat, + images, + ollama, + openai, + retrieval, + pipelines, + tasks, ) -from open_webui.apps.openai.main import ( - app as openai_app, - generate_chat_completion as generate_openai_chat_completion, - get_all_models as get_openai_models, - get_all_models_responses as get_openai_models_responses, -) -from open_webui.apps.retrieval.main import app as retrieval_app -from open_webui.apps.retrieval.utils import get_sources_from_files + +from open_webui.retrieval.utils import get_sources_from_files -from open_webui.apps.socket.main import ( +from open_webui.socket.main import ( app as socket_app, periodic_usage_pool_cleanup, get_event_call, get_event_emitter, ) -from open_webui.apps.webui.internal.db import Session -from open_webui.apps.webui.main import ( + + +from open_webui.internal.db import Session + + +from backend.open_webui.routers.webui import ( app as webui_app, generate_function_chat_completion, get_all_models as get_open_webui_models, ) -from open_webui.apps.webui.models.functions import Functions -from open_webui.apps.webui.models.models import Models -from open_webui.apps.webui.models.users import UserModel, Users -from open_webui.apps.webui.utils import load_function_module_by_id +from open_webui.models.functions import Functions +from open_webui.models.models import Models +from open_webui.models.users import UserModel, Users +from backend.open_webui.utils.plugin import load_function_module_by_id + + +from open_webui.constants import TASKS from open_webui.config import ( - CACHE_DIR, - CORS_ALLOW_ORIGIN, - DEFAULT_LOCALE, - ENABLE_ADMIN_CHAT_ACCESS, - ENABLE_ADMIN_EXPORT, + # Ollama ENABLE_OLLAMA_API, + OLLAMA_BASE_URLS, + OLLAMA_API_CONFIGS, + # OpenAI ENABLE_OPENAI_API, - ENABLE_TAGS_GENERATION, - ENV, - FRONTEND_BUILD_DIR, - OAUTH_PROVIDERS, - STATIC_DIR, - TASK_MODEL, - TASK_MODEL_EXTERNAL, - ENABLE_SEARCH_QUERY_GENERATION, - ENABLE_RETRIEVAL_QUERY_GENERATION, - QUERY_GENERATION_PROMPT_TEMPLATE, - DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE, - TITLE_GENERATION_PROMPT_TEMPLATE, - TAGS_GENERATION_PROMPT_TEMPLATE, - ENABLE_AUTOCOMPLETE_GENERATION, - AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH, - AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE, - DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE, - TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, - WEBHOOK_URL, + OPENAI_API_BASE_URLS, + OPENAI_API_KEYS, + OPENAI_API_CONFIGS, + # Image + AUTOMATIC1111_API_AUTH, + AUTOMATIC1111_BASE_URL, + AUTOMATIC1111_CFG_SCALE, + AUTOMATIC1111_SAMPLER, + AUTOMATIC1111_SCHEDULER, + COMFYUI_BASE_URL, + COMFYUI_WORKFLOW, + COMFYUI_WORKFLOW_NODES, + ENABLE_IMAGE_GENERATION, + IMAGE_GENERATION_ENGINE, + IMAGE_GENERATION_MODEL, + IMAGE_SIZE, + IMAGE_STEPS, + IMAGES_OPENAI_API_BASE_URL, + IMAGES_OPENAI_API_KEY, + # Audio + AUDIO_STT_ENGINE, + AUDIO_STT_MODEL, + AUDIO_STT_OPENAI_API_BASE_URL, + AUDIO_STT_OPENAI_API_KEY, + AUDIO_TTS_API_KEY, + AUDIO_TTS_ENGINE, + AUDIO_TTS_MODEL, + AUDIO_TTS_OPENAI_API_BASE_URL, + AUDIO_TTS_OPENAI_API_KEY, + AUDIO_TTS_SPLIT_ON, + AUDIO_TTS_VOICE, + AUDIO_TTS_AZURE_SPEECH_REGION, + AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT, + WHISPER_MODEL, + WHISPER_MODEL_AUTO_UPDATE, + WHISPER_MODEL_DIR, + # WebUI WEBUI_AUTH, WEBUI_NAME, + WEBUI_BANNERS, + WEBHOOK_URL, + ADMIN_EMAIL, + SHOW_ADMIN_DETAILS, + JWT_EXPIRES_IN, + ENABLE_SIGNUP, + ENABLE_LOGIN_FORM, + ENABLE_API_KEY, + ENABLE_COMMUNITY_SHARING, + ENABLE_MESSAGE_RATING, + ENABLE_EVALUATION_ARENA_MODELS, + USER_PERMISSIONS, + DEFAULT_USER_ROLE, + DEFAULT_PROMPT_SUGGESTIONS, + DEFAULT_MODELS, + DEFAULT_ARENA_MODEL, + MODEL_ORDER_LIST, + EVALUATION_ARENA_MODELS, + # WebUI (OAuth) + ENABLE_OAUTH_ROLE_MANAGEMENT, + OAUTH_ROLES_CLAIM, + OAUTH_EMAIL_CLAIM, + OAUTH_PICTURE_CLAIM, + OAUTH_USERNAME_CLAIM, + OAUTH_ALLOWED_ROLES, + OAUTH_ADMIN_ROLES, + # WebUI (LDAP) + ENABLE_LDAP, + LDAP_SERVER_LABEL, + LDAP_SERVER_HOST, + LDAP_SERVER_PORT, + LDAP_ATTRIBUTE_FOR_USERNAME, + LDAP_SEARCH_FILTERS, + LDAP_SEARCH_BASE, + LDAP_APP_DN, + LDAP_APP_PASSWORD, + LDAP_USE_TLS, + LDAP_CA_CERT_FILE, + LDAP_CIPHERS, + # Misc + ENV, + CACHE_DIR, + STATIC_DIR, + FRONTEND_BUILD_DIR, + CORS_ALLOW_ORIGIN, + DEFAULT_LOCALE, + OAUTH_PROVIDERS, + # Admin + ENABLE_ADMIN_CHAT_ACCESS, + ENABLE_ADMIN_EXPORT, + # Tasks + TASK_MODEL, + TASK_MODEL_EXTERNAL, + ENABLE_TAGS_GENERATION, + ENABLE_SEARCH_QUERY_GENERATION, + ENABLE_RETRIEVAL_QUERY_GENERATION, + ENABLE_AUTOCOMPLETE_GENERATION, + TITLE_GENERATION_PROMPT_TEMPLATE, + TAGS_GENERATION_PROMPT_TEMPLATE, + TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, + QUERY_GENERATION_PROMPT_TEMPLATE, + AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE, + AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH, AppConfig, reset_config, ) -from open_webui.constants import TASKS from open_webui.env import ( CHANGELOG, GLOBAL_LOG_LEVEL, SAFE_MODE, SRC_LOG_LEVELS, VERSION, + WEBUI_URL, WEBUI_BUILD_HASH, WEBUI_SECRET_KEY, WEBUI_SESSION_COOKIE_SAME_SITE, WEBUI_SESSION_COOKIE_SECURE, - WEBUI_URL, + WEBUI_AUTH_TRUSTED_EMAIL_HEADER, + WEBUI_AUTH_TRUSTED_NAME_HEADER, BYPASS_MODEL_ACCESS_CONTROL, RESET_CONFIG_ON_START, OFFLINE_MODE, ) + + from open_webui.utils.misc import ( add_or_update_system_message, get_last_user_message, prepend_to_first_user_message_content, ) -from open_webui.utils.oauth import oauth_manager + + from open_webui.utils.payload import convert_payload_openai_to_ollama from open_webui.utils.response import ( convert_response_ollama_to_openai, convert_streaming_response_ollama_to_openai, ) -from open_webui.utils.security_headers import SecurityHeadersMiddleware + from open_webui.utils.task import ( rag_template, tools_function_calling_generation_template, ) from open_webui.utils.tools import get_tools +from open_webui.utils.access_control import has_access + from open_webui.utils.auth import ( decode_token, get_admin_user, @@ -140,7 +232,9 @@ from open_webui.utils.auth import ( get_http_authorization_cred, get_verified_user, ) -from open_webui.utils.access_control import has_access +from open_webui.utils.oauth import oauth_manager +from open_webui.utils.security_headers import SecurityHeadersMiddleware + if SAFE_MODE: print("SAFE MODE ENABLED") @@ -197,36 +291,186 @@ app = FastAPI( app.state.config = AppConfig() -app.state.config.ENABLE_OPENAI_API = ENABLE_OPENAI_API -app.state.config.ENABLE_OLLAMA_API = ENABLE_OLLAMA_API +######################################## +# +# OLLAMA +# +######################################## + + +app.state.config.ENABLE_OLLAMA_API = ENABLE_OLLAMA_API +app.state.config.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS +app.state.config.OLLAMA_API_CONFIGS = OLLAMA_API_CONFIGS + + +######################################## +# +# OPENAI +# +######################################## + +app.state.config.ENABLE_OPENAI_API = ENABLE_OPENAI_API +app.state.config.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS +app.state.config.OPENAI_API_KEYS = OPENAI_API_KEYS +app.state.config.OPENAI_API_CONFIGS = OPENAI_API_CONFIGS + + +######################################## +# +# WEBUI +# +######################################## + +app.state.config.ENABLE_SIGNUP = ENABLE_SIGNUP +app.state.config.ENABLE_LOGIN_FORM = ENABLE_LOGIN_FORM +app.state.config.ENABLE_API_KEY = ENABLE_API_KEY + +app.state.config.JWT_EXPIRES_IN = JWT_EXPIRES_IN + +app.state.config.SHOW_ADMIN_DETAILS = SHOW_ADMIN_DETAILS +app.state.config.ADMIN_EMAIL = ADMIN_EMAIL + + +app.state.config.DEFAULT_MODELS = DEFAULT_MODELS +app.state.config.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS +app.state.config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE + +app.state.config.USER_PERMISSIONS = USER_PERMISSIONS app.state.config.WEBHOOK_URL = WEBHOOK_URL +app.state.config.BANNERS = WEBUI_BANNERS +app.state.config.MODEL_ORDER_LIST = MODEL_ORDER_LIST + +app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING +app.state.config.ENABLE_MESSAGE_RATING = ENABLE_MESSAGE_RATING + +app.state.config.ENABLE_EVALUATION_ARENA_MODELS = ENABLE_EVALUATION_ARENA_MODELS +app.state.config.EVALUATION_ARENA_MODELS = EVALUATION_ARENA_MODELS + +app.state.config.OAUTH_USERNAME_CLAIM = OAUTH_USERNAME_CLAIM +app.state.config.OAUTH_PICTURE_CLAIM = OAUTH_PICTURE_CLAIM +app.state.config.OAUTH_EMAIL_CLAIM = OAUTH_EMAIL_CLAIM + +app.state.config.ENABLE_OAUTH_ROLE_MANAGEMENT = ENABLE_OAUTH_ROLE_MANAGEMENT +app.state.config.OAUTH_ROLES_CLAIM = OAUTH_ROLES_CLAIM +app.state.config.OAUTH_ALLOWED_ROLES = OAUTH_ALLOWED_ROLES +app.state.config.OAUTH_ADMIN_ROLES = OAUTH_ADMIN_ROLES + +app.state.config.ENABLE_LDAP = ENABLE_LDAP +app.state.config.LDAP_SERVER_LABEL = LDAP_SERVER_LABEL +app.state.config.LDAP_SERVER_HOST = LDAP_SERVER_HOST +app.state.config.LDAP_SERVER_PORT = LDAP_SERVER_PORT +app.state.config.LDAP_ATTRIBUTE_FOR_USERNAME = LDAP_ATTRIBUTE_FOR_USERNAME +app.state.config.LDAP_APP_DN = LDAP_APP_DN +app.state.config.LDAP_APP_PASSWORD = LDAP_APP_PASSWORD +app.state.config.LDAP_SEARCH_BASE = LDAP_SEARCH_BASE +app.state.config.LDAP_SEARCH_FILTERS = LDAP_SEARCH_FILTERS +app.state.config.LDAP_USE_TLS = LDAP_USE_TLS +app.state.config.LDAP_CA_CERT_FILE = LDAP_CA_CERT_FILE +app.state.config.LDAP_CIPHERS = LDAP_CIPHERS + + +app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER +app.state.AUTH_TRUSTED_NAME_HEADER = WEBUI_AUTH_TRUSTED_NAME_HEADER + +app.state.TOOLS = {} +app.state.FUNCTIONS = {} + + +######################################## +# +# RETRIEVAL +# +######################################## + +######################################## +# +# IMAGES +# +######################################## + +app.state.config.IMAGE_GENERATION_ENGINE = IMAGE_GENERATION_ENGINE +app.state.config.ENABLE_IMAGE_GENERATION = ENABLE_IMAGE_GENERATION + +app.state.config.IMAGES_OPENAI_API_BASE_URL = IMAGES_OPENAI_API_BASE_URL +app.state.config.IMAGES_OPENAI_API_KEY = IMAGES_OPENAI_API_KEY + +app.state.config.IMAGE_GENERATION_MODEL = IMAGE_GENERATION_MODEL + +app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL +app.state.config.AUTOMATIC1111_API_AUTH = AUTOMATIC1111_API_AUTH +app.state.config.AUTOMATIC1111_CFG_SCALE = AUTOMATIC1111_CFG_SCALE +app.state.config.AUTOMATIC1111_SAMPLER = AUTOMATIC1111_SAMPLER +app.state.config.AUTOMATIC1111_SCHEDULER = AUTOMATIC1111_SCHEDULER +app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL +app.state.config.COMFYUI_WORKFLOW = COMFYUI_WORKFLOW +app.state.config.COMFYUI_WORKFLOW_NODES = COMFYUI_WORKFLOW_NODES + +app.state.config.IMAGE_SIZE = IMAGE_SIZE +app.state.config.IMAGE_STEPS = IMAGE_STEPS + + +######################################## +# +# AUDIO +# +######################################## + +app.state.config.STT_OPENAI_API_BASE_URL = AUDIO_STT_OPENAI_API_BASE_URL +app.state.config.STT_OPENAI_API_KEY = AUDIO_STT_OPENAI_API_KEY +app.state.config.STT_ENGINE = AUDIO_STT_ENGINE +app.state.config.STT_MODEL = AUDIO_STT_MODEL + +app.state.config.WHISPER_MODEL = WHISPER_MODEL + +app.state.config.TTS_OPENAI_API_BASE_URL = AUDIO_TTS_OPENAI_API_BASE_URL +app.state.config.TTS_OPENAI_API_KEY = AUDIO_TTS_OPENAI_API_KEY +app.state.config.TTS_ENGINE = AUDIO_TTS_ENGINE +app.state.config.TTS_MODEL = AUDIO_TTS_MODEL +app.state.config.TTS_VOICE = AUDIO_TTS_VOICE +app.state.config.TTS_API_KEY = AUDIO_TTS_API_KEY +app.state.config.TTS_SPLIT_ON = AUDIO_TTS_SPLIT_ON + + +app.state.config.TTS_AZURE_SPEECH_REGION = AUDIO_TTS_AZURE_SPEECH_REGION +app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT = AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT + + +app.state.faster_whisper_model = None +app.state.speech_synthesiser = None +app.state.speech_speaker_embeddings_dataset = None + + +######################################## +# +# TASKS +# +######################################## + app.state.config.TASK_MODEL = TASK_MODEL app.state.config.TASK_MODEL_EXTERNAL = TASK_MODEL_EXTERNAL -app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMPLATE +app.state.config.ENABLE_SEARCH_QUERY_GENERATION = ENABLE_SEARCH_QUERY_GENERATION +app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION = ENABLE_RETRIEVAL_QUERY_GENERATION app.state.config.ENABLE_AUTOCOMPLETE_GENERATION = ENABLE_AUTOCOMPLETE_GENERATION +app.state.config.ENABLE_TAGS_GENERATION = ENABLE_TAGS_GENERATION + + +app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMPLATE +app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = TAGS_GENERATION_PROMPT_TEMPLATE +app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = ( + TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE +) +app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE = QUERY_GENERATION_PROMPT_TEMPLATE +app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE = ( + AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE +) app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH = ( AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH ) -app.state.config.ENABLE_TAGS_GENERATION = ENABLE_TAGS_GENERATION -app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = TAGS_GENERATION_PROMPT_TEMPLATE - - -app.state.config.ENABLE_SEARCH_QUERY_GENERATION = ENABLE_SEARCH_QUERY_GENERATION -app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION = ENABLE_RETRIEVAL_QUERY_GENERATION -app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE = QUERY_GENERATION_PROMPT_TEMPLATE - -app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE = ( - AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE -) - -app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = ( - TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE -) ################################## # @@ -570,13 +814,6 @@ async def chat_completion_files_handler( return body, {"sources": sources} -def is_chat_completion_request(request): - return request.method == "POST" and any( - endpoint in request.url.path - for endpoint in ["/ollama/api/chat", "/chat/completions"] - ) - - async def get_body_and_model_and_user(request, models): # Read the original request body body = await request.body() @@ -598,7 +835,10 @@ async def get_body_and_model_and_user(request, models): class ChatCompletionMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): - if not is_chat_completion_request(request): + if not request.method == "POST" and any( + endpoint in request.url.path + for endpoint in ["/ollama/api/chat", "/chat/completions"] + ): return await call_next(request) log.debug(f"request.url.path: {request.url.path}") @@ -875,7 +1115,10 @@ def filter_pipeline(payload, user, models): class PipelineMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): - if not is_chat_completion_request(request): + if not request.method == "POST" and any( + endpoint in request.url.path + for endpoint in ["/ollama/api/chat", "/chat/completions"] + ): return await call_next(request) log.debug(f"request.url.path: {request.url.path}") @@ -945,9 +1188,6 @@ class PipelineMiddleware(BaseHTTPMiddleware): app.add_middleware(PipelineMiddleware) -from urllib.parse import urlencode, parse_qs, urlparse - - class RedirectMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): # Check if the request is a GET request @@ -969,16 +1209,6 @@ class RedirectMiddleware(BaseHTTPMiddleware): # Add the middleware to the app app.add_middleware(RedirectMiddleware) - - -app.add_middleware( - CORSMiddleware, - allow_origins=CORS_ALLOW_ORIGIN, - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) - app.add_middleware(SecurityHeadersMiddleware) @@ -1000,12 +1230,12 @@ async def check_url(request: Request, call_next): return response -@app.middleware("http") -async def update_embedding_function(request: Request, call_next): - response = await call_next(request) - if "/embedding/update" in request.url.path: - webui_app.state.EMBEDDING_FUNCTION = retrieval_app.state.EMBEDDING_FUNCTION - return response +# @app.middleware("http") +# async def update_embedding_function(request: Request, call_next): +# response = await call_next(request) +# if "/embedding/update" in request.url.path: +# webui_app.state.EMBEDDING_FUNCTION = retrieval_app.state.EMBEDDING_FUNCTION +# return response @app.middleware("http") @@ -1026,17 +1256,30 @@ async def inspect_websocket(request: Request, call_next): return await call_next(request) +app.add_middleware( + CORSMiddleware, + allow_origins=CORS_ALLOW_ORIGIN, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + + app.mount("/ws", socket_app) + + app.mount("/ollama", ollama_app) app.mount("/openai", openai_app) app.mount("/images/api/v1", images_app) app.mount("/audio/api/v1", audio_app) + + app.mount("/retrieval/api/v1", retrieval_app) app.mount("/api/v1", webui_app) -webui_app.state.EMBEDDING_FUNCTION = retrieval_app.state.EMBEDDING_FUNCTION +app.state.EMBEDDING_FUNCTION = retrieval_app.state.EMBEDDING_FUNCTION async def get_all_base_models(): @@ -1045,7 +1288,7 @@ async def get_all_base_models(): ollama_models = [] if app.state.config.ENABLE_OPENAI_API: - openai_models = await get_openai_models() + openai_models = await openai.get_all_models() openai_models = openai_models["data"] if app.state.config.ENABLE_OLLAMA_API: @@ -1255,740 +1498,6 @@ async def get_base_models(user=Depends(get_admin_user)): return {"data": models} -@app.post("/api/chat/completions") -async def generate_chat_completions( - request: Request, - form_data: dict, - user=Depends(get_verified_user), - bypass_filter: bool = False, -): - if BYPASS_MODEL_ACCESS_CONTROL: - bypass_filter = True - - model_list = request.state.models - models = {model["id"]: model for model in model_list} - - model_id = form_data["model"] - if model_id not in models: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Model not found", - ) - - model = models[model_id] - - # Check if user has access to the model - if not bypass_filter and user.role == "user": - if model.get("arena"): - if not has_access( - user.id, - type="read", - access_control=model.get("info", {}) - .get("meta", {}) - .get("access_control", {}), - ): - raise HTTPException( - status_code=403, - detail="Model not found", - ) - else: - model_info = Models.get_model_by_id(model_id) - if not model_info: - raise HTTPException( - status_code=404, - detail="Model not found", - ) - elif not ( - user.id == model_info.user_id - or has_access( - user.id, type="read", access_control=model_info.access_control - ) - ): - raise HTTPException( - status_code=403, - detail="Model not found", - ) - - if model["owned_by"] == "arena": - model_ids = model.get("info", {}).get("meta", {}).get("model_ids") - filter_mode = model.get("info", {}).get("meta", {}).get("filter_mode") - if model_ids and filter_mode == "exclude": - model_ids = [ - model["id"] - for model in await get_all_models() - if model.get("owned_by") != "arena" and model["id"] not in model_ids - ] - - selected_model_id = None - if isinstance(model_ids, list) and model_ids: - selected_model_id = random.choice(model_ids) - else: - model_ids = [ - model["id"] - for model in await get_all_models() - if model.get("owned_by") != "arena" - ] - selected_model_id = random.choice(model_ids) - - form_data["model"] = selected_model_id - - if form_data.get("stream") == True: - - async def stream_wrapper(stream): - yield f"data: {json.dumps({'selected_model_id': selected_model_id})}\n\n" - async for chunk in stream: - yield chunk - - response = await generate_chat_completions( - form_data, user, bypass_filter=True - ) - return StreamingResponse( - stream_wrapper(response.body_iterator), media_type="text/event-stream" - ) - else: - return { - **( - await generate_chat_completions(form_data, user, bypass_filter=True) - ), - "selected_model_id": selected_model_id, - } - - if model.get("pipe"): - # Below does not require bypass_filter because this is the only route the uses this function and it is already bypassing the filter - return await generate_function_chat_completion( - form_data, user=user, models=models - ) - if model["owned_by"] == "ollama": - # Using /ollama/api/chat endpoint - form_data = convert_payload_openai_to_ollama(form_data) - form_data = GenerateChatCompletionForm(**form_data) - response = await generate_ollama_chat_completion( - form_data=form_data, user=user, bypass_filter=bypass_filter - ) - if form_data.stream: - response.headers["content-type"] = "text/event-stream" - return StreamingResponse( - convert_streaming_response_ollama_to_openai(response), - headers=dict(response.headers), - ) - else: - return convert_response_ollama_to_openai(response) - else: - return await generate_openai_chat_completion( - form_data, user=user, bypass_filter=bypass_filter - ) - - -@app.post("/api/chat/completed") -async def chat_completed(form_data: dict, user=Depends(get_verified_user)): - model_list = await get_all_models() - models = {model["id"]: model for model in model_list} - - data = form_data - model_id = data["model"] - if model_id not in models: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Model not found", - ) - - model = models[model_id] - sorted_filters = get_sorted_filters(model_id, models) - if "pipeline" in model: - sorted_filters = [model] + sorted_filters - - for filter in sorted_filters: - r = None - try: - urlIdx = filter["urlIdx"] - - url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] - key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] - - if key != "": - headers = {"Authorization": f"Bearer {key}"} - r = requests.post( - f"{url}/{filter['id']}/filter/outlet", - headers=headers, - json={ - "user": { - "id": user.id, - "name": user.name, - "email": user.email, - "role": user.role, - }, - "body": data, - }, - ) - - r.raise_for_status() - data = r.json() - except Exception as e: - # Handle connection error here - print(f"Connection error: {e}") - - if r is not None: - try: - res = r.json() - if "detail" in res: - return JSONResponse( - status_code=r.status_code, - content=res, - ) - except Exception: - pass - - else: - pass - - __event_emitter__ = get_event_emitter( - { - "chat_id": data["chat_id"], - "message_id": data["id"], - "session_id": data["session_id"], - } - ) - - __event_call__ = get_event_call( - { - "chat_id": data["chat_id"], - "message_id": data["id"], - "session_id": data["session_id"], - } - ) - - def get_priority(function_id): - function = Functions.get_function_by_id(function_id) - if function is not None and hasattr(function, "valves"): - # TODO: Fix FunctionModel to include vavles - return (function.valves if function.valves else {}).get("priority", 0) - return 0 - - filter_ids = [function.id for function in Functions.get_global_filter_functions()] - if "info" in model and "meta" in model["info"]: - filter_ids.extend(model["info"]["meta"].get("filterIds", [])) - filter_ids = list(set(filter_ids)) - - enabled_filter_ids = [ - function.id - for function in Functions.get_functions_by_type("filter", active_only=True) - ] - filter_ids = [ - filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids - ] - - # Sort filter_ids by priority, using the get_priority function - filter_ids.sort(key=get_priority) - - for filter_id in filter_ids: - filter = Functions.get_function_by_id(filter_id) - if not filter: - continue - - if filter_id in webui_app.state.FUNCTIONS: - function_module = webui_app.state.FUNCTIONS[filter_id] - else: - function_module, _, _ = load_function_module_by_id(filter_id) - webui_app.state.FUNCTIONS[filter_id] = function_module - - if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): - valves = Functions.get_function_valves_by_id(filter_id) - function_module.valves = function_module.Valves( - **(valves if valves else {}) - ) - - if not hasattr(function_module, "outlet"): - continue - try: - outlet = function_module.outlet - - # Get the signature of the function - sig = inspect.signature(outlet) - params = {"body": data} - - # Extra parameters to be passed to the function - extra_params = { - "__model__": model, - "__id__": filter_id, - "__event_emitter__": __event_emitter__, - "__event_call__": __event_call__, - } - - # Add extra params in contained in function signature - for key, value in extra_params.items(): - if key in sig.parameters: - params[key] = value - - if "__user__" in sig.parameters: - __user__ = { - "id": user.id, - "email": user.email, - "name": user.name, - "role": user.role, - } - - try: - if hasattr(function_module, "UserValves"): - __user__["valves"] = function_module.UserValves( - **Functions.get_user_valves_by_id_and_user_id( - filter_id, user.id - ) - ) - except Exception as e: - print(e) - - params = {**params, "__user__": __user__} - - if inspect.iscoroutinefunction(outlet): - data = await outlet(**params) - else: - data = outlet(**params) - - except Exception as e: - print(f"Error: {e}") - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content={"detail": str(e)}, - ) - - return data - - -@app.post("/api/chat/actions/{action_id}") -async def chat_action(action_id: str, form_data: dict, user=Depends(get_verified_user)): - if "." in action_id: - action_id, sub_action_id = action_id.split(".") - else: - sub_action_id = None - - action = Functions.get_function_by_id(action_id) - if not action: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Action not found", - ) - - model_list = await get_all_models() - models = {model["id"]: model for model in model_list} - - data = form_data - model_id = data["model"] - - if model_id not in models: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Model not found", - ) - model = models[model_id] - - __event_emitter__ = get_event_emitter( - { - "chat_id": data["chat_id"], - "message_id": data["id"], - "session_id": data["session_id"], - } - ) - __event_call__ = get_event_call( - { - "chat_id": data["chat_id"], - "message_id": data["id"], - "session_id": data["session_id"], - } - ) - - if action_id in webui_app.state.FUNCTIONS: - function_module = webui_app.state.FUNCTIONS[action_id] - else: - function_module, _, _ = load_function_module_by_id(action_id) - webui_app.state.FUNCTIONS[action_id] = function_module - - if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): - valves = Functions.get_function_valves_by_id(action_id) - function_module.valves = function_module.Valves(**(valves if valves else {})) - - if hasattr(function_module, "action"): - try: - action = function_module.action - - # Get the signature of the function - sig = inspect.signature(action) - params = {"body": data} - - # Extra parameters to be passed to the function - extra_params = { - "__model__": model, - "__id__": sub_action_id if sub_action_id is not None else action_id, - "__event_emitter__": __event_emitter__, - "__event_call__": __event_call__, - } - - # Add extra params in contained in function signature - for key, value in extra_params.items(): - if key in sig.parameters: - params[key] = value - - if "__user__" in sig.parameters: - __user__ = { - "id": user.id, - "email": user.email, - "name": user.name, - "role": user.role, - } - - try: - if hasattr(function_module, "UserValves"): - __user__["valves"] = function_module.UserValves( - **Functions.get_user_valves_by_id_and_user_id( - action_id, user.id - ) - ) - except Exception as e: - print(e) - - params = {**params, "__user__": __user__} - - if inspect.iscoroutinefunction(action): - data = await action(**params) - else: - data = action(**params) - - except Exception as e: - print(f"Error: {e}") - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content={"detail": str(e)}, - ) - - return data - - -################################## -# -# Pipelines Endpoints -# -################################## - - -# TODO: Refactor pipelines API endpoints below into a separate file - - -@app.get("/api/pipelines/list") -async def get_pipelines_list(user=Depends(get_admin_user)): - responses = await get_openai_models_responses() - - log.debug(f"get_pipelines_list: get_openai_models_responses returned {responses}") - urlIdxs = [ - idx - for idx, response in enumerate(responses) - if response is not None and "pipelines" in response - ] - - return { - "data": [ - { - "url": openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx], - "idx": urlIdx, - } - for urlIdx in urlIdxs - ] - } - - -@app.post("/api/pipelines/upload") -async def upload_pipeline( - urlIdx: int = Form(...), file: UploadFile = File(...), user=Depends(get_admin_user) -): - print("upload_pipeline", urlIdx, file.filename) - # Check if the uploaded file is a python file - if not (file.filename and file.filename.endswith(".py")): - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Only Python (.py) files are allowed.", - ) - - upload_folder = f"{CACHE_DIR}/pipelines" - os.makedirs(upload_folder, exist_ok=True) - file_path = os.path.join(upload_folder, file.filename) - - r = None - try: - # Save the uploaded file - with open(file_path, "wb") as buffer: - shutil.copyfileobj(file.file, buffer) - - url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] - key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] - - headers = {"Authorization": f"Bearer {key}"} - - with open(file_path, "rb") as f: - files = {"file": f} - r = requests.post(f"{url}/pipelines/upload", headers=headers, files=files) - - r.raise_for_status() - data = r.json() - - return {**data} - except Exception as e: - # Handle connection error here - print(f"Connection error: {e}") - - detail = "Pipeline not found" - status_code = status.HTTP_404_NOT_FOUND - if r is not None: - status_code = r.status_code - try: - res = r.json() - if "detail" in res: - detail = res["detail"] - except Exception: - pass - - raise HTTPException( - status_code=status_code, - detail=detail, - ) - finally: - # Ensure the file is deleted after the upload is completed or on failure - if os.path.exists(file_path): - os.remove(file_path) - - -class AddPipelineForm(BaseModel): - url: str - urlIdx: int - - -@app.post("/api/pipelines/add") -async def add_pipeline(form_data: AddPipelineForm, user=Depends(get_admin_user)): - r = None - try: - urlIdx = form_data.urlIdx - - url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] - key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] - - headers = {"Authorization": f"Bearer {key}"} - r = requests.post( - f"{url}/pipelines/add", headers=headers, json={"url": form_data.url} - ) - - r.raise_for_status() - data = r.json() - - return {**data} - except Exception as e: - # Handle connection error here - print(f"Connection error: {e}") - - detail = "Pipeline not found" - if r is not None: - try: - res = r.json() - if "detail" in res: - detail = res["detail"] - except Exception: - pass - - raise HTTPException( - status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND), - detail=detail, - ) - - -class DeletePipelineForm(BaseModel): - id: str - urlIdx: int - - -@app.delete("/api/pipelines/delete") -async def delete_pipeline(form_data: DeletePipelineForm, user=Depends(get_admin_user)): - r = None - try: - urlIdx = form_data.urlIdx - - url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] - key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] - - headers = {"Authorization": f"Bearer {key}"} - r = requests.delete( - f"{url}/pipelines/delete", headers=headers, json={"id": form_data.id} - ) - - r.raise_for_status() - data = r.json() - - return {**data} - except Exception as e: - # Handle connection error here - print(f"Connection error: {e}") - - detail = "Pipeline not found" - if r is not None: - try: - res = r.json() - if "detail" in res: - detail = res["detail"] - except Exception: - pass - - raise HTTPException( - status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND), - detail=detail, - ) - - -@app.get("/api/pipelines") -async def get_pipelines(urlIdx: Optional[int] = None, user=Depends(get_admin_user)): - r = None - try: - url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] - key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] - - headers = {"Authorization": f"Bearer {key}"} - r = requests.get(f"{url}/pipelines", headers=headers) - - r.raise_for_status() - data = r.json() - - return {**data} - except Exception as e: - # Handle connection error here - print(f"Connection error: {e}") - - detail = "Pipeline not found" - if r is not None: - try: - res = r.json() - if "detail" in res: - detail = res["detail"] - except Exception: - pass - - raise HTTPException( - status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND), - detail=detail, - ) - - -@app.get("/api/pipelines/{pipeline_id}/valves") -async def get_pipeline_valves( - urlIdx: Optional[int], - pipeline_id: str, - user=Depends(get_admin_user), -): - r = None - try: - url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] - key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] - - headers = {"Authorization": f"Bearer {key}"} - r = requests.get(f"{url}/{pipeline_id}/valves", headers=headers) - - r.raise_for_status() - data = r.json() - - return {**data} - except Exception as e: - # Handle connection error here - print(f"Connection error: {e}") - - detail = "Pipeline not found" - - if r is not None: - try: - res = r.json() - if "detail" in res: - detail = res["detail"] - except Exception: - pass - - raise HTTPException( - status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND), - detail=detail, - ) - - -@app.get("/api/pipelines/{pipeline_id}/valves/spec") -async def get_pipeline_valves_spec( - urlIdx: Optional[int], - pipeline_id: str, - user=Depends(get_admin_user), -): - r = None - try: - url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] - key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] - - headers = {"Authorization": f"Bearer {key}"} - r = requests.get(f"{url}/{pipeline_id}/valves/spec", headers=headers) - - r.raise_for_status() - data = r.json() - - return {**data} - except Exception as e: - # Handle connection error here - print(f"Connection error: {e}") - - detail = "Pipeline not found" - if r is not None: - try: - res = r.json() - if "detail" in res: - detail = res["detail"] - except Exception: - pass - - raise HTTPException( - status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND), - detail=detail, - ) - - -@app.post("/api/pipelines/{pipeline_id}/valves/update") -async def update_pipeline_valves( - urlIdx: Optional[int], - pipeline_id: str, - form_data: dict, - user=Depends(get_admin_user), -): - r = None - try: - url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] - key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] - - headers = {"Authorization": f"Bearer {key}"} - r = requests.post( - f"{url}/{pipeline_id}/valves/update", - headers=headers, - json={**form_data}, - ) - - r.raise_for_status() - data = r.json() - - return {**data} - except Exception as e: - # Handle connection error here - print(f"Connection error: {e}") - - detail = "Pipeline not found" - - if r is not None: - try: - res = r.json() - if "detail" in res: - detail = res["detail"] - except Exception: - pass - - raise HTTPException( - status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND), - detail=detail, - ) - - ################################## # # Config Endpoints @@ -2075,7 +1584,8 @@ async def get_app_config(request: Request): } -# TODO: webhook endpoint should be under config endpoints +class UrlForm(BaseModel): + url: str @app.get("/api/webhook") @@ -2085,10 +1595,6 @@ async def get_webhook_url(user=Depends(get_admin_user)): } -class UrlForm(BaseModel): - url: str - - @app.post("/api/webhook") async def update_webhook_url(form_data: UrlForm, user=Depends(get_admin_user)): app.state.config.WEBHOOK_URL = form_data.url @@ -2103,11 +1609,6 @@ async def get_app_version(): } -@app.get("/api/changelog") -async def get_app_changelog(): - return {key: CHANGELOG[key] for idx, key in enumerate(CHANGELOG) if idx < 5} - - @app.get("/api/version/updates") async def get_app_latest_release_version(): if OFFLINE_MODE: @@ -2131,6 +1632,11 @@ async def get_app_latest_release_version(): return {"current": VERSION, "latest": VERSION} +@app.get("/api/changelog") +async def get_app_changelog(): + return {key: CHANGELOG[key] for idx, key in enumerate(CHANGELOG) if idx < 5} + + ############################ # OAuth Login & Callback ############################ @@ -2218,7 +1724,6 @@ async def healthcheck_with_db(): app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static") app.mount("/cache", StaticFiles(directory=CACHE_DIR), name="cache") - if os.path.exists(FRONTEND_BUILD_DIR): mimetypes.add_type("text/javascript", ".js") app.mount( diff --git a/backend/open_webui/migrations/env.py b/backend/open_webui/migrations/env.py index 5e860c8a0..128881647 100644 --- a/backend/open_webui/migrations/env.py +++ b/backend/open_webui/migrations/env.py @@ -1,7 +1,7 @@ from logging.config import fileConfig from alembic import context -from open_webui.apps.webui.models.auths import Auth +from open_webui.models.auths import Auth from open_webui.env import DATABASE_URL from sqlalchemy import engine_from_config, pool diff --git a/backend/open_webui/migrations/script.py.mako b/backend/open_webui/migrations/script.py.mako index 01e730e77..bcf5567fd 100644 --- a/backend/open_webui/migrations/script.py.mako +++ b/backend/open_webui/migrations/script.py.mako @@ -9,7 +9,7 @@ from typing import Sequence, Union from alembic import op import sqlalchemy as sa -import open_webui.apps.webui.internal.db +import open_webui.internal.db ${imports if imports else ""} # revision identifiers, used by Alembic. diff --git a/backend/open_webui/migrations/versions/7e5b5dc7342b_init.py b/backend/open_webui/migrations/versions/7e5b5dc7342b_init.py index 607a7b2c9..9e56282ef 100644 --- a/backend/open_webui/migrations/versions/7e5b5dc7342b_init.py +++ b/backend/open_webui/migrations/versions/7e5b5dc7342b_init.py @@ -11,8 +11,8 @@ from typing import Sequence, Union import sqlalchemy as sa from alembic import op -import open_webui.apps.webui.internal.db -from open_webui.apps.webui.internal.db import JSONField +import open_webui.internal.db +from open_webui.internal.db import JSONField from open_webui.migrations.util import get_existing_tables # revision identifiers, used by Alembic. diff --git a/backend/open_webui/apps/webui/models/auths.py b/backend/open_webui/models/auths.py similarity index 97% rename from backend/open_webui/apps/webui/models/auths.py rename to backend/open_webui/models/auths.py index 391b2e9ec..f07c36c73 100644 --- a/backend/open_webui/apps/webui/models/auths.py +++ b/backend/open_webui/models/auths.py @@ -2,8 +2,8 @@ import logging import uuid from typing import Optional -from open_webui.apps.webui.internal.db import Base, get_db -from open_webui.apps.webui.models.users import UserModel, Users +from open_webui.internal.db import Base, get_db +from open_webui.models.users import UserModel, Users from open_webui.env import SRC_LOG_LEVELS from pydantic import BaseModel from sqlalchemy import Boolean, Column, String, Text diff --git a/backend/open_webui/apps/webui/models/chats.py b/backend/open_webui/models/chats.py similarity index 99% rename from backend/open_webui/apps/webui/models/chats.py rename to backend/open_webui/models/chats.py index 21250add8..3e621a150 100644 --- a/backend/open_webui/apps/webui/models/chats.py +++ b/backend/open_webui/models/chats.py @@ -3,8 +3,8 @@ import time import uuid from typing import Optional -from open_webui.apps.webui.internal.db import Base, get_db -from open_webui.apps.webui.models.tags import TagModel, Tag, Tags +from open_webui.internal.db import Base, get_db +from open_webui.models.tags import TagModel, Tag, Tags from pydantic import BaseModel, ConfigDict diff --git a/backend/open_webui/apps/webui/models/feedbacks.py b/backend/open_webui/models/feedbacks.py similarity index 98% rename from backend/open_webui/apps/webui/models/feedbacks.py rename to backend/open_webui/models/feedbacks.py index c2356dfd8..7ff5c4540 100644 --- a/backend/open_webui/apps/webui/models/feedbacks.py +++ b/backend/open_webui/models/feedbacks.py @@ -3,8 +3,8 @@ import time import uuid from typing import Optional -from open_webui.apps.webui.internal.db import Base, get_db -from open_webui.apps.webui.models.chats import Chats +from open_webui.internal.db import Base, get_db +from open_webui.models.chats import Chats from open_webui.env import SRC_LOG_LEVELS from pydantic import BaseModel, ConfigDict diff --git a/backend/open_webui/apps/webui/models/files.py b/backend/open_webui/models/files.py similarity index 98% rename from backend/open_webui/apps/webui/models/files.py rename to backend/open_webui/models/files.py index 31c9164b6..4050b0140 100644 --- a/backend/open_webui/apps/webui/models/files.py +++ b/backend/open_webui/models/files.py @@ -2,7 +2,7 @@ import logging import time from typing import Optional -from open_webui.apps.webui.internal.db import Base, JSONField, get_db +from open_webui.internal.db import Base, JSONField, get_db from open_webui.env import SRC_LOG_LEVELS from pydantic import BaseModel, ConfigDict from sqlalchemy import BigInteger, Column, String, Text, JSON diff --git a/backend/open_webui/apps/webui/models/folders.py b/backend/open_webui/models/folders.py similarity index 98% rename from backend/open_webui/apps/webui/models/folders.py rename to backend/open_webui/models/folders.py index 90e8880aa..040774196 100644 --- a/backend/open_webui/apps/webui/models/folders.py +++ b/backend/open_webui/models/folders.py @@ -3,8 +3,8 @@ import time import uuid from typing import Optional -from open_webui.apps.webui.internal.db import Base, get_db -from open_webui.apps.webui.models.chats import Chats +from open_webui.internal.db import Base, get_db +from open_webui.models.chats import Chats from open_webui.env import SRC_LOG_LEVELS from pydantic import BaseModel, ConfigDict diff --git a/backend/open_webui/apps/webui/models/functions.py b/backend/open_webui/models/functions.py similarity index 98% rename from backend/open_webui/apps/webui/models/functions.py rename to backend/open_webui/models/functions.py index fda155075..6c6aed862 100644 --- a/backend/open_webui/apps/webui/models/functions.py +++ b/backend/open_webui/models/functions.py @@ -2,8 +2,8 @@ import logging import time from typing import Optional -from open_webui.apps.webui.internal.db import Base, JSONField, get_db -from open_webui.apps.webui.models.users import Users +from open_webui.internal.db import Base, JSONField, get_db +from open_webui.models.users import Users from open_webui.env import SRC_LOG_LEVELS from pydantic import BaseModel, ConfigDict from sqlalchemy import BigInteger, Boolean, Column, String, Text diff --git a/backend/open_webui/apps/webui/models/groups.py b/backend/open_webui/models/groups.py similarity index 97% rename from backend/open_webui/apps/webui/models/groups.py rename to backend/open_webui/models/groups.py index e692198cd..8f0728411 100644 --- a/backend/open_webui/apps/webui/models/groups.py +++ b/backend/open_webui/models/groups.py @@ -4,10 +4,10 @@ import time from typing import Optional import uuid -from open_webui.apps.webui.internal.db import Base, get_db +from open_webui.internal.db import Base, get_db from open_webui.env import SRC_LOG_LEVELS -from open_webui.apps.webui.models.files import FileMetadataResponse +from open_webui.models.files import FileMetadataResponse from pydantic import BaseModel, ConfigDict diff --git a/backend/open_webui/apps/webui/models/knowledge.py b/backend/open_webui/models/knowledge.py similarity index 97% rename from backend/open_webui/apps/webui/models/knowledge.py rename to backend/open_webui/models/knowledge.py index e1a13b3fd..bed3d5542 100644 --- a/backend/open_webui/apps/webui/models/knowledge.py +++ b/backend/open_webui/models/knowledge.py @@ -4,11 +4,11 @@ import time from typing import Optional import uuid -from open_webui.apps.webui.internal.db import Base, get_db +from open_webui.internal.db import Base, get_db from open_webui.env import SRC_LOG_LEVELS -from open_webui.apps.webui.models.files import FileMetadataResponse -from open_webui.apps.webui.models.users import Users, UserResponse +from open_webui.models.files import FileMetadataResponse +from open_webui.models.users import Users, UserResponse from pydantic import BaseModel, ConfigDict diff --git a/backend/open_webui/apps/webui/models/memories.py b/backend/open_webui/models/memories.py similarity index 98% rename from backend/open_webui/apps/webui/models/memories.py rename to backend/open_webui/models/memories.py index 6686058d3..c8dae9726 100644 --- a/backend/open_webui/apps/webui/models/memories.py +++ b/backend/open_webui/models/memories.py @@ -2,7 +2,7 @@ import time import uuid from typing import Optional -from open_webui.apps.webui.internal.db import Base, get_db +from open_webui.internal.db import Base, get_db from pydantic import BaseModel, ConfigDict from sqlalchemy import BigInteger, Column, String, Text diff --git a/backend/open_webui/apps/webui/models/models.py b/backend/open_webui/models/models.py similarity index 98% rename from backend/open_webui/apps/webui/models/models.py rename to backend/open_webui/models/models.py index 50581bc73..f2f59d7c4 100644 --- a/backend/open_webui/apps/webui/models/models.py +++ b/backend/open_webui/models/models.py @@ -2,10 +2,10 @@ import logging import time from typing import Optional -from open_webui.apps.webui.internal.db import Base, JSONField, get_db +from open_webui.internal.db import Base, JSONField, get_db from open_webui.env import SRC_LOG_LEVELS -from open_webui.apps.webui.models.users import Users, UserResponse +from open_webui.models.users import Users, UserResponse from pydantic import BaseModel, ConfigDict diff --git a/backend/open_webui/apps/webui/models/prompts.py b/backend/open_webui/models/prompts.py similarity index 97% rename from backend/open_webui/apps/webui/models/prompts.py rename to backend/open_webui/models/prompts.py index fe9999195..8ef4cd2be 100644 --- a/backend/open_webui/apps/webui/models/prompts.py +++ b/backend/open_webui/models/prompts.py @@ -1,8 +1,8 @@ import time from typing import Optional -from open_webui.apps.webui.internal.db import Base, get_db -from open_webui.apps.webui.models.users import Users, UserResponse +from open_webui.internal.db import Base, get_db +from open_webui.models.users import Users, UserResponse from pydantic import BaseModel, ConfigDict from sqlalchemy import BigInteger, Column, String, Text, JSON diff --git a/backend/open_webui/apps/webui/models/tags.py b/backend/open_webui/models/tags.py similarity index 98% rename from backend/open_webui/apps/webui/models/tags.py rename to backend/open_webui/models/tags.py index 7424a2660..3e812db95 100644 --- a/backend/open_webui/apps/webui/models/tags.py +++ b/backend/open_webui/models/tags.py @@ -3,7 +3,7 @@ import time import uuid from typing import Optional -from open_webui.apps.webui.internal.db import Base, get_db +from open_webui.internal.db import Base, get_db from open_webui.env import SRC_LOG_LEVELS diff --git a/backend/open_webui/apps/webui/models/tools.py b/backend/open_webui/models/tools.py similarity index 98% rename from backend/open_webui/apps/webui/models/tools.py rename to backend/open_webui/models/tools.py index 8f798c317..a5f13ebb7 100644 --- a/backend/open_webui/apps/webui/models/tools.py +++ b/backend/open_webui/models/tools.py @@ -2,8 +2,8 @@ import logging import time from typing import Optional -from open_webui.apps.webui.internal.db import Base, JSONField, get_db -from open_webui.apps.webui.models.users import Users, UserResponse +from open_webui.internal.db import Base, JSONField, get_db +from open_webui.models.users import Users, UserResponse from open_webui.env import SRC_LOG_LEVELS from pydantic import BaseModel, ConfigDict from sqlalchemy import BigInteger, Column, String, Text, JSON diff --git a/backend/open_webui/apps/webui/models/users.py b/backend/open_webui/models/users.py similarity index 98% rename from backend/open_webui/apps/webui/models/users.py rename to backend/open_webui/models/users.py index 5bbcc3099..5b6c27214 100644 --- a/backend/open_webui/apps/webui/models/users.py +++ b/backend/open_webui/models/users.py @@ -1,8 +1,8 @@ import time from typing import Optional -from open_webui.apps.webui.internal.db import Base, JSONField, get_db -from open_webui.apps.webui.models.chats import Chats +from open_webui.internal.db import Base, JSONField, get_db +from open_webui.models.chats import Chats from pydantic import BaseModel, ConfigDict from sqlalchemy import BigInteger, Column, String, Text diff --git a/backend/open_webui/apps/retrieval/loaders/main.py b/backend/open_webui/retrieval/loaders/main.py similarity index 100% rename from backend/open_webui/apps/retrieval/loaders/main.py rename to backend/open_webui/retrieval/loaders/main.py diff --git a/backend/open_webui/apps/retrieval/loaders/youtube.py b/backend/open_webui/retrieval/loaders/youtube.py similarity index 100% rename from backend/open_webui/apps/retrieval/loaders/youtube.py rename to backend/open_webui/retrieval/loaders/youtube.py diff --git a/backend/open_webui/apps/retrieval/models/colbert.py b/backend/open_webui/retrieval/models/colbert.py similarity index 100% rename from backend/open_webui/apps/retrieval/models/colbert.py rename to backend/open_webui/retrieval/models/colbert.py diff --git a/backend/open_webui/apps/retrieval/utils.py b/backend/open_webui/retrieval/utils.py similarity index 100% rename from backend/open_webui/apps/retrieval/utils.py rename to backend/open_webui/retrieval/utils.py diff --git a/backend/open_webui/apps/retrieval/vector/connector.py b/backend/open_webui/retrieval/vector/connector.py similarity index 100% rename from backend/open_webui/apps/retrieval/vector/connector.py rename to backend/open_webui/retrieval/vector/connector.py diff --git a/backend/open_webui/apps/retrieval/vector/dbs/chroma.py b/backend/open_webui/retrieval/vector/dbs/chroma.py similarity index 100% rename from backend/open_webui/apps/retrieval/vector/dbs/chroma.py rename to backend/open_webui/retrieval/vector/dbs/chroma.py diff --git a/backend/open_webui/apps/retrieval/vector/dbs/milvus.py b/backend/open_webui/retrieval/vector/dbs/milvus.py similarity index 100% rename from backend/open_webui/apps/retrieval/vector/dbs/milvus.py rename to backend/open_webui/retrieval/vector/dbs/milvus.py diff --git a/backend/open_webui/apps/retrieval/vector/dbs/opensearch.py b/backend/open_webui/retrieval/vector/dbs/opensearch.py similarity index 100% rename from backend/open_webui/apps/retrieval/vector/dbs/opensearch.py rename to backend/open_webui/retrieval/vector/dbs/opensearch.py diff --git a/backend/open_webui/apps/retrieval/vector/dbs/pgvector.py b/backend/open_webui/retrieval/vector/dbs/pgvector.py similarity index 99% rename from backend/open_webui/apps/retrieval/vector/dbs/pgvector.py rename to backend/open_webui/retrieval/vector/dbs/pgvector.py index d537943a1..b8317957e 100644 --- a/backend/open_webui/apps/retrieval/vector/dbs/pgvector.py +++ b/backend/open_webui/retrieval/vector/dbs/pgvector.py @@ -40,7 +40,7 @@ class PgvectorClient: # if no pgvector uri, use the existing database connection if not PGVECTOR_DB_URL: - from open_webui.apps.webui.internal.db import Session + from open_webui.internal.db import Session self.session = Session else: diff --git a/backend/open_webui/apps/retrieval/vector/dbs/qdrant.py b/backend/open_webui/retrieval/vector/dbs/qdrant.py similarity index 100% rename from backend/open_webui/apps/retrieval/vector/dbs/qdrant.py rename to backend/open_webui/retrieval/vector/dbs/qdrant.py diff --git a/backend/open_webui/apps/retrieval/vector/main.py b/backend/open_webui/retrieval/vector/main.py similarity index 100% rename from backend/open_webui/apps/retrieval/vector/main.py rename to backend/open_webui/retrieval/vector/main.py diff --git a/backend/open_webui/apps/retrieval/web/bing.py b/backend/open_webui/retrieval/web/bing.py similarity index 100% rename from backend/open_webui/apps/retrieval/web/bing.py rename to backend/open_webui/retrieval/web/bing.py diff --git a/backend/open_webui/apps/retrieval/web/brave.py b/backend/open_webui/retrieval/web/brave.py similarity index 100% rename from backend/open_webui/apps/retrieval/web/brave.py rename to backend/open_webui/retrieval/web/brave.py diff --git a/backend/open_webui/apps/retrieval/web/duckduckgo.py b/backend/open_webui/retrieval/web/duckduckgo.py similarity index 100% rename from backend/open_webui/apps/retrieval/web/duckduckgo.py rename to backend/open_webui/retrieval/web/duckduckgo.py diff --git a/backend/open_webui/apps/retrieval/web/google_pse.py b/backend/open_webui/retrieval/web/google_pse.py similarity index 100% rename from backend/open_webui/apps/retrieval/web/google_pse.py rename to backend/open_webui/retrieval/web/google_pse.py diff --git a/backend/open_webui/apps/retrieval/web/jina_search.py b/backend/open_webui/retrieval/web/jina_search.py similarity index 100% rename from backend/open_webui/apps/retrieval/web/jina_search.py rename to backend/open_webui/retrieval/web/jina_search.py diff --git a/backend/open_webui/apps/retrieval/web/kagi.py b/backend/open_webui/retrieval/web/kagi.py similarity index 100% rename from backend/open_webui/apps/retrieval/web/kagi.py rename to backend/open_webui/retrieval/web/kagi.py diff --git a/backend/open_webui/apps/retrieval/web/main.py b/backend/open_webui/retrieval/web/main.py similarity index 100% rename from backend/open_webui/apps/retrieval/web/main.py rename to backend/open_webui/retrieval/web/main.py diff --git a/backend/open_webui/apps/retrieval/web/mojeek.py b/backend/open_webui/retrieval/web/mojeek.py similarity index 100% rename from backend/open_webui/apps/retrieval/web/mojeek.py rename to backend/open_webui/retrieval/web/mojeek.py diff --git a/backend/open_webui/apps/retrieval/web/searchapi.py b/backend/open_webui/retrieval/web/searchapi.py similarity index 100% rename from backend/open_webui/apps/retrieval/web/searchapi.py rename to backend/open_webui/retrieval/web/searchapi.py diff --git a/backend/open_webui/apps/retrieval/web/searxng.py b/backend/open_webui/retrieval/web/searxng.py similarity index 100% rename from backend/open_webui/apps/retrieval/web/searxng.py rename to backend/open_webui/retrieval/web/searxng.py diff --git a/backend/open_webui/apps/retrieval/web/serper.py b/backend/open_webui/retrieval/web/serper.py similarity index 100% rename from backend/open_webui/apps/retrieval/web/serper.py rename to backend/open_webui/retrieval/web/serper.py diff --git a/backend/open_webui/apps/retrieval/web/serply.py b/backend/open_webui/retrieval/web/serply.py similarity index 100% rename from backend/open_webui/apps/retrieval/web/serply.py rename to backend/open_webui/retrieval/web/serply.py diff --git a/backend/open_webui/apps/retrieval/web/serpstack.py b/backend/open_webui/retrieval/web/serpstack.py similarity index 100% rename from backend/open_webui/apps/retrieval/web/serpstack.py rename to backend/open_webui/retrieval/web/serpstack.py diff --git a/backend/open_webui/apps/retrieval/web/tavily.py b/backend/open_webui/retrieval/web/tavily.py similarity index 100% rename from backend/open_webui/apps/retrieval/web/tavily.py rename to backend/open_webui/retrieval/web/tavily.py diff --git a/backend/open_webui/apps/retrieval/web/testdata/bing.json b/backend/open_webui/retrieval/web/testdata/bing.json similarity index 100% rename from backend/open_webui/apps/retrieval/web/testdata/bing.json rename to backend/open_webui/retrieval/web/testdata/bing.json diff --git a/backend/open_webui/apps/retrieval/web/testdata/brave.json b/backend/open_webui/retrieval/web/testdata/brave.json similarity index 100% rename from backend/open_webui/apps/retrieval/web/testdata/brave.json rename to backend/open_webui/retrieval/web/testdata/brave.json diff --git a/backend/open_webui/apps/retrieval/web/testdata/google_pse.json b/backend/open_webui/retrieval/web/testdata/google_pse.json similarity index 100% rename from backend/open_webui/apps/retrieval/web/testdata/google_pse.json rename to backend/open_webui/retrieval/web/testdata/google_pse.json diff --git a/backend/open_webui/apps/retrieval/web/testdata/searchapi.json b/backend/open_webui/retrieval/web/testdata/searchapi.json similarity index 100% rename from backend/open_webui/apps/retrieval/web/testdata/searchapi.json rename to backend/open_webui/retrieval/web/testdata/searchapi.json diff --git a/backend/open_webui/apps/retrieval/web/testdata/searxng.json b/backend/open_webui/retrieval/web/testdata/searxng.json similarity index 100% rename from backend/open_webui/apps/retrieval/web/testdata/searxng.json rename to backend/open_webui/retrieval/web/testdata/searxng.json diff --git a/backend/open_webui/apps/retrieval/web/testdata/serper.json b/backend/open_webui/retrieval/web/testdata/serper.json similarity index 100% rename from backend/open_webui/apps/retrieval/web/testdata/serper.json rename to backend/open_webui/retrieval/web/testdata/serper.json diff --git a/backend/open_webui/apps/retrieval/web/testdata/serply.json b/backend/open_webui/retrieval/web/testdata/serply.json similarity index 100% rename from backend/open_webui/apps/retrieval/web/testdata/serply.json rename to backend/open_webui/retrieval/web/testdata/serply.json diff --git a/backend/open_webui/apps/retrieval/web/testdata/serpstack.json b/backend/open_webui/retrieval/web/testdata/serpstack.json similarity index 100% rename from backend/open_webui/apps/retrieval/web/testdata/serpstack.json rename to backend/open_webui/retrieval/web/testdata/serpstack.json diff --git a/backend/open_webui/apps/retrieval/web/utils.py b/backend/open_webui/retrieval/web/utils.py similarity index 100% rename from backend/open_webui/apps/retrieval/web/utils.py rename to backend/open_webui/retrieval/web/utils.py diff --git a/backend/open_webui/apps/audio/main.py b/backend/open_webui/routers/audio.py similarity index 94% rename from backend/open_webui/apps/audio/main.py rename to backend/open_webui/routers/audio.py index a3972f19f..3203727a7 100644 --- a/backend/open_webui/apps/audio/main.py +++ b/backend/open_webui/routers/audio.py @@ -25,11 +25,10 @@ from open_webui.config import ( AUDIO_TTS_VOICE, AUDIO_TTS_AZURE_SPEECH_REGION, AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT, - CACHE_DIR, - CORS_ALLOW_ORIGIN, WHISPER_MODEL, WHISPER_MODEL_AUTO_UPDATE, WHISPER_MODEL_DIR, + CACHE_DIR, AppConfig, ) @@ -55,44 +54,6 @@ MAX_FILE_SIZE = MAX_FILE_SIZE_MB * 1024 * 1024 # Convert MB to bytes log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["AUDIO"]) -app = FastAPI( - docs_url="/docs" if ENV == "dev" else None, - openapi_url="/openapi.json" if ENV == "dev" else None, - redoc_url=None, -) - -app.add_middleware( - CORSMiddleware, - allow_origins=CORS_ALLOW_ORIGIN, - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) - -app.state.config = AppConfig() - -app.state.config.STT_OPENAI_API_BASE_URL = AUDIO_STT_OPENAI_API_BASE_URL -app.state.config.STT_OPENAI_API_KEY = AUDIO_STT_OPENAI_API_KEY -app.state.config.STT_ENGINE = AUDIO_STT_ENGINE -app.state.config.STT_MODEL = AUDIO_STT_MODEL - -app.state.config.WHISPER_MODEL = WHISPER_MODEL -app.state.faster_whisper_model = None - -app.state.config.TTS_OPENAI_API_BASE_URL = AUDIO_TTS_OPENAI_API_BASE_URL -app.state.config.TTS_OPENAI_API_KEY = AUDIO_TTS_OPENAI_API_KEY -app.state.config.TTS_ENGINE = AUDIO_TTS_ENGINE -app.state.config.TTS_MODEL = AUDIO_TTS_MODEL -app.state.config.TTS_VOICE = AUDIO_TTS_VOICE -app.state.config.TTS_API_KEY = AUDIO_TTS_API_KEY -app.state.config.TTS_SPLIT_ON = AUDIO_TTS_SPLIT_ON - - -app.state.speech_synthesiser = None -app.state.speech_speaker_embeddings_dataset = None - -app.state.config.TTS_AZURE_SPEECH_REGION = AUDIO_TTS_AZURE_SPEECH_REGION -app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT = AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT # setting device type for whisper model whisper_device_type = DEVICE_TYPE if DEVICE_TYPE and DEVICE_TYPE == "cuda" else "cpu" diff --git a/backend/open_webui/apps/webui/routers/auths.py b/backend/open_webui/routers/auths.py similarity index 99% rename from backend/open_webui/apps/webui/routers/auths.py rename to backend/open_webui/routers/auths.py index 094ce568f..0b1f42edf 100644 --- a/backend/open_webui/apps/webui/routers/auths.py +++ b/backend/open_webui/routers/auths.py @@ -5,7 +5,7 @@ import datetime import logging from aiohttp import ClientSession -from open_webui.apps.webui.models.auths import ( +from open_webui.models.auths import ( AddUserForm, ApiKey, Auths, @@ -18,7 +18,7 @@ from open_webui.apps.webui.models.auths import ( UpdateProfileForm, UserResponse, ) -from open_webui.apps.webui.models.users import Users +from open_webui.models.users import Users from open_webui.constants import ERROR_MESSAGES, WEBHOOK_MESSAGES from open_webui.env import ( diff --git a/backend/open_webui/routers/chat.py b/backend/open_webui/routers/chat.py index e69de29bb..fba1ffa1b 100644 --- a/backend/open_webui/routers/chat.py +++ b/backend/open_webui/routers/chat.py @@ -0,0 +1,411 @@ +from fastapi import APIRouter, Depends, HTTPException, Response, status +from pydantic import BaseModel + +router = APIRouter() + + +@app.post("/api/chat/completions") +async def generate_chat_completions( + request: Request, + form_data: dict, + user=Depends(get_verified_user), + bypass_filter: bool = False, +): + if BYPASS_MODEL_ACCESS_CONTROL: + bypass_filter = True + + model_list = request.state.models + models = {model["id"]: model for model in model_list} + + model_id = form_data["model"] + if model_id not in models: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Model not found", + ) + + model = models[model_id] + + # Check if user has access to the model + if not bypass_filter and user.role == "user": + if model.get("arena"): + if not has_access( + user.id, + type="read", + access_control=model.get("info", {}) + .get("meta", {}) + .get("access_control", {}), + ): + raise HTTPException( + status_code=403, + detail="Model not found", + ) + else: + model_info = Models.get_model_by_id(model_id) + if not model_info: + raise HTTPException( + status_code=404, + detail="Model not found", + ) + elif not ( + user.id == model_info.user_id + or has_access( + user.id, type="read", access_control=model_info.access_control + ) + ): + raise HTTPException( + status_code=403, + detail="Model not found", + ) + + if model["owned_by"] == "arena": + model_ids = model.get("info", {}).get("meta", {}).get("model_ids") + filter_mode = model.get("info", {}).get("meta", {}).get("filter_mode") + if model_ids and filter_mode == "exclude": + model_ids = [ + model["id"] + for model in await get_all_models() + if model.get("owned_by") != "arena" and model["id"] not in model_ids + ] + + selected_model_id = None + if isinstance(model_ids, list) and model_ids: + selected_model_id = random.choice(model_ids) + else: + model_ids = [ + model["id"] + for model in await get_all_models() + if model.get("owned_by") != "arena" + ] + selected_model_id = random.choice(model_ids) + + form_data["model"] = selected_model_id + + if form_data.get("stream") == True: + + async def stream_wrapper(stream): + yield f"data: {json.dumps({'selected_model_id': selected_model_id})}\n\n" + async for chunk in stream: + yield chunk + + response = await generate_chat_completions( + form_data, user, bypass_filter=True + ) + return StreamingResponse( + stream_wrapper(response.body_iterator), media_type="text/event-stream" + ) + else: + return { + **( + await generate_chat_completions(form_data, user, bypass_filter=True) + ), + "selected_model_id": selected_model_id, + } + + if model.get("pipe"): + # Below does not require bypass_filter because this is the only route the uses this function and it is already bypassing the filter + return await generate_function_chat_completion( + form_data, user=user, models=models + ) + if model["owned_by"] == "ollama": + # Using /ollama/api/chat endpoint + form_data = convert_payload_openai_to_ollama(form_data) + form_data = GenerateChatCompletionForm(**form_data) + response = await generate_ollama_chat_completion( + form_data=form_data, user=user, bypass_filter=bypass_filter + ) + if form_data.stream: + response.headers["content-type"] = "text/event-stream" + return StreamingResponse( + convert_streaming_response_ollama_to_openai(response), + headers=dict(response.headers), + ) + else: + return convert_response_ollama_to_openai(response) + else: + return await generate_openai_chat_completion( + form_data, user=user, bypass_filter=bypass_filter + ) + + +@app.post("/api/chat/completed") +async def chat_completed(form_data: dict, user=Depends(get_verified_user)): + model_list = await get_all_models() + models = {model["id"]: model for model in model_list} + + data = form_data + model_id = data["model"] + if model_id not in models: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Model not found", + ) + + model = models[model_id] + sorted_filters = get_sorted_filters(model_id, models) + if "pipeline" in model: + sorted_filters = [model] + sorted_filters + + for filter in sorted_filters: + r = None + try: + urlIdx = filter["urlIdx"] + + url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] + key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] + + if key != "": + headers = {"Authorization": f"Bearer {key}"} + r = requests.post( + f"{url}/{filter['id']}/filter/outlet", + headers=headers, + json={ + "user": { + "id": user.id, + "name": user.name, + "email": user.email, + "role": user.role, + }, + "body": data, + }, + ) + + r.raise_for_status() + data = r.json() + except Exception as e: + # Handle connection error here + print(f"Connection error: {e}") + + if r is not None: + try: + res = r.json() + if "detail" in res: + return JSONResponse( + status_code=r.status_code, + content=res, + ) + except Exception: + pass + + else: + pass + + __event_emitter__ = get_event_emitter( + { + "chat_id": data["chat_id"], + "message_id": data["id"], + "session_id": data["session_id"], + } + ) + + __event_call__ = get_event_call( + { + "chat_id": data["chat_id"], + "message_id": data["id"], + "session_id": data["session_id"], + } + ) + + def get_priority(function_id): + function = Functions.get_function_by_id(function_id) + if function is not None and hasattr(function, "valves"): + # TODO: Fix FunctionModel to include vavles + return (function.valves if function.valves else {}).get("priority", 0) + return 0 + + filter_ids = [function.id for function in Functions.get_global_filter_functions()] + if "info" in model and "meta" in model["info"]: + filter_ids.extend(model["info"]["meta"].get("filterIds", [])) + filter_ids = list(set(filter_ids)) + + enabled_filter_ids = [ + function.id + for function in Functions.get_functions_by_type("filter", active_only=True) + ] + filter_ids = [ + filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids + ] + + # Sort filter_ids by priority, using the get_priority function + filter_ids.sort(key=get_priority) + + for filter_id in filter_ids: + filter = Functions.get_function_by_id(filter_id) + if not filter: + continue + + if filter_id in webui_app.state.FUNCTIONS: + function_module = webui_app.state.FUNCTIONS[filter_id] + else: + function_module, _, _ = load_function_module_by_id(filter_id) + webui_app.state.FUNCTIONS[filter_id] = function_module + + if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): + valves = Functions.get_function_valves_by_id(filter_id) + function_module.valves = function_module.Valves( + **(valves if valves else {}) + ) + + if not hasattr(function_module, "outlet"): + continue + try: + outlet = function_module.outlet + + # Get the signature of the function + sig = inspect.signature(outlet) + params = {"body": data} + + # Extra parameters to be passed to the function + extra_params = { + "__model__": model, + "__id__": filter_id, + "__event_emitter__": __event_emitter__, + "__event_call__": __event_call__, + } + + # Add extra params in contained in function signature + for key, value in extra_params.items(): + if key in sig.parameters: + params[key] = value + + if "__user__" in sig.parameters: + __user__ = { + "id": user.id, + "email": user.email, + "name": user.name, + "role": user.role, + } + + try: + if hasattr(function_module, "UserValves"): + __user__["valves"] = function_module.UserValves( + **Functions.get_user_valves_by_id_and_user_id( + filter_id, user.id + ) + ) + except Exception as e: + print(e) + + params = {**params, "__user__": __user__} + + if inspect.iscoroutinefunction(outlet): + data = await outlet(**params) + else: + data = outlet(**params) + + except Exception as e: + print(f"Error: {e}") + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"detail": str(e)}, + ) + + return data + + +@app.post("/api/chat/actions/{action_id}") +async def chat_action(action_id: str, form_data: dict, user=Depends(get_verified_user)): + if "." in action_id: + action_id, sub_action_id = action_id.split(".") + else: + sub_action_id = None + + action = Functions.get_function_by_id(action_id) + if not action: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Action not found", + ) + + model_list = await get_all_models() + models = {model["id"]: model for model in model_list} + + data = form_data + model_id = data["model"] + + if model_id not in models: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Model not found", + ) + model = models[model_id] + + __event_emitter__ = get_event_emitter( + { + "chat_id": data["chat_id"], + "message_id": data["id"], + "session_id": data["session_id"], + } + ) + __event_call__ = get_event_call( + { + "chat_id": data["chat_id"], + "message_id": data["id"], + "session_id": data["session_id"], + } + ) + + if action_id in webui_app.state.FUNCTIONS: + function_module = webui_app.state.FUNCTIONS[action_id] + else: + function_module, _, _ = load_function_module_by_id(action_id) + webui_app.state.FUNCTIONS[action_id] = function_module + + if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): + valves = Functions.get_function_valves_by_id(action_id) + function_module.valves = function_module.Valves(**(valves if valves else {})) + + if hasattr(function_module, "action"): + try: + action = function_module.action + + # Get the signature of the function + sig = inspect.signature(action) + params = {"body": data} + + # Extra parameters to be passed to the function + extra_params = { + "__model__": model, + "__id__": sub_action_id if sub_action_id is not None else action_id, + "__event_emitter__": __event_emitter__, + "__event_call__": __event_call__, + } + + # Add extra params in contained in function signature + for key, value in extra_params.items(): + if key in sig.parameters: + params[key] = value + + if "__user__" in sig.parameters: + __user__ = { + "id": user.id, + "email": user.email, + "name": user.name, + "role": user.role, + } + + try: + if hasattr(function_module, "UserValves"): + __user__["valves"] = function_module.UserValves( + **Functions.get_user_valves_by_id_and_user_id( + action_id, user.id + ) + ) + except Exception as e: + print(e) + + params = {**params, "__user__": __user__} + + if inspect.iscoroutinefunction(action): + data = await action(**params) + else: + data = action(**params) + + except Exception as e: + print(f"Error: {e}") + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"detail": str(e)}, + ) + + return data diff --git a/backend/open_webui/apps/webui/routers/chats.py b/backend/open_webui/routers/chats.py similarity index 99% rename from backend/open_webui/apps/webui/routers/chats.py rename to backend/open_webui/routers/chats.py index ec5dae4bf..5e0e75e24 100644 --- a/backend/open_webui/apps/webui/routers/chats.py +++ b/backend/open_webui/routers/chats.py @@ -2,15 +2,15 @@ import json import logging from typing import Optional -from open_webui.apps.webui.models.chats import ( +from open_webui.models.chats import ( ChatForm, ChatImportForm, ChatResponse, Chats, ChatTitleIdResponse, ) -from open_webui.apps.webui.models.tags import TagModel, Tags -from open_webui.apps.webui.models.folders import Folders +from open_webui.models.tags import TagModel, Tags +from open_webui.models.folders import Folders from open_webui.config import ENABLE_ADMIN_CHAT_ACCESS, ENABLE_ADMIN_EXPORT from open_webui.constants import ERROR_MESSAGES diff --git a/backend/open_webui/apps/webui/routers/configs.py b/backend/open_webui/routers/configs.py similarity index 100% rename from backend/open_webui/apps/webui/routers/configs.py rename to backend/open_webui/routers/configs.py diff --git a/backend/open_webui/apps/webui/routers/evaluations.py b/backend/open_webui/routers/evaluations.py similarity index 97% rename from backend/open_webui/apps/webui/routers/evaluations.py rename to backend/open_webui/routers/evaluations.py index 0bcee2a79..f0c4a6b06 100644 --- a/backend/open_webui/apps/webui/routers/evaluations.py +++ b/backend/open_webui/routers/evaluations.py @@ -2,8 +2,8 @@ from typing import Optional from fastapi import APIRouter, Depends, HTTPException, status, Request from pydantic import BaseModel -from open_webui.apps.webui.models.users import Users, UserModel -from open_webui.apps.webui.models.feedbacks import ( +from open_webui.models.users import Users, UserModel +from open_webui.models.feedbacks import ( FeedbackModel, FeedbackResponse, FeedbackForm, diff --git a/backend/open_webui/apps/webui/routers/files.py b/backend/open_webui/routers/files.py similarity index 98% rename from backend/open_webui/apps/webui/routers/files.py rename to backend/open_webui/routers/files.py index 4b7cf1ed4..49deb998f 100644 --- a/backend/open_webui/apps/webui/routers/files.py +++ b/backend/open_webui/routers/files.py @@ -8,13 +8,13 @@ import mimetypes from open_webui.storage.provider import Storage -from open_webui.apps.webui.models.files import ( +from open_webui.models.files import ( FileForm, FileModel, FileModelResponse, Files, ) -from open_webui.apps.retrieval.main import process_file, ProcessFileForm +from backend.open_webui.routers.retrieval import process_file, ProcessFileForm from open_webui.config import UPLOAD_DIR from open_webui.env import SRC_LOG_LEVELS diff --git a/backend/open_webui/apps/webui/routers/folders.py b/backend/open_webui/routers/folders.py similarity index 98% rename from backend/open_webui/apps/webui/routers/folders.py rename to backend/open_webui/routers/folders.py index f05781476..ca2fbd213 100644 --- a/backend/open_webui/apps/webui/routers/folders.py +++ b/backend/open_webui/routers/folders.py @@ -8,12 +8,12 @@ from pydantic import BaseModel import mimetypes -from open_webui.apps.webui.models.folders import ( +from open_webui.models.folders import ( FolderForm, FolderModel, Folders, ) -from open_webui.apps.webui.models.chats import Chats +from open_webui.models.chats import Chats from open_webui.config import UPLOAD_DIR from open_webui.env import SRC_LOG_LEVELS diff --git a/backend/open_webui/apps/webui/routers/functions.py b/backend/open_webui/routers/functions.py similarity index 98% rename from backend/open_webui/apps/webui/routers/functions.py rename to backend/open_webui/routers/functions.py index bdd422b95..bb780f112 100644 --- a/backend/open_webui/apps/webui/routers/functions.py +++ b/backend/open_webui/routers/functions.py @@ -2,13 +2,13 @@ import os from pathlib import Path from typing import Optional -from open_webui.apps.webui.models.functions import ( +from open_webui.models.functions import ( FunctionForm, FunctionModel, FunctionResponse, Functions, ) -from open_webui.apps.webui.utils import load_function_module_by_id, replace_imports +from backend.open_webui.utils.plugin import load_function_module_by_id, replace_imports from open_webui.config import CACHE_DIR from open_webui.constants import ERROR_MESSAGES from fastapi import APIRouter, Depends, HTTPException, Request, status diff --git a/backend/open_webui/apps/webui/routers/groups.py b/backend/open_webui/routers/groups.py similarity index 98% rename from backend/open_webui/apps/webui/routers/groups.py rename to backend/open_webui/routers/groups.py index ef392fb6a..e8f8994a4 100644 --- a/backend/open_webui/apps/webui/routers/groups.py +++ b/backend/open_webui/routers/groups.py @@ -2,7 +2,7 @@ import os from pathlib import Path from typing import Optional -from open_webui.apps.webui.models.groups import ( +from open_webui.models.groups import ( Groups, GroupForm, GroupUpdateForm, diff --git a/backend/open_webui/apps/images/main.py b/backend/open_webui/routers/images.py similarity index 92% rename from backend/open_webui/apps/images/main.py rename to backend/open_webui/routers/images.py index 14209df2f..f4c12ab64 100644 --- a/backend/open_webui/apps/images/main.py +++ b/backend/open_webui/routers/images.py @@ -9,31 +9,14 @@ from pathlib import Path from typing import Optional import requests -from open_webui.apps.images.utils.comfyui import ( +from open_webui.utils.images.comfyui import ( ComfyUIGenerateImageForm, ComfyUIWorkflow, comfyui_generate_image, ) -from open_webui.config import ( - AUTOMATIC1111_API_AUTH, - AUTOMATIC1111_BASE_URL, - AUTOMATIC1111_CFG_SCALE, - AUTOMATIC1111_SAMPLER, - AUTOMATIC1111_SCHEDULER, - CACHE_DIR, - COMFYUI_BASE_URL, - COMFYUI_WORKFLOW, - COMFYUI_WORKFLOW_NODES, - CORS_ALLOW_ORIGIN, - ENABLE_IMAGE_GENERATION, - IMAGE_GENERATION_ENGINE, - IMAGE_GENERATION_MODEL, - IMAGE_SIZE, - IMAGE_STEPS, - IMAGES_OPENAI_API_BASE_URL, - IMAGES_OPENAI_API_KEY, - AppConfig, -) + + +from open_webui.config import CACHE_DIR from open_webui.constants import ERROR_MESSAGES from open_webui.env import ENV, SRC_LOG_LEVELS, ENABLE_FORWARD_USER_INFO_HEADERS @@ -54,36 +37,6 @@ app = FastAPI( redoc_url=None, ) -app.add_middleware( - CORSMiddleware, - allow_origins=CORS_ALLOW_ORIGIN, - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) - -app.state.config = AppConfig() - -app.state.config.ENGINE = IMAGE_GENERATION_ENGINE -app.state.config.ENABLED = ENABLE_IMAGE_GENERATION - -app.state.config.OPENAI_API_BASE_URL = IMAGES_OPENAI_API_BASE_URL -app.state.config.OPENAI_API_KEY = IMAGES_OPENAI_API_KEY - -app.state.config.MODEL = IMAGE_GENERATION_MODEL - -app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL -app.state.config.AUTOMATIC1111_API_AUTH = AUTOMATIC1111_API_AUTH -app.state.config.AUTOMATIC1111_CFG_SCALE = AUTOMATIC1111_CFG_SCALE -app.state.config.AUTOMATIC1111_SAMPLER = AUTOMATIC1111_SAMPLER -app.state.config.AUTOMATIC1111_SCHEDULER = AUTOMATIC1111_SCHEDULER -app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL -app.state.config.COMFYUI_WORKFLOW = COMFYUI_WORKFLOW -app.state.config.COMFYUI_WORKFLOW_NODES = COMFYUI_WORKFLOW_NODES - -app.state.config.IMAGE_SIZE = IMAGE_SIZE -app.state.config.IMAGE_STEPS = IMAGE_STEPS - @app.get("/config") async def get_config(request: Request, user=Depends(get_admin_user)): diff --git a/backend/open_webui/apps/webui/routers/knowledge.py b/backend/open_webui/routers/knowledge.py similarity index 98% rename from backend/open_webui/apps/webui/routers/knowledge.py rename to backend/open_webui/routers/knowledge.py index d572e83b7..1617f452e 100644 --- a/backend/open_webui/apps/webui/routers/knowledge.py +++ b/backend/open_webui/routers/knowledge.py @@ -4,15 +4,15 @@ from pydantic import BaseModel from fastapi import APIRouter, Depends, HTTPException, status, Request import logging -from open_webui.apps.webui.models.knowledge import ( +from open_webui.models.knowledge import ( Knowledges, KnowledgeForm, KnowledgeResponse, KnowledgeUserResponse, ) -from open_webui.apps.webui.models.files import Files, FileModel +from open_webui.models.files import Files, FileModel from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT -from open_webui.apps.retrieval.main import process_file, ProcessFileForm +from backend.open_webui.routers.retrieval import process_file, ProcessFileForm from open_webui.constants import ERROR_MESSAGES diff --git a/backend/open_webui/apps/webui/routers/memories.py b/backend/open_webui/routers/memories.py similarity index 98% rename from backend/open_webui/apps/webui/routers/memories.py rename to backend/open_webui/routers/memories.py index 60993607f..7973038c4 100644 --- a/backend/open_webui/apps/webui/routers/memories.py +++ b/backend/open_webui/routers/memories.py @@ -3,7 +3,7 @@ from pydantic import BaseModel import logging from typing import Optional -from open_webui.apps.webui.models.memories import Memories, MemoryModel +from open_webui.models.memories import Memories, MemoryModel from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT from open_webui.utils.auth import get_verified_user from open_webui.env import SRC_LOG_LEVELS diff --git a/backend/open_webui/apps/webui/routers/models.py b/backend/open_webui/routers/models.py similarity index 99% rename from backend/open_webui/apps/webui/routers/models.py rename to backend/open_webui/routers/models.py index 2e073219a..db981a913 100644 --- a/backend/open_webui/apps/webui/routers/models.py +++ b/backend/open_webui/routers/models.py @@ -1,6 +1,6 @@ from typing import Optional -from open_webui.apps.webui.models.models import ( +from open_webui.models.models import ( ModelForm, ModelModel, ModelResponse, diff --git a/backend/open_webui/apps/ollama/main.py b/backend/open_webui/routers/ollama.py similarity index 98% rename from backend/open_webui/apps/ollama/main.py rename to backend/open_webui/routers/ollama.py index 48142fd9f..581a881b7 100644 --- a/backend/open_webui/apps/ollama/main.py +++ b/backend/open_webui/routers/ollama.py @@ -12,15 +12,22 @@ import aiohttp from aiocache import cached import requests -from open_webui.apps.webui.models.models import Models + +from fastapi import Depends, FastAPI, File, HTTPException, Request, UploadFile +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import StreamingResponse +from pydantic import BaseModel, ConfigDict +from starlette.background import BackgroundTask + + +from open_webui.models.models import Models + + from open_webui.config import ( - CORS_ALLOW_ORIGIN, - ENABLE_OLLAMA_API, - OLLAMA_BASE_URLS, - OLLAMA_API_CONFIGS, UPLOAD_DIR, - AppConfig, ) + + from open_webui.env import ( AIOHTTP_CLIENT_TIMEOUT, AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST, @@ -30,11 +37,6 @@ from open_webui.env import ( from open_webui.constants import ERROR_MESSAGES from open_webui.env import ENV, SRC_LOG_LEVELS -from fastapi import Depends, FastAPI, File, HTTPException, Request, UploadFile -from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import StreamingResponse -from pydantic import BaseModel, ConfigDict -from starlette.background import BackgroundTask from open_webui.utils.misc import ( @@ -52,27 +54,6 @@ log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["OLLAMA"]) -app = FastAPI( - docs_url="/docs" if ENV == "dev" else None, - openapi_url="/openapi.json" if ENV == "dev" else None, - redoc_url=None, -) - -app.add_middleware( - CORSMiddleware, - allow_origins=CORS_ALLOW_ORIGIN, - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) - -app.state.config = AppConfig() - -app.state.config.ENABLE_OLLAMA_API = ENABLE_OLLAMA_API -app.state.config.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS -app.state.config.OLLAMA_API_CONFIGS = OLLAMA_API_CONFIGS - - # TODO: Implement a more intelligent load balancing mechanism for distributing requests among multiple backend instances. # Current implementation uses a simple round-robin approach (random.choice). Consider incorporating algorithms like weighted round-robin, # least connections, or least response time for better resource utilization and performance optimization. diff --git a/backend/open_webui/apps/openai/main.py b/backend/open_webui/routers/openai.py similarity index 97% rename from backend/open_webui/apps/openai/main.py rename to backend/open_webui/routers/openai.py index b64e7b28d..1e9ca4af7 100644 --- a/backend/open_webui/apps/openai/main.py +++ b/backend/open_webui/routers/openai.py @@ -10,7 +10,7 @@ from aiocache import cached import requests -from open_webui.apps.webui.models.models import Models +from open_webui.models.models import Models from open_webui.config import ( CACHE_DIR, CORS_ALLOW_ORIGIN, @@ -48,29 +48,6 @@ log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["OPENAI"]) -app = FastAPI( - docs_url="/docs" if ENV == "dev" else None, - openapi_url="/openapi.json" if ENV == "dev" else None, - redoc_url=None, -) - - -app.add_middleware( - CORSMiddleware, - allow_origins=CORS_ALLOW_ORIGIN, - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) - -app.state.config = AppConfig() - -app.state.config.ENABLE_OPENAI_API = ENABLE_OPENAI_API -app.state.config.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS -app.state.config.OPENAI_API_KEYS = OPENAI_API_KEYS -app.state.config.OPENAI_API_CONFIGS = OPENAI_API_CONFIGS - - @app.get("/config") async def get_config(user=Depends(get_admin_user)): return { @@ -91,7 +68,6 @@ class OpenAIConfigForm(BaseModel): @app.post("/config/update") async def update_config(form_data: OpenAIConfigForm, user=Depends(get_admin_user)): app.state.config.ENABLE_OPENAI_API = form_data.ENABLE_OPENAI_API - app.state.config.OPENAI_API_BASE_URLS = form_data.OPENAI_API_BASE_URLS app.state.config.OPENAI_API_KEYS = form_data.OPENAI_API_KEYS diff --git a/backend/open_webui/routers/pipelines.py b/backend/open_webui/routers/pipelines.py index 0d9a32c83..9450d520b 100644 --- a/backend/open_webui/routers/pipelines.py +++ b/backend/open_webui/routers/pipelines.py @@ -3,7 +3,7 @@ from pydantic import BaseModel from starlette.responses import FileResponse -from open_webui.apps.webui.models.chats import ChatTitleMessagesForm +from open_webui.models.chats import ChatTitleMessagesForm from open_webui.config import DATA_DIR, ENABLE_ADMIN_EXPORT from open_webui.constants import ERROR_MESSAGES @@ -14,86 +14,328 @@ from open_webui.utils.auth import get_admin_user router = APIRouter() -@router.get("/gravatar") -async def get_gravatar( - email: str, +################################## +# +# Pipelines Endpoints +# +################################## + + +# TODO: Refactor pipelines API endpoints below into a separate file + + +@app.get("/api/pipelines/list") +async def get_pipelines_list(user=Depends(get_admin_user)): + responses = await get_openai_models_responses() + + log.debug(f"get_pipelines_list: get_openai_models_responses returned {responses}") + urlIdxs = [ + idx + for idx, response in enumerate(responses) + if response is not None and "pipelines" in response + ] + + return { + "data": [ + { + "url": openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx], + "idx": urlIdx, + } + for urlIdx in urlIdxs + ] + } + + +@app.post("/api/pipelines/upload") +async def upload_pipeline( + urlIdx: int = Form(...), file: UploadFile = File(...), user=Depends(get_admin_user) ): - return get_gravatar_url(email) - - -class CodeFormatRequest(BaseModel): - code: str - - -@router.post("/code/format") -async def format_code(request: CodeFormatRequest): - try: - formatted_code = black.format_str(request.code, mode=black.Mode()) - return {"code": formatted_code} - except black.NothingChanged: - return {"code": request.code} - except Exception as e: - raise HTTPException(status_code=400, detail=str(e)) - - -class MarkdownForm(BaseModel): - md: str - - -@router.post("/markdown") -async def get_html_from_markdown( - form_data: MarkdownForm, -): - return {"html": markdown.markdown(form_data.md)} - - -class ChatForm(BaseModel): - title: str - messages: list[dict] - - -@router.post("/pdf") -async def download_chat_as_pdf( - form_data: ChatTitleMessagesForm, -): - try: - pdf_bytes = PDFGenerator(form_data).generate_chat_pdf() - - return Response( - content=pdf_bytes, - media_type="application/pdf", - headers={"Content-Disposition": "attachment;filename=chat.pdf"}, - ) - except Exception as e: - print(e) - raise HTTPException(status_code=400, detail=str(e)) - - -@router.get("/db/download") -async def download_db(user=Depends(get_admin_user)): - if not ENABLE_ADMIN_EXPORT: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail=ERROR_MESSAGES.ACCESS_PROHIBITED, - ) - from open_webui.apps.webui.internal.db import engine - - if engine.name != "sqlite": + print("upload_pipeline", urlIdx, file.filename) + # Check if the uploaded file is a python file + if not (file.filename and file.filename.endswith(".py")): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.DB_NOT_SQLITE, + detail="Only Python (.py) files are allowed.", ) - return FileResponse( - engine.url.database, - media_type="application/octet-stream", - filename="webui.db", - ) + + upload_folder = f"{CACHE_DIR}/pipelines" + os.makedirs(upload_folder, exist_ok=True) + file_path = os.path.join(upload_folder, file.filename) + + r = None + try: + # Save the uploaded file + with open(file_path, "wb") as buffer: + shutil.copyfileobj(file.file, buffer) + + url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] + key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] + + headers = {"Authorization": f"Bearer {key}"} + + with open(file_path, "rb") as f: + files = {"file": f} + r = requests.post(f"{url}/pipelines/upload", headers=headers, files=files) + + r.raise_for_status() + data = r.json() + + return {**data} + except Exception as e: + # Handle connection error here + print(f"Connection error: {e}") + + detail = "Pipeline not found" + status_code = status.HTTP_404_NOT_FOUND + if r is not None: + status_code = r.status_code + try: + res = r.json() + if "detail" in res: + detail = res["detail"] + except Exception: + pass + + raise HTTPException( + status_code=status_code, + detail=detail, + ) + finally: + # Ensure the file is deleted after the upload is completed or on failure + if os.path.exists(file_path): + os.remove(file_path) -@router.get("/litellm/config") -async def download_litellm_config_yaml(user=Depends(get_admin_user)): - return FileResponse( - f"{DATA_DIR}/litellm/config.yaml", - media_type="application/octet-stream", - filename="config.yaml", - ) +class AddPipelineForm(BaseModel): + url: str + urlIdx: int + + +@app.post("/api/pipelines/add") +async def add_pipeline(form_data: AddPipelineForm, user=Depends(get_admin_user)): + r = None + try: + urlIdx = form_data.urlIdx + + url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] + key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] + + headers = {"Authorization": f"Bearer {key}"} + r = requests.post( + f"{url}/pipelines/add", headers=headers, json={"url": form_data.url} + ) + + r.raise_for_status() + data = r.json() + + return {**data} + except Exception as e: + # Handle connection error here + print(f"Connection error: {e}") + + detail = "Pipeline not found" + if r is not None: + try: + res = r.json() + if "detail" in res: + detail = res["detail"] + except Exception: + pass + + raise HTTPException( + status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND), + detail=detail, + ) + + +class DeletePipelineForm(BaseModel): + id: str + urlIdx: int + + +@app.delete("/api/pipelines/delete") +async def delete_pipeline(form_data: DeletePipelineForm, user=Depends(get_admin_user)): + r = None + try: + urlIdx = form_data.urlIdx + + url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] + key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] + + headers = {"Authorization": f"Bearer {key}"} + r = requests.delete( + f"{url}/pipelines/delete", headers=headers, json={"id": form_data.id} + ) + + r.raise_for_status() + data = r.json() + + return {**data} + except Exception as e: + # Handle connection error here + print(f"Connection error: {e}") + + detail = "Pipeline not found" + if r is not None: + try: + res = r.json() + if "detail" in res: + detail = res["detail"] + except Exception: + pass + + raise HTTPException( + status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND), + detail=detail, + ) + + +@app.get("/api/pipelines") +async def get_pipelines(urlIdx: Optional[int] = None, user=Depends(get_admin_user)): + r = None + try: + url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] + key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] + + headers = {"Authorization": f"Bearer {key}"} + r = requests.get(f"{url}/pipelines", headers=headers) + + r.raise_for_status() + data = r.json() + + return {**data} + except Exception as e: + # Handle connection error here + print(f"Connection error: {e}") + + detail = "Pipeline not found" + if r is not None: + try: + res = r.json() + if "detail" in res: + detail = res["detail"] + except Exception: + pass + + raise HTTPException( + status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND), + detail=detail, + ) + + +@app.get("/api/pipelines/{pipeline_id}/valves") +async def get_pipeline_valves( + urlIdx: Optional[int], + pipeline_id: str, + user=Depends(get_admin_user), +): + r = None + try: + url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] + key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] + + headers = {"Authorization": f"Bearer {key}"} + r = requests.get(f"{url}/{pipeline_id}/valves", headers=headers) + + r.raise_for_status() + data = r.json() + + return {**data} + except Exception as e: + # Handle connection error here + print(f"Connection error: {e}") + + detail = "Pipeline not found" + + if r is not None: + try: + res = r.json() + if "detail" in res: + detail = res["detail"] + except Exception: + pass + + raise HTTPException( + status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND), + detail=detail, + ) + + +@app.get("/api/pipelines/{pipeline_id}/valves/spec") +async def get_pipeline_valves_spec( + urlIdx: Optional[int], + pipeline_id: str, + user=Depends(get_admin_user), +): + r = None + try: + url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] + key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] + + headers = {"Authorization": f"Bearer {key}"} + r = requests.get(f"{url}/{pipeline_id}/valves/spec", headers=headers) + + r.raise_for_status() + data = r.json() + + return {**data} + except Exception as e: + # Handle connection error here + print(f"Connection error: {e}") + + detail = "Pipeline not found" + if r is not None: + try: + res = r.json() + if "detail" in res: + detail = res["detail"] + except Exception: + pass + + raise HTTPException( + status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND), + detail=detail, + ) + + +@app.post("/api/pipelines/{pipeline_id}/valves/update") +async def update_pipeline_valves( + urlIdx: Optional[int], + pipeline_id: str, + form_data: dict, + user=Depends(get_admin_user), +): + r = None + try: + url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] + key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] + + headers = {"Authorization": f"Bearer {key}"} + r = requests.post( + f"{url}/{pipeline_id}/valves/update", + headers=headers, + json={**form_data}, + ) + + r.raise_for_status() + data = r.json() + + return {**data} + except Exception as e: + # Handle connection error here + print(f"Connection error: {e}") + + detail = "Pipeline not found" + + if r is not None: + try: + res = r.json() + if "detail" in res: + detail = res["detail"] + except Exception: + pass + + raise HTTPException( + status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND), + detail=detail, + ) diff --git a/backend/open_webui/apps/webui/routers/prompts.py b/backend/open_webui/routers/prompts.py similarity index 98% rename from backend/open_webui/apps/webui/routers/prompts.py rename to backend/open_webui/routers/prompts.py index 89a60fd95..4f1c48482 100644 --- a/backend/open_webui/apps/webui/routers/prompts.py +++ b/backend/open_webui/routers/prompts.py @@ -1,6 +1,6 @@ from typing import Optional -from open_webui.apps.webui.models.prompts import ( +from open_webui.models.prompts import ( PromptForm, PromptUserResponse, PromptModel, diff --git a/backend/open_webui/apps/retrieval/main.py b/backend/open_webui/routers/retrieval.py similarity index 99% rename from backend/open_webui/apps/retrieval/main.py rename to backend/open_webui/routers/retrieval.py index cfbc5beee..517e2894d 100644 --- a/backend/open_webui/apps/retrieval/main.py +++ b/backend/open_webui/routers/retrieval.py @@ -18,7 +18,7 @@ import tiktoken from open_webui.storage.provider import Storage -from open_webui.apps.webui.models.knowledge import Knowledges +from open_webui.models.knowledge import Knowledges from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT # Document loaders @@ -43,7 +43,7 @@ from open_webui.apps.retrieval.web.tavily import search_tavily from open_webui.apps.retrieval.web.bing import search_bing -from open_webui.apps.retrieval.utils import ( +from backend.open_webui.retrieval.utils import ( get_embedding_function, get_model_path, query_collection, @@ -52,7 +52,7 @@ from open_webui.apps.retrieval.utils import ( query_doc_with_hybrid_search, ) -from open_webui.apps.webui.models.files import Files +from open_webui.models.files import Files from open_webui.config import ( BRAVE_SEARCH_API_KEY, KAGI_SEARCH_API_KEY, diff --git a/backend/open_webui/routers/tasks.py b/backend/open_webui/routers/tasks.py index c4ccc4700..4af25d4d3 100644 --- a/backend/open_webui/routers/tasks.py +++ b/backend/open_webui/routers/tasks.py @@ -2,6 +2,8 @@ from fastapi import APIRouter, Depends, HTTPException, Response, status, Request from pydantic import BaseModel from starlette.responses import FileResponse from typing import Optional +import logging + from open_webui.utils.task import ( title_generation_template, @@ -12,6 +14,17 @@ from open_webui.utils.task import ( moa_response_generation_template, ) from open_webui.utils.auth import get_admin_user, get_verified_user +from open_webui.constants import TASKS + +from open_webui.config import ( + DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE, + DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE, +) +from open_webui.env import SRC_LOG_LEVELS + + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["MODELS"]) router = APIRouter() @@ -197,7 +210,9 @@ Artificial Intelligence in Healthcare @router.post("/tags/completions") -async def generate_chat_tags(form_data: dict, user=Depends(get_verified_user)): +async def generate_chat_tags( + request: Request, form_data: dict, user=Depends(get_verified_user) +): if not request.app.state.config.ENABLE_TAGS_GENERATION: return JSONResponse( diff --git a/backend/open_webui/apps/webui/routers/tools.py b/backend/open_webui/routers/tools.py similarity index 98% rename from backend/open_webui/apps/webui/routers/tools.py rename to backend/open_webui/routers/tools.py index 410f12d64..bf4a93f6c 100644 --- a/backend/open_webui/apps/webui/routers/tools.py +++ b/backend/open_webui/routers/tools.py @@ -1,14 +1,14 @@ from pathlib import Path from typing import Optional -from open_webui.apps.webui.models.tools import ( +from open_webui.models.tools import ( ToolForm, ToolModel, ToolResponse, ToolUserResponse, Tools, ) -from open_webui.apps.webui.utils import load_tools_module_by_id, replace_imports +from backend.open_webui.utils.plugin import load_tools_module_by_id, replace_imports from open_webui.config import CACHE_DIR from open_webui.constants import ERROR_MESSAGES from fastapi import APIRouter, Depends, HTTPException, Request, status diff --git a/backend/open_webui/apps/webui/routers/users.py b/backend/open_webui/routers/users.py similarity index 98% rename from backend/open_webui/apps/webui/routers/users.py rename to backend/open_webui/routers/users.py index 92131b9ad..1206d56f2 100644 --- a/backend/open_webui/apps/webui/routers/users.py +++ b/backend/open_webui/routers/users.py @@ -1,9 +1,9 @@ import logging from typing import Optional -from open_webui.apps.webui.models.auths import Auths -from open_webui.apps.webui.models.chats import Chats -from open_webui.apps.webui.models.users import ( +from open_webui.models.auths import Auths +from open_webui.models.chats import Chats +from open_webui.models.users import ( UserModel, UserRoleUpdateForm, Users, diff --git a/backend/open_webui/apps/webui/routers/utils.py b/backend/open_webui/routers/utils.py similarity index 95% rename from backend/open_webui/apps/webui/routers/utils.py rename to backend/open_webui/routers/utils.py index a4c33a03b..ea73e9759 100644 --- a/backend/open_webui/apps/webui/routers/utils.py +++ b/backend/open_webui/routers/utils.py @@ -1,7 +1,7 @@ import black import markdown -from open_webui.apps.webui.models.chats import ChatTitleMessagesForm +from open_webui.models.chats import ChatTitleMessagesForm from open_webui.config import DATA_DIR, ENABLE_ADMIN_EXPORT from open_webui.constants import ERROR_MESSAGES from fastapi import APIRouter, Depends, HTTPException, Response, status @@ -76,7 +76,7 @@ async def download_db(user=Depends(get_admin_user)): status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED, ) - from open_webui.apps.webui.internal.db import engine + from open_webui.internal.db import engine if engine.name != "sqlite": raise HTTPException( diff --git a/backend/open_webui/apps/webui/main.py b/backend/open_webui/routers/webui.py similarity index 98% rename from backend/open_webui/apps/webui/main.py rename to backend/open_webui/routers/webui.py index 054c6280e..1ac4db152 100644 --- a/backend/open_webui/apps/webui/main.py +++ b/backend/open_webui/routers/webui.py @@ -5,9 +5,9 @@ import time from typing import AsyncGenerator, Generator, Iterator from open_webui.apps.socket.main import get_event_call, get_event_emitter -from open_webui.apps.webui.models.functions import Functions -from open_webui.apps.webui.models.models import Models -from open_webui.apps.webui.routers import ( +from open_webui.models.functions import Functions +from open_webui.models.models import Models +from open_webui.routers import ( auths, chats, folders, @@ -24,7 +24,7 @@ from open_webui.apps.webui.routers import ( users, utils, ) -from open_webui.apps.webui.utils import load_function_module_by_id +from backend.open_webui.utils.plugin import load_function_module_by_id from open_webui.config import ( ADMIN_EMAIL, CORS_ALLOW_ORIGIN, diff --git a/backend/open_webui/apps/socket/main.py b/backend/open_webui/socket/main.py similarity index 99% rename from backend/open_webui/apps/socket/main.py rename to backend/open_webui/socket/main.py index 8ec8937a1..7a673f098 100644 --- a/backend/open_webui/apps/socket/main.py +++ b/backend/open_webui/socket/main.py @@ -6,7 +6,7 @@ import logging import sys import time -from open_webui.apps.webui.models.users import Users +from open_webui.models.users import Users from open_webui.env import ( ENABLE_WEBSOCKET_SUPPORT, WEBSOCKET_MANAGER, diff --git a/backend/open_webui/apps/socket/utils.py b/backend/open_webui/socket/utils.py similarity index 100% rename from backend/open_webui/apps/socket/utils.py rename to backend/open_webui/socket/utils.py diff --git a/backend/open_webui/test/apps/webui/routers/test_auths.py b/backend/open_webui/test/apps/webui/routers/test_auths.py index cee68228e..f0f69e26d 100644 --- a/backend/open_webui/test/apps/webui/routers/test_auths.py +++ b/backend/open_webui/test/apps/webui/routers/test_auths.py @@ -7,8 +7,8 @@ class TestAuths(AbstractPostgresTest): def setup_class(cls): super().setup_class() - from open_webui.apps.webui.models.auths import Auths - from open_webui.apps.webui.models.users import Users + from open_webui.models.auths import Auths + from open_webui.models.users import Users cls.users = Users cls.auths = Auths diff --git a/backend/open_webui/test/apps/webui/routers/test_chats.py b/backend/open_webui/test/apps/webui/routers/test_chats.py index 935316fd8..a36a01fb1 100644 --- a/backend/open_webui/test/apps/webui/routers/test_chats.py +++ b/backend/open_webui/test/apps/webui/routers/test_chats.py @@ -12,7 +12,7 @@ class TestChats(AbstractPostgresTest): def setup_method(self): super().setup_method() - from open_webui.apps.webui.models.chats import ChatForm, Chats + from open_webui.models.chats import ChatForm, Chats self.chats = Chats self.chats.insert_new_chat( @@ -88,7 +88,7 @@ class TestChats(AbstractPostgresTest): def test_get_user_archived_chats(self): self.chats.archive_all_chats_by_user_id("2") - from open_webui.apps.webui.internal.db import Session + from open_webui.internal.db import Session Session.commit() with mock_webui_user(id="2"): diff --git a/backend/open_webui/test/apps/webui/routers/test_models.py b/backend/open_webui/test/apps/webui/routers/test_models.py index 1d52658b8..c16ca9d07 100644 --- a/backend/open_webui/test/apps/webui/routers/test_models.py +++ b/backend/open_webui/test/apps/webui/routers/test_models.py @@ -7,7 +7,7 @@ class TestModels(AbstractPostgresTest): def setup_class(cls): super().setup_class() - from open_webui.apps.webui.models.models import Model + from open_webui.models.models import Model cls.models = Model diff --git a/backend/open_webui/test/apps/webui/routers/test_users.py b/backend/open_webui/test/apps/webui/routers/test_users.py index 6facf7055..1a58ab147 100644 --- a/backend/open_webui/test/apps/webui/routers/test_users.py +++ b/backend/open_webui/test/apps/webui/routers/test_users.py @@ -25,7 +25,7 @@ class TestUsers(AbstractPostgresTest): def setup_class(cls): super().setup_class() - from open_webui.apps.webui.models.users import Users + from open_webui.models.users import Users cls.users = Users diff --git a/backend/open_webui/test/util/abstract_integration_test.py b/backend/open_webui/test/util/abstract_integration_test.py index 2814731e0..e8492befb 100644 --- a/backend/open_webui/test/util/abstract_integration_test.py +++ b/backend/open_webui/test/util/abstract_integration_test.py @@ -115,7 +115,7 @@ class AbstractPostgresTest(AbstractIntegrationTest): pytest.fail(f"Could not setup test environment: {ex}") def _check_db_connection(self): - from open_webui.apps.webui.internal.db import Session + from open_webui.internal.db import Session retries = 10 while retries > 0: @@ -139,7 +139,7 @@ class AbstractPostgresTest(AbstractIntegrationTest): cls.docker_client.containers.get(cls.DOCKER_CONTAINER_NAME).remove(force=True) def teardown_method(self): - from open_webui.apps.webui.internal.db import Session + from open_webui.internal.db import Session # rollback everything not yet committed Session.commit() diff --git a/backend/open_webui/test/util/mock_user.py b/backend/open_webui/test/util/mock_user.py index ba8e24d4e..e25256460 100644 --- a/backend/open_webui/test/util/mock_user.py +++ b/backend/open_webui/test/util/mock_user.py @@ -5,7 +5,7 @@ from fastapi import FastAPI @contextmanager def mock_webui_user(**kwargs): - from open_webui.apps.webui.main import app + from backend.open_webui.routers.webui import app with mock_user(app, **kwargs): yield @@ -19,7 +19,7 @@ def mock_user(app: FastAPI, **kwargs): get_admin_user, get_current_user_by_api_key, ) - from open_webui.apps.webui.models.users import User + from open_webui.models.users import User def create_user(): user_parameters = { diff --git a/backend/open_webui/utils/access_control.py b/backend/open_webui/utils/access_control.py index 270b28bcc..3b3e75a8b 100644 --- a/backend/open_webui/utils/access_control.py +++ b/backend/open_webui/utils/access_control.py @@ -1,5 +1,5 @@ from typing import Optional, Union, List, Dict, Any -from open_webui.apps.webui.models.groups import Groups +from open_webui.models.groups import Groups import json diff --git a/backend/open_webui/utils/auth.py b/backend/open_webui/utils/auth.py index cde953102..e1a0ca671 100644 --- a/backend/open_webui/utils/auth.py +++ b/backend/open_webui/utils/auth.py @@ -5,7 +5,7 @@ import jwt from datetime import UTC, datetime, timedelta from typing import Optional, Union, List, Dict -from open_webui.apps.webui.models.users import Users +from open_webui.models.users import Users from open_webui.constants import ERROR_MESSAGES from open_webui.env import WEBUI_SECRET_KEY diff --git a/backend/open_webui/apps/images/utils/comfyui.py b/backend/open_webui/utils/images/comfyui.py similarity index 100% rename from backend/open_webui/apps/images/utils/comfyui.py rename to backend/open_webui/utils/images/comfyui.py diff --git a/backend/open_webui/utils/oauth.py b/backend/open_webui/utils/oauth.py index 37dc5b788..f0ab7a345 100644 --- a/backend/open_webui/utils/oauth.py +++ b/backend/open_webui/utils/oauth.py @@ -12,8 +12,8 @@ from fastapi import ( ) from starlette.responses import RedirectResponse -from open_webui.apps.webui.models.auths import Auths -from open_webui.apps.webui.models.users import Users +from open_webui.models.auths import Auths +from open_webui.models.users import Users from open_webui.config import ( DEFAULT_USER_ROLE, ENABLE_OAUTH_SIGNUP, @@ -158,8 +158,13 @@ class OAuthManager: if not email: log.warning(f"OAuth callback failed, email is missing: {user_data}") raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) - if "*" not in auth_manager_config.OAUTH_ALLOWED_DOMAINS and email.split("@")[-1] not in auth_manager_config.OAUTH_ALLOWED_DOMAINS: - log.warning(f"OAuth callback failed, e-mail domain is not in the list of allowed domains: {user_data}") + if ( + "*" not in auth_manager_config.OAUTH_ALLOWED_DOMAINS + and email.split("@")[-1] not in auth_manager_config.OAUTH_ALLOWED_DOMAINS + ): + log.warning( + f"OAuth callback failed, e-mail domain is not in the list of allowed domains: {user_data}" + ) raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) # Check if the user exists diff --git a/backend/open_webui/utils/pdf_generator.py b/backend/open_webui/utils/pdf_generator.py index 6b7d506e6..bbaf42dbb 100644 --- a/backend/open_webui/utils/pdf_generator.py +++ b/backend/open_webui/utils/pdf_generator.py @@ -9,7 +9,7 @@ import site from fpdf import FPDF from open_webui.env import STATIC_DIR, FONTS_DIR -from open_webui.apps.webui.models.chats import ChatTitleMessagesForm +from open_webui.models.chats import ChatTitleMessagesForm class PDFGenerator: diff --git a/backend/open_webui/apps/webui/utils.py b/backend/open_webui/utils/plugin.py similarity index 98% rename from backend/open_webui/apps/webui/utils.py rename to backend/open_webui/utils/plugin.py index 054158b3e..17b86cea1 100644 --- a/backend/open_webui/apps/webui/utils.py +++ b/backend/open_webui/utils/plugin.py @@ -8,8 +8,8 @@ import tempfile import logging from open_webui.env import SRC_LOG_LEVELS -from open_webui.apps.webui.models.functions import Functions -from open_webui.apps.webui.models.tools import Tools +from open_webui.models.functions import Functions +from open_webui.models.tools import Tools log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MAIN"]) diff --git a/backend/open_webui/utils/tools.py b/backend/open_webui/utils/tools.py index 60a9f942f..a88e71f20 100644 --- a/backend/open_webui/utils/tools.py +++ b/backend/open_webui/utils/tools.py @@ -5,9 +5,9 @@ from typing import Any, Awaitable, Callable, get_type_hints from functools import update_wrapper, partial from langchain_core.utils.function_calling import convert_to_openai_function -from open_webui.apps.webui.models.tools import Tools -from open_webui.apps.webui.models.users import UserModel -from open_webui.apps.webui.utils import load_tools_module_by_id +from open_webui.models.tools import Tools +from open_webui.models.users import UserModel +from backend.open_webui.utils.plugin import load_tools_module_by_id from pydantic import BaseModel, Field, create_model log = logging.getLogger(__name__) From 481919965015caaf95fd3a678d8b1e093980a76a Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Wed, 11 Dec 2024 02:41:25 -0800 Subject: [PATCH 03/26] wip --- backend/open_webui/main.py | 181 +++++++++++++++-- backend/open_webui/routers/ollama.py | 292 ++++++++++++++++----------- 2 files changed, 334 insertions(+), 139 deletions(-) diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index ab43ef8b4..308489ee6 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -46,6 +46,21 @@ from open_webui.routers import ( retrieval, pipelines, tasks, + auths, + chats, + folders, + configs, + groups, + files, + functions, + memories, + models, + knowledge, + prompts, + evaluations, + tools, + users, + utils, ) from open_webui.retrieval.utils import get_sources_from_files @@ -117,6 +132,60 @@ from open_webui.config import ( WHISPER_MODEL, WHISPER_MODEL_AUTO_UPDATE, WHISPER_MODEL_DIR, + # Retrieval + RAG_TEMPLATE, + DEFAULT_RAG_TEMPLATE, + RAG_EMBEDDING_MODEL, + RAG_EMBEDDING_MODEL_AUTO_UPDATE, + RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE, + RAG_RERANKING_MODEL, + RAG_RERANKING_MODEL_AUTO_UPDATE, + RAG_RERANKING_MODEL_TRUST_REMOTE_CODE, + RAG_EMBEDDING_ENGINE, + RAG_EMBEDDING_BATCH_SIZE, + RAG_RELEVANCE_THRESHOLD, + RAG_FILE_MAX_COUNT, + RAG_FILE_MAX_SIZE, + RAG_OPENAI_API_BASE_URL, + RAG_OPENAI_API_KEY, + RAG_OLLAMA_BASE_URL, + RAG_OLLAMA_API_KEY, + CHUNK_OVERLAP, + CHUNK_SIZE, + CONTENT_EXTRACTION_ENGINE, + TIKA_SERVER_URL, + RAG_TOP_K, + RAG_TEXT_SPLITTER, + TIKTOKEN_ENCODING_NAME, + PDF_EXTRACT_IMAGES, + YOUTUBE_LOADER_LANGUAGE, + YOUTUBE_LOADER_PROXY_URL, + # Retrieval (Web Search) + RAG_WEB_SEARCH_ENGINE, + RAG_WEB_SEARCH_RESULT_COUNT, + RAG_WEB_SEARCH_CONCURRENT_REQUESTS, + RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, + JINA_API_KEY, + SEARCHAPI_API_KEY, + SEARCHAPI_ENGINE, + SEARXNG_QUERY_URL, + SERPER_API_KEY, + SERPLY_API_KEY, + SERPSTACK_API_KEY, + SERPSTACK_HTTPS, + TAVILY_API_KEY, + BING_SEARCH_V7_ENDPOINT, + BING_SEARCH_V7_SUBSCRIPTION_KEY, + BRAVE_SEARCH_API_KEY, + KAGI_SEARCH_API_KEY, + MOJEEK_SEARCH_API_KEY, + GOOGLE_PSE_API_KEY, + GOOGLE_PSE_ENGINE_ID, + ENABLE_RAG_HYBRID_SEARCH, + ENABLE_RAG_LOCAL_WEB_FETCH, + ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, + ENABLE_RAG_WEB_SEARCH, + UPLOAD_DIR, # WebUI WEBUI_AUTH, WEBUI_NAME, @@ -383,6 +452,72 @@ app.state.FUNCTIONS = {} # ######################################## + +app.state.config.TOP_K = RAG_TOP_K +app.state.config.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD +app.state.config.FILE_MAX_SIZE = RAG_FILE_MAX_SIZE +app.state.config.FILE_MAX_COUNT = RAG_FILE_MAX_COUNT + +app.state.config.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH +app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = ( + ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION +) + +app.state.config.CONTENT_EXTRACTION_ENGINE = CONTENT_EXTRACTION_ENGINE +app.state.config.TIKA_SERVER_URL = TIKA_SERVER_URL + +app.state.config.TEXT_SPLITTER = RAG_TEXT_SPLITTER +app.state.config.TIKTOKEN_ENCODING_NAME = TIKTOKEN_ENCODING_NAME + +app.state.config.CHUNK_SIZE = CHUNK_SIZE +app.state.config.CHUNK_OVERLAP = CHUNK_OVERLAP + +app.state.config.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE +app.state.config.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL +app.state.config.RAG_EMBEDDING_BATCH_SIZE = RAG_EMBEDDING_BATCH_SIZE +app.state.config.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL +app.state.config.RAG_TEMPLATE = RAG_TEMPLATE + +app.state.config.RAG_OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL +app.state.config.RAG_OPENAI_API_KEY = RAG_OPENAI_API_KEY + +app.state.config.RAG_OLLAMA_BASE_URL = RAG_OLLAMA_BASE_URL +app.state.config.RAG_OLLAMA_API_KEY = RAG_OLLAMA_API_KEY + +app.state.config.PDF_EXTRACT_IMAGES = PDF_EXTRACT_IMAGES + +app.state.config.YOUTUBE_LOADER_LANGUAGE = YOUTUBE_LOADER_LANGUAGE +app.state.config.YOUTUBE_LOADER_PROXY_URL = YOUTUBE_LOADER_PROXY_URL + + +app.state.config.ENABLE_RAG_WEB_SEARCH = ENABLE_RAG_WEB_SEARCH +app.state.config.RAG_WEB_SEARCH_ENGINE = RAG_WEB_SEARCH_ENGINE +app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = RAG_WEB_SEARCH_DOMAIN_FILTER_LIST + +app.state.config.SEARXNG_QUERY_URL = SEARXNG_QUERY_URL +app.state.config.GOOGLE_PSE_API_KEY = GOOGLE_PSE_API_KEY +app.state.config.GOOGLE_PSE_ENGINE_ID = GOOGLE_PSE_ENGINE_ID +app.state.config.BRAVE_SEARCH_API_KEY = BRAVE_SEARCH_API_KEY +app.state.config.KAGI_SEARCH_API_KEY = KAGI_SEARCH_API_KEY +app.state.config.MOJEEK_SEARCH_API_KEY = MOJEEK_SEARCH_API_KEY +app.state.config.SERPSTACK_API_KEY = SERPSTACK_API_KEY +app.state.config.SERPSTACK_HTTPS = SERPSTACK_HTTPS +app.state.config.SERPER_API_KEY = SERPER_API_KEY +app.state.config.SERPLY_API_KEY = SERPLY_API_KEY +app.state.config.TAVILY_API_KEY = TAVILY_API_KEY +app.state.config.SEARCHAPI_API_KEY = SEARCHAPI_API_KEY +app.state.config.SEARCHAPI_ENGINE = SEARCHAPI_ENGINE +app.state.config.JINA_API_KEY = JINA_API_KEY +app.state.config.BING_SEARCH_V7_ENDPOINT = BING_SEARCH_V7_ENDPOINT +app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY = BING_SEARCH_V7_SUBSCRIPTION_KEY + +app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = RAG_WEB_SEARCH_RESULT_COUNT +app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS + + +app.state.YOUTUBE_LOADER_TRANSLATION = None +app.state.EMBEDDING_FUNCTION = None + ######################################## # # IMAGES @@ -1083,8 +1218,8 @@ def filter_pipeline(payload, user, models): try: urlIdx = filter["urlIdx"] - url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] - key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] + url = app.state.config.OPENAI_API_BASE_URLS[urlIdx] + key = app.state.config.OPENAI_API_KEYS[urlIdx] if key == "": continue @@ -1230,14 +1365,6 @@ async def check_url(request: Request, call_next): return response -# @app.middleware("http") -# async def update_embedding_function(request: Request, call_next): -# response = await call_next(request) -# if "/embedding/update" in request.url.path: -# webui_app.state.EMBEDDING_FUNCTION = retrieval_app.state.EMBEDDING_FUNCTION -# return response - - @app.middleware("http") async def inspect_websocket(request: Request, call_next): if ( @@ -1268,18 +1395,36 @@ app.add_middleware( app.mount("/ws", socket_app) -app.mount("/ollama", ollama_app) -app.mount("/openai", openai_app) - -app.mount("/images/api/v1", images_app) -app.mount("/audio/api/v1", audio_app) +app.include_router(ollama.router, prefix="/ollama") +app.include_router(openai.router, prefix="/openai") -app.mount("/retrieval/api/v1", retrieval_app) +app.include_router(images.router, prefix="/api/v1/images") +app.include_router(audio.router, prefix="/api/v1/audio") +app.include_router(retrieval.router, prefix="/api/v1/retrieval") -app.mount("/api/v1", webui_app) -app.state.EMBEDDING_FUNCTION = retrieval_app.state.EMBEDDING_FUNCTION +app.include_router(configs.router, prefix="/api/v1/configs", tags=["configs"]) + +app.include_router(auths.router, prefix="/api/v1/auths", tags=["auths"]) +app.include_router(users.router, prefix="/api/v1/users", tags=["users"]) + +app.include_router(chats.router, prefix="/api/v1/chats", tags=["chats"]) + +app.include_router(models.router, prefix="/api/v1/models", tags=["models"]) +app.include_router(knowledge.router, prefix="/api/v1/knowledge", tags=["knowledge"]) +app.include_router(prompts.router, prefix="/api/v1/prompts", tags=["prompts"]) +app.include_router(tools.router, prefix="/api/v1/tools", tags=["tools"]) + +app.include_router(memories.router, prefix="/api/v1/memories", tags=["memories"]) +app.include_router(folders.router, prefix="/api/v1/folders", tags=["folders"]) +app.include_router(groups.router, prefix="/api/v1/groups", tags=["groups"]) +app.include_router(files.router, prefix="/api/v1/files", tags=["files"]) +app.include_router(functions.router, prefix="/api/v1/functions", tags=["functions"]) +app.include_router( + evaluations.router, prefix="/api/v1/evaluations", tags=["evaluations"] +) +app.include_router(utils.router, prefix="/api/v1/utils", tags=["utils"]) async def get_all_base_models(): diff --git a/backend/open_webui/routers/ollama.py b/backend/open_webui/routers/ollama.py index 581a881b7..8a43d5c52 100644 --- a/backend/open_webui/routers/ollama.py +++ b/backend/open_webui/routers/ollama.py @@ -13,7 +13,15 @@ from aiocache import cached import requests -from fastapi import Depends, FastAPI, File, HTTPException, Request, UploadFile +from fastapi import ( + Depends, + FastAPI, + File, + HTTPException, + Request, + UploadFile, + APIRouter, +) from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse from pydantic import BaseModel, ConfigDict @@ -26,18 +34,15 @@ from open_webui.models.models import Models from open_webui.config import ( UPLOAD_DIR, ) - - from open_webui.env import ( + ENV, + SRC_LOG_LEVELS, AIOHTTP_CLIENT_TIMEOUT, AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST, BYPASS_MODEL_ACCESS_CONTROL, ) - from open_webui.constants import ERROR_MESSAGES -from open_webui.env import ENV, SRC_LOG_LEVELS - from open_webui.utils.misc import ( calculate_sha256, @@ -54,13 +59,15 @@ log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["OLLAMA"]) +router = APIRouter() + # TODO: Implement a more intelligent load balancing mechanism for distributing requests among multiple backend instances. # Current implementation uses a simple round-robin approach (random.choice). Consider incorporating algorithms like weighted round-robin, # least connections, or least response time for better resource utilization and performance optimization. -@app.head("/") -@app.get("/") +@router.head("/") +@router.get("/") async def get_status(): return {"status": True} @@ -70,7 +77,7 @@ class ConnectionVerificationForm(BaseModel): key: Optional[str] = None -@app.post("/verify") +@router.post("/verify") async def verify_connection( form_data: ConnectionVerificationForm, user=Depends(get_admin_user) ): @@ -110,12 +117,12 @@ async def verify_connection( raise HTTPException(status_code=500, detail=error_detail) -@app.get("/config") -async def get_config(user=Depends(get_admin_user)): +@router.get("/config") +async def get_config(request: Request, user=Depends(get_admin_user)): return { - "ENABLE_OLLAMA_API": app.state.config.ENABLE_OLLAMA_API, - "OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS, - "OLLAMA_API_CONFIGS": app.state.config.OLLAMA_API_CONFIGS, + "ENABLE_OLLAMA_API": request.app.state.config.ENABLE_OLLAMA_API, + "OLLAMA_BASE_URLS": request.app.state.config.OLLAMA_BASE_URLS, + "OLLAMA_API_CONFIGS": request.app.state.config.OLLAMA_API_CONFIGS, } @@ -125,23 +132,25 @@ class OllamaConfigForm(BaseModel): OLLAMA_API_CONFIGS: dict -@app.post("/config/update") -async def update_config(form_data: OllamaConfigForm, user=Depends(get_admin_user)): - app.state.config.ENABLE_OLLAMA_API = form_data.ENABLE_OLLAMA_API - app.state.config.OLLAMA_BASE_URLS = form_data.OLLAMA_BASE_URLS +@router.post("/config/update") +async def update_config( + request: Request, form_data: OllamaConfigForm, user=Depends(get_admin_user) +): + request.app.state.config.ENABLE_OLLAMA_API = form_data.ENABLE_OLLAMA_API + request.app.state.config.OLLAMA_BASE_URLS = form_data.OLLAMA_BASE_URLS - app.state.config.OLLAMA_API_CONFIGS = form_data.OLLAMA_API_CONFIGS + request.app.state.config.OLLAMA_API_CONFIGS = form_data.OLLAMA_API_CONFIGS # Remove any extra configs - config_urls = app.state.config.OLLAMA_API_CONFIGS.keys() - for url in list(app.state.config.OLLAMA_BASE_URLS): + config_urls = request.app.state.config.OLLAMA_API_CONFIGS.keys() + for url in list(request.app.state.config.OLLAMA_BASE_URLS): if url not in config_urls: - app.state.config.OLLAMA_API_CONFIGS.pop(url, None) + request.app.state.config.OLLAMA_API_CONFIGS.pop(url, None) return { - "ENABLE_OLLAMA_API": app.state.config.ENABLE_OLLAMA_API, - "OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS, - "OLLAMA_API_CONFIGS": app.state.config.OLLAMA_API_CONFIGS, + "ENABLE_OLLAMA_API": request.app.state.config.ENABLE_OLLAMA_API, + "OLLAMA_BASE_URLS": request.app.state.config.OLLAMA_BASE_URLS, + "OLLAMA_API_CONFIGS": request.app.state.config.OLLAMA_API_CONFIGS, } @@ -158,6 +167,12 @@ async def aiohttp_get(url, key=None): return None +def get_api_key(url, configs): + parsed_url = urlparse(url) + base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" + return configs.get(base_url, {}).get("key", None) + + async def cleanup_response( response: Optional[aiohttp.ClientResponse], session: Optional[aiohttp.ClientSession], @@ -169,7 +184,11 @@ async def cleanup_response( async def post_streaming_url( - url: str, payload: Union[str, bytes], stream: bool = True, content_type=None + url: str, + payload: Union[str, bytes], + stream: bool = True, + key: Optional[str] = None, + content_type=None, ): r = None try: @@ -177,12 +196,6 @@ async def post_streaming_url( trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT) ) - parsed_url = urlparse(url) - base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" - - api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {}) - key = api_config.get("key", None) - headers = {"Content-Type": "application/json"} if key: headers["Authorization"] = f"Bearer {key}" @@ -246,13 +259,13 @@ def merge_models_lists(model_lists): @cached(ttl=3) async def get_all_models(): log.info("get_all_models()") - if app.state.config.ENABLE_OLLAMA_API: + if request.app.state.config.ENABLE_OLLAMA_API: tasks = [] - for idx, url in enumerate(app.state.config.OLLAMA_BASE_URLS): - if url not in app.state.config.OLLAMA_API_CONFIGS: + for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS): + if url not in request.app.state.config.OLLAMA_API_CONFIGS: tasks.append(aiohttp_get(f"{url}/api/tags")) else: - api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) + api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}) enable = api_config.get("enable", True) key = api_config.get("key", None) @@ -265,8 +278,8 @@ async def get_all_models(): for idx, response in enumerate(responses): if response: - url = app.state.config.OLLAMA_BASE_URLS[idx] - api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) + url = request.app.state.config.OLLAMA_BASE_URLS[idx] + api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}) prefix_id = api_config.get("prefix_id", None) model_ids = api_config.get("model_ids", []) @@ -298,21 +311,21 @@ async def get_all_models(): return models -@app.get("/api/tags") -@app.get("/api/tags/{url_idx}") +@router.get("/api/tags") +@router.get("/api/tags/{url_idx}") async def get_ollama_tags( - url_idx: Optional[int] = None, user=Depends(get_verified_user) + request: Request, url_idx: Optional[int] = None, user=Depends(get_verified_user) ): models = [] if url_idx is None: models = await get_all_models() else: - url = app.state.config.OLLAMA_BASE_URLS[url_idx] + url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] parsed_url = urlparse(url) base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" - api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {}) + api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(base_url, {}) key = api_config.get("key", None) headers = {} @@ -356,18 +369,20 @@ async def get_ollama_tags( return models -@app.get("/api/version") -@app.get("/api/version/{url_idx}") -async def get_ollama_versions(url_idx: Optional[int] = None): - if app.state.config.ENABLE_OLLAMA_API: +@router.get("/api/version") +@router.get("/api/version/{url_idx}") +async def get_ollama_versions(request: Request, url_idx: Optional[int] = None): + if request.app.state.config.ENABLE_OLLAMA_API: if url_idx is None: # returns lowest version tasks = [ aiohttp_get( f"{url}/api/version", - app.state.config.OLLAMA_API_CONFIGS.get(url, {}).get("key", None), + request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}).get( + "key", None + ), ) - for url in app.state.config.OLLAMA_BASE_URLS + for url in request.app.state.config.OLLAMA_BASE_URLS ] responses = await asyncio.gather(*tasks) responses = list(filter(lambda x: x is not None, responses)) @@ -387,7 +402,7 @@ async def get_ollama_versions(url_idx: Optional[int] = None): detail=ERROR_MESSAGES.OLLAMA_NOT_FOUND, ) else: - url = app.state.config.OLLAMA_BASE_URLS[url_idx] + url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] r = None try: @@ -414,22 +429,24 @@ async def get_ollama_versions(url_idx: Optional[int] = None): return {"version": False} -@app.get("/api/ps") -async def get_ollama_loaded_models(user=Depends(get_verified_user)): +@router.get("/api/ps") +async def get_ollama_loaded_models(request: Request, user=Depends(get_verified_user)): """ List models that are currently loaded into Ollama memory, and which node they are loaded on. """ - if app.state.config.ENABLE_OLLAMA_API: + if request.app.state.config.ENABLE_OLLAMA_API: tasks = [ aiohttp_get( f"{url}/api/ps", - app.state.config.OLLAMA_API_CONFIGS.get(url, {}).get("key", None), + request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}).get( + "key", None + ), ) - for url in app.state.config.OLLAMA_BASE_URLS + for url in request.app.state.config.OLLAMA_BASE_URLS ] responses = await asyncio.gather(*tasks) - return dict(zip(app.state.config.OLLAMA_BASE_URLS, responses)) + return dict(zip(request.app.state.config.OLLAMA_BASE_URLS, responses)) else: return {} @@ -438,18 +455,25 @@ class ModelNameForm(BaseModel): name: str -@app.post("/api/pull") -@app.post("/api/pull/{url_idx}") +@router.post("/api/pull") +@router.post("/api/pull/{url_idx}") async def pull_model( - form_data: ModelNameForm, url_idx: int = 0, user=Depends(get_admin_user) + request: Request, + form_data: ModelNameForm, + url_idx: int = 0, + user=Depends(get_admin_user), ): - url = app.state.config.OLLAMA_BASE_URLS[url_idx] + url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") # Admin should be able to pull models from any source payload = {**form_data.model_dump(exclude_none=True), "insecure": True} - return await post_streaming_url(f"{url}/api/pull", json.dumps(payload)) + return await post_streaming_url( + url=f"{url}/api/pull", + payload=json.dumps(payload), + key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS), + ) class PushModelForm(BaseModel): @@ -458,9 +482,10 @@ class PushModelForm(BaseModel): stream: Optional[bool] = None -@app.delete("/api/push") -@app.delete("/api/push/{url_idx}") +@router.delete("/api/push") +@router.delete("/api/push/{url_idx}") async def push_model( + request: Request, form_data: PushModelForm, url_idx: Optional[int] = None, user=Depends(get_admin_user), @@ -477,11 +502,13 @@ async def push_model( detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name), ) - url = app.state.config.OLLAMA_BASE_URLS[url_idx] + url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] log.debug(f"url: {url}") return await post_streaming_url( - f"{url}/api/push", form_data.model_dump_json(exclude_none=True).encode() + url=f"{url}/api/push", + payload=form_data.model_dump_json(exclude_none=True).encode(), + key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS), ) @@ -492,17 +519,22 @@ class CreateModelForm(BaseModel): path: Optional[str] = None -@app.post("/api/create") -@app.post("/api/create/{url_idx}") +@router.post("/api/create") +@router.post("/api/create/{url_idx}") async def create_model( - form_data: CreateModelForm, url_idx: int = 0, user=Depends(get_admin_user) + request: Request, + form_data: CreateModelForm, + url_idx: int = 0, + user=Depends(get_admin_user), ): log.debug(f"form_data: {form_data}") - url = app.state.config.OLLAMA_BASE_URLS[url_idx] + url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") return await post_streaming_url( - f"{url}/api/create", form_data.model_dump_json(exclude_none=True).encode() + url=f"{url}/api/create", + payload=form_data.model_dump_json(exclude_none=True).encode(), + key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS), ) @@ -511,9 +543,10 @@ class CopyModelForm(BaseModel): destination: str -@app.post("/api/copy") -@app.post("/api/copy/{url_idx}") +@router.post("/api/copy") +@router.post("/api/copy/{url_idx}") async def copy_model( + request: Request, form_data: CopyModelForm, url_idx: Optional[int] = None, user=Depends(get_admin_user), @@ -530,13 +563,13 @@ async def copy_model( detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.source), ) - url = app.state.config.OLLAMA_BASE_URLS[url_idx] + url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") parsed_url = urlparse(url) base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" - api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {}) + api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(base_url, {}) key = api_config.get("key", None) headers = {"Content-Type": "application/json"} @@ -573,9 +606,10 @@ async def copy_model( ) -@app.delete("/api/delete") -@app.delete("/api/delete/{url_idx}") +@router.delete("/api/delete") +@router.delete("/api/delete/{url_idx}") async def delete_model( + request: Request, form_data: ModelNameForm, url_idx: Optional[int] = None, user=Depends(get_admin_user), @@ -592,13 +626,13 @@ async def delete_model( detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name), ) - url = app.state.config.OLLAMA_BASE_URLS[url_idx] + url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") parsed_url = urlparse(url) base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" - api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {}) + api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(base_url, {}) key = api_config.get("key", None) headers = {"Content-Type": "application/json"} @@ -634,8 +668,10 @@ async def delete_model( ) -@app.post("/api/show") -async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_user)): +@router.post("/api/show") +async def show_model_info( + request: Request, form_data: ModelNameForm, user=Depends(get_verified_user) +): model_list = await get_all_models() models = {model["model"]: model for model in model_list["models"]} @@ -646,13 +682,13 @@ async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_us ) url_idx = random.choice(models[form_data.name]["urls"]) - url = app.state.config.OLLAMA_BASE_URLS[url_idx] + url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") parsed_url = urlparse(url) base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" - api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {}) + api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(base_url, {}) key = api_config.get("key", None) headers = {"Content-Type": "application/json"} @@ -701,8 +737,8 @@ class GenerateEmbedForm(BaseModel): keep_alive: Optional[Union[int, str]] = None -@app.post("/api/embed") -@app.post("/api/embed/{url_idx}") +@router.post("/api/embed") +@router.post("/api/embed/{url_idx}") async def generate_embeddings( form_data: GenerateEmbedForm, url_idx: Optional[int] = None, @@ -711,8 +747,8 @@ async def generate_embeddings( return await generate_ollama_batch_embeddings(form_data, url_idx) -@app.post("/api/embeddings") -@app.post("/api/embeddings/{url_idx}") +@router.post("/api/embeddings") +@router.post("/api/embeddings/{url_idx}") async def generate_embeddings( form_data: GenerateEmbeddingsForm, url_idx: Optional[int] = None, @@ -744,13 +780,13 @@ async def generate_ollama_embeddings( detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), ) - url = app.state.config.OLLAMA_BASE_URLS[url_idx] + url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") parsed_url = urlparse(url) base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" - api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {}) + api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(base_url, {}) key = api_config.get("key", None) headers = {"Content-Type": "application/json"} @@ -814,13 +850,13 @@ async def generate_ollama_batch_embeddings( detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), ) - url = app.state.config.OLLAMA_BASE_URLS[url_idx] + url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") parsed_url = urlparse(url) base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" - api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {}) + api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(base_url, {}) key = api_config.get("key", None) headers = {"Content-Type": "application/json"} @@ -873,9 +909,10 @@ class GenerateCompletionForm(BaseModel): keep_alive: Optional[Union[int, str]] = None -@app.post("/api/generate") -@app.post("/api/generate/{url_idx}") +@router.post("/api/generate") +@router.post("/api/generate/{url_idx}") async def generate_completion( + request: Request, form_data: GenerateCompletionForm, url_idx: Optional[int] = None, user=Depends(get_verified_user), @@ -897,15 +934,17 @@ async def generate_completion( detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), ) - url = app.state.config.OLLAMA_BASE_URLS[url_idx] - api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) + url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] + api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}) prefix_id = api_config.get("prefix_id", None) if prefix_id: form_data.model = form_data.model.replace(f"{prefix_id}.", "") log.info(f"url: {url}") return await post_streaming_url( - f"{url}/api/generate", form_data.model_dump_json(exclude_none=True).encode() + url=f"{url}/api/generate", + payload=form_data.model_dump_json(exclude_none=True).encode(), + key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS), ) @@ -936,13 +975,14 @@ async def get_ollama_url(url_idx: Optional[int], model: str): detail=ERROR_MESSAGES.MODEL_NOT_FOUND(model), ) url_idx = random.choice(models[model]["urls"]) - url = app.state.config.OLLAMA_BASE_URLS[url_idx] + url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] return url -@app.post("/api/chat") -@app.post("/api/chat/{url_idx}") +@router.post("/api/chat") +@router.post("/api/chat/{url_idx}") async def generate_chat_completion( + request: Request, form_data: GenerateChatCompletionForm, url_idx: Optional[int] = None, user=Depends(get_verified_user), @@ -1003,15 +1043,16 @@ async def generate_chat_completion( parsed_url = urlparse(url) base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" - api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {}) + api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(base_url, {}) prefix_id = api_config.get("prefix_id", None) if prefix_id: payload["model"] = payload["model"].replace(f"{prefix_id}.", "") return await post_streaming_url( - f"{url}/api/chat", - json.dumps(payload), + url=f"{url}/api/chat", + payload=json.dumps(payload), stream=form_data.stream, + key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS), content_type="application/x-ndjson", ) @@ -1043,10 +1084,13 @@ class OpenAICompletionForm(BaseModel): model_config = ConfigDict(extra="allow") -@app.post("/v1/completions") -@app.post("/v1/completions/{url_idx}") +@router.post("/v1/completions") +@router.post("/v1/completions/{url_idx}") async def generate_openai_completion( - form_data: dict, url_idx: Optional[int] = None, user=Depends(get_verified_user) + request: Request, + form_data: dict, + url_idx: Optional[int] = None, + user=Depends(get_verified_user), ): try: form_data = OpenAICompletionForm(**form_data) @@ -1099,22 +1143,24 @@ async def generate_openai_completion( url = await get_ollama_url(url_idx, payload["model"]) log.info(f"url: {url}") - api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) + api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}) prefix_id = api_config.get("prefix_id", None) if prefix_id: payload["model"] = payload["model"].replace(f"{prefix_id}.", "") return await post_streaming_url( - f"{url}/v1/completions", - json.dumps(payload), + url=f"{url}/v1/completions", + payload=json.dumps(payload), stream=payload.get("stream", False), + key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS), ) -@app.post("/v1/chat/completions") -@app.post("/v1/chat/completions/{url_idx}") +@router.post("/v1/chat/completions") +@router.post("/v1/chat/completions/{url_idx}") async def generate_openai_chat_completion( + request: Request, form_data: dict, url_idx: Optional[int] = None, user=Depends(get_verified_user), @@ -1172,21 +1218,23 @@ async def generate_openai_chat_completion( url = await get_ollama_url(url_idx, payload["model"]) log.info(f"url: {url}") - api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) + api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}) prefix_id = api_config.get("prefix_id", None) if prefix_id: payload["model"] = payload["model"].replace(f"{prefix_id}.", "") return await post_streaming_url( - f"{url}/v1/chat/completions", - json.dumps(payload), + url=f"{url}/v1/chat/completions", + payload=json.dumps(payload), stream=payload.get("stream", False), + key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS), ) -@app.get("/v1/models") -@app.get("/v1/models/{url_idx}") +@router.get("/v1/models") +@router.get("/v1/models/{url_idx}") async def get_openai_models( + request: Request, url_idx: Optional[int] = None, user=Depends(get_verified_user), ): @@ -1205,7 +1253,7 @@ async def get_openai_models( ] else: - url = app.state.config.OLLAMA_BASE_URLS[url_idx] + url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] try: r = requests.request(method="GET", url=f"{url}/api/tags") r.raise_for_status() @@ -1329,9 +1377,10 @@ async def download_file_stream( # url = "https://huggingface.co/TheBloke/stablelm-zephyr-3b-GGUF/resolve/main/stablelm-zephyr-3b.Q2_K.gguf" -@app.post("/models/download") -@app.post("/models/download/{url_idx}") +@router.post("/models/download") +@router.post("/models/download/{url_idx}") async def download_model( + request: Request, form_data: UrlForm, url_idx: Optional[int] = None, user=Depends(get_admin_user), @@ -1346,7 +1395,7 @@ async def download_model( if url_idx is None: url_idx = 0 - url = app.state.config.OLLAMA_BASE_URLS[url_idx] + url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] file_name = parse_huggingface_url(form_data.url) @@ -1360,16 +1409,17 @@ async def download_model( return None -@app.post("/models/upload") -@app.post("/models/upload/{url_idx}") +@router.post("/models/upload") +@router.post("/models/upload/{url_idx}") def upload_model( + request: Request, file: UploadFile = File(...), url_idx: Optional[int] = None, user=Depends(get_admin_user), ): if url_idx is None: url_idx = 0 - ollama_url = app.state.config.OLLAMA_BASE_URLS[url_idx] + ollama_url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] file_path = f"{UPLOAD_DIR}/{file.filename}" From df48eac22bbd5665d4e3a2d4f8ca635555449900 Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Wed, 11 Dec 2024 03:38:45 -0800 Subject: [PATCH 04/26] wip --- backend/open_webui/main.py | 12 +- backend/open_webui/routers/ollama.py | 378 ++++++++++++++------------- 2 files changed, 204 insertions(+), 186 deletions(-) diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 308489ee6..2e1929bb3 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -372,6 +372,7 @@ app.state.config.ENABLE_OLLAMA_API = ENABLE_OLLAMA_API app.state.config.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS app.state.config.OLLAMA_API_CONFIGS = OLLAMA_API_CONFIGS +app.state.OLLAMA_MODELS = {} ######################################## # @@ -384,6 +385,7 @@ app.state.config.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS app.state.config.OPENAI_API_KEYS = OPENAI_API_KEYS app.state.config.OPENAI_API_CONFIGS = OPENAI_API_CONFIGS +app.state.OPENAI_MODELS = {} ######################################## # @@ -607,6 +609,14 @@ app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH = ( ) +######################################## +# +# WEBUI +# +######################################## + +app.state.MODELS = {} + ################################## # # ChatCompletion Middleware @@ -1437,7 +1447,7 @@ async def get_all_base_models(): openai_models = openai_models["data"] if app.state.config.ENABLE_OLLAMA_API: - ollama_models = await get_ollama_models() + ollama_models = await ollama.get_all_models() ollama_models = [ { "id": model["model"], diff --git a/backend/open_webui/routers/ollama.py b/backend/open_webui/routers/ollama.py index 8a43d5c52..19bc12e21 100644 --- a/backend/open_webui/routers/ollama.py +++ b/backend/open_webui/routers/ollama.py @@ -1,3 +1,7 @@ +# TODO: Implement a more intelligent load balancing mechanism for distributing requests among multiple backend instances. +# Current implementation uses a simple round-robin approach (random.choice). Consider incorporating algorithms like weighted round-robin, +# least connections, or least response time for better resource utilization and performance optimization. + import asyncio import json import logging @@ -29,6 +33,16 @@ from starlette.background import BackgroundTask from open_webui.models.models import Models +from open_webui.utils.misc import ( + calculate_sha256, +) +from open_webui.utils.payload import ( + apply_model_params_to_body_ollama, + apply_model_params_to_body_openai, + apply_model_system_prompt_to_body, +) +from open_webui.utils.auth import get_admin_user, get_verified_user +from open_webui.utils.access_control import has_access from open_webui.config import ( @@ -41,29 +55,114 @@ from open_webui.env import ( AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST, BYPASS_MODEL_ACCESS_CONTROL, ) - from open_webui.constants import ERROR_MESSAGES -from open_webui.utils.misc import ( - calculate_sha256, -) -from open_webui.utils.payload import ( - apply_model_params_to_body_ollama, - apply_model_params_to_body_openai, - apply_model_system_prompt_to_body, -) -from open_webui.utils.auth import get_admin_user, get_verified_user -from open_webui.utils.access_control import has_access - log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["OLLAMA"]) -router = APIRouter() +########################################## +# +# Utility functions +# +########################################## -# TODO: Implement a more intelligent load balancing mechanism for distributing requests among multiple backend instances. -# Current implementation uses a simple round-robin approach (random.choice). Consider incorporating algorithms like weighted round-robin, -# least connections, or least response time for better resource utilization and performance optimization. + +async def send_get_request(url, key=None): + timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST) + try: + async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: + async with session.get( + url, headers={**({"Authorization": f"Bearer {key}"} if key else {})} + ) as response: + return await response.json() + except Exception as e: + # Handle connection error here + log.error(f"Connection error: {e}") + return None + + +async def send_post_request( + url: str, + payload: Union[str, bytes], + stream: bool = True, + key: Optional[str] = None, + content_type: Optional[str] = None, +): + async def cleanup_response( + response: Optional[aiohttp.ClientResponse], + session: Optional[aiohttp.ClientSession], + ): + if response: + response.close() + if session: + await session.close() + + r = None + try: + session = aiohttp.ClientSession( + trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT) + ) + + r = await session.post( + url, + data=payload, + headers={ + "Content-Type": "application/json", + **({"Authorization": f"Bearer {key}"} if key else {}), + }, + ) + r.raise_for_status() + + if stream: + response_headers = dict(r.headers) + + if content_type: + response_headers["Content-Type"] = content_type + + return StreamingResponse( + r.content, + status_code=r.status, + headers=response_headers, + background=BackgroundTask( + cleanup_response, response=r, session=session + ), + ) + else: + res = await r.json() + await cleanup_response(r, session) + return res + + except Exception as e: + detail = None + + if r is not None: + try: + res = await r.json() + if "error" in res: + detail = f"Ollama: {res.get('error', 'Unknown error')}" + except Exception: + detail = f"Ollama: {e}" + + raise HTTPException( + status_code=r.status if r else 500, + detail=detail if detail else "Open WebUI: Server Connection Error", + ) + + +def get_api_key(url, configs): + parsed_url = urlparse(url) + base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" + return configs.get(base_url, {}).get("key", None) + + +########################################## +# +# API routes +# +########################################## + +router = APIRouter() @router.head("/") @@ -84,35 +183,31 @@ async def verify_connection( url = form_data.url key = form_data.key - headers = {} - if key: - headers["Authorization"] = f"Bearer {key}" - - timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST) - async with aiohttp.ClientSession(timeout=timeout) as session: + async with aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST) + ) as session: try: - async with session.get(f"{url}/api/version", headers=headers) as r: + async with session.get( + f"{url}/api/version", + headers={**({"Authorization": f"Bearer {key}"} if key else {})}, + ) as r: if r.status != 200: - # Extract response error details if available - error_detail = f"HTTP Error: {r.status}" + detail = f"HTTP Error: {r.status}" res = await r.json() + if "error" in res: - error_detail = f"External Error: {res['error']}" - raise Exception(error_detail) - - response_data = await r.json() - return response_data + detail = f"External Error: {res['error']}" + raise Exception(detail) + data = await r.json() + return data except aiohttp.ClientError as e: - # ClientError covers all aiohttp requests issues log.exception(f"Client error: {str(e)}") - # Handle aiohttp-specific connection issues, timeout etc. raise HTTPException( status_code=500, detail="Open WebUI: Server Connection Error" ) except Exception as e: log.exception(f"Unexpected error: {e}") - # Generic error handler in case parsing JSON or other steps fail error_detail = f"Unexpected error: {str(e)}" raise HTTPException(status_code=500, detail=error_detail) @@ -137,8 +232,8 @@ async def update_config( request: Request, form_data: OllamaConfigForm, user=Depends(get_admin_user) ): request.app.state.config.ENABLE_OLLAMA_API = form_data.ENABLE_OLLAMA_API - request.app.state.config.OLLAMA_BASE_URLS = form_data.OLLAMA_BASE_URLS + request.app.state.config.OLLAMA_BASE_URLS = form_data.OLLAMA_BASE_URLS request.app.state.config.OLLAMA_API_CONFIGS = form_data.OLLAMA_API_CONFIGS # Remove any extra configs @@ -154,127 +249,26 @@ async def update_config( } -async def aiohttp_get(url, key=None): - timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST) - try: - headers = {"Authorization": f"Bearer {key}"} if key else {} - async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: - async with session.get(url, headers=headers) as response: - return await response.json() - except Exception as e: - # Handle connection error here - log.error(f"Connection error: {e}") - return None - - -def get_api_key(url, configs): - parsed_url = urlparse(url) - base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" - return configs.get(base_url, {}).get("key", None) - - -async def cleanup_response( - response: Optional[aiohttp.ClientResponse], - session: Optional[aiohttp.ClientSession], -): - if response: - response.close() - if session: - await session.close() - - -async def post_streaming_url( - url: str, - payload: Union[str, bytes], - stream: bool = True, - key: Optional[str] = None, - content_type=None, -): - r = None - try: - session = aiohttp.ClientSession( - trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT) - ) - - headers = {"Content-Type": "application/json"} - if key: - headers["Authorization"] = f"Bearer {key}" - - r = await session.post( - url, - data=payload, - headers=headers, - ) - r.raise_for_status() - - if stream: - response_headers = dict(r.headers) - if content_type: - response_headers["Content-Type"] = content_type - return StreamingResponse( - r.content, - status_code=r.status, - headers=response_headers, - background=BackgroundTask( - cleanup_response, response=r, session=session - ), - ) - else: - res = await r.json() - await cleanup_response(r, session) - return res - - except Exception as e: - error_detail = "Open WebUI: Server Connection Error" - if r is not None: - try: - res = await r.json() - if "error" in res: - error_detail = f"Ollama: {res['error']}" - except Exception: - error_detail = f"Ollama: {e}" - - raise HTTPException( - status_code=r.status if r else 500, - detail=error_detail, - ) - - -def merge_models_lists(model_lists): - merged_models = {} - - for idx, model_list in enumerate(model_lists): - if model_list is not None: - for model in model_list: - id = model["model"] - if id not in merged_models: - model["urls"] = [idx] - merged_models[id] = model - else: - merged_models[id]["urls"].append(idx) - - return list(merged_models.values()) - - @cached(ttl=3) -async def get_all_models(): +async def get_all_models(request: Request): log.info("get_all_models()") if request.app.state.config.ENABLE_OLLAMA_API: - tasks = [] + request_tasks = [] + for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS): if url not in request.app.state.config.OLLAMA_API_CONFIGS: - tasks.append(aiohttp_get(f"{url}/api/tags")) + request_tasks.append(send_get_request(f"{url}/api/tags")) else: api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}) enable = api_config.get("enable", True) key = api_config.get("key", None) if enable: - tasks.append(aiohttp_get(f"{url}/api/tags", key)) + request_tasks.append(send_get_request(f"{url}/api/tags", key)) else: - tasks.append(asyncio.ensure_future(asyncio.sleep(0, None))) + request_tasks.append(asyncio.ensure_future(asyncio.sleep(0, None))) - responses = await asyncio.gather(*tasks) + responses = await asyncio.gather(*request_tasks) for idx, response in enumerate(responses): if response: @@ -296,6 +290,21 @@ async def get_all_models(): for model in response.get("models", []): model["model"] = f"{prefix_id}.{model['model']}" + def merge_models_lists(model_lists): + merged_models = {} + + for idx, model_list in enumerate(model_lists): + if model_list is not None: + for model in model_list: + id = model["model"] + if id not in merged_models: + model["urls"] = [idx] + merged_models[id] = model + else: + merged_models[id]["urls"].append(idx) + + return list(merged_models.values()) + models = { "models": merge_models_lists( map( @@ -311,60 +320,61 @@ async def get_all_models(): return models +async def get_filtered_models(models, user): + # Filter models based on user access control + filtered_models = [] + for model in models.get("models", []): + model_info = Models.get_model_by_id(model["model"]) + if model_info: + if user.id == model_info.user_id or has_access( + user.id, type="read", access_control=model_info.access_control + ): + filtered_models.append(model) + return filtered_models + + @router.get("/api/tags") @router.get("/api/tags/{url_idx}") async def get_ollama_tags( request: Request, url_idx: Optional[int] = None, user=Depends(get_verified_user) ): models = [] + if url_idx is None: models = await get_all_models() else: url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] - - parsed_url = urlparse(url) - base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" - - api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(base_url, {}) - key = api_config.get("key", None) - - headers = {} - if key: - headers["Authorization"] = f"Bearer {key}" + key = get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS) r = None try: - r = requests.request(method="GET", url=f"{url}/api/tags", headers=headers) + r = requests.request( + method="GET", + url=f"{url}/api/tags", + headers={**({"Authorization": f"Bearer {key}"} if key else {})}, + ) r.raise_for_status() models = r.json() except Exception as e: log.exception(e) - error_detail = "Open WebUI: Server Connection Error" + + detail = None if r is not None: try: res = r.json() if "error" in res: - error_detail = f"Ollama: {res['error']}" + detail = f"Ollama: {res['error']}" except Exception: - error_detail = f"Ollama: {e}" + detail = f"Ollama: {e}" raise HTTPException( status_code=r.status_code if r else 500, - detail=error_detail, + detail=detail if detail else "Open WebUI: Server Connection Error", ) if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL: - # Filter models based on user access control - filtered_models = [] - for model in models.get("models", []): - model_info = Models.get_model_by_id(model["model"]) - if model_info: - if user.id == model_info.user_id or has_access( - user.id, type="read", access_control=model_info.access_control - ): - filtered_models.append(model) - models["models"] = filtered_models + models["models"] = get_filtered_models(models, user) return models @@ -376,7 +386,7 @@ async def get_ollama_versions(request: Request, url_idx: Optional[int] = None): if url_idx is None: # returns lowest version tasks = [ - aiohttp_get( + send_get_request( f"{url}/api/version", request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}).get( "key", None @@ -412,18 +422,19 @@ async def get_ollama_versions(request: Request, url_idx: Optional[int] = None): return r.json() except Exception as e: log.exception(e) - error_detail = "Open WebUI: Server Connection Error" + + detail = None if r is not None: try: res = r.json() if "error" in res: - error_detail = f"Ollama: {res['error']}" + detail = f"Ollama: {res['error']}" except Exception: - error_detail = f"Ollama: {e}" + detail = f"Ollama: {e}" raise HTTPException( status_code=r.status_code if r else 500, - detail=error_detail, + detail=detail if detail else "Open WebUI: Server Connection Error", ) else: return {"version": False} @@ -436,7 +447,7 @@ async def get_ollama_loaded_models(request: Request, user=Depends(get_verified_u """ if request.app.state.config.ENABLE_OLLAMA_API: tasks = [ - aiohttp_get( + send_get_request( f"{url}/api/ps", request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}).get( "key", None @@ -469,7 +480,7 @@ async def pull_model( # Admin should be able to pull models from any source payload = {**form_data.model_dump(exclude_none=True), "insecure": True} - return await post_streaming_url( + return await send_post_request( url=f"{url}/api/pull", payload=json.dumps(payload), key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS), @@ -505,7 +516,7 @@ async def push_model( url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] log.debug(f"url: {url}") - return await post_streaming_url( + return await send_post_request( url=f"{url}/api/push", payload=form_data.model_dump_json(exclude_none=True).encode(), key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS), @@ -531,7 +542,7 @@ async def create_model( url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") - return await post_streaming_url( + return await send_post_request( url=f"{url}/api/create", payload=form_data.model_dump_json(exclude_none=True).encode(), key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS), @@ -941,7 +952,7 @@ async def generate_completion( form_data.model = form_data.model.replace(f"{prefix_id}.", "") log.info(f"url: {url}") - return await post_streaming_url( + return await send_post_request( url=f"{url}/api/generate", payload=form_data.model_dump_json(exclude_none=True).encode(), key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS), @@ -966,15 +977,13 @@ class GenerateChatCompletionForm(BaseModel): async def get_ollama_url(url_idx: Optional[int], model: str): if url_idx is None: - model_list = await get_all_models() - models = {model["model"]: model for model in model_list["models"]} - + models = request.app.state.OLLAMA_MODELS if model not in models: raise HTTPException( status_code=400, detail=ERROR_MESSAGES.MODEL_NOT_FOUND(model), ) - url_idx = random.choice(models[model]["urls"]) + url_idx = random.choice(models[model].get("urls", [])) url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] return url @@ -1037,7 +1046,6 @@ async def generate_chat_completion( payload["model"] = f"{payload['model']}:latest" url = await get_ollama_url(url_idx, payload["model"]) - log.info(f"url: {url}") log.debug(f"generate_chat_completion() - 2.payload = {payload}") parsed_url = urlparse(url) @@ -1048,7 +1056,7 @@ async def generate_chat_completion( if prefix_id: payload["model"] = payload["model"].replace(f"{prefix_id}.", "") - return await post_streaming_url( + return await send_post_request( url=f"{url}/api/chat", payload=json.dumps(payload), stream=form_data.stream, @@ -1149,7 +1157,7 @@ async def generate_openai_completion( if prefix_id: payload["model"] = payload["model"].replace(f"{prefix_id}.", "") - return await post_streaming_url( + return await send_post_request( url=f"{url}/v1/completions", payload=json.dumps(payload), stream=payload.get("stream", False), @@ -1223,7 +1231,7 @@ async def generate_openai_chat_completion( if prefix_id: payload["model"] = payload["model"].replace(f"{prefix_id}.", "") - return await post_streaming_url( + return await send_post_request( url=f"{url}/v1/chat/completions", payload=json.dumps(payload), stream=payload.get("stream", False), From df0cdd9f3ca064d5d8498538a32a675697d3cac9 Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Wed, 11 Dec 2024 04:37:47 -0800 Subject: [PATCH 05/26] wip --- backend/open_webui/routers/audio.py | 474 ++++++++++++++------------- backend/open_webui/routers/ollama.py | 377 +++++++++------------ 2 files changed, 404 insertions(+), 447 deletions(-) diff --git a/backend/open_webui/routers/audio.py b/backend/open_webui/routers/audio.py index 3203727a7..d410369af 100644 --- a/backend/open_webui/routers/audio.py +++ b/backend/open_webui/routers/audio.py @@ -11,25 +11,27 @@ from pydub.silence import split_on_silence import aiohttp import aiofiles import requests + +from fastapi import ( + Depends, + FastAPI, + File, + HTTPException, + Request, + UploadFile, + status, + APIRouter, +) +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import FileResponse +from pydantic import BaseModel + + +from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.config import ( - AUDIO_STT_ENGINE, - AUDIO_STT_MODEL, - AUDIO_STT_OPENAI_API_BASE_URL, - AUDIO_STT_OPENAI_API_KEY, - AUDIO_TTS_API_KEY, - AUDIO_TTS_ENGINE, - AUDIO_TTS_MODEL, - AUDIO_TTS_OPENAI_API_BASE_URL, - AUDIO_TTS_OPENAI_API_KEY, - AUDIO_TTS_SPLIT_ON, - AUDIO_TTS_VOICE, - AUDIO_TTS_AZURE_SPEECH_REGION, - AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT, - WHISPER_MODEL, WHISPER_MODEL_AUTO_UPDATE, WHISPER_MODEL_DIR, CACHE_DIR, - AppConfig, ) from open_webui.constants import ERROR_MESSAGES @@ -40,78 +42,25 @@ from open_webui.env import ( ENABLE_FORWARD_USER_INFO_HEADERS, ) -from fastapi import Depends, FastAPI, File, HTTPException, Request, UploadFile, status -from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import FileResponse -from pydantic import BaseModel -from open_webui.utils.auth import get_admin_user, get_verified_user + +router = APIRouter() # Constants MAX_FILE_SIZE_MB = 25 MAX_FILE_SIZE = MAX_FILE_SIZE_MB * 1024 * 1024 # Convert MB to bytes - log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["AUDIO"]) - -# setting device type for whisper model -whisper_device_type = DEVICE_TYPE if DEVICE_TYPE and DEVICE_TYPE == "cuda" else "cpu" -log.info(f"whisper_device_type: {whisper_device_type}") - SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/") SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True) -def set_faster_whisper_model(model: str, auto_update: bool = False): - if model and app.state.config.STT_ENGINE == "": - from faster_whisper import WhisperModel - - faster_whisper_kwargs = { - "model_size_or_path": model, - "device": whisper_device_type, - "compute_type": "int8", - "download_root": WHISPER_MODEL_DIR, - "local_files_only": not auto_update, - } - - try: - app.state.faster_whisper_model = WhisperModel(**faster_whisper_kwargs) - except Exception: - log.warning( - "WhisperModel initialization failed, attempting download with local_files_only=False" - ) - faster_whisper_kwargs["local_files_only"] = False - app.state.faster_whisper_model = WhisperModel(**faster_whisper_kwargs) - - else: - app.state.faster_whisper_model = None - - -class TTSConfigForm(BaseModel): - OPENAI_API_BASE_URL: str - OPENAI_API_KEY: str - API_KEY: str - ENGINE: str - MODEL: str - VOICE: str - SPLIT_ON: str - AZURE_SPEECH_REGION: str - AZURE_SPEECH_OUTPUT_FORMAT: str - - -class STTConfigForm(BaseModel): - OPENAI_API_BASE_URL: str - OPENAI_API_KEY: str - ENGINE: str - MODEL: str - WHISPER_MODEL: str - - -class AudioConfigUpdateForm(BaseModel): - tts: TTSConfigForm - stt: STTConfigForm - +########################################## +# +# Utility functions +# +########################################## from pydub import AudioSegment from pydub.utils import mediainfo @@ -140,71 +89,124 @@ def convert_mp4_to_wav(file_path, output_path): print(f"Converted {file_path} to {output_path}") -@app.get("/config") -async def get_audio_config(user=Depends(get_admin_user)): +def set_faster_whisper_model(model: str, auto_update: bool = False): + whisper_model = None + if model: + from faster_whisper import WhisperModel + + faster_whisper_kwargs = { + "model_size_or_path": model, + "device": DEVICE_TYPE if DEVICE_TYPE and DEVICE_TYPE == "cuda" else "cpu", + "compute_type": "int8", + "download_root": WHISPER_MODEL_DIR, + "local_files_only": not auto_update, + } + + try: + whisper_model = WhisperModel(**faster_whisper_kwargs) + except Exception: + log.warning( + "WhisperModel initialization failed, attempting download with local_files_only=False" + ) + faster_whisper_kwargs["local_files_only"] = False + whisper_model = WhisperModel(**faster_whisper_kwargs) + return whisper_model + + +class TTSConfigForm(BaseModel): + OPENAI_API_BASE_URL: str + OPENAI_API_KEY: str + API_KEY: str + ENGINE: str + MODEL: str + VOICE: str + SPLIT_ON: str + AZURE_SPEECH_REGION: str + AZURE_SPEECH_OUTPUT_FORMAT: str + + +class STTConfigForm(BaseModel): + OPENAI_API_BASE_URL: str + OPENAI_API_KEY: str + ENGINE: str + MODEL: str + WHISPER_MODEL: str + + +class AudioConfigUpdateForm(BaseModel): + tts: TTSConfigForm + stt: STTConfigForm + + +@router.get("/config") +async def get_audio_config(request: Request, user=Depends(get_admin_user)): return { "tts": { - "OPENAI_API_BASE_URL": app.state.config.TTS_OPENAI_API_BASE_URL, - "OPENAI_API_KEY": app.state.config.TTS_OPENAI_API_KEY, - "API_KEY": app.state.config.TTS_API_KEY, - "ENGINE": app.state.config.TTS_ENGINE, - "MODEL": app.state.config.TTS_MODEL, - "VOICE": app.state.config.TTS_VOICE, - "SPLIT_ON": app.state.config.TTS_SPLIT_ON, - "AZURE_SPEECH_REGION": app.state.config.TTS_AZURE_SPEECH_REGION, - "AZURE_SPEECH_OUTPUT_FORMAT": app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT, + "OPENAI_API_BASE_URL": request.app.state.config.TTS_OPENAI_API_BASE_URL, + "OPENAI_API_KEY": request.app.state.config.TTS_OPENAI_API_KEY, + "API_KEY": request.app.state.config.TTS_API_KEY, + "ENGINE": request.app.state.config.TTS_ENGINE, + "MODEL": request.app.state.config.TTS_MODEL, + "VOICE": request.app.state.config.TTS_VOICE, + "SPLIT_ON": request.app.state.config.TTS_SPLIT_ON, + "AZURE_SPEECH_REGION": request.app.state.config.TTS_AZURE_SPEECH_REGION, + "AZURE_SPEECH_OUTPUT_FORMAT": request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT, }, "stt": { - "OPENAI_API_BASE_URL": app.state.config.STT_OPENAI_API_BASE_URL, - "OPENAI_API_KEY": app.state.config.STT_OPENAI_API_KEY, - "ENGINE": app.state.config.STT_ENGINE, - "MODEL": app.state.config.STT_MODEL, - "WHISPER_MODEL": app.state.config.WHISPER_MODEL, + "OPENAI_API_BASE_URL": request.app.state.config.STT_OPENAI_API_BASE_URL, + "OPENAI_API_KEY": request.app.state.config.STT_OPENAI_API_KEY, + "ENGINE": request.app.state.config.STT_ENGINE, + "MODEL": request.app.state.config.STT_MODEL, + "WHISPER_MODEL": request.app.state.config.WHISPER_MODEL, }, } -@app.post("/config/update") +@router.post("/config/update") async def update_audio_config( - form_data: AudioConfigUpdateForm, user=Depends(get_admin_user) + request: Request, form_data: AudioConfigUpdateForm, user=Depends(get_admin_user) ): - app.state.config.TTS_OPENAI_API_BASE_URL = form_data.tts.OPENAI_API_BASE_URL - app.state.config.TTS_OPENAI_API_KEY = form_data.tts.OPENAI_API_KEY - app.state.config.TTS_API_KEY = form_data.tts.API_KEY - app.state.config.TTS_ENGINE = form_data.tts.ENGINE - app.state.config.TTS_MODEL = form_data.tts.MODEL - app.state.config.TTS_VOICE = form_data.tts.VOICE - app.state.config.TTS_SPLIT_ON = form_data.tts.SPLIT_ON - app.state.config.TTS_AZURE_SPEECH_REGION = form_data.tts.AZURE_SPEECH_REGION - app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT = ( + request.app.state.config.TTS_OPENAI_API_BASE_URL = form_data.tts.OPENAI_API_BASE_URL + request.app.state.config.TTS_OPENAI_API_KEY = form_data.tts.OPENAI_API_KEY + request.app.state.config.TTS_API_KEY = form_data.tts.API_KEY + request.app.state.config.TTS_ENGINE = form_data.tts.ENGINE + request.app.state.config.TTS_MODEL = form_data.tts.MODEL + request.app.state.config.TTS_VOICE = form_data.tts.VOICE + request.app.state.config.TTS_SPLIT_ON = form_data.tts.SPLIT_ON + request.app.state.config.TTS_AZURE_SPEECH_REGION = form_data.tts.AZURE_SPEECH_REGION + request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT = ( form_data.tts.AZURE_SPEECH_OUTPUT_FORMAT ) - app.state.config.STT_OPENAI_API_BASE_URL = form_data.stt.OPENAI_API_BASE_URL - app.state.config.STT_OPENAI_API_KEY = form_data.stt.OPENAI_API_KEY - app.state.config.STT_ENGINE = form_data.stt.ENGINE - app.state.config.STT_MODEL = form_data.stt.MODEL - app.state.config.WHISPER_MODEL = form_data.stt.WHISPER_MODEL - set_faster_whisper_model(form_data.stt.WHISPER_MODEL, WHISPER_MODEL_AUTO_UPDATE) + request.app.state.config.STT_OPENAI_API_BASE_URL = form_data.stt.OPENAI_API_BASE_URL + request.app.state.config.STT_OPENAI_API_KEY = form_data.stt.OPENAI_API_KEY + request.app.state.config.STT_ENGINE = form_data.stt.ENGINE + request.app.state.config.STT_MODEL = form_data.stt.MODEL + request.app.state.config.WHISPER_MODEL = form_data.stt.WHISPER_MODEL + + if request.app.state.config.STT_ENGINE == "": + request.app.state.faster_whisper_model = set_faster_whisper_model( + form_data.stt.WHISPER_MODEL, WHISPER_MODEL_AUTO_UPDATE + ) return { "tts": { - "OPENAI_API_BASE_URL": app.state.config.TTS_OPENAI_API_BASE_URL, - "OPENAI_API_KEY": app.state.config.TTS_OPENAI_API_KEY, - "API_KEY": app.state.config.TTS_API_KEY, - "ENGINE": app.state.config.TTS_ENGINE, - "MODEL": app.state.config.TTS_MODEL, - "VOICE": app.state.config.TTS_VOICE, - "SPLIT_ON": app.state.config.TTS_SPLIT_ON, - "AZURE_SPEECH_REGION": app.state.config.TTS_AZURE_SPEECH_REGION, - "AZURE_SPEECH_OUTPUT_FORMAT": app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT, + "OPENAI_API_BASE_URL": request.app.state.config.TTS_OPENAI_API_BASE_URL, + "OPENAI_API_KEY": request.app.state.config.TTS_OPENAI_API_KEY, + "API_KEY": request.app.state.config.TTS_API_KEY, + "ENGINE": request.app.state.config.TTS_ENGINE, + "MODEL": request.app.state.config.TTS_MODEL, + "VOICE": request.app.state.config.TTS_VOICE, + "SPLIT_ON": request.app.state.config.TTS_SPLIT_ON, + "AZURE_SPEECH_REGION": request.app.state.config.TTS_AZURE_SPEECH_REGION, + "AZURE_SPEECH_OUTPUT_FORMAT": request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT, }, "stt": { - "OPENAI_API_BASE_URL": app.state.config.STT_OPENAI_API_BASE_URL, - "OPENAI_API_KEY": app.state.config.STT_OPENAI_API_KEY, - "ENGINE": app.state.config.STT_ENGINE, - "MODEL": app.state.config.STT_MODEL, - "WHISPER_MODEL": app.state.config.WHISPER_MODEL, + "OPENAI_API_BASE_URL": request.app.state.config.STT_OPENAI_API_BASE_URL, + "OPENAI_API_KEY": request.app.state.config.STT_OPENAI_API_KEY, + "ENGINE": request.app.state.config.STT_ENGINE, + "MODEL": request.app.state.config.STT_MODEL, + "WHISPER_MODEL": request.app.state.config.WHISPER_MODEL, }, } @@ -213,18 +215,18 @@ def load_speech_pipeline(): from transformers import pipeline from datasets import load_dataset - if app.state.speech_synthesiser is None: - app.state.speech_synthesiser = pipeline( + if request.app.state.speech_synthesiser is None: + request.app.state.speech_synthesiser = pipeline( "text-to-speech", "microsoft/speecht5_tts" ) - if app.state.speech_speaker_embeddings_dataset is None: - app.state.speech_speaker_embeddings_dataset = load_dataset( + if request.app.state.speech_speaker_embeddings_dataset is None: + request.app.state.speech_speaker_embeddings_dataset = load_dataset( "Matthijs/cmu-arctic-xvectors", split="validation" ) -@app.post("/speech") +@router.post("/speech") async def speech(request: Request, user=Depends(get_verified_user)): body = await request.body() name = hashlib.sha256(body).hexdigest() @@ -236,9 +238,11 @@ async def speech(request: Request, user=Depends(get_verified_user)): if file_path.is_file(): return FileResponse(file_path) - if app.state.config.TTS_ENGINE == "openai": + if request.app.state.config.TTS_ENGINE == "openai": headers = {} - headers["Authorization"] = f"Bearer {app.state.config.TTS_OPENAI_API_KEY}" + headers["Authorization"] = ( + f"Bearer {request.app.state.config.TTS_OPENAI_API_KEY}" + ) headers["Content-Type"] = "application/json" if ENABLE_FORWARD_USER_INFO_HEADERS: @@ -250,7 +254,7 @@ async def speech(request: Request, user=Depends(get_verified_user)): try: body = body.decode("utf-8") body = json.loads(body) - body["model"] = app.state.config.TTS_MODEL + body["model"] = request.app.state.config.TTS_MODEL body = json.dumps(body).encode("utf-8") except Exception: pass @@ -258,7 +262,7 @@ async def speech(request: Request, user=Depends(get_verified_user)): try: async with aiohttp.ClientSession() as session: async with session.post( - url=f"{app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech", + url=f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech", data=body, headers=headers, ) as r: @@ -287,7 +291,7 @@ async def speech(request: Request, user=Depends(get_verified_user)): detail=error_detail, ) - elif app.state.config.TTS_ENGINE == "elevenlabs": + elif request.app.state.config.TTS_ENGINE == "elevenlabs": try: payload = json.loads(body.decode("utf-8")) except Exception as e: @@ -305,11 +309,11 @@ async def speech(request: Request, user=Depends(get_verified_user)): headers = { "Accept": "audio/mpeg", "Content-Type": "application/json", - "xi-api-key": app.state.config.TTS_API_KEY, + "xi-api-key": request.app.state.config.TTS_API_KEY, } data = { "text": payload["input"], - "model_id": app.state.config.TTS_MODEL, + "model_id": request.app.state.config.TTS_MODEL, "voice_settings": {"stability": 0.5, "similarity_boost": 0.5}, } @@ -341,21 +345,21 @@ async def speech(request: Request, user=Depends(get_verified_user)): detail=error_detail, ) - elif app.state.config.TTS_ENGINE == "azure": + elif request.app.state.config.TTS_ENGINE == "azure": try: payload = json.loads(body.decode("utf-8")) except Exception as e: log.exception(e) raise HTTPException(status_code=400, detail="Invalid JSON payload") - region = app.state.config.TTS_AZURE_SPEECH_REGION - language = app.state.config.TTS_VOICE - locale = "-".join(app.state.config.TTS_VOICE.split("-")[:1]) - output_format = app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT + region = request.app.state.config.TTS_AZURE_SPEECH_REGION + language = request.app.state.config.TTS_VOICE + locale = "-".join(request.app.state.config.TTS_VOICE.split("-")[:1]) + output_format = request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT url = f"https://{region}.tts.speech.microsoft.com/cognitiveservices/v1" headers = { - "Ocp-Apim-Subscription-Key": app.state.config.TTS_API_KEY, + "Ocp-Apim-Subscription-Key": request.app.state.config.TTS_API_KEY, "Content-Type": "application/ssml+xml", "X-Microsoft-OutputFormat": output_format, } @@ -378,7 +382,7 @@ async def speech(request: Request, user=Depends(get_verified_user)): except Exception as e: log.exception(e) raise HTTPException(status_code=500, detail=str(e)) - elif app.state.config.TTS_ENGINE == "transformers": + elif request.app.state.config.TTS_ENGINE == "transformers": payload = None try: payload = json.loads(body.decode("utf-8")) @@ -391,12 +395,12 @@ async def speech(request: Request, user=Depends(get_verified_user)): load_speech_pipeline() - embeddings_dataset = app.state.speech_speaker_embeddings_dataset + embeddings_dataset = request.app.state.speech_speaker_embeddings_dataset speaker_index = 6799 try: speaker_index = embeddings_dataset["filename"].index( - app.state.config.TTS_MODEL + request.app.state.config.TTS_MODEL ) except Exception: pass @@ -405,7 +409,7 @@ async def speech(request: Request, user=Depends(get_verified_user)): embeddings_dataset[speaker_index]["xvector"] ).unsqueeze(0) - speech = app.state.speech_synthesiser( + speech = request.app.state.speech_synthesiser( payload["input"], forward_params={"speaker_embeddings": speaker_embedding}, ) @@ -417,17 +421,19 @@ async def speech(request: Request, user=Depends(get_verified_user)): return FileResponse(file_path) -def transcribe(file_path): +def transcribe(request: Request, file_path): print("transcribe", file_path) filename = os.path.basename(file_path) file_dir = os.path.dirname(file_path) id = filename.split(".")[0] - if app.state.config.STT_ENGINE == "": - if app.state.faster_whisper_model is None: - set_faster_whisper_model(app.state.config.WHISPER_MODEL) + if request.app.state.config.STT_ENGINE == "": + if request.app.state.faster_whisper_model is None: + request.app.state.faster_whisper_model = set_faster_whisper_model( + request.app.state.config.WHISPER_MODEL + ) - model = app.state.faster_whisper_model + model = request.app.state.faster_whisper_model segments, info = model.transcribe(file_path, beam_size=5) log.info( "Detected language '%s' with probability %f" @@ -444,31 +450,24 @@ def transcribe(file_path): log.debug(data) return data - elif app.state.config.STT_ENGINE == "openai": + elif request.app.state.config.STT_ENGINE == "openai": if is_mp4_audio(file_path): - print("is_mp4_audio") os.rename(file_path, file_path.replace(".wav", ".mp4")) # Convert MP4 audio file to WAV format convert_mp4_to_wav(file_path.replace(".wav", ".mp4"), file_path) - headers = {"Authorization": f"Bearer {app.state.config.STT_OPENAI_API_KEY}"} - - files = {"file": (filename, open(file_path, "rb"))} - data = {"model": app.state.config.STT_MODEL} - - log.debug(files, data) - r = None try: r = requests.post( - url=f"{app.state.config.STT_OPENAI_API_BASE_URL}/audio/transcriptions", - headers=headers, - files=files, - data=data, + url=f"{request.app.state.config.STT_OPENAI_API_BASE_URL}/audio/transcriptions", + headers={ + "Authorization": f"Bearer {request.app.state.config.STT_OPENAI_API_KEY}" + }, + files={"file": (filename, open(file_path, "rb"))}, + data={"model": request.app.state.config.STT_MODEL}, ) r.raise_for_status() - data = r.json() # save the transcript to a json file @@ -476,24 +475,43 @@ def transcribe(file_path): with open(transcript_file, "w") as f: json.dump(data, f) - print(data) return data except Exception as e: log.exception(e) - error_detail = "Open WebUI: Server Connection Error" + + detail = None if r is not None: try: res = r.json() if "error" in res: - error_detail = f"External: {res['error']['message']}" + detail = f"External: {res['error'].get('message', '')}" except Exception: - error_detail = f"External: {e}" + detail = f"External: {e}" - raise Exception(error_detail) + raise Exception(detail if detail else "Open WebUI: Server Connection Error") -@app.post("/transcriptions") +def compress_audio(file_path): + if os.path.getsize(file_path) > MAX_FILE_SIZE: + file_dir = os.path.dirname(file_path) + audio = AudioSegment.from_file(file_path) + audio = audio.set_frame_rate(16000).set_channels(1) # Compress audio + compressed_path = f"{file_dir}/{id}_compressed.opus" + audio.export(compressed_path, format="opus", bitrate="32k") + log.debug(f"Compressed audio to {compressed_path}") + + if ( + os.path.getsize(compressed_path) > MAX_FILE_SIZE + ): # Still larger than MAX_FILE_SIZE after compression + raise Exception(ERROR_MESSAGES.FILE_TOO_LARGE(size=f"{MAX_FILE_SIZE_MB}MB")) + return compressed_path + else: + return file_path + + +@router.post("/transcriptions") def transcription( + request: Request, file: UploadFile = File(...), user=Depends(get_verified_user), ): @@ -520,36 +538,22 @@ def transcription( f.write(contents) try: - if os.path.getsize(file_path) > MAX_FILE_SIZE: # file is bigger than 25MB - log.debug(f"File size is larger than {MAX_FILE_SIZE_MB}MB") - audio = AudioSegment.from_file(file_path) - audio = audio.set_frame_rate(16000).set_channels(1) # Compress audio - compressed_path = f"{file_dir}/{id}_compressed.opus" - audio.export(compressed_path, format="opus", bitrate="32k") - log.debug(f"Compressed audio to {compressed_path}") - file_path = compressed_path + try: + file_path = compress_audio(file_path) + except Exception as e: + log.exception(e) - if ( - os.path.getsize(file_path) > MAX_FILE_SIZE - ): # Still larger than 25MB after compression - log.debug( - f"Compressed file size is still larger than {MAX_FILE_SIZE_MB}MB: {os.path.getsize(file_path)}" - ) - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.FILE_TOO_LARGE( - size=f"{MAX_FILE_SIZE_MB}MB" - ), - ) - - data = transcribe(file_path) - else: - data = transcribe(file_path) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT(e), + ) + data = transcribe(request, file_path) file_path = file_path.split("/")[-1] return {**data, "filename": file_path} except Exception as e: log.exception(e) + raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT(e), @@ -564,39 +568,41 @@ def transcription( ) -def get_available_models() -> list[dict]: - if app.state.config.TTS_ENGINE == "openai": - return [{"id": "tts-1"}, {"id": "tts-1-hd"}] - elif app.state.config.TTS_ENGINE == "elevenlabs": - headers = { - "xi-api-key": app.state.config.TTS_API_KEY, - "Content-Type": "application/json", - } - +def get_available_models(request: Request) -> list[dict]: + available_models = [] + if request.app.state.config.TTS_ENGINE == "openai": + available_models = [{"id": "tts-1"}, {"id": "tts-1-hd"}] + elif request.app.state.config.TTS_ENGINE == "elevenlabs": try: response = requests.get( - "https://api.elevenlabs.io/v1/models", headers=headers, timeout=5 + "https://api.elevenlabs.io/v1/models", + headers={ + "xi-api-key": request.app.state.config.TTS_API_KEY, + "Content-Type": "application/json", + }, + timeout=5, ) response.raise_for_status() models = response.json() - return [ + + available_models = [ {"name": model["name"], "id": model["model_id"]} for model in models ] except requests.RequestException as e: log.error(f"Error fetching voices: {str(e)}") - return [] + return available_models -@app.get("/models") -async def get_models(user=Depends(get_verified_user)): - return {"models": get_available_models()} +@router.get("/models") +async def get_models(request: Request, user=Depends(get_verified_user)): + return {"models": get_available_models(request)} -def get_available_voices() -> dict: +def get_available_voices(request) -> dict: """Returns {voice_id: voice_name} dict""" - ret = {} - if app.state.config.TTS_ENGINE == "openai": - ret = { + available_voices = {} + if request.app.state.config.TTS_ENGINE == "openai": + available_voices = { "alloy": "alloy", "echo": "echo", "fable": "fable", @@ -604,33 +610,38 @@ def get_available_voices() -> dict: "nova": "nova", "shimmer": "shimmer", } - elif app.state.config.TTS_ENGINE == "elevenlabs": + elif request.app.state.config.TTS_ENGINE == "elevenlabs": try: - ret = get_elevenlabs_voices() + available_voices = get_elevenlabs_voices( + api_key=request.app.state.config.TTS_API_KEY + ) except Exception: # Avoided @lru_cache with exception pass - elif app.state.config.TTS_ENGINE == "azure": + elif request.app.state.config.TTS_ENGINE == "azure": try: - region = app.state.config.TTS_AZURE_SPEECH_REGION + region = request.app.state.config.TTS_AZURE_SPEECH_REGION url = f"https://{region}.tts.speech.microsoft.com/cognitiveservices/voices/list" - headers = {"Ocp-Apim-Subscription-Key": app.state.config.TTS_API_KEY} + headers = { + "Ocp-Apim-Subscription-Key": request.app.state.config.TTS_API_KEY + } response = requests.get(url, headers=headers) response.raise_for_status() voices = response.json() + for voice in voices: - ret[voice["ShortName"]] = ( + available_voices[voice["ShortName"]] = ( f"{voice['DisplayName']} ({voice['ShortName']})" ) except requests.RequestException as e: log.error(f"Error fetching voices: {str(e)}") - return ret + return available_voices @lru_cache -def get_elevenlabs_voices() -> dict: +def get_elevenlabs_voices(api_key: str) -> dict: """ Note, set the following in your .env file to use Elevenlabs: AUDIO_TTS_ENGINE=elevenlabs @@ -638,13 +649,16 @@ def get_elevenlabs_voices() -> dict: AUDIO_TTS_VOICE=EXAVITQu4vr4xnSDxMaL # From https://api.elevenlabs.io/v1/voices AUDIO_TTS_MODEL=eleven_multilingual_v2 """ - headers = { - "xi-api-key": app.state.config.TTS_API_KEY, - "Content-Type": "application/json", - } + try: # TODO: Add retries - response = requests.get("https://api.elevenlabs.io/v1/voices", headers=headers) + response = requests.get( + "https://api.elevenlabs.io/v1/voices", + headers={ + "xi-api-key": api_key, + "Content-Type": "application/json", + }, + ) response.raise_for_status() voices_data = response.json() @@ -659,6 +673,10 @@ def get_elevenlabs_voices() -> dict: return voices -@app.get("/voices") -async def get_voices(user=Depends(get_verified_user)): - return {"voices": [{"id": k, "name": v} for k, v in get_available_voices().items()]} +@router.get("/voices") +async def get_voices(request: Request, user=Depends(get_verified_user)): + return { + "voices": [ + {"id": k, "name": v} for k, v in get_available_voices(request).items() + ] + } diff --git a/backend/open_webui/routers/ollama.py b/backend/open_webui/routers/ollama.py index 19bc12e21..082d14ec3 100644 --- a/backend/open_webui/routers/ollama.py +++ b/backend/open_webui/routers/ollama.py @@ -385,7 +385,7 @@ async def get_ollama_versions(request: Request, url_idx: Optional[int] = None): if request.app.state.config.ENABLE_OLLAMA_API: if url_idx is None: # returns lowest version - tasks = [ + request_tasks = [ send_get_request( f"{url}/api/version", request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}).get( @@ -394,7 +394,7 @@ async def get_ollama_versions(request: Request, url_idx: Optional[int] = None): ) for url in request.app.state.config.OLLAMA_BASE_URLS ] - responses = await asyncio.gather(*tasks) + responses = await asyncio.gather(*request_tasks) responses = list(filter(lambda x: x is not None, responses)) if len(responses) > 0: @@ -446,7 +446,7 @@ async def get_ollama_loaded_models(request: Request, user=Depends(get_verified_u List models that are currently loaded into Ollama memory, and which node they are loaded on. """ if request.app.state.config.ENABLE_OLLAMA_API: - tasks = [ + request_tasks = [ send_get_request( f"{url}/api/ps", request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}).get( @@ -455,7 +455,7 @@ async def get_ollama_loaded_models(request: Request, user=Depends(get_verified_u ) for url in request.app.state.config.OLLAMA_BASE_URLS ] - responses = await asyncio.gather(*tasks) + responses = await asyncio.gather(*request_tasks) return dict(zip(request.app.state.config.OLLAMA_BASE_URLS, responses)) else: @@ -502,8 +502,8 @@ async def push_model( user=Depends(get_admin_user), ): if url_idx is None: - model_list = await get_all_models() - models = {model["model"]: model for model in model_list["models"]} + await get_all_models(request) + models = request.app.state.OLLAMA_MODELS if form_data.name in models: url_idx = models[form_data.name]["urls"][0] @@ -540,7 +540,6 @@ async def create_model( ): log.debug(f"form_data: {form_data}") url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] - log.info(f"url: {url}") return await send_post_request( url=f"{url}/api/create", @@ -563,8 +562,8 @@ async def copy_model( user=Depends(get_admin_user), ): if url_idx is None: - model_list = await get_all_models() - models = {model["model"]: model for model in model_list["models"]} + await get_all_models() + models = request.app.state.OLLAMA_MODELS if form_data.source in models: url_idx = models[form_data.source]["urls"][0] @@ -575,45 +574,37 @@ async def copy_model( ) url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] - log.info(f"url: {url}") - - parsed_url = urlparse(url) - base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" - - api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(base_url, {}) - key = api_config.get("key", None) - - headers = {"Content-Type": "application/json"} - if key: - headers["Authorization"] = f"Bearer {key}" - - r = requests.request( - method="POST", - url=f"{url}/api/copy", - headers=headers, - data=form_data.model_dump_json(exclude_none=True).encode(), - ) + key = get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS) try: + r = requests.request( + method="POST", + url=f"{url}/api/copy", + headers={ + "Content-Type": "application/json", + **({"Authorization": f"Bearer {key}"} if key else {}), + }, + data=form_data.model_dump_json(exclude_none=True).encode(), + ) r.raise_for_status() log.debug(f"r.text: {r.text}") - return True except Exception as e: log.exception(e) - error_detail = "Open WebUI: Server Connection Error" + + detail = None if r is not None: try: res = r.json() if "error" in res: - error_detail = f"Ollama: {res['error']}" + detail = f"Ollama: {res['error']}" except Exception: - error_detail = f"Ollama: {e}" + detail = f"Ollama: {e}" raise HTTPException( status_code=r.status_code if r else 500, - detail=error_detail, + detail=detail if detail else "Open WebUI: Server Connection Error", ) @@ -626,8 +617,8 @@ async def delete_model( user=Depends(get_admin_user), ): if url_idx is None: - model_list = await get_all_models() - models = {model["model"]: model for model in model_list["models"]} + await get_all_models() + models = request.app.state.OLLAMA_MODELS if form_data.name in models: url_idx = models[form_data.name]["urls"][0] @@ -638,44 +629,37 @@ async def delete_model( ) url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] - log.info(f"url: {url}") + key = get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS) - parsed_url = urlparse(url) - base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" - - api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(base_url, {}) - key = api_config.get("key", None) - - headers = {"Content-Type": "application/json"} - if key: - headers["Authorization"] = f"Bearer {key}" - - r = requests.request( - method="DELETE", - url=f"{url}/api/delete", - data=form_data.model_dump_json(exclude_none=True).encode(), - headers=headers, - ) try: + r = requests.request( + method="DELETE", + url=f"{url}/api/delete", + data=form_data.model_dump_json(exclude_none=True).encode(), + headers={ + "Content-Type": "application/json", + **({"Authorization": f"Bearer {key}"} if key else {}), + }, + ) r.raise_for_status() log.debug(f"r.text: {r.text}") - return True except Exception as e: log.exception(e) - error_detail = "Open WebUI: Server Connection Error" + + detail = None if r is not None: try: res = r.json() if "error" in res: - error_detail = f"Ollama: {res['error']}" + detail = f"Ollama: {res['error']}" except Exception: - error_detail = f"Ollama: {e}" + detail = f"Ollama: {e}" raise HTTPException( status_code=r.status_code if r else 500, - detail=error_detail, + detail=detail if detail else "Open WebUI: Server Connection Error", ) @@ -683,8 +667,8 @@ async def delete_model( async def show_model_info( request: Request, form_data: ModelNameForm, user=Depends(get_verified_user) ): - model_list = await get_all_models() - models = {model["model"]: model for model in model_list["models"]} + await get_all_models() + models = request.app.state.OLLAMA_MODELS if form_data.name not in models: raise HTTPException( @@ -693,53 +677,41 @@ async def show_model_info( ) url_idx = random.choice(models[form_data.name]["urls"]) + url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] - log.info(f"url: {url}") + key = get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS) - parsed_url = urlparse(url) - base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" - - api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(base_url, {}) - key = api_config.get("key", None) - - headers = {"Content-Type": "application/json"} - if key: - headers["Authorization"] = f"Bearer {key}" - - r = requests.request( - method="POST", - url=f"{url}/api/show", - headers=headers, - data=form_data.model_dump_json(exclude_none=True).encode(), - ) try: + r = requests.request( + method="POST", + url=f"{url}/api/show", + headers={ + "Content-Type": "application/json", + **({"Authorization": f"Bearer {key}"} if key else {}), + }, + data=form_data.model_dump_json(exclude_none=True).encode(), + ) r.raise_for_status() return r.json() except Exception as e: log.exception(e) - error_detail = "Open WebUI: Server Connection Error" + + detail = None if r is not None: try: res = r.json() if "error" in res: - error_detail = f"Ollama: {res['error']}" + detail = f"Ollama: {res['error']}" except Exception: - error_detail = f"Ollama: {e}" + detail = f"Ollama: {e}" raise HTTPException( status_code=r.status_code if r else 500, - detail=error_detail, + detail=detail if detail else "Open WebUI: Server Connection Error", ) -class GenerateEmbeddingsForm(BaseModel): - model: str - prompt: str - options: Optional[dict] = None - keep_alive: Optional[Union[int, str]] = None - - class GenerateEmbedForm(BaseModel): model: str input: list[str] | str @@ -750,103 +722,17 @@ class GenerateEmbedForm(BaseModel): @router.post("/api/embed") @router.post("/api/embed/{url_idx}") -async def generate_embeddings( +async def embed( + request: Request, form_data: GenerateEmbedForm, url_idx: Optional[int] = None, user=Depends(get_verified_user), -): - return await generate_ollama_batch_embeddings(form_data, url_idx) - - -@router.post("/api/embeddings") -@router.post("/api/embeddings/{url_idx}") -async def generate_embeddings( - form_data: GenerateEmbeddingsForm, - url_idx: Optional[int] = None, - user=Depends(get_verified_user), -): - return await generate_ollama_embeddings(form_data=form_data, url_idx=url_idx) - - -async def generate_ollama_embeddings( - form_data: GenerateEmbeddingsForm, - url_idx: Optional[int] = None, -): - log.info(f"generate_ollama_embeddings {form_data}") - - if url_idx is None: - model_list = await get_all_models() - models = {model["model"]: model for model in model_list["models"]} - - model = form_data.model - - if ":" not in model: - model = f"{model}:latest" - - if model in models: - url_idx = random.choice(models[model]["urls"]) - else: - raise HTTPException( - status_code=400, - detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), - ) - - url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] - log.info(f"url: {url}") - - parsed_url = urlparse(url) - base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" - - api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(base_url, {}) - key = api_config.get("key", None) - - headers = {"Content-Type": "application/json"} - if key: - headers["Authorization"] = f"Bearer {key}" - - r = requests.request( - method="POST", - url=f"{url}/api/embeddings", - headers=headers, - data=form_data.model_dump_json(exclude_none=True).encode(), - ) - try: - r.raise_for_status() - - data = r.json() - - log.info(f"generate_ollama_embeddings {data}") - - if "embedding" in data: - return data - else: - raise Exception("Something went wrong :/") - except Exception as e: - log.exception(e) - error_detail = "Open WebUI: Server Connection Error" - if r is not None: - try: - res = r.json() - if "error" in res: - error_detail = f"Ollama: {res['error']}" - except Exception: - error_detail = f"Ollama: {e}" - - raise HTTPException( - status_code=r.status_code if r else 500, - detail=error_detail, - ) - - -async def generate_ollama_batch_embeddings( - form_data: GenerateEmbedForm, - url_idx: Optional[int] = None, ): log.info(f"generate_ollama_batch_embeddings {form_data}") if url_idx is None: - model_list = await get_all_models() - models = {model["model"]: model for model in model_list["models"]} + await get_all_models() + models = request.app.state.OLLAMA_MODELS model = form_data.model @@ -862,47 +748,107 @@ async def generate_ollama_batch_embeddings( ) url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] - log.info(f"url: {url}") + key = get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS) - parsed_url = urlparse(url) - base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" - - api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(base_url, {}) - key = api_config.get("key", None) - - headers = {"Content-Type": "application/json"} - if key: - headers["Authorization"] = f"Bearer {key}" - - r = requests.request( - method="POST", - url=f"{url}/api/embed", - headers=headers, - data=form_data.model_dump_json(exclude_none=True).encode(), - ) try: + r = requests.request( + method="POST", + url=f"{url}/api/embed", + headers={ + "Content-Type": "application/json", + **({"Authorization": f"Bearer {key}"} if key else {}), + }, + data=form_data.model_dump_json(exclude_none=True).encode(), + ) r.raise_for_status() data = r.json() - - log.info(f"generate_ollama_batch_embeddings {data}") - - if "embeddings" in data: - return data - else: - raise Exception("Something went wrong :/") + return data except Exception as e: log.exception(e) - error_detail = "Open WebUI: Server Connection Error" + + detail = None if r is not None: try: res = r.json() if "error" in res: - error_detail = f"Ollama: {res['error']}" + detail = f"Ollama: {res['error']}" except Exception: - error_detail = f"Ollama: {e}" + detail = f"Ollama: {e}" - raise Exception(error_detail) + raise HTTPException( + status_code=r.status_code if r else 500, + detail=detail if detail else "Open WebUI: Server Connection Error", + ) + + +class GenerateEmbeddingsForm(BaseModel): + model: str + prompt: str + options: Optional[dict] = None + keep_alive: Optional[Union[int, str]] = None + + +@router.post("/api/embeddings") +@router.post("/api/embeddings/{url_idx}") +async def embeddings( + request: Request, + form_data: GenerateEmbeddingsForm, + url_idx: Optional[int] = None, + user=Depends(get_verified_user), +): + log.info(f"generate_ollama_embeddings {form_data}") + + if url_idx is None: + await get_all_models() + models = request.app.state.OLLAMA_MODELS + + model = form_data.model + + if ":" not in model: + model = f"{model}:latest" + + if model in models: + url_idx = random.choice(models[model]["urls"]) + else: + raise HTTPException( + status_code=400, + detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), + ) + + url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] + key = get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS) + + try: + r = requests.request( + method="POST", + url=f"{url}/api/embeddings", + headers={ + "Content-Type": "application/json", + **({"Authorization": f"Bearer {key}"} if key else {}), + }, + data=form_data.model_dump_json(exclude_none=True).encode(), + ) + r.raise_for_status() + + data = r.json() + return data + except Exception as e: + log.exception(e) + + detail = None + if r is not None: + try: + res = r.json() + if "error" in res: + detail = f"Ollama: {res['error']}" + except Exception: + detail = f"Ollama: {e}" + + raise HTTPException( + status_code=r.status_code if r else 500, + detail=detail if detail else "Open WebUI: Server Connection Error", + ) class GenerateCompletionForm(BaseModel): @@ -947,10 +893,10 @@ async def generate_completion( url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}) + prefix_id = api_config.get("prefix_id", None) if prefix_id: form_data.model = form_data.model.replace(f"{prefix_id}.", "") - log.info(f"url: {url}") return await send_post_request( url=f"{url}/api/generate", @@ -975,7 +921,7 @@ class GenerateChatCompletionForm(BaseModel): keep_alive: Optional[Union[int, str]] = None -async def get_ollama_url(url_idx: Optional[int], model: str): +async def get_ollama_url(request: Request, model: str, url_idx: Optional[int] = None): if url_idx is None: models = request.app.state.OLLAMA_MODELS if model not in models: @@ -1001,7 +947,6 @@ async def generate_chat_completion( bypass_filter = True payload = {**form_data.model_dump(exclude_none=True)} - log.debug(f"generate_chat_completion() - 1.payload = {payload}") if "metadata" in payload: del payload["metadata"] @@ -1045,13 +990,9 @@ async def generate_chat_completion( if ":" not in payload["model"]: payload["model"] = f"{payload['model']}:latest" - url = await get_ollama_url(url_idx, payload["model"]) - log.debug(f"generate_chat_completion() - 2.payload = {payload}") + url = await get_ollama_url(request, payload["model"], url_idx) + api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}) - parsed_url = urlparse(url) - base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" - - api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(base_url, {}) prefix_id = api_config.get("prefix_id", None) if prefix_id: payload["model"] = payload["model"].replace(f"{prefix_id}.", "") @@ -1148,10 +1089,9 @@ async def generate_openai_completion( if ":" not in payload["model"]: payload["model"] = f"{payload['model']}:latest" - url = await get_ollama_url(url_idx, payload["model"]) - log.info(f"url: {url}") - + url = await get_ollama_url(request, payload["model"], url_idx) api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}) + prefix_id = api_config.get("prefix_id", None) if prefix_id: @@ -1223,10 +1163,9 @@ async def generate_openai_chat_completion( if ":" not in payload["model"]: payload["model"] = f"{payload['model']}:latest" - url = await get_ollama_url(url_idx, payload["model"]) - log.info(f"url: {url}") - + url = await get_ollama_url(request, payload["model"], url_idx) api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}) + prefix_id = api_config.get("prefix_id", None) if prefix_id: payload["model"] = payload["model"].replace(f"{prefix_id}.", "") From 87d695caad8ff1fc70b50d62cf689837842318f5 Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Wed, 11 Dec 2024 04:47:35 -0800 Subject: [PATCH 06/26] Update audio.py --- backend/open_webui/routers/audio.py | 159 ++++++++++++++++------------ 1 file changed, 90 insertions(+), 69 deletions(-) diff --git a/backend/open_webui/routers/audio.py b/backend/open_webui/routers/audio.py index d410369af..a26355945 100644 --- a/backend/open_webui/routers/audio.py +++ b/backend/open_webui/routers/audio.py @@ -113,6 +113,13 @@ def set_faster_whisper_model(model: str, auto_update: bool = False): return whisper_model +########################################## +# +# Audio API +# +########################################## + + class TTSConfigForm(BaseModel): OPENAI_API_BASE_URL: str OPENAI_API_KEY: str @@ -238,35 +245,38 @@ async def speech(request: Request, user=Depends(get_verified_user)): if file_path.is_file(): return FileResponse(file_path) + payload = None + try: + payload = json.loads(body.decode("utf-8")) + except Exception as e: + log.exception(e) + raise HTTPException(status_code=400, detail="Invalid JSON payload") + if request.app.state.config.TTS_ENGINE == "openai": - headers = {} - headers["Authorization"] = ( - f"Bearer {request.app.state.config.TTS_OPENAI_API_KEY}" - ) - headers["Content-Type"] = "application/json" - - if ENABLE_FORWARD_USER_INFO_HEADERS: - headers["X-OpenWebUI-User-Name"] = user.name - headers["X-OpenWebUI-User-Id"] = user.id - headers["X-OpenWebUI-User-Email"] = user.email - headers["X-OpenWebUI-User-Role"] = user.role - - try: - body = body.decode("utf-8") - body = json.loads(body) - body["model"] = request.app.state.config.TTS_MODEL - body = json.dumps(body).encode("utf-8") - except Exception: - pass + payload["model"] = request.app.state.config.TTS_MODEL try: async with aiohttp.ClientSession() as session: async with session.post( url=f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech", - data=body, - headers=headers, + data=payload, + headers={ + "Content-Type": "application/json", + "Authorization": f"Bearer {request.app.state.config.TTS_OPENAI_API_KEY}", + **( + { + "X-OpenWebUI-User-Name": user.name, + "X-OpenWebUI-User-Id": user.id, + "X-OpenWebUI-User-Email": user.email, + "X-OpenWebUI-User-Role": user.role, + } + if ENABLE_FORWARD_USER_INFO_HEADERS + else {} + ), + }, ) as r: r.raise_for_status() + async with aiofiles.open(file_path, "wb") as f: await f.write(await r.read()) @@ -277,50 +287,47 @@ async def speech(request: Request, user=Depends(get_verified_user)): except Exception as e: log.exception(e) - error_detail = "Open WebUI: Server Connection Error" + detail = None + try: if r.status != 200: res = await r.json() if "error" in res: - error_detail = f"External: {res['error']['message']}" + detail = f"External: {res['error'].get('message', '')}" except Exception: - error_detail = f"External: {e}" + detail = f"External: {e}" raise HTTPException( status_code=getattr(r, "status", 500), - detail=error_detail, + detail=detail if detail else "Open WebUI: Server Connection Error", ) elif request.app.state.config.TTS_ENGINE == "elevenlabs": - try: - payload = json.loads(body.decode("utf-8")) - except Exception as e: - log.exception(e) - raise HTTPException(status_code=400, detail="Invalid JSON payload") - voice_id = payload.get("voice", "") + if voice_id not in get_available_voices(): raise HTTPException( status_code=400, detail="Invalid voice id", ) - url = f"https://api.elevenlabs.io/v1/text-to-speech/{voice_id}" - headers = { - "Accept": "audio/mpeg", - "Content-Type": "application/json", - "xi-api-key": request.app.state.config.TTS_API_KEY, - } - data = { - "text": payload["input"], - "model_id": request.app.state.config.TTS_MODEL, - "voice_settings": {"stability": 0.5, "similarity_boost": 0.5}, - } - try: async with aiohttp.ClientSession() as session: - async with session.post(url, json=data, headers=headers) as r: + async with session.post( + f"https://api.elevenlabs.io/v1/text-to-speech/{voice_id}", + json={ + "text": payload["input"], + "model_id": request.app.state.config.TTS_MODEL, + "voice_settings": {"stability": 0.5, "similarity_boost": 0.5}, + }, + headers={ + "Accept": "audio/mpeg", + "Content-Type": "application/json", + "xi-api-key": request.app.state.config.TTS_API_KEY, + }, + ) as r: r.raise_for_status() + async with aiofiles.open(file_path, "wb") as f: await f.write(await r.read()) @@ -331,18 +338,19 @@ async def speech(request: Request, user=Depends(get_verified_user)): except Exception as e: log.exception(e) - error_detail = "Open WebUI: Server Connection Error" + detail = None + try: if r.status != 200: res = await r.json() if "error" in res: - error_detail = f"External: {res['error']['message']}" + detail = f"External: {res['error'].get('message', '')}" except Exception: - error_detail = f"External: {e}" + detail = f"External: {e}" raise HTTPException( status_code=getattr(r, "status", 500), - detail=error_detail, + detail=detail if detail else "Open WebUI: Server Connection Error", ) elif request.app.state.config.TTS_ENGINE == "azure": @@ -356,32 +364,45 @@ async def speech(request: Request, user=Depends(get_verified_user)): language = request.app.state.config.TTS_VOICE locale = "-".join(request.app.state.config.TTS_VOICE.split("-")[:1]) output_format = request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT - url = f"https://{region}.tts.speech.microsoft.com/cognitiveservices/v1" - - headers = { - "Ocp-Apim-Subscription-Key": request.app.state.config.TTS_API_KEY, - "Content-Type": "application/ssml+xml", - "X-Microsoft-OutputFormat": output_format, - } - - data = f""" - {payload["input"]} - """ try: + data = f""" + {payload["input"]} + """ async with aiohttp.ClientSession() as session: - async with session.post(url, headers=headers, data=data) as response: - if response.status == 200: - async with aiofiles.open(file_path, "wb") as f: - await f.write(await response.read()) - return FileResponse(file_path) - else: - error_msg = f"Error synthesizing speech - {response.reason}" - log.error(error_msg) - raise HTTPException(status_code=500, detail=error_msg) + async with session.post( + f"https://{region}.tts.speech.microsoft.com/cognitiveservices/v1", + headers={ + "Ocp-Apim-Subscription-Key": request.app.state.config.TTS_API_KEY, + "Content-Type": "application/ssml+xml", + "X-Microsoft-OutputFormat": output_format, + }, + data=data, + ) as r: + r.raise_for_status() + + async with aiofiles.open(file_path, "wb") as f: + await f.write(await r.read()) + + return FileResponse(file_path) + except Exception as e: log.exception(e) - raise HTTPException(status_code=500, detail=str(e)) + detail = None + + try: + if r.status != 200: + res = await r.json() + if "error" in res: + detail = f"External: {res['error'].get('message', '')}" + except Exception: + detail = f"External: {e}" + + raise HTTPException( + status_code=getattr(r, "status", 500), + detail=detail if detail else "Open WebUI: Server Connection Error", + ) + elif request.app.state.config.TTS_ENGINE == "transformers": payload = None try: From 3ec0a58cd7cea776b5fdd78eee8868189497d835 Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Wed, 11 Dec 2024 17:50:48 -0800 Subject: [PATCH 07/26] wip --- backend/open_webui/main.py | 6 +- backend/open_webui/routers/chat.py | 411 --------------------- backend/open_webui/routers/ollama.py | 3 + backend/open_webui/routers/openai.py | 510 +++++++++++++++------------ backend/open_webui/routers/webui.py | 93 ----- 5 files changed, 300 insertions(+), 723 deletions(-) delete mode 100644 backend/open_webui/routers/chat.py diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 2e1929bb3..a2d114844 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -1409,9 +1409,9 @@ app.include_router(ollama.router, prefix="/ollama") app.include_router(openai.router, prefix="/openai") -app.include_router(images.router, prefix="/api/v1/images") -app.include_router(audio.router, prefix="/api/v1/audio") -app.include_router(retrieval.router, prefix="/api/v1/retrieval") +app.include_router(images.router, prefix="/api/v1/images", tags=["images"]) +app.include_router(audio.router, prefix="/api/v1/audio", tags=["audio"]) +app.include_router(retrieval.router, prefix="/api/v1/retrieval", tags=["retrieval"]) app.include_router(configs.router, prefix="/api/v1/configs", tags=["configs"]) diff --git a/backend/open_webui/routers/chat.py b/backend/open_webui/routers/chat.py deleted file mode 100644 index fba1ffa1b..000000000 --- a/backend/open_webui/routers/chat.py +++ /dev/null @@ -1,411 +0,0 @@ -from fastapi import APIRouter, Depends, HTTPException, Response, status -from pydantic import BaseModel - -router = APIRouter() - - -@app.post("/api/chat/completions") -async def generate_chat_completions( - request: Request, - form_data: dict, - user=Depends(get_verified_user), - bypass_filter: bool = False, -): - if BYPASS_MODEL_ACCESS_CONTROL: - bypass_filter = True - - model_list = request.state.models - models = {model["id"]: model for model in model_list} - - model_id = form_data["model"] - if model_id not in models: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Model not found", - ) - - model = models[model_id] - - # Check if user has access to the model - if not bypass_filter and user.role == "user": - if model.get("arena"): - if not has_access( - user.id, - type="read", - access_control=model.get("info", {}) - .get("meta", {}) - .get("access_control", {}), - ): - raise HTTPException( - status_code=403, - detail="Model not found", - ) - else: - model_info = Models.get_model_by_id(model_id) - if not model_info: - raise HTTPException( - status_code=404, - detail="Model not found", - ) - elif not ( - user.id == model_info.user_id - or has_access( - user.id, type="read", access_control=model_info.access_control - ) - ): - raise HTTPException( - status_code=403, - detail="Model not found", - ) - - if model["owned_by"] == "arena": - model_ids = model.get("info", {}).get("meta", {}).get("model_ids") - filter_mode = model.get("info", {}).get("meta", {}).get("filter_mode") - if model_ids and filter_mode == "exclude": - model_ids = [ - model["id"] - for model in await get_all_models() - if model.get("owned_by") != "arena" and model["id"] not in model_ids - ] - - selected_model_id = None - if isinstance(model_ids, list) and model_ids: - selected_model_id = random.choice(model_ids) - else: - model_ids = [ - model["id"] - for model in await get_all_models() - if model.get("owned_by") != "arena" - ] - selected_model_id = random.choice(model_ids) - - form_data["model"] = selected_model_id - - if form_data.get("stream") == True: - - async def stream_wrapper(stream): - yield f"data: {json.dumps({'selected_model_id': selected_model_id})}\n\n" - async for chunk in stream: - yield chunk - - response = await generate_chat_completions( - form_data, user, bypass_filter=True - ) - return StreamingResponse( - stream_wrapper(response.body_iterator), media_type="text/event-stream" - ) - else: - return { - **( - await generate_chat_completions(form_data, user, bypass_filter=True) - ), - "selected_model_id": selected_model_id, - } - - if model.get("pipe"): - # Below does not require bypass_filter because this is the only route the uses this function and it is already bypassing the filter - return await generate_function_chat_completion( - form_data, user=user, models=models - ) - if model["owned_by"] == "ollama": - # Using /ollama/api/chat endpoint - form_data = convert_payload_openai_to_ollama(form_data) - form_data = GenerateChatCompletionForm(**form_data) - response = await generate_ollama_chat_completion( - form_data=form_data, user=user, bypass_filter=bypass_filter - ) - if form_data.stream: - response.headers["content-type"] = "text/event-stream" - return StreamingResponse( - convert_streaming_response_ollama_to_openai(response), - headers=dict(response.headers), - ) - else: - return convert_response_ollama_to_openai(response) - else: - return await generate_openai_chat_completion( - form_data, user=user, bypass_filter=bypass_filter - ) - - -@app.post("/api/chat/completed") -async def chat_completed(form_data: dict, user=Depends(get_verified_user)): - model_list = await get_all_models() - models = {model["id"]: model for model in model_list} - - data = form_data - model_id = data["model"] - if model_id not in models: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Model not found", - ) - - model = models[model_id] - sorted_filters = get_sorted_filters(model_id, models) - if "pipeline" in model: - sorted_filters = [model] + sorted_filters - - for filter in sorted_filters: - r = None - try: - urlIdx = filter["urlIdx"] - - url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] - key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] - - if key != "": - headers = {"Authorization": f"Bearer {key}"} - r = requests.post( - f"{url}/{filter['id']}/filter/outlet", - headers=headers, - json={ - "user": { - "id": user.id, - "name": user.name, - "email": user.email, - "role": user.role, - }, - "body": data, - }, - ) - - r.raise_for_status() - data = r.json() - except Exception as e: - # Handle connection error here - print(f"Connection error: {e}") - - if r is not None: - try: - res = r.json() - if "detail" in res: - return JSONResponse( - status_code=r.status_code, - content=res, - ) - except Exception: - pass - - else: - pass - - __event_emitter__ = get_event_emitter( - { - "chat_id": data["chat_id"], - "message_id": data["id"], - "session_id": data["session_id"], - } - ) - - __event_call__ = get_event_call( - { - "chat_id": data["chat_id"], - "message_id": data["id"], - "session_id": data["session_id"], - } - ) - - def get_priority(function_id): - function = Functions.get_function_by_id(function_id) - if function is not None and hasattr(function, "valves"): - # TODO: Fix FunctionModel to include vavles - return (function.valves if function.valves else {}).get("priority", 0) - return 0 - - filter_ids = [function.id for function in Functions.get_global_filter_functions()] - if "info" in model and "meta" in model["info"]: - filter_ids.extend(model["info"]["meta"].get("filterIds", [])) - filter_ids = list(set(filter_ids)) - - enabled_filter_ids = [ - function.id - for function in Functions.get_functions_by_type("filter", active_only=True) - ] - filter_ids = [ - filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids - ] - - # Sort filter_ids by priority, using the get_priority function - filter_ids.sort(key=get_priority) - - for filter_id in filter_ids: - filter = Functions.get_function_by_id(filter_id) - if not filter: - continue - - if filter_id in webui_app.state.FUNCTIONS: - function_module = webui_app.state.FUNCTIONS[filter_id] - else: - function_module, _, _ = load_function_module_by_id(filter_id) - webui_app.state.FUNCTIONS[filter_id] = function_module - - if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): - valves = Functions.get_function_valves_by_id(filter_id) - function_module.valves = function_module.Valves( - **(valves if valves else {}) - ) - - if not hasattr(function_module, "outlet"): - continue - try: - outlet = function_module.outlet - - # Get the signature of the function - sig = inspect.signature(outlet) - params = {"body": data} - - # Extra parameters to be passed to the function - extra_params = { - "__model__": model, - "__id__": filter_id, - "__event_emitter__": __event_emitter__, - "__event_call__": __event_call__, - } - - # Add extra params in contained in function signature - for key, value in extra_params.items(): - if key in sig.parameters: - params[key] = value - - if "__user__" in sig.parameters: - __user__ = { - "id": user.id, - "email": user.email, - "name": user.name, - "role": user.role, - } - - try: - if hasattr(function_module, "UserValves"): - __user__["valves"] = function_module.UserValves( - **Functions.get_user_valves_by_id_and_user_id( - filter_id, user.id - ) - ) - except Exception as e: - print(e) - - params = {**params, "__user__": __user__} - - if inspect.iscoroutinefunction(outlet): - data = await outlet(**params) - else: - data = outlet(**params) - - except Exception as e: - print(f"Error: {e}") - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content={"detail": str(e)}, - ) - - return data - - -@app.post("/api/chat/actions/{action_id}") -async def chat_action(action_id: str, form_data: dict, user=Depends(get_verified_user)): - if "." in action_id: - action_id, sub_action_id = action_id.split(".") - else: - sub_action_id = None - - action = Functions.get_function_by_id(action_id) - if not action: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Action not found", - ) - - model_list = await get_all_models() - models = {model["id"]: model for model in model_list} - - data = form_data - model_id = data["model"] - - if model_id not in models: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Model not found", - ) - model = models[model_id] - - __event_emitter__ = get_event_emitter( - { - "chat_id": data["chat_id"], - "message_id": data["id"], - "session_id": data["session_id"], - } - ) - __event_call__ = get_event_call( - { - "chat_id": data["chat_id"], - "message_id": data["id"], - "session_id": data["session_id"], - } - ) - - if action_id in webui_app.state.FUNCTIONS: - function_module = webui_app.state.FUNCTIONS[action_id] - else: - function_module, _, _ = load_function_module_by_id(action_id) - webui_app.state.FUNCTIONS[action_id] = function_module - - if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): - valves = Functions.get_function_valves_by_id(action_id) - function_module.valves = function_module.Valves(**(valves if valves else {})) - - if hasattr(function_module, "action"): - try: - action = function_module.action - - # Get the signature of the function - sig = inspect.signature(action) - params = {"body": data} - - # Extra parameters to be passed to the function - extra_params = { - "__model__": model, - "__id__": sub_action_id if sub_action_id is not None else action_id, - "__event_emitter__": __event_emitter__, - "__event_call__": __event_call__, - } - - # Add extra params in contained in function signature - for key, value in extra_params.items(): - if key in sig.parameters: - params[key] = value - - if "__user__" in sig.parameters: - __user__ = { - "id": user.id, - "email": user.email, - "name": user.name, - "role": user.role, - } - - try: - if hasattr(function_module, "UserValves"): - __user__["valves"] = function_module.UserValves( - **Functions.get_user_valves_by_id_and_user_id( - action_id, user.id - ) - ) - except Exception as e: - print(e) - - params = {**params, "__user__": __user__} - - if inspect.iscoroutinefunction(action): - data = await action(**params) - else: - data = action(**params) - - except Exception as e: - print(f"Error: {e}") - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content={"detail": str(e)}, - ) - - return data diff --git a/backend/open_webui/routers/ollama.py b/backend/open_webui/routers/ollama.py index 082d14ec3..b217b8f45 100644 --- a/backend/open_webui/routers/ollama.py +++ b/backend/open_webui/routers/ollama.py @@ -317,6 +317,9 @@ async def get_all_models(request: Request): else: models = {"models": []} + request.app.state.OLLAMA_MODELS = { + model["model"]: model for model in models["models"] + } return models diff --git a/backend/open_webui/routers/openai.py b/backend/open_webui/routers/openai.py index 1e9ca4af7..34c5683a8 100644 --- a/backend/open_webui/routers/openai.py +++ b/backend/open_webui/routers/openai.py @@ -10,15 +10,15 @@ from aiocache import cached import requests +from fastapi import Depends, FastAPI, HTTPException, Request, APIRouter +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import FileResponse, StreamingResponse +from pydantic import BaseModel +from starlette.background import BackgroundTask + from open_webui.models.models import Models from open_webui.config import ( CACHE_DIR, - CORS_ALLOW_ORIGIN, - ENABLE_OPENAI_API, - OPENAI_API_BASE_URLS, - OPENAI_API_KEYS, - OPENAI_API_CONFIGS, - AppConfig, ) from open_webui.env import ( AIOHTTP_CLIENT_TIMEOUT, @@ -29,11 +29,7 @@ from open_webui.env import ( from open_webui.constants import ERROR_MESSAGES from open_webui.env import ENV, SRC_LOG_LEVELS -from fastapi import Depends, FastAPI, HTTPException, Request -from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import FileResponse, StreamingResponse -from pydantic import BaseModel -from starlette.background import BackgroundTask + from open_webui.utils.payload import ( apply_model_params_to_body_openai, @@ -48,13 +44,69 @@ log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["OPENAI"]) -@app.get("/config") -async def get_config(user=Depends(get_admin_user)): +########################################## +# +# Utility functions +# +########################################## + + +async def send_get_request(url, key=None): + timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST) + try: + async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: + async with session.get( + url, headers={**({"Authorization": f"Bearer {key}"} if key else {})} + ) as response: + return await response.json() + except Exception as e: + # Handle connection error here + log.error(f"Connection error: {e}") + return None + + +async def cleanup_response( + response: Optional[aiohttp.ClientResponse], + session: Optional[aiohttp.ClientSession], +): + if response: + response.close() + if session: + await session.close() + + +def openai_o1_handler(payload): + """ + Handle O1 specific parameters + """ + if "max_tokens" in payload: + # Remove "max_tokens" from the payload + payload["max_completion_tokens"] = payload["max_tokens"] + del payload["max_tokens"] + + # Fix: O1 does not support the "system" parameter, Modify "system" to "user" + if payload["messages"][0]["role"] == "system": + payload["messages"][0]["role"] = "user" + + return payload + + +########################################## +# +# API routes +# +########################################## + +router = APIRouter() + + +@router.get("/config") +async def get_config(request: Request, user=Depends(get_admin_user)): return { - "ENABLE_OPENAI_API": app.state.config.ENABLE_OPENAI_API, - "OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS, - "OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS, - "OPENAI_API_CONFIGS": app.state.config.OPENAI_API_CONFIGS, + "ENABLE_OPENAI_API": request.app.state.config.ENABLE_OPENAI_API, + "OPENAI_API_BASE_URLS": request.app.state.config.OPENAI_API_BASE_URLS, + "OPENAI_API_KEYS": request.app.state.config.OPENAI_API_KEYS, + "OPENAI_API_CONFIGS": request.app.state.config.OPENAI_API_CONFIGS, } @@ -65,49 +117,56 @@ class OpenAIConfigForm(BaseModel): OPENAI_API_CONFIGS: dict -@app.post("/config/update") -async def update_config(form_data: OpenAIConfigForm, user=Depends(get_admin_user)): - app.state.config.ENABLE_OPENAI_API = form_data.ENABLE_OPENAI_API - app.state.config.OPENAI_API_BASE_URLS = form_data.OPENAI_API_BASE_URLS - app.state.config.OPENAI_API_KEYS = form_data.OPENAI_API_KEYS +@router.post("/config/update") +async def update_config( + request: Request, form_data: OpenAIConfigForm, user=Depends(get_admin_user) +): + request.app.state.config.ENABLE_OPENAI_API = form_data.ENABLE_OPENAI_API + request.app.state.config.OPENAI_API_BASE_URLS = form_data.OPENAI_API_BASE_URLS + request.app.state.config.OPENAI_API_KEYS = form_data.OPENAI_API_KEYS # Check if API KEYS length is same than API URLS length - if len(app.state.config.OPENAI_API_KEYS) != len( - app.state.config.OPENAI_API_BASE_URLS + if len(request.app.state.config.OPENAI_API_KEYS) != len( + request.app.state.config.OPENAI_API_BASE_URLS ): - if len(app.state.config.OPENAI_API_KEYS) > len( - app.state.config.OPENAI_API_BASE_URLS + if len(request.app.state.config.OPENAI_API_KEYS) > len( + request.app.state.config.OPENAI_API_BASE_URLS ): - app.state.config.OPENAI_API_KEYS = app.state.config.OPENAI_API_KEYS[ - : len(app.state.config.OPENAI_API_BASE_URLS) - ] + request.app.state.config.OPENAI_API_KEYS = ( + request.app.state.config.OPENAI_API_KEYS[ + : len(request.app.state.config.OPENAI_API_BASE_URLS) + ] + ) else: - app.state.config.OPENAI_API_KEYS += [""] * ( - len(app.state.config.OPENAI_API_BASE_URLS) - - len(app.state.config.OPENAI_API_KEYS) + request.app.state.config.OPENAI_API_KEYS += [""] * ( + len(request.app.state.config.OPENAI_API_BASE_URLS) + - len(request.app.state.config.OPENAI_API_KEYS) ) - app.state.config.OPENAI_API_CONFIGS = form_data.OPENAI_API_CONFIGS + request.app.state.config.OPENAI_API_CONFIGS = form_data.OPENAI_API_CONFIGS # Remove any extra configs - config_urls = app.state.config.OPENAI_API_CONFIGS.keys() - for idx, url in enumerate(app.state.config.OPENAI_API_BASE_URLS): + config_urls = request.app.state.config.OPENAI_API_CONFIGS.keys() + for idx, url in enumerate(request.app.state.config.OPENAI_API_BASE_URLS): if url not in config_urls: - app.state.config.OPENAI_API_CONFIGS.pop(url, None) + request.app.state.config.OPENAI_API_CONFIGS.pop(url, None) return { - "ENABLE_OPENAI_API": app.state.config.ENABLE_OPENAI_API, - "OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS, - "OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS, - "OPENAI_API_CONFIGS": app.state.config.OPENAI_API_CONFIGS, + "ENABLE_OPENAI_API": request.app.state.config.ENABLE_OPENAI_API, + "OPENAI_API_BASE_URLS": request.app.state.config.OPENAI_API_BASE_URLS, + "OPENAI_API_KEYS": request.app.state.config.OPENAI_API_KEYS, + "OPENAI_API_CONFIGS": request.app.state.config.OPENAI_API_CONFIGS, } -@app.post("/audio/speech") +@router.post("/audio/speech") async def speech(request: Request, user=Depends(get_verified_user)): idx = None try: - idx = app.state.config.OPENAI_API_BASE_URLS.index("https://api.openai.com/v1") + idx = request.app.state.config.OPENAI_API_BASE_URLS.index( + "https://api.openai.com/v1" + ) + body = await request.body() name = hashlib.sha256(body).hexdigest() @@ -120,23 +179,35 @@ async def speech(request: Request, user=Depends(get_verified_user)): if file_path.is_file(): return FileResponse(file_path) - headers = {} - headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEYS[idx]}" - headers["Content-Type"] = "application/json" - if "openrouter.ai" in app.state.config.OPENAI_API_BASE_URLS[idx]: - headers["HTTP-Referer"] = "https://openwebui.com/" - headers["X-Title"] = "Open WebUI" - if ENABLE_FORWARD_USER_INFO_HEADERS: - headers["X-OpenWebUI-User-Name"] = user.name - headers["X-OpenWebUI-User-Id"] = user.id - headers["X-OpenWebUI-User-Email"] = user.email - headers["X-OpenWebUI-User-Role"] = user.role + url = request.app.state.config.OPENAI_API_BASE_URLS[idx] + r = None try: r = requests.post( - url=f"{app.state.config.OPENAI_API_BASE_URLS[idx]}/audio/speech", + url=f"{url}/audio/speech", data=body, - headers=headers, + headers={ + "Content-Type": "application/json", + "Authorization": f"Bearer {request.app.state.config.OPENAI_API_KEYS[idx]}", + **( + { + "HTTP-Referer": "https://openwebui.com/", + "X-Title": "Open WebUI", + } + if "openrouter.ai" in url + else {} + ), + **( + { + "X-OpenWebUI-User-Name": user.name, + "X-OpenWebUI-User-Id": user.id, + "X-OpenWebUI-User-Email": user.email, + "X-OpenWebUI-User-Role": user.role, + } + if ENABLE_FORWARD_USER_INFO_HEADERS + else {} + ), + }, stream=True, ) @@ -155,46 +226,25 @@ async def speech(request: Request, user=Depends(get_verified_user)): except Exception as e: log.exception(e) - error_detail = "Open WebUI: Server Connection Error" + + detail = None if r is not None: try: res = r.json() if "error" in res: - error_detail = f"External: {res['error']}" + detail = f"External: {res['error']}" except Exception: - error_detail = f"External: {e}" + detail = f"External: {e}" raise HTTPException( - status_code=r.status_code if r else 500, detail=error_detail + status_code=r.status_code if r else 500, + detail=detail if detail else "Open WebUI: Server Connection Error", ) except ValueError: raise HTTPException(status_code=401, detail=ERROR_MESSAGES.OPENAI_NOT_FOUND) -async def aiohttp_get(url, key=None): - timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST) - try: - headers = {"Authorization": f"Bearer {key}"} if key else {} - async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: - async with session.get(url, headers=headers) as response: - return await response.json() - except Exception as e: - # Handle connection error here - log.error(f"Connection error: {e}") - return None - - -async def cleanup_response( - response: Optional[aiohttp.ClientResponse], - session: Optional[aiohttp.ClientSession], -): - if response: - response.close() - if session: - await session.close() - - def merge_models_lists(model_lists): log.debug(f"merge_models_lists {model_lists}") merged_list = [] @@ -212,7 +262,7 @@ def merge_models_lists(model_lists): } for model in models if "api.openai.com" - not in app.state.config.OPENAI_API_BASE_URLS[idx] + not in request.app.state.config.OPENAI_API_BASE_URLS[idx] or not any( name in model["id"] for name in [ @@ -230,40 +280,43 @@ def merge_models_lists(model_lists): return merged_list -async def get_all_models_responses() -> list: - if not app.state.config.ENABLE_OPENAI_API: +async def get_all_models_responses(request: Request) -> list: + if not request.app.state.config.ENABLE_OPENAI_API: return [] # Check if API KEYS length is same than API URLS length - num_urls = len(app.state.config.OPENAI_API_BASE_URLS) - num_keys = len(app.state.config.OPENAI_API_KEYS) + num_urls = len(request.app.state.config.OPENAI_API_BASE_URLS) + num_keys = len(request.app.state.config.OPENAI_API_KEYS) if num_keys != num_urls: # if there are more keys than urls, remove the extra keys if num_keys > num_urls: - new_keys = app.state.config.OPENAI_API_KEYS[:num_urls] - app.state.config.OPENAI_API_KEYS = new_keys + new_keys = request.app.state.config.OPENAI_API_KEYS[:num_urls] + request.app.state.config.OPENAI_API_KEYS = new_keys # if there are more urls than keys, add empty keys else: - app.state.config.OPENAI_API_KEYS += [""] * (num_urls - num_keys) + request.app.state.config.OPENAI_API_KEYS += [""] * (num_urls - num_keys) - tasks = [] - for idx, url in enumerate(app.state.config.OPENAI_API_BASE_URLS): - if url not in app.state.config.OPENAI_API_CONFIGS: - tasks.append( - aiohttp_get(f"{url}/models", app.state.config.OPENAI_API_KEYS[idx]) + request_tasks = [] + for idx, url in enumerate(request.app.state.config.OPENAI_API_BASE_URLS): + if url not in request.app.state.config.OPENAI_API_CONFIGS: + request_tasks.append( + send_get_request( + f"{url}/models", request.app.state.config.OPENAI_API_KEYS[idx] + ) ) else: - api_config = app.state.config.OPENAI_API_CONFIGS.get(url, {}) + api_config = request.app.state.config.OPENAI_API_CONFIGS.get(url, {}) enable = api_config.get("enable", True) model_ids = api_config.get("model_ids", []) if enable: if len(model_ids) == 0: - tasks.append( - aiohttp_get( - f"{url}/models", app.state.config.OPENAI_API_KEYS[idx] + request_tasks.append( + send_get_request( + f"{url}/models", + request.app.state.config.OPENAI_API_KEYS[idx], ) ) else: @@ -281,16 +334,18 @@ async def get_all_models_responses() -> list: ], } - tasks.append(asyncio.ensure_future(asyncio.sleep(0, model_list))) + request_tasks.append( + asyncio.ensure_future(asyncio.sleep(0, model_list)) + ) else: - tasks.append(asyncio.ensure_future(asyncio.sleep(0, None))) + request_tasks.append(asyncio.ensure_future(asyncio.sleep(0, None))) - responses = await asyncio.gather(*tasks) + responses = await asyncio.gather(*request_tasks) for idx, response in enumerate(responses): if response: - url = app.state.config.OPENAI_API_BASE_URLS[idx] - api_config = app.state.config.OPENAI_API_CONFIGS.get(url, {}) + url = request.app.state.config.OPENAI_API_BASE_URLS[idx] + api_config = request.app.state.config.OPENAI_API_CONFIGS.get(url, {}) prefix_id = api_config.get("prefix_id", None) @@ -301,15 +356,27 @@ async def get_all_models_responses() -> list: model["id"] = f"{prefix_id}.{model['id']}" log.debug(f"get_all_models:responses() {responses}") - return responses +async def get_filtered_models(models, user): + # Filter models based on user access control + filtered_models = [] + for model in models.get("data", []): + model_info = Models.get_model_by_id(model["id"]) + if model_info: + if user.id == model_info.user_id or has_access( + user.id, type="read", access_control=model_info.access_control + ): + filtered_models.append(model) + return filtered_models + + @cached(ttl=3) -async def get_all_models() -> dict[str, list]: +async def get_all_models(request: Request) -> dict[str, list]: log.info("get_all_models()") - if not app.state.config.ENABLE_OPENAI_API: + if not request.app.state.config.ENABLE_OPENAI_API: return {"data": []} responses = await get_all_models_responses() @@ -324,12 +391,15 @@ async def get_all_models() -> dict[str, list]: models = {"data": merge_models_lists(map(extract_data, responses))} log.debug(f"models: {models}") + request.app.state.OPENAI_MODELS = {model["id"]: model for model in models["data"]} return models -@app.get("/models") -@app.get("/models/{url_idx}") -async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_user)): +@router.get("/models") +@router.get("/models/{url_idx}") +async def get_models( + request: Request, url_idx: Optional[int] = None, user=Depends(get_verified_user) +): models = { "data": [], } @@ -337,25 +407,33 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_us if url_idx is None: models = await get_all_models() else: - url = app.state.config.OPENAI_API_BASE_URLS[url_idx] - key = app.state.config.OPENAI_API_KEYS[url_idx] - - headers = {} - headers["Authorization"] = f"Bearer {key}" - headers["Content-Type"] = "application/json" - - if ENABLE_FORWARD_USER_INFO_HEADERS: - headers["X-OpenWebUI-User-Name"] = user.name - headers["X-OpenWebUI-User-Id"] = user.id - headers["X-OpenWebUI-User-Email"] = user.email - headers["X-OpenWebUI-User-Role"] = user.role + url = request.app.state.config.OPENAI_API_BASE_URLS[url_idx] + key = request.app.state.config.OPENAI_API_KEYS[url_idx] r = None - - timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST) - async with aiohttp.ClientSession(timeout=timeout) as session: + async with aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout( + total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST + ) + ) as session: try: - async with session.get(f"{url}/models", headers=headers) as r: + async with session.get( + f"{url}/models", + headers={ + "Authorization": f"Bearer {key}", + "Content-Type": "application/json", + **( + { + "X-OpenWebUI-User-Name": user.name, + "X-OpenWebUI-User-Id": user.id, + "X-OpenWebUI-User-Email": user.email, + "X-OpenWebUI-User-Role": user.role, + } + if ENABLE_FORWARD_USER_INFO_HEADERS + else {} + ), + }, + ) as r: if r.status != 200: # Extract response error details if available error_detail = f"HTTP Error: {r.status}" @@ -389,27 +467,16 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_us except aiohttp.ClientError as e: # ClientError covers all aiohttp requests issues log.exception(f"Client error: {str(e)}") - # Handle aiohttp-specific connection issues, timeout etc. raise HTTPException( status_code=500, detail="Open WebUI: Server Connection Error" ) except Exception as e: log.exception(f"Unexpected error: {e}") - # Generic error handler in case parsing JSON or other steps fail error_detail = f"Unexpected error: {str(e)}" raise HTTPException(status_code=500, detail=error_detail) if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL: - # Filter models based on user access control - filtered_models = [] - for model in models.get("data", []): - model_info = Models.get_model_by_id(model["id"]) - if model_info: - if user.id == model_info.user_id or has_access( - user.id, type="read", access_control=model_info.access_control - ): - filtered_models.append(model) - models["data"] = filtered_models + models["data"] = get_filtered_models(models, user) return models @@ -419,21 +486,24 @@ class ConnectionVerificationForm(BaseModel): key: str -@app.post("/verify") +@router.post("/verify") async def verify_connection( form_data: ConnectionVerificationForm, user=Depends(get_admin_user) ): url = form_data.url key = form_data.key - headers = {} - headers["Authorization"] = f"Bearer {key}" - headers["Content-Type"] = "application/json" - - timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST) - async with aiohttp.ClientSession(timeout=timeout) as session: + async with aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST) + ) as session: try: - async with session.get(f"{url}/models", headers=headers) as r: + async with session.get( + f"{url}/models", + headers={ + "Authorization": f"Bearer {key}", + "Content-Type": "application/json", + }, + ) as r: if r.status != 200: # Extract response error details if available error_detail = f"HTTP Error: {r.status}" @@ -448,26 +518,24 @@ async def verify_connection( except aiohttp.ClientError as e: # ClientError covers all aiohttp requests issues log.exception(f"Client error: {str(e)}") - # Handle aiohttp-specific connection issues, timeout etc. raise HTTPException( status_code=500, detail="Open WebUI: Server Connection Error" ) except Exception as e: log.exception(f"Unexpected error: {e}") - # Generic error handler in case parsing JSON or other steps fail error_detail = f"Unexpected error: {str(e)}" raise HTTPException(status_code=500, detail=error_detail) -@app.post("/chat/completions") +@router.post("/chat/completions") async def generate_chat_completion( + request: Request, form_data: dict, user=Depends(get_verified_user), bypass_filter: Optional[bool] = False, ): idx = 0 payload = {**form_data} - if "metadata" in payload: del payload["metadata"] @@ -502,15 +570,7 @@ async def generate_chat_completion( detail="Model not found", ) - # Attemp to get urlIdx from the model - models = await get_all_models() - - # Find the model from the list - model = next( - (model for model in models["data"] if model["id"] == payload.get("model")), - None, - ) - + model = request.app.state.OPENAI_MODELS.get(model_id) if model: idx = model["urlIdx"] else: @@ -520,11 +580,11 @@ async def generate_chat_completion( ) # Get the API config for the model - api_config = app.state.config.OPENAI_API_CONFIGS.get( - app.state.config.OPENAI_API_BASE_URLS[idx], {} + api_config = request.app.state.config.OPENAI_API_CONFIGS.get( + request.app.state.config.OPENAI_API_BASE_URLS[idx], {} ) - prefix_id = api_config.get("prefix_id", None) + prefix_id = api_config.get("prefix_id", None) if prefix_id: payload["model"] = payload["model"].replace(f"{prefix_id}.", "") @@ -537,43 +597,26 @@ async def generate_chat_completion( "role": user.role, } - url = app.state.config.OPENAI_API_BASE_URLS[idx] - key = app.state.config.OPENAI_API_KEYS[idx] + url = request.app.state.config.OPENAI_API_BASE_URLS[idx] + key = request.app.state.config.OPENAI_API_KEYS[idx] # Fix: O1 does not support the "max_tokens" parameter, Modify "max_tokens" to "max_completion_tokens" is_o1 = payload["model"].lower().startswith("o1-") - # Change max_completion_tokens to max_tokens (Backward compatible) - if "api.openai.com" not in url and not is_o1: - if "max_completion_tokens" in payload: - # Remove "max_completion_tokens" from the payload - payload["max_tokens"] = payload["max_completion_tokens"] - del payload["max_completion_tokens"] - else: - if is_o1 and "max_tokens" in payload: + if is_o1: + payload = openai_o1_handler(payload) + elif "api.openai.com" not in url: + # Remove "max_tokens" from the payload for backward compatibility + if "max_tokens" in payload: payload["max_completion_tokens"] = payload["max_tokens"] del payload["max_tokens"] - if "max_tokens" in payload and "max_completion_tokens" in payload: - del payload["max_tokens"] - # Fix: O1 does not support the "system" parameter, Modify "system" to "user" - if is_o1 and payload["messages"][0]["role"] == "system": - payload["messages"][0]["role"] = "user" + # TODO: check if below is needed + # if "max_tokens" in payload and "max_completion_tokens" in payload: + # del payload["max_tokens"] # Convert the modified body back to JSON payload = json.dumps(payload) - headers = {} - headers["Authorization"] = f"Bearer {key}" - headers["Content-Type"] = "application/json" - if "openrouter.ai" in app.state.config.OPENAI_API_BASE_URLS[idx]: - headers["HTTP-Referer"] = "https://openwebui.com/" - headers["X-Title"] = "Open WebUI" - if ENABLE_FORWARD_USER_INFO_HEADERS: - headers["X-OpenWebUI-User-Name"] = user.name - headers["X-OpenWebUI-User-Id"] = user.id - headers["X-OpenWebUI-User-Email"] = user.email - headers["X-OpenWebUI-User-Role"] = user.role - r = None session = None streaming = False @@ -583,11 +626,33 @@ async def generate_chat_completion( session = aiohttp.ClientSession( trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT) ) + r = await session.request( method="POST", url=f"{url}/chat/completions", data=payload, - headers=headers, + headers={ + "Authorization": f"Bearer {key}", + "Content-Type": "application/json", + **( + { + "HTTP-Referer": "https://openwebui.com/", + "X-Title": "Open WebUI", + } + if "openrouter.ai" in url + else {} + ), + **( + { + "X-OpenWebUI-User-Name": user.name, + "X-OpenWebUI-User-Id": user.id, + "X-OpenWebUI-User-Email": user.email, + "X-OpenWebUI-User-Role": user.role, + } + if ENABLE_FORWARD_USER_INFO_HEADERS + else {} + ), + }, ) # Check if response is SSE @@ -612,14 +677,18 @@ async def generate_chat_completion( return response except Exception as e: log.exception(e) - error_detail = "Open WebUI: Server Connection Error" + + detail = None if isinstance(response, dict): if "error" in response: - error_detail = f"{response['error']['message'] if 'message' in response['error'] else response['error']}" + detail = f"{response['error']['message'] if 'message' in response['error'] else response['error']}" elif isinstance(response, str): - error_detail = response + detail = response - raise HTTPException(status_code=r.status if r else 500, detail=error_detail) + raise HTTPException( + status_code=r.status if r else 500, + detail=detail if detail else "Open WebUI: Server Connection Error", + ) finally: if not streaming and session: if r: @@ -627,25 +696,17 @@ async def generate_chat_completion( await session.close() -@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) +@router.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) async def proxy(path: str, request: Request, user=Depends(get_verified_user)): - idx = 0 + """ + Deprecated: proxy all requests to OpenAI API + """ body = await request.body() - url = app.state.config.OPENAI_API_BASE_URLS[idx] - key = app.state.config.OPENAI_API_KEYS[idx] - - target_url = f"{url}/{path}" - - headers = {} - headers["Authorization"] = f"Bearer {key}" - headers["Content-Type"] = "application/json" - if ENABLE_FORWARD_USER_INFO_HEADERS: - headers["X-OpenWebUI-User-Name"] = user.name - headers["X-OpenWebUI-User-Id"] = user.id - headers["X-OpenWebUI-User-Email"] = user.email - headers["X-OpenWebUI-User-Role"] = user.role + idx = 0 + url = request.app.state.config.OPENAI_API_BASE_URLS[idx] + key = request.app.state.config.OPENAI_API_KEYS[idx] r = None session = None @@ -655,11 +716,23 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)): session = aiohttp.ClientSession(trust_env=True) r = await session.request( method=request.method, - url=target_url, + url=f"{url}/{path}", data=body, - headers=headers, + headers={ + "Authorization": f"Bearer {key}", + "Content-Type": "application/json", + **( + { + "X-OpenWebUI-User-Name": user.name, + "X-OpenWebUI-User-Id": user.id, + "X-OpenWebUI-User-Email": user.email, + "X-OpenWebUI-User-Role": user.role, + } + if ENABLE_FORWARD_USER_INFO_HEADERS + else {} + ), + }, ) - r.raise_for_status() # Check if response is SSE @@ -676,18 +749,23 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)): else: response_data = await r.json() return response_data + except Exception as e: log.exception(e) - error_detail = "Open WebUI: Server Connection Error" + + detail = None if r is not None: try: res = await r.json() print(res) if "error" in res: - error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}" + detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}" except Exception: - error_detail = f"External: {e}" - raise HTTPException(status_code=r.status if r else 500, detail=error_detail) + detail = f"External: {e}" + raise HTTPException( + status_code=r.status if r else 500, + detail=detail if detail else "Open WebUI: Server Connection Error", + ) finally: if not streaming and session: if r: diff --git a/backend/open_webui/routers/webui.py b/backend/open_webui/routers/webui.py index 1ac4db152..d3942db97 100644 --- a/backend/open_webui/routers/webui.py +++ b/backend/open_webui/routers/webui.py @@ -89,103 +89,10 @@ from open_webui.utils.payload import ( from open_webui.utils.tools import get_tools -app = FastAPI( - docs_url="/docs" if ENV == "dev" else None, - openapi_url="/openapi.json" if ENV == "dev" else None, - redoc_url=None, -) log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MAIN"]) -app.state.config = AppConfig() - -app.state.config.ENABLE_SIGNUP = ENABLE_SIGNUP -app.state.config.ENABLE_LOGIN_FORM = ENABLE_LOGIN_FORM -app.state.config.ENABLE_API_KEY = ENABLE_API_KEY - -app.state.config.JWT_EXPIRES_IN = JWT_EXPIRES_IN -app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER -app.state.AUTH_TRUSTED_NAME_HEADER = WEBUI_AUTH_TRUSTED_NAME_HEADER - - -app.state.config.SHOW_ADMIN_DETAILS = SHOW_ADMIN_DETAILS -app.state.config.ADMIN_EMAIL = ADMIN_EMAIL - - -app.state.config.DEFAULT_MODELS = DEFAULT_MODELS -app.state.config.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS -app.state.config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE - - -app.state.config.USER_PERMISSIONS = USER_PERMISSIONS -app.state.config.WEBHOOK_URL = WEBHOOK_URL -app.state.config.BANNERS = WEBUI_BANNERS -app.state.config.MODEL_ORDER_LIST = MODEL_ORDER_LIST - -app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING -app.state.config.ENABLE_MESSAGE_RATING = ENABLE_MESSAGE_RATING - -app.state.config.ENABLE_EVALUATION_ARENA_MODELS = ENABLE_EVALUATION_ARENA_MODELS -app.state.config.EVALUATION_ARENA_MODELS = EVALUATION_ARENA_MODELS - -app.state.config.OAUTH_USERNAME_CLAIM = OAUTH_USERNAME_CLAIM -app.state.config.OAUTH_PICTURE_CLAIM = OAUTH_PICTURE_CLAIM -app.state.config.OAUTH_EMAIL_CLAIM = OAUTH_EMAIL_CLAIM - -app.state.config.ENABLE_OAUTH_ROLE_MANAGEMENT = ENABLE_OAUTH_ROLE_MANAGEMENT -app.state.config.OAUTH_ROLES_CLAIM = OAUTH_ROLES_CLAIM -app.state.config.OAUTH_ALLOWED_ROLES = OAUTH_ALLOWED_ROLES -app.state.config.OAUTH_ADMIN_ROLES = OAUTH_ADMIN_ROLES - -app.state.config.ENABLE_LDAP = ENABLE_LDAP -app.state.config.LDAP_SERVER_LABEL = LDAP_SERVER_LABEL -app.state.config.LDAP_SERVER_HOST = LDAP_SERVER_HOST -app.state.config.LDAP_SERVER_PORT = LDAP_SERVER_PORT -app.state.config.LDAP_ATTRIBUTE_FOR_USERNAME = LDAP_ATTRIBUTE_FOR_USERNAME -app.state.config.LDAP_APP_DN = LDAP_APP_DN -app.state.config.LDAP_APP_PASSWORD = LDAP_APP_PASSWORD -app.state.config.LDAP_SEARCH_BASE = LDAP_SEARCH_BASE -app.state.config.LDAP_SEARCH_FILTERS = LDAP_SEARCH_FILTERS -app.state.config.LDAP_USE_TLS = LDAP_USE_TLS -app.state.config.LDAP_CA_CERT_FILE = LDAP_CA_CERT_FILE -app.state.config.LDAP_CIPHERS = LDAP_CIPHERS - -app.state.TOOLS = {} -app.state.FUNCTIONS = {} - -app.add_middleware( - CORSMiddleware, - allow_origins=CORS_ALLOW_ORIGIN, - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) - - -app.include_router(configs.router, prefix="/configs", tags=["configs"]) - -app.include_router(auths.router, prefix="/auths", tags=["auths"]) -app.include_router(users.router, prefix="/users", tags=["users"]) - -app.include_router(chats.router, prefix="/chats", tags=["chats"]) - -app.include_router(models.router, prefix="/models", tags=["models"]) -app.include_router(knowledge.router, prefix="/knowledge", tags=["knowledge"]) -app.include_router(prompts.router, prefix="/prompts", tags=["prompts"]) -app.include_router(tools.router, prefix="/tools", tags=["tools"]) - -app.include_router(memories.router, prefix="/memories", tags=["memories"]) -app.include_router(folders.router, prefix="/folders", tags=["folders"]) - -app.include_router(groups.router, prefix="/groups", tags=["groups"]) -app.include_router(files.router, prefix="/files", tags=["files"]) -app.include_router(functions.router, prefix="/functions", tags=["functions"]) -app.include_router(evaluations.router, prefix="/evaluations", tags=["evaluations"]) - - -app.include_router(utils.router, prefix="/utils", tags=["utils"]) - @app.get("/") async def get_status(): From 867c4bc0d0a86270c29eeea0c6c6c524bc3e217d Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Wed, 11 Dec 2024 18:05:42 -0800 Subject: [PATCH 08/26] wip: retrieval --- backend/open_webui/main.py | 419 +++++++- backend/open_webui/retrieval/utils.py | 2 +- .../open_webui/retrieval/vector/connector.py | 10 +- .../open_webui/retrieval/vector/dbs/chroma.py | 2 +- .../open_webui/retrieval/vector/dbs/milvus.py | 2 +- .../retrieval/vector/dbs/opensearch.py | 2 +- .../retrieval/vector/dbs/pgvector.py | 2 +- .../open_webui/retrieval/vector/dbs/qdrant.py | 2 +- backend/open_webui/retrieval/web/bing.py | 2 +- backend/open_webui/retrieval/web/brave.py | 2 +- .../open_webui/retrieval/web/duckduckgo.py | 2 +- .../open_webui/retrieval/web/google_pse.py | 2 +- .../open_webui/retrieval/web/jina_search.py | 2 +- backend/open_webui/retrieval/web/kagi.py | 10 +- backend/open_webui/retrieval/web/mojeek.py | 2 +- backend/open_webui/retrieval/web/searchapi.py | 2 +- backend/open_webui/retrieval/web/searxng.py | 2 +- backend/open_webui/retrieval/web/serper.py | 2 +- backend/open_webui/retrieval/web/serply.py | 2 +- backend/open_webui/retrieval/web/serpstack.py | 2 +- backend/open_webui/retrieval/web/tavily.py | 2 +- backend/open_webui/routers/knowledge.py | 2 +- backend/open_webui/routers/memories.py | 2 +- backend/open_webui/routers/retrieval.py | 964 ++++++++---------- 24 files changed, 897 insertions(+), 546 deletions(-) diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index a2d114844..8184b467b 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -516,9 +516,12 @@ app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY = BING_SEARCH_V7_SUBSCRIPTION_K app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = RAG_WEB_SEARCH_RESULT_COUNT app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS +app.state.EMBEDDING_FUNCTION = None +app.state.sentence_transformer_ef = None +app.state.sentence_transformer_rf = None app.state.YOUTUBE_LOADER_TRANSLATION = None -app.state.EMBEDDING_FUNCTION = None + ######################################## # @@ -1653,6 +1656,420 @@ async def get_base_models(user=Depends(get_admin_user)): return {"data": models} +################################## +# +# Chat Endpoints +# +################################## + + +@app.post("/api/chat/completions") +async def generate_chat_completions( + request: Request, + form_data: dict, + user=Depends(get_verified_user), + bypass_filter: bool = False, +): + if BYPASS_MODEL_ACCESS_CONTROL: + bypass_filter = True + + model_list = request.state.models + models = {model["id"]: model for model in model_list} + + model_id = form_data["model"] + if model_id not in models: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Model not found", + ) + + model = models[model_id] + + # Check if user has access to the model + if not bypass_filter and user.role == "user": + if model.get("arena"): + if not has_access( + user.id, + type="read", + access_control=model.get("info", {}) + .get("meta", {}) + .get("access_control", {}), + ): + raise HTTPException( + status_code=403, + detail="Model not found", + ) + else: + model_info = Models.get_model_by_id(model_id) + if not model_info: + raise HTTPException( + status_code=404, + detail="Model not found", + ) + elif not ( + user.id == model_info.user_id + or has_access( + user.id, type="read", access_control=model_info.access_control + ) + ): + raise HTTPException( + status_code=403, + detail="Model not found", + ) + + if model["owned_by"] == "arena": + model_ids = model.get("info", {}).get("meta", {}).get("model_ids") + filter_mode = model.get("info", {}).get("meta", {}).get("filter_mode") + if model_ids and filter_mode == "exclude": + model_ids = [ + model["id"] + for model in await get_all_models() + if model.get("owned_by") != "arena" and model["id"] not in model_ids + ] + + selected_model_id = None + if isinstance(model_ids, list) and model_ids: + selected_model_id = random.choice(model_ids) + else: + model_ids = [ + model["id"] + for model in await get_all_models() + if model.get("owned_by") != "arena" + ] + selected_model_id = random.choice(model_ids) + + form_data["model"] = selected_model_id + + if form_data.get("stream") == True: + + async def stream_wrapper(stream): + yield f"data: {json.dumps({'selected_model_id': selected_model_id})}\n\n" + async for chunk in stream: + yield chunk + + response = await generate_chat_completions( + form_data, user, bypass_filter=True + ) + return StreamingResponse( + stream_wrapper(response.body_iterator), media_type="text/event-stream" + ) + else: + return { + **( + await generate_chat_completions(form_data, user, bypass_filter=True) + ), + "selected_model_id": selected_model_id, + } + + if model.get("pipe"): + # Below does not require bypass_filter because this is the only route the uses this function and it is already bypassing the filter + return await generate_function_chat_completion( + form_data, user=user, models=models + ) + if model["owned_by"] == "ollama": + # Using /ollama/api/chat endpoint + form_data = convert_payload_openai_to_ollama(form_data) + form_data = GenerateChatCompletionForm(**form_data) + response = await generate_ollama_chat_completion( + form_data=form_data, user=user, bypass_filter=bypass_filter + ) + if form_data.stream: + response.headers["content-type"] = "text/event-stream" + return StreamingResponse( + convert_streaming_response_ollama_to_openai(response), + headers=dict(response.headers), + ) + else: + return convert_response_ollama_to_openai(response) + else: + return await generate_openai_chat_completion( + form_data, user=user, bypass_filter=bypass_filter + ) + + +@app.post("/api/chat/completed") +async def chat_completed(form_data: dict, user=Depends(get_verified_user)): + model_list = await get_all_models() + models = {model["id"]: model for model in model_list} + + data = form_data + model_id = data["model"] + if model_id not in models: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Model not found", + ) + + model = models[model_id] + sorted_filters = get_sorted_filters(model_id, models) + if "pipeline" in model: + sorted_filters = [model] + sorted_filters + + for filter in sorted_filters: + r = None + try: + urlIdx = filter["urlIdx"] + + url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] + key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] + + if key != "": + headers = {"Authorization": f"Bearer {key}"} + r = requests.post( + f"{url}/{filter['id']}/filter/outlet", + headers=headers, + json={ + "user": { + "id": user.id, + "name": user.name, + "email": user.email, + "role": user.role, + }, + "body": data, + }, + ) + + r.raise_for_status() + data = r.json() + except Exception as e: + # Handle connection error here + print(f"Connection error: {e}") + + if r is not None: + try: + res = r.json() + if "detail" in res: + return JSONResponse( + status_code=r.status_code, + content=res, + ) + except Exception: + pass + + else: + pass + + __event_emitter__ = get_event_emitter( + { + "chat_id": data["chat_id"], + "message_id": data["id"], + "session_id": data["session_id"], + } + ) + + __event_call__ = get_event_call( + { + "chat_id": data["chat_id"], + "message_id": data["id"], + "session_id": data["session_id"], + } + ) + + def get_priority(function_id): + function = Functions.get_function_by_id(function_id) + if function is not None and hasattr(function, "valves"): + # TODO: Fix FunctionModel to include vavles + return (function.valves if function.valves else {}).get("priority", 0) + return 0 + + filter_ids = [function.id for function in Functions.get_global_filter_functions()] + if "info" in model and "meta" in model["info"]: + filter_ids.extend(model["info"]["meta"].get("filterIds", [])) + filter_ids = list(set(filter_ids)) + + enabled_filter_ids = [ + function.id + for function in Functions.get_functions_by_type("filter", active_only=True) + ] + filter_ids = [ + filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids + ] + + # Sort filter_ids by priority, using the get_priority function + filter_ids.sort(key=get_priority) + + for filter_id in filter_ids: + filter = Functions.get_function_by_id(filter_id) + if not filter: + continue + + if filter_id in webui_app.state.FUNCTIONS: + function_module = webui_app.state.FUNCTIONS[filter_id] + else: + function_module, _, _ = load_function_module_by_id(filter_id) + webui_app.state.FUNCTIONS[filter_id] = function_module + + if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): + valves = Functions.get_function_valves_by_id(filter_id) + function_module.valves = function_module.Valves( + **(valves if valves else {}) + ) + + if not hasattr(function_module, "outlet"): + continue + try: + outlet = function_module.outlet + + # Get the signature of the function + sig = inspect.signature(outlet) + params = {"body": data} + + # Extra parameters to be passed to the function + extra_params = { + "__model__": model, + "__id__": filter_id, + "__event_emitter__": __event_emitter__, + "__event_call__": __event_call__, + } + + # Add extra params in contained in function signature + for key, value in extra_params.items(): + if key in sig.parameters: + params[key] = value + + if "__user__" in sig.parameters: + __user__ = { + "id": user.id, + "email": user.email, + "name": user.name, + "role": user.role, + } + + try: + if hasattr(function_module, "UserValves"): + __user__["valves"] = function_module.UserValves( + **Functions.get_user_valves_by_id_and_user_id( + filter_id, user.id + ) + ) + except Exception as e: + print(e) + + params = {**params, "__user__": __user__} + + if inspect.iscoroutinefunction(outlet): + data = await outlet(**params) + else: + data = outlet(**params) + + except Exception as e: + print(f"Error: {e}") + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"detail": str(e)}, + ) + + return data + + +@app.post("/api/chat/actions/{action_id}") +async def chat_action(action_id: str, form_data: dict, user=Depends(get_verified_user)): + if "." in action_id: + action_id, sub_action_id = action_id.split(".") + else: + sub_action_id = None + + action = Functions.get_function_by_id(action_id) + if not action: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Action not found", + ) + + model_list = await get_all_models() + models = {model["id"]: model for model in model_list} + + data = form_data + model_id = data["model"] + + if model_id not in models: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Model not found", + ) + model = models[model_id] + + __event_emitter__ = get_event_emitter( + { + "chat_id": data["chat_id"], + "message_id": data["id"], + "session_id": data["session_id"], + } + ) + __event_call__ = get_event_call( + { + "chat_id": data["chat_id"], + "message_id": data["id"], + "session_id": data["session_id"], + } + ) + + if action_id in webui_app.state.FUNCTIONS: + function_module = webui_app.state.FUNCTIONS[action_id] + else: + function_module, _, _ = load_function_module_by_id(action_id) + webui_app.state.FUNCTIONS[action_id] = function_module + + if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): + valves = Functions.get_function_valves_by_id(action_id) + function_module.valves = function_module.Valves(**(valves if valves else {})) + + if hasattr(function_module, "action"): + try: + action = function_module.action + + # Get the signature of the function + sig = inspect.signature(action) + params = {"body": data} + + # Extra parameters to be passed to the function + extra_params = { + "__model__": model, + "__id__": sub_action_id if sub_action_id is not None else action_id, + "__event_emitter__": __event_emitter__, + "__event_call__": __event_call__, + } + + # Add extra params in contained in function signature + for key, value in extra_params.items(): + if key in sig.parameters: + params[key] = value + + if "__user__" in sig.parameters: + __user__ = { + "id": user.id, + "email": user.email, + "name": user.name, + "role": user.role, + } + + try: + if hasattr(function_module, "UserValves"): + __user__["valves"] = function_module.UserValves( + **Functions.get_user_valves_by_id_and_user_id( + action_id, user.id + ) + ) + except Exception as e: + print(e) + + params = {**params, "__user__": __user__} + + if inspect.iscoroutinefunction(action): + data = await action(**params) + else: + data = action(**params) + + except Exception as e: + print(f"Error: {e}") + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"detail": str(e)}, + ) + + return data + + ################################## # # Config Endpoints diff --git a/backend/open_webui/retrieval/utils.py b/backend/open_webui/retrieval/utils.py index bf939ecf1..9444ade95 100644 --- a/backend/open_webui/retrieval/utils.py +++ b/backend/open_webui/retrieval/utils.py @@ -11,7 +11,7 @@ from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriev from langchain_community.retrievers import BM25Retriever from langchain_core.documents import Document -from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT +from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT from open_webui.utils.misc import get_last_user_message from open_webui.env import SRC_LOG_LEVELS diff --git a/backend/open_webui/retrieval/vector/connector.py b/backend/open_webui/retrieval/vector/connector.py index 528835b56..bf97bc7b1 100644 --- a/backend/open_webui/retrieval/vector/connector.py +++ b/backend/open_webui/retrieval/vector/connector.py @@ -1,22 +1,22 @@ from open_webui.config import VECTOR_DB if VECTOR_DB == "milvus": - from open_webui.apps.retrieval.vector.dbs.milvus import MilvusClient + from open_webui.retrieval.vector.dbs.milvus import MilvusClient VECTOR_DB_CLIENT = MilvusClient() elif VECTOR_DB == "qdrant": - from open_webui.apps.retrieval.vector.dbs.qdrant import QdrantClient + from open_webui.retrieval.vector.dbs.qdrant import QdrantClient VECTOR_DB_CLIENT = QdrantClient() elif VECTOR_DB == "opensearch": - from open_webui.apps.retrieval.vector.dbs.opensearch import OpenSearchClient + from open_webui.retrieval.vector.dbs.opensearch import OpenSearchClient VECTOR_DB_CLIENT = OpenSearchClient() elif VECTOR_DB == "pgvector": - from open_webui.apps.retrieval.vector.dbs.pgvector import PgvectorClient + from open_webui.retrieval.vector.dbs.pgvector import PgvectorClient VECTOR_DB_CLIENT = PgvectorClient() else: - from open_webui.apps.retrieval.vector.dbs.chroma import ChromaClient + from open_webui.retrieval.vector.dbs.chroma import ChromaClient VECTOR_DB_CLIENT = ChromaClient() diff --git a/backend/open_webui/retrieval/vector/dbs/chroma.py b/backend/open_webui/retrieval/vector/dbs/chroma.py index b2fcdd16a..00d73a889 100644 --- a/backend/open_webui/retrieval/vector/dbs/chroma.py +++ b/backend/open_webui/retrieval/vector/dbs/chroma.py @@ -4,7 +4,7 @@ from chromadb.utils.batch_utils import create_batches from typing import Optional -from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult +from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult from open_webui.config import ( CHROMA_DATA_PATH, CHROMA_HTTP_HOST, diff --git a/backend/open_webui/retrieval/vector/dbs/milvus.py b/backend/open_webui/retrieval/vector/dbs/milvus.py index 5351f860e..31d890664 100644 --- a/backend/open_webui/retrieval/vector/dbs/milvus.py +++ b/backend/open_webui/retrieval/vector/dbs/milvus.py @@ -4,7 +4,7 @@ import json from typing import Optional -from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult +from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult from open_webui.config import ( MILVUS_URI, ) diff --git a/backend/open_webui/retrieval/vector/dbs/opensearch.py b/backend/open_webui/retrieval/vector/dbs/opensearch.py index 6234b2837..b3d8b5eb8 100644 --- a/backend/open_webui/retrieval/vector/dbs/opensearch.py +++ b/backend/open_webui/retrieval/vector/dbs/opensearch.py @@ -1,7 +1,7 @@ from opensearchpy import OpenSearch from typing import Optional -from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult +from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult from open_webui.config import ( OPENSEARCH_URI, OPENSEARCH_SSL, diff --git a/backend/open_webui/retrieval/vector/dbs/pgvector.py b/backend/open_webui/retrieval/vector/dbs/pgvector.py index b8317957e..cb8c545e9 100644 --- a/backend/open_webui/retrieval/vector/dbs/pgvector.py +++ b/backend/open_webui/retrieval/vector/dbs/pgvector.py @@ -18,7 +18,7 @@ from sqlalchemy.dialects.postgresql import JSONB, array from pgvector.sqlalchemy import Vector from sqlalchemy.ext.mutable import MutableDict -from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult +from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult from open_webui.config import PGVECTOR_DB_URL VECTOR_LENGTH = 1536 diff --git a/backend/open_webui/retrieval/vector/dbs/qdrant.py b/backend/open_webui/retrieval/vector/dbs/qdrant.py index 60c1c3d4d..f077ae45a 100644 --- a/backend/open_webui/retrieval/vector/dbs/qdrant.py +++ b/backend/open_webui/retrieval/vector/dbs/qdrant.py @@ -4,7 +4,7 @@ from qdrant_client import QdrantClient as Qclient from qdrant_client.http.models import PointStruct from qdrant_client.models import models -from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult +from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult from open_webui.config import QDRANT_URI, QDRANT_API_KEY NO_LIMIT = 999999999 diff --git a/backend/open_webui/retrieval/web/bing.py b/backend/open_webui/retrieval/web/bing.py index b5f889c54..09beb3460 100644 --- a/backend/open_webui/retrieval/web/bing.py +++ b/backend/open_webui/retrieval/web/bing.py @@ -3,7 +3,7 @@ import os from pprint import pprint from typing import Optional import requests -from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results +from open_webui.retrieval.web.main import SearchResult, get_filtered_results from open_webui.env import SRC_LOG_LEVELS import argparse diff --git a/backend/open_webui/retrieval/web/brave.py b/backend/open_webui/retrieval/web/brave.py index f988b3b08..3075db990 100644 --- a/backend/open_webui/retrieval/web/brave.py +++ b/backend/open_webui/retrieval/web/brave.py @@ -2,7 +2,7 @@ import logging from typing import Optional import requests -from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results +from open_webui.retrieval.web.main import SearchResult, get_filtered_results from open_webui.env import SRC_LOG_LEVELS log = logging.getLogger(__name__) diff --git a/backend/open_webui/retrieval/web/duckduckgo.py b/backend/open_webui/retrieval/web/duckduckgo.py index 11e512296..7c0c3f1c2 100644 --- a/backend/open_webui/retrieval/web/duckduckgo.py +++ b/backend/open_webui/retrieval/web/duckduckgo.py @@ -1,7 +1,7 @@ import logging from typing import Optional -from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results +from open_webui.retrieval.web.main import SearchResult, get_filtered_results from duckduckgo_search import DDGS from open_webui.env import SRC_LOG_LEVELS diff --git a/backend/open_webui/retrieval/web/google_pse.py b/backend/open_webui/retrieval/web/google_pse.py index 61b919583..2c51dd3c9 100644 --- a/backend/open_webui/retrieval/web/google_pse.py +++ b/backend/open_webui/retrieval/web/google_pse.py @@ -2,7 +2,7 @@ import logging from typing import Optional import requests -from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results +from open_webui.retrieval.web.main import SearchResult, get_filtered_results from open_webui.env import SRC_LOG_LEVELS log = logging.getLogger(__name__) diff --git a/backend/open_webui/retrieval/web/jina_search.py b/backend/open_webui/retrieval/web/jina_search.py index f5e2febbe..3de6c1807 100644 --- a/backend/open_webui/retrieval/web/jina_search.py +++ b/backend/open_webui/retrieval/web/jina_search.py @@ -1,7 +1,7 @@ import logging import requests -from open_webui.apps.retrieval.web.main import SearchResult +from open_webui.retrieval.web.main import SearchResult from open_webui.env import SRC_LOG_LEVELS from yarl import URL diff --git a/backend/open_webui/retrieval/web/kagi.py b/backend/open_webui/retrieval/web/kagi.py index c8c2699ed..0b69da8bc 100644 --- a/backend/open_webui/retrieval/web/kagi.py +++ b/backend/open_webui/retrieval/web/kagi.py @@ -2,7 +2,7 @@ import logging from typing import Optional import requests -from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results +from open_webui.retrieval.web.main import SearchResult, get_filtered_results from open_webui.env import SRC_LOG_LEVELS log = logging.getLogger(__name__) @@ -31,17 +31,15 @@ def search_kagi( response.raise_for_status() json_response = response.json() search_results = json_response.get("data", []) - + results = [ SearchResult( - link=result["url"], - title=result["title"], - snippet=result.get("snippet") + link=result["url"], title=result["title"], snippet=result.get("snippet") ) for result in search_results if result["t"] == 0 ] - + print(results) if filter_list: diff --git a/backend/open_webui/retrieval/web/mojeek.py b/backend/open_webui/retrieval/web/mojeek.py index f257c92aa..d298b0ee5 100644 --- a/backend/open_webui/retrieval/web/mojeek.py +++ b/backend/open_webui/retrieval/web/mojeek.py @@ -2,7 +2,7 @@ import logging from typing import Optional import requests -from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results +from open_webui.retrieval.web.main import SearchResult, get_filtered_results from open_webui.env import SRC_LOG_LEVELS log = logging.getLogger(__name__) diff --git a/backend/open_webui/retrieval/web/searchapi.py b/backend/open_webui/retrieval/web/searchapi.py index 412dc6b69..38bc0b574 100644 --- a/backend/open_webui/retrieval/web/searchapi.py +++ b/backend/open_webui/retrieval/web/searchapi.py @@ -3,7 +3,7 @@ from typing import Optional from urllib.parse import urlencode import requests -from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results +from open_webui.retrieval.web.main import SearchResult, get_filtered_results from open_webui.env import SRC_LOG_LEVELS log = logging.getLogger(__name__) diff --git a/backend/open_webui/retrieval/web/searxng.py b/backend/open_webui/retrieval/web/searxng.py index cb1eaf91d..15e3c098a 100644 --- a/backend/open_webui/retrieval/web/searxng.py +++ b/backend/open_webui/retrieval/web/searxng.py @@ -2,7 +2,7 @@ import logging from typing import Optional import requests -from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results +from open_webui.retrieval.web.main import SearchResult, get_filtered_results from open_webui.env import SRC_LOG_LEVELS log = logging.getLogger(__name__) diff --git a/backend/open_webui/retrieval/web/serper.py b/backend/open_webui/retrieval/web/serper.py index 436fa167e..685e34375 100644 --- a/backend/open_webui/retrieval/web/serper.py +++ b/backend/open_webui/retrieval/web/serper.py @@ -3,7 +3,7 @@ import logging from typing import Optional import requests -from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results +from open_webui.retrieval.web.main import SearchResult, get_filtered_results from open_webui.env import SRC_LOG_LEVELS log = logging.getLogger(__name__) diff --git a/backend/open_webui/retrieval/web/serply.py b/backend/open_webui/retrieval/web/serply.py index 1c2521c47..a9b473eb0 100644 --- a/backend/open_webui/retrieval/web/serply.py +++ b/backend/open_webui/retrieval/web/serply.py @@ -3,7 +3,7 @@ from typing import Optional from urllib.parse import urlencode import requests -from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results +from open_webui.retrieval.web.main import SearchResult, get_filtered_results from open_webui.env import SRC_LOG_LEVELS log = logging.getLogger(__name__) diff --git a/backend/open_webui/retrieval/web/serpstack.py b/backend/open_webui/retrieval/web/serpstack.py index b655934de..d4dbda57c 100644 --- a/backend/open_webui/retrieval/web/serpstack.py +++ b/backend/open_webui/retrieval/web/serpstack.py @@ -2,7 +2,7 @@ import logging from typing import Optional import requests -from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results +from open_webui.retrieval.web.main import SearchResult, get_filtered_results from open_webui.env import SRC_LOG_LEVELS log = logging.getLogger(__name__) diff --git a/backend/open_webui/retrieval/web/tavily.py b/backend/open_webui/retrieval/web/tavily.py index 03b0be75a..cc468725d 100644 --- a/backend/open_webui/retrieval/web/tavily.py +++ b/backend/open_webui/retrieval/web/tavily.py @@ -1,7 +1,7 @@ import logging import requests -from open_webui.apps.retrieval.web.main import SearchResult +from open_webui.retrieval.web.main import SearchResult from open_webui.env import SRC_LOG_LEVELS log = logging.getLogger(__name__) diff --git a/backend/open_webui/routers/knowledge.py b/backend/open_webui/routers/knowledge.py index 1617f452e..85a4d30fd 100644 --- a/backend/open_webui/routers/knowledge.py +++ b/backend/open_webui/routers/knowledge.py @@ -11,7 +11,7 @@ from open_webui.models.knowledge import ( KnowledgeUserResponse, ) from open_webui.models.files import Files, FileModel -from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT +from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT from backend.open_webui.routers.retrieval import process_file, ProcessFileForm diff --git a/backend/open_webui/routers/memories.py b/backend/open_webui/routers/memories.py index 7973038c4..e72cf1445 100644 --- a/backend/open_webui/routers/memories.py +++ b/backend/open_webui/routers/memories.py @@ -4,7 +4,7 @@ import logging from typing import Optional from open_webui.models.memories import Memories, MemoryModel -from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT +from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT from open_webui.utils.auth import get_verified_user from open_webui.env import SRC_LOG_LEVELS diff --git a/backend/open_webui/routers/retrieval.py b/backend/open_webui/routers/retrieval.py index 517e2894d..7e0dc6018 100644 --- a/backend/open_webui/routers/retrieval.py +++ b/backend/open_webui/routers/retrieval.py @@ -1,5 +1,3 @@ -# TODO: Merge this with the webui_app and make it a single app - import json import logging import mimetypes @@ -11,39 +9,55 @@ from datetime import datetime from pathlib import Path from typing import Iterator, Optional, Sequence, Union -from fastapi import Depends, FastAPI, File, Form, HTTPException, UploadFile, status +from fastapi import ( + Depends, + FastAPI, + File, + Form, + HTTPException, + UploadFile, + Request, + status, + APIRouter, +) from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel import tiktoken -from open_webui.storage.provider import Storage +from langchain.text_splitter import RecursiveCharacterTextSplitter, TokenTextSplitter +from langchain_core.documents import Document + +from open_webui.models.files import Files from open_webui.models.knowledge import Knowledges -from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT +from open_webui.storage.provider import Storage + + +from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT # Document loaders -from open_webui.apps.retrieval.loaders.main import Loader -from open_webui.apps.retrieval.loaders.youtube import YoutubeLoader +from open_webui.retrieval.loaders.main import Loader +from open_webui.retrieval.loaders.youtube import YoutubeLoader # Web search engines -from open_webui.apps.retrieval.web.main import SearchResult -from open_webui.apps.retrieval.web.utils import get_web_loader -from open_webui.apps.retrieval.web.brave import search_brave -from open_webui.apps.retrieval.web.kagi import search_kagi -from open_webui.apps.retrieval.web.mojeek import search_mojeek -from open_webui.apps.retrieval.web.duckduckgo import search_duckduckgo -from open_webui.apps.retrieval.web.google_pse import search_google_pse -from open_webui.apps.retrieval.web.jina_search import search_jina -from open_webui.apps.retrieval.web.searchapi import search_searchapi -from open_webui.apps.retrieval.web.searxng import search_searxng -from open_webui.apps.retrieval.web.serper import search_serper -from open_webui.apps.retrieval.web.serply import search_serply -from open_webui.apps.retrieval.web.serpstack import search_serpstack -from open_webui.apps.retrieval.web.tavily import search_tavily -from open_webui.apps.retrieval.web.bing import search_bing +from open_webui.retrieval.web.main import SearchResult +from open_webui.retrieval.web.utils import get_web_loader +from open_webui.retrieval.web.brave import search_brave +from open_webui.retrieval.web.kagi import search_kagi +from open_webui.retrieval.web.mojeek import search_mojeek +from open_webui.retrieval.web.duckduckgo import search_duckduckgo +from open_webui.retrieval.web.google_pse import search_google_pse +from open_webui.retrieval.web.jina_search import search_jina +from open_webui.retrieval.web.searchapi import search_searchapi +from open_webui.retrieval.web.searxng import search_searxng +from open_webui.retrieval.web.serper import search_serper +from open_webui.retrieval.web.serply import search_serply +from open_webui.retrieval.web.serpstack import search_serpstack +from open_webui.retrieval.web.tavily import search_tavily +from open_webui.retrieval.web.bing import search_bing -from backend.open_webui.retrieval.utils import ( +from open_webui.retrieval.utils import ( get_embedding_function, get_model_path, query_collection, @@ -51,246 +65,132 @@ from backend.open_webui.retrieval.utils import ( query_doc, query_doc_with_hybrid_search, ) +from open_webui.utils.misc import ( + calculate_sha256_string, +) +from open_webui.utils.auth import get_admin_user, get_verified_user + -from open_webui.models.files import Files from open_webui.config import ( - BRAVE_SEARCH_API_KEY, - KAGI_SEARCH_API_KEY, - MOJEEK_SEARCH_API_KEY, - TIKTOKEN_ENCODING_NAME, - RAG_TEXT_SPLITTER, - CHUNK_OVERLAP, - CHUNK_SIZE, - CONTENT_EXTRACTION_ENGINE, - CORS_ALLOW_ORIGIN, - ENABLE_RAG_HYBRID_SEARCH, - ENABLE_RAG_LOCAL_WEB_FETCH, - ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, - ENABLE_RAG_WEB_SEARCH, ENV, - GOOGLE_PSE_API_KEY, - GOOGLE_PSE_ENGINE_ID, - PDF_EXTRACT_IMAGES, - RAG_EMBEDDING_ENGINE, - RAG_EMBEDDING_MODEL, RAG_EMBEDDING_MODEL_AUTO_UPDATE, RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE, - RAG_EMBEDDING_BATCH_SIZE, - RAG_FILE_MAX_COUNT, - RAG_FILE_MAX_SIZE, - RAG_OPENAI_API_BASE_URL, - RAG_OPENAI_API_KEY, - RAG_OLLAMA_BASE_URL, - RAG_OLLAMA_API_KEY, - RAG_RELEVANCE_THRESHOLD, - RAG_RERANKING_MODEL, RAG_RERANKING_MODEL_AUTO_UPDATE, RAG_RERANKING_MODEL_TRUST_REMOTE_CODE, - DEFAULT_RAG_TEMPLATE, - RAG_TEMPLATE, - RAG_TOP_K, - RAG_WEB_SEARCH_CONCURRENT_REQUESTS, - RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, - RAG_WEB_SEARCH_ENGINE, - RAG_WEB_SEARCH_RESULT_COUNT, - JINA_API_KEY, - SEARCHAPI_API_KEY, - SEARCHAPI_ENGINE, - SEARXNG_QUERY_URL, - SERPER_API_KEY, - SERPLY_API_KEY, - SERPSTACK_API_KEY, - SERPSTACK_HTTPS, - TAVILY_API_KEY, - BING_SEARCH_V7_ENDPOINT, - BING_SEARCH_V7_SUBSCRIPTION_KEY, - TIKA_SERVER_URL, UPLOAD_DIR, - YOUTUBE_LOADER_LANGUAGE, - YOUTUBE_LOADER_PROXY_URL, DEFAULT_LOCALE, - AppConfig, ) -from open_webui.constants import ERROR_MESSAGES from open_webui.env import ( SRC_LOG_LEVELS, DEVICE_TYPE, DOCKER, ) -from open_webui.utils.misc import ( - calculate_sha256, - calculate_sha256_string, - extract_folders_after_data_docs, - sanitize_filename, -) -from open_webui.utils.auth import get_admin_user, get_verified_user - -from langchain.text_splitter import RecursiveCharacterTextSplitter, TokenTextSplitter -from langchain_core.documents import Document - +from open_webui.constants import ERROR_MESSAGES log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["RAG"]) -app = FastAPI( - docs_url="/docs" if ENV == "dev" else None, - openapi_url="/openapi.json" if ENV == "dev" else None, - redoc_url=None, -) - -app.state.config = AppConfig() - -app.state.config.TOP_K = RAG_TOP_K -app.state.config.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD -app.state.config.FILE_MAX_SIZE = RAG_FILE_MAX_SIZE -app.state.config.FILE_MAX_COUNT = RAG_FILE_MAX_COUNT - -app.state.config.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH -app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = ( - ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION -) - -app.state.config.CONTENT_EXTRACTION_ENGINE = CONTENT_EXTRACTION_ENGINE -app.state.config.TIKA_SERVER_URL = TIKA_SERVER_URL - -app.state.config.TEXT_SPLITTER = RAG_TEXT_SPLITTER -app.state.config.TIKTOKEN_ENCODING_NAME = TIKTOKEN_ENCODING_NAME - -app.state.config.CHUNK_SIZE = CHUNK_SIZE -app.state.config.CHUNK_OVERLAP = CHUNK_OVERLAP - -app.state.config.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE -app.state.config.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL -app.state.config.RAG_EMBEDDING_BATCH_SIZE = RAG_EMBEDDING_BATCH_SIZE -app.state.config.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL -app.state.config.RAG_TEMPLATE = RAG_TEMPLATE - -app.state.config.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL -app.state.config.OPENAI_API_KEY = RAG_OPENAI_API_KEY - -app.state.config.OLLAMA_BASE_URL = RAG_OLLAMA_BASE_URL -app.state.config.OLLAMA_API_KEY = RAG_OLLAMA_API_KEY - -app.state.config.PDF_EXTRACT_IMAGES = PDF_EXTRACT_IMAGES - -app.state.config.YOUTUBE_LOADER_LANGUAGE = YOUTUBE_LOADER_LANGUAGE -app.state.config.YOUTUBE_LOADER_PROXY_URL = YOUTUBE_LOADER_PROXY_URL -app.state.YOUTUBE_LOADER_TRANSLATION = None - - -app.state.config.ENABLE_RAG_WEB_SEARCH = ENABLE_RAG_WEB_SEARCH -app.state.config.RAG_WEB_SEARCH_ENGINE = RAG_WEB_SEARCH_ENGINE -app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = RAG_WEB_SEARCH_DOMAIN_FILTER_LIST - -app.state.config.SEARXNG_QUERY_URL = SEARXNG_QUERY_URL -app.state.config.GOOGLE_PSE_API_KEY = GOOGLE_PSE_API_KEY -app.state.config.GOOGLE_PSE_ENGINE_ID = GOOGLE_PSE_ENGINE_ID -app.state.config.BRAVE_SEARCH_API_KEY = BRAVE_SEARCH_API_KEY -app.state.config.KAGI_SEARCH_API_KEY = KAGI_SEARCH_API_KEY -app.state.config.MOJEEK_SEARCH_API_KEY = MOJEEK_SEARCH_API_KEY -app.state.config.SERPSTACK_API_KEY = SERPSTACK_API_KEY -app.state.config.SERPSTACK_HTTPS = SERPSTACK_HTTPS -app.state.config.SERPER_API_KEY = SERPER_API_KEY -app.state.config.SERPLY_API_KEY = SERPLY_API_KEY -app.state.config.TAVILY_API_KEY = TAVILY_API_KEY -app.state.config.SEARCHAPI_API_KEY = SEARCHAPI_API_KEY -app.state.config.SEARCHAPI_ENGINE = SEARCHAPI_ENGINE -app.state.config.JINA_API_KEY = JINA_API_KEY -app.state.config.BING_SEARCH_V7_ENDPOINT = BING_SEARCH_V7_ENDPOINT -app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY = BING_SEARCH_V7_SUBSCRIPTION_KEY - -app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = RAG_WEB_SEARCH_RESULT_COUNT -app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS +########################################## +# +# Utility functions +# +########################################## def update_embedding_model( + request: Request, embedding_model: str, auto_update: bool = False, ): - if embedding_model and app.state.config.RAG_EMBEDDING_ENGINE == "": + if embedding_model and request.app.state.config.RAG_EMBEDDING_ENGINE == "": from sentence_transformers import SentenceTransformer try: - app.state.sentence_transformer_ef = SentenceTransformer( + request.app.state.sentence_transformer_ef = SentenceTransformer( get_model_path(embedding_model, auto_update), device=DEVICE_TYPE, trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE, ) except Exception as e: log.debug(f"Error loading SentenceTransformer: {e}") - app.state.sentence_transformer_ef = None + request.app.state.sentence_transformer_ef = None else: - app.state.sentence_transformer_ef = None + request.app.state.sentence_transformer_ef = None def update_reranking_model( + request: Request, reranking_model: str, auto_update: bool = False, ): if reranking_model: if any(model in reranking_model for model in ["jinaai/jina-colbert-v2"]): try: - from open_webui.apps.retrieval.models.colbert import ColBERT + from open_webui.retrieval.models.colbert import ColBERT - app.state.sentence_transformer_rf = ColBERT( + request.app.state.sentence_transformer_rf = ColBERT( get_model_path(reranking_model, auto_update), env="docker" if DOCKER else None, ) except Exception as e: log.error(f"ColBERT: {e}") - app.state.sentence_transformer_rf = None - app.state.config.ENABLE_RAG_HYBRID_SEARCH = False + request.app.state.sentence_transformer_rf = None + request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = False else: import sentence_transformers try: - app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder( - get_model_path(reranking_model, auto_update), - device=DEVICE_TYPE, - trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE, + request.app.state.sentence_transformer_rf = ( + sentence_transformers.CrossEncoder( + get_model_path(reranking_model, auto_update), + device=DEVICE_TYPE, + trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE, + ) ) except: log.error("CrossEncoder error") - app.state.sentence_transformer_rf = None - app.state.config.ENABLE_RAG_HYBRID_SEARCH = False + request.app.state.sentence_transformer_rf = None + request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = False else: - app.state.sentence_transformer_rf = None + request.app.state.sentence_transformer_rf = None update_embedding_model( - app.state.config.RAG_EMBEDDING_MODEL, + request.app.state.config.RAG_EMBEDDING_MODEL, RAG_EMBEDDING_MODEL_AUTO_UPDATE, ) update_reranking_model( - app.state.config.RAG_RERANKING_MODEL, + request.app.state.config.RAG_RERANKING_MODEL, RAG_RERANKING_MODEL_AUTO_UPDATE, ) -app.state.EMBEDDING_FUNCTION = get_embedding_function( - app.state.config.RAG_EMBEDDING_ENGINE, - app.state.config.RAG_EMBEDDING_MODEL, - app.state.sentence_transformer_ef, - ( - app.state.config.OPENAI_API_BASE_URL - if app.state.config.RAG_EMBEDDING_ENGINE == "openai" - else app.state.config.OLLAMA_BASE_URL - ), - ( - app.state.config.OPENAI_API_KEY - if app.state.config.RAG_EMBEDDING_ENGINE == "openai" - else app.state.config.OLLAMA_API_KEY - ), - app.state.config.RAG_EMBEDDING_BATCH_SIZE, -) +########################################## +# +# API routes +# +########################################## -app.add_middleware( - CORSMiddleware, - allow_origins=CORS_ALLOW_ORIGIN, - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], + +router = APIRouter() + + +request.app.state.EMBEDDING_FUNCTION = get_embedding_function( + request.app.state.config.RAG_EMBEDDING_ENGINE, + request.app.state.config.RAG_EMBEDDING_MODEL, + request.app.state.sentence_transformer_ef, + ( + request.app.state.config.OPENAI_API_BASE_URL + if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai" + else request.app.state.config.OLLAMA_BASE_URL + ), + ( + request.app.state.config.OPENAI_API_KEY + if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai" + else request.app.state.config.OLLAMA_API_KEY + ), + request.app.state.config.RAG_EMBEDDING_BATCH_SIZE, ) @@ -306,43 +206,43 @@ class SearchForm(CollectionNameForm): query: str -@app.get("/") -async def get_status(): +@router.get("/") +async def get_status(request: Request): return { "status": True, - "chunk_size": app.state.config.CHUNK_SIZE, - "chunk_overlap": app.state.config.CHUNK_OVERLAP, - "template": app.state.config.RAG_TEMPLATE, - "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE, - "embedding_model": app.state.config.RAG_EMBEDDING_MODEL, - "reranking_model": app.state.config.RAG_RERANKING_MODEL, - "embedding_batch_size": app.state.config.RAG_EMBEDDING_BATCH_SIZE, + "chunk_size": request.app.state.config.CHUNK_SIZE, + "chunk_overlap": request.app.state.config.CHUNK_OVERLAP, + "template": request.app.state.config.RAG_TEMPLATE, + "embedding_engine": request.app.state.config.RAG_EMBEDDING_ENGINE, + "embedding_model": request.app.state.config.RAG_EMBEDDING_MODEL, + "reranking_model": request.app.state.config.RAG_RERANKING_MODEL, + "embedding_batch_size": request.app.state.config.RAG_EMBEDDING_BATCH_SIZE, } -@app.get("/embedding") -async def get_embedding_config(user=Depends(get_admin_user)): +@router.get("/embedding") +async def get_embedding_config(request: Request, user=Depends(get_admin_user)): return { "status": True, - "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE, - "embedding_model": app.state.config.RAG_EMBEDDING_MODEL, - "embedding_batch_size": app.state.config.RAG_EMBEDDING_BATCH_SIZE, + "embedding_engine": request.app.state.config.RAG_EMBEDDING_ENGINE, + "embedding_model": request.app.state.config.RAG_EMBEDDING_MODEL, + "embedding_batch_size": request.app.state.config.RAG_EMBEDDING_BATCH_SIZE, "openai_config": { - "url": app.state.config.OPENAI_API_BASE_URL, - "key": app.state.config.OPENAI_API_KEY, + "url": request.app.state.config.OPENAI_API_BASE_URL, + "key": request.app.state.config.OPENAI_API_KEY, }, "ollama_config": { - "url": app.state.config.OLLAMA_BASE_URL, - "key": app.state.config.OLLAMA_API_KEY, + "url": request.app.state.config.OLLAMA_BASE_URL, + "key": request.app.state.config.OLLAMA_API_KEY, }, } -@app.get("/reranking") -async def get_reraanking_config(user=Depends(get_admin_user)): +@router.get("/reranking") +async def get_reraanking_config(request: Request, user=Depends(get_admin_user)): return { "status": True, - "reranking_model": app.state.config.RAG_RERANKING_MODEL, + "reranking_model": request.app.state.config.RAG_RERANKING_MODEL, } @@ -364,59 +264,63 @@ class EmbeddingModelUpdateForm(BaseModel): embedding_batch_size: Optional[int] = 1 -@app.post("/embedding/update") +@router.post("/embedding/update") async def update_embedding_config( - form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user) + request: Request, form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user) ): log.info( - f"Updating embedding model: {app.state.config.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}" + f"Updating embedding model: {request.app.state.config.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}" ) try: - app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine - app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model + request.app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine + request.app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model - if app.state.config.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]: + if request.app.state.config.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]: if form_data.openai_config is not None: - app.state.config.OPENAI_API_BASE_URL = form_data.openai_config.url - app.state.config.OPENAI_API_KEY = form_data.openai_config.key + request.app.state.config.OPENAI_API_BASE_URL = ( + form_data.openai_config.url + ) + request.app.state.config.OPENAI_API_KEY = form_data.openai_config.key if form_data.ollama_config is not None: - app.state.config.OLLAMA_BASE_URL = form_data.ollama_config.url - app.state.config.OLLAMA_API_KEY = form_data.ollama_config.key + request.app.state.config.OLLAMA_BASE_URL = form_data.ollama_config.url + request.app.state.config.OLLAMA_API_KEY = form_data.ollama_config.key - app.state.config.RAG_EMBEDDING_BATCH_SIZE = form_data.embedding_batch_size + request.app.state.config.RAG_EMBEDDING_BATCH_SIZE = ( + form_data.embedding_batch_size + ) - update_embedding_model(app.state.config.RAG_EMBEDDING_MODEL) + update_embedding_model(request.app.state.config.RAG_EMBEDDING_MODEL) - app.state.EMBEDDING_FUNCTION = get_embedding_function( - app.state.config.RAG_EMBEDDING_ENGINE, - app.state.config.RAG_EMBEDDING_MODEL, - app.state.sentence_transformer_ef, + request.app.state.EMBEDDING_FUNCTION = get_embedding_function( + request.app.state.config.RAG_EMBEDDING_ENGINE, + request.app.state.config.RAG_EMBEDDING_MODEL, + request.app.state.sentence_transformer_ef, ( - app.state.config.OPENAI_API_BASE_URL - if app.state.config.RAG_EMBEDDING_ENGINE == "openai" - else app.state.config.OLLAMA_BASE_URL + request.app.state.config.OPENAI_API_BASE_URL + if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai" + else request.app.state.config.OLLAMA_BASE_URL ), ( - app.state.config.OPENAI_API_KEY - if app.state.config.RAG_EMBEDDING_ENGINE == "openai" - else app.state.config.OLLAMA_API_KEY + request.app.state.config.OPENAI_API_KEY + if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai" + else request.app.state.config.OLLAMA_API_KEY ), - app.state.config.RAG_EMBEDDING_BATCH_SIZE, + request.app.state.config.RAG_EMBEDDING_BATCH_SIZE, ) return { "status": True, - "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE, - "embedding_model": app.state.config.RAG_EMBEDDING_MODEL, - "embedding_batch_size": app.state.config.RAG_EMBEDDING_BATCH_SIZE, + "embedding_engine": request.app.state.config.RAG_EMBEDDING_ENGINE, + "embedding_model": request.app.state.config.RAG_EMBEDDING_MODEL, + "embedding_batch_size": request.app.state.config.RAG_EMBEDDING_BATCH_SIZE, "openai_config": { - "url": app.state.config.OPENAI_API_BASE_URL, - "key": app.state.config.OPENAI_API_KEY, + "url": request.app.state.config.OPENAI_API_BASE_URL, + "key": request.app.state.config.OPENAI_API_KEY, }, "ollama_config": { - "url": app.state.config.OLLAMA_BASE_URL, - "key": app.state.config.OLLAMA_API_KEY, + "url": request.app.state.config.OLLAMA_BASE_URL, + "key": request.app.state.config.OLLAMA_API_KEY, }, } except Exception as e: @@ -431,21 +335,21 @@ class RerankingModelUpdateForm(BaseModel): reranking_model: str -@app.post("/reranking/update") +@router.post("/reranking/update") async def update_reranking_config( - form_data: RerankingModelUpdateForm, user=Depends(get_admin_user) + request: Request, form_data: RerankingModelUpdateForm, user=Depends(get_admin_user) ): log.info( - f"Updating reranking model: {app.state.config.RAG_RERANKING_MODEL} to {form_data.reranking_model}" + f"Updating reranking model: {request.app.state.config.RAG_RERANKING_MODEL} to {form_data.reranking_model}" ) try: - app.state.config.RAG_RERANKING_MODEL = form_data.reranking_model + request.app.state.config.RAG_RERANKING_MODEL = form_data.reranking_model - update_reranking_model(app.state.config.RAG_RERANKING_MODEL, True) + update_reranking_model(request.app.state.config.RAG_RERANKING_MODEL, True) return { "status": True, - "reranking_model": app.state.config.RAG_RERANKING_MODEL, + "reranking_model": request.app.state.config.RAG_RERANKING_MODEL, } except Exception as e: log.exception(f"Problem updating reranking model: {e}") @@ -455,52 +359,52 @@ async def update_reranking_config( ) -@app.get("/config") -async def get_rag_config(user=Depends(get_admin_user)): +@router.get("/config") +async def get_rag_config(request: Request, user=Depends(get_admin_user)): return { "status": True, - "pdf_extract_images": app.state.config.PDF_EXTRACT_IMAGES, + "pdf_extract_images": request.app.state.config.PDF_EXTRACT_IMAGES, "content_extraction": { - "engine": app.state.config.CONTENT_EXTRACTION_ENGINE, - "tika_server_url": app.state.config.TIKA_SERVER_URL, + "engine": request.app.state.config.CONTENT_EXTRACTION_ENGINE, + "tika_server_url": request.app.state.config.TIKA_SERVER_URL, }, "chunk": { - "text_splitter": app.state.config.TEXT_SPLITTER, - "chunk_size": app.state.config.CHUNK_SIZE, - "chunk_overlap": app.state.config.CHUNK_OVERLAP, + "text_splitter": request.app.state.config.TEXT_SPLITTER, + "chunk_size": request.app.state.config.CHUNK_SIZE, + "chunk_overlap": request.app.state.config.CHUNK_OVERLAP, }, "file": { - "max_size": app.state.config.FILE_MAX_SIZE, - "max_count": app.state.config.FILE_MAX_COUNT, + "max_size": request.app.state.config.FILE_MAX_SIZE, + "max_count": request.app.state.config.FILE_MAX_COUNT, }, "youtube": { - "language": app.state.config.YOUTUBE_LOADER_LANGUAGE, - "translation": app.state.YOUTUBE_LOADER_TRANSLATION, - "proxy_url": app.state.config.YOUTUBE_LOADER_PROXY_URL, + "language": request.app.state.config.YOUTUBE_LOADER_LANGUAGE, + "translation": request.app.state.YOUTUBE_LOADER_TRANSLATION, + "proxy_url": request.app.state.config.YOUTUBE_LOADER_PROXY_URL, }, "web": { - "web_loader_ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, + "web_loader_ssl_verification": request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, "search": { - "enabled": app.state.config.ENABLE_RAG_WEB_SEARCH, - "engine": app.state.config.RAG_WEB_SEARCH_ENGINE, - "searxng_query_url": app.state.config.SEARXNG_QUERY_URL, - "google_pse_api_key": app.state.config.GOOGLE_PSE_API_KEY, - "google_pse_engine_id": app.state.config.GOOGLE_PSE_ENGINE_ID, - "brave_search_api_key": app.state.config.BRAVE_SEARCH_API_KEY, - "kagi_search_api_key": app.state.config.KAGI_SEARCH_API_KEY, - "mojeek_search_api_key": app.state.config.MOJEEK_SEARCH_API_KEY, - "serpstack_api_key": app.state.config.SERPSTACK_API_KEY, - "serpstack_https": app.state.config.SERPSTACK_HTTPS, - "serper_api_key": app.state.config.SERPER_API_KEY, - "serply_api_key": app.state.config.SERPLY_API_KEY, - "tavily_api_key": app.state.config.TAVILY_API_KEY, - "searchapi_api_key": app.state.config.SEARCHAPI_API_KEY, - "seaarchapi_engine": app.state.config.SEARCHAPI_ENGINE, - "jina_api_key": app.state.config.JINA_API_KEY, - "bing_search_v7_endpoint": app.state.config.BING_SEARCH_V7_ENDPOINT, - "bing_search_v7_subscription_key": app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY, - "result_count": app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, - "concurrent_requests": app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS, + "enabled": request.app.state.config.ENABLE_RAG_WEB_SEARCH, + "engine": request.app.state.config.RAG_WEB_SEARCH_ENGINE, + "searxng_query_url": request.app.state.config.SEARXNG_QUERY_URL, + "google_pse_api_key": request.app.state.config.GOOGLE_PSE_API_KEY, + "google_pse_engine_id": request.app.state.config.GOOGLE_PSE_ENGINE_ID, + "brave_search_api_key": request.app.state.config.BRAVE_SEARCH_API_KEY, + "kagi_search_api_key": request.app.state.config.KAGI_SEARCH_API_KEY, + "mojeek_search_api_key": request.app.state.config.MOJEEK_SEARCH_API_KEY, + "serpstack_api_key": request.app.state.config.SERPSTACK_API_KEY, + "serpstack_https": request.app.state.config.SERPSTACK_HTTPS, + "serper_api_key": request.app.state.config.SERPER_API_KEY, + "serply_api_key": request.app.state.config.SERPLY_API_KEY, + "tavily_api_key": request.app.state.config.TAVILY_API_KEY, + "searchapi_api_key": request.app.state.config.SEARCHAPI_API_KEY, + "seaarchapi_engine": request.app.state.config.SEARCHAPI_ENGINE, + "jina_api_key": request.app.state.config.JINA_API_KEY, + "bing_search_v7_endpoint": request.app.state.config.BING_SEARCH_V7_ENDPOINT, + "bing_search_v7_subscription_key": request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY, + "result_count": request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, + "concurrent_requests": request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS, }, }, } @@ -565,139 +469,159 @@ class ConfigUpdateForm(BaseModel): web: Optional[WebConfig] = None -@app.post("/config/update") -async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)): - app.state.config.PDF_EXTRACT_IMAGES = ( +@router.post("/config/update") +async def update_rag_config( + request: Request, form_data: ConfigUpdateForm, user=Depends(get_admin_user) +): + request.app.state.config.PDF_EXTRACT_IMAGES = ( form_data.pdf_extract_images if form_data.pdf_extract_images is not None - else app.state.config.PDF_EXTRACT_IMAGES + else request.app.state.config.PDF_EXTRACT_IMAGES ) if form_data.file is not None: - app.state.config.FILE_MAX_SIZE = form_data.file.max_size - app.state.config.FILE_MAX_COUNT = form_data.file.max_count + request.app.state.config.FILE_MAX_SIZE = form_data.file.max_size + request.app.state.config.FILE_MAX_COUNT = form_data.file.max_count if form_data.content_extraction is not None: log.info(f"Updating text settings: {form_data.content_extraction}") - app.state.config.CONTENT_EXTRACTION_ENGINE = form_data.content_extraction.engine - app.state.config.TIKA_SERVER_URL = form_data.content_extraction.tika_server_url + request.app.state.config.CONTENT_EXTRACTION_ENGINE = ( + form_data.content_extraction.engine + ) + request.app.state.config.TIKA_SERVER_URL = ( + form_data.content_extraction.tika_server_url + ) if form_data.chunk is not None: - app.state.config.TEXT_SPLITTER = form_data.chunk.text_splitter - app.state.config.CHUNK_SIZE = form_data.chunk.chunk_size - app.state.config.CHUNK_OVERLAP = form_data.chunk.chunk_overlap + request.app.state.config.TEXT_SPLITTER = form_data.chunk.text_splitter + request.app.state.config.CHUNK_SIZE = form_data.chunk.chunk_size + request.app.state.config.CHUNK_OVERLAP = form_data.chunk.chunk_overlap if form_data.youtube is not None: - app.state.config.YOUTUBE_LOADER_LANGUAGE = form_data.youtube.language - app.state.config.YOUTUBE_LOADER_PROXY_URL = form_data.youtube.proxy_url - app.state.YOUTUBE_LOADER_TRANSLATION = form_data.youtube.translation + request.app.state.config.YOUTUBE_LOADER_LANGUAGE = form_data.youtube.language + request.app.state.config.YOUTUBE_LOADER_PROXY_URL = form_data.youtube.proxy_url + request.app.state.YOUTUBE_LOADER_TRANSLATION = form_data.youtube.translation if form_data.web is not None: - app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = ( + request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = ( # Note: When UI "Bypass SSL verification for Websites"=True then ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION=False form_data.web.web_loader_ssl_verification ) - app.state.config.ENABLE_RAG_WEB_SEARCH = form_data.web.search.enabled - app.state.config.RAG_WEB_SEARCH_ENGINE = form_data.web.search.engine - app.state.config.SEARXNG_QUERY_URL = form_data.web.search.searxng_query_url - app.state.config.GOOGLE_PSE_API_KEY = form_data.web.search.google_pse_api_key - app.state.config.GOOGLE_PSE_ENGINE_ID = ( + request.app.state.config.ENABLE_RAG_WEB_SEARCH = form_data.web.search.enabled + request.app.state.config.RAG_WEB_SEARCH_ENGINE = form_data.web.search.engine + request.app.state.config.SEARXNG_QUERY_URL = ( + form_data.web.search.searxng_query_url + ) + request.app.state.config.GOOGLE_PSE_API_KEY = ( + form_data.web.search.google_pse_api_key + ) + request.app.state.config.GOOGLE_PSE_ENGINE_ID = ( form_data.web.search.google_pse_engine_id ) - app.state.config.BRAVE_SEARCH_API_KEY = ( + request.app.state.config.BRAVE_SEARCH_API_KEY = ( form_data.web.search.brave_search_api_key ) - app.state.config.KAGI_SEARCH_API_KEY = form_data.web.search.kagi_search_api_key - app.state.config.MOJEEK_SEARCH_API_KEY = ( + request.app.state.config.KAGI_SEARCH_API_KEY = ( + form_data.web.search.kagi_search_api_key + ) + request.app.state.config.MOJEEK_SEARCH_API_KEY = ( form_data.web.search.mojeek_search_api_key ) - app.state.config.SERPSTACK_API_KEY = form_data.web.search.serpstack_api_key - app.state.config.SERPSTACK_HTTPS = form_data.web.search.serpstack_https - app.state.config.SERPER_API_KEY = form_data.web.search.serper_api_key - app.state.config.SERPLY_API_KEY = form_data.web.search.serply_api_key - app.state.config.TAVILY_API_KEY = form_data.web.search.tavily_api_key - app.state.config.SEARCHAPI_API_KEY = form_data.web.search.searchapi_api_key - app.state.config.SEARCHAPI_ENGINE = form_data.web.search.searchapi_engine + request.app.state.config.SERPSTACK_API_KEY = ( + form_data.web.search.serpstack_api_key + ) + request.app.state.config.SERPSTACK_HTTPS = form_data.web.search.serpstack_https + request.app.state.config.SERPER_API_KEY = form_data.web.search.serper_api_key + request.app.state.config.SERPLY_API_KEY = form_data.web.search.serply_api_key + request.app.state.config.TAVILY_API_KEY = form_data.web.search.tavily_api_key + request.app.state.config.SEARCHAPI_API_KEY = ( + form_data.web.search.searchapi_api_key + ) + request.app.state.config.SEARCHAPI_ENGINE = ( + form_data.web.search.searchapi_engine + ) - app.state.config.JINA_API_KEY = form_data.web.search.jina_api_key - app.state.config.BING_SEARCH_V7_ENDPOINT = ( + request.app.state.config.JINA_API_KEY = form_data.web.search.jina_api_key + request.app.state.config.BING_SEARCH_V7_ENDPOINT = ( form_data.web.search.bing_search_v7_endpoint ) - app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY = ( + request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY = ( form_data.web.search.bing_search_v7_subscription_key ) - app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = form_data.web.search.result_count - app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = ( + request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = ( + form_data.web.search.result_count + ) + request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = ( form_data.web.search.concurrent_requests ) return { "status": True, - "pdf_extract_images": app.state.config.PDF_EXTRACT_IMAGES, + "pdf_extract_images": request.app.state.config.PDF_EXTRACT_IMAGES, "file": { - "max_size": app.state.config.FILE_MAX_SIZE, - "max_count": app.state.config.FILE_MAX_COUNT, + "max_size": request.app.state.config.FILE_MAX_SIZE, + "max_count": request.app.state.config.FILE_MAX_COUNT, }, "content_extraction": { - "engine": app.state.config.CONTENT_EXTRACTION_ENGINE, - "tika_server_url": app.state.config.TIKA_SERVER_URL, + "engine": request.app.state.config.CONTENT_EXTRACTION_ENGINE, + "tika_server_url": request.app.state.config.TIKA_SERVER_URL, }, "chunk": { - "text_splitter": app.state.config.TEXT_SPLITTER, - "chunk_size": app.state.config.CHUNK_SIZE, - "chunk_overlap": app.state.config.CHUNK_OVERLAP, + "text_splitter": request.app.state.config.TEXT_SPLITTER, + "chunk_size": request.app.state.config.CHUNK_SIZE, + "chunk_overlap": request.app.state.config.CHUNK_OVERLAP, }, "youtube": { - "language": app.state.config.YOUTUBE_LOADER_LANGUAGE, - "proxy_url": app.state.config.YOUTUBE_LOADER_PROXY_URL, - "translation": app.state.YOUTUBE_LOADER_TRANSLATION, + "language": request.app.state.config.YOUTUBE_LOADER_LANGUAGE, + "proxy_url": request.app.state.config.YOUTUBE_LOADER_PROXY_URL, + "translation": request.app.state.YOUTUBE_LOADER_TRANSLATION, }, "web": { - "web_loader_ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, + "web_loader_ssl_verification": request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, "search": { - "enabled": app.state.config.ENABLE_RAG_WEB_SEARCH, - "engine": app.state.config.RAG_WEB_SEARCH_ENGINE, - "searxng_query_url": app.state.config.SEARXNG_QUERY_URL, - "google_pse_api_key": app.state.config.GOOGLE_PSE_API_KEY, - "google_pse_engine_id": app.state.config.GOOGLE_PSE_ENGINE_ID, - "brave_search_api_key": app.state.config.BRAVE_SEARCH_API_KEY, - "kagi_search_api_key": app.state.config.KAGI_SEARCH_API_KEY, - "mojeek_search_api_key": app.state.config.MOJEEK_SEARCH_API_KEY, - "serpstack_api_key": app.state.config.SERPSTACK_API_KEY, - "serpstack_https": app.state.config.SERPSTACK_HTTPS, - "serper_api_key": app.state.config.SERPER_API_KEY, - "serply_api_key": app.state.config.SERPLY_API_KEY, - "serachapi_api_key": app.state.config.SEARCHAPI_API_KEY, - "searchapi_engine": app.state.config.SEARCHAPI_ENGINE, - "tavily_api_key": app.state.config.TAVILY_API_KEY, - "jina_api_key": app.state.config.JINA_API_KEY, - "bing_search_v7_endpoint": app.state.config.BING_SEARCH_V7_ENDPOINT, - "bing_search_v7_subscription_key": app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY, - "result_count": app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, - "concurrent_requests": app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS, + "enabled": request.app.state.config.ENABLE_RAG_WEB_SEARCH, + "engine": request.app.state.config.RAG_WEB_SEARCH_ENGINE, + "searxng_query_url": request.app.state.config.SEARXNG_QUERY_URL, + "google_pse_api_key": request.app.state.config.GOOGLE_PSE_API_KEY, + "google_pse_engine_id": request.app.state.config.GOOGLE_PSE_ENGINE_ID, + "brave_search_api_key": request.app.state.config.BRAVE_SEARCH_API_KEY, + "kagi_search_api_key": request.app.state.config.KAGI_SEARCH_API_KEY, + "mojeek_search_api_key": request.app.state.config.MOJEEK_SEARCH_API_KEY, + "serpstack_api_key": request.app.state.config.SERPSTACK_API_KEY, + "serpstack_https": request.app.state.config.SERPSTACK_HTTPS, + "serper_api_key": request.app.state.config.SERPER_API_KEY, + "serply_api_key": request.app.state.config.SERPLY_API_KEY, + "serachapi_api_key": request.app.state.config.SEARCHAPI_API_KEY, + "searchapi_engine": request.app.state.config.SEARCHAPI_ENGINE, + "tavily_api_key": request.app.state.config.TAVILY_API_KEY, + "jina_api_key": request.app.state.config.JINA_API_KEY, + "bing_search_v7_endpoint": request.app.state.config.BING_SEARCH_V7_ENDPOINT, + "bing_search_v7_subscription_key": request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY, + "result_count": request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, + "concurrent_requests": request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS, }, }, } -@app.get("/template") -async def get_rag_template(user=Depends(get_verified_user)): +@router.get("/template") +async def get_rag_template(request: Request, user=Depends(get_verified_user)): return { "status": True, - "template": app.state.config.RAG_TEMPLATE, + "template": request.app.state.config.RAG_TEMPLATE, } -@app.get("/query/settings") -async def get_query_settings(user=Depends(get_admin_user)): +@router.get("/query/settings") +async def get_query_settings(request: Request, user=Depends(get_admin_user)): return { "status": True, - "template": app.state.config.RAG_TEMPLATE, - "k": app.state.config.TOP_K, - "r": app.state.config.RELEVANCE_THRESHOLD, - "hybrid": app.state.config.ENABLE_RAG_HYBRID_SEARCH, + "template": request.app.state.config.RAG_TEMPLATE, + "k": request.app.state.config.TOP_K, + "r": request.app.state.config.RELEVANCE_THRESHOLD, + "hybrid": request.app.state.config.ENABLE_RAG_HYBRID_SEARCH, } @@ -708,24 +632,24 @@ class QuerySettingsForm(BaseModel): hybrid: Optional[bool] = None -@app.post("/query/settings/update") +@router.post("/query/settings/update") async def update_query_settings( - form_data: QuerySettingsForm, user=Depends(get_admin_user) + request: Request, form_data: QuerySettingsForm, user=Depends(get_admin_user) ): - app.state.config.RAG_TEMPLATE = form_data.template - app.state.config.TOP_K = form_data.k if form_data.k else 4 - app.state.config.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0 + request.app.state.config.RAG_TEMPLATE = form_data.template + request.app.state.config.TOP_K = form_data.k if form_data.k else 4 + request.app.state.config.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0 - app.state.config.ENABLE_RAG_HYBRID_SEARCH = ( + request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = ( form_data.hybrid if form_data.hybrid else False ) return { "status": True, - "template": app.state.config.RAG_TEMPLATE, - "k": app.state.config.TOP_K, - "r": app.state.config.RELEVANCE_THRESHOLD, - "hybrid": app.state.config.ENABLE_RAG_HYBRID_SEARCH, + "template": request.app.state.config.RAG_TEMPLATE, + "k": request.app.state.config.TOP_K, + "r": request.app.state.config.RELEVANCE_THRESHOLD, + "hybrid": request.app.state.config.ENABLE_RAG_HYBRID_SEARCH, } @@ -736,24 +660,8 @@ async def update_query_settings( #################################### -def _get_docs_info(docs: list[Document]) -> str: - docs_info = set() - - # Trying to select relevant metadata identifying the document. - for doc in docs: - metadata = getattr(doc, "metadata", {}) - doc_name = metadata.get("name", "") - if not doc_name: - doc_name = metadata.get("title", "") - if not doc_name: - doc_name = metadata.get("source", "") - if doc_name: - docs_info.add(doc_name) - - return ", ".join(docs_info) - - def save_docs_to_vector_db( + request: Request, docs, collection_name, metadata: Optional[dict] = None, @@ -761,6 +669,22 @@ def save_docs_to_vector_db( split: bool = True, add: bool = False, ) -> bool: + def _get_docs_info(docs: list[Document]) -> str: + docs_info = set() + + # Trying to select relevant metadata identifying the document. + for doc in docs: + metadata = getattr(doc, "metadata", {}) + doc_name = metadata.get("name", "") + if not doc_name: + doc_name = metadata.get("title", "") + if not doc_name: + doc_name = metadata.get("source", "") + if doc_name: + docs_info.add(doc_name) + + return ", ".join(docs_info) + log.info( f"save_docs_to_vector_db: document {_get_docs_info(docs)} {collection_name}" ) @@ -779,22 +703,22 @@ def save_docs_to_vector_db( raise ValueError(ERROR_MESSAGES.DUPLICATE_CONTENT) if split: - if app.state.config.TEXT_SPLITTER in ["", "character"]: + if request.app.state.config.TEXT_SPLITTER in ["", "character"]: text_splitter = RecursiveCharacterTextSplitter( - chunk_size=app.state.config.CHUNK_SIZE, - chunk_overlap=app.state.config.CHUNK_OVERLAP, + chunk_size=request.app.state.config.CHUNK_SIZE, + chunk_overlap=request.app.state.config.CHUNK_OVERLAP, add_start_index=True, ) - elif app.state.config.TEXT_SPLITTER == "token": + elif request.app.state.config.TEXT_SPLITTER == "token": log.info( - f"Using token text splitter: {app.state.config.TIKTOKEN_ENCODING_NAME}" + f"Using token text splitter: {request.app.state.config.TIKTOKEN_ENCODING_NAME}" ) - tiktoken.get_encoding(str(app.state.config.TIKTOKEN_ENCODING_NAME)) + tiktoken.get_encoding(str(request.app.state.config.TIKTOKEN_ENCODING_NAME)) text_splitter = TokenTextSplitter( - encoding_name=str(app.state.config.TIKTOKEN_ENCODING_NAME), - chunk_size=app.state.config.CHUNK_SIZE, - chunk_overlap=app.state.config.CHUNK_OVERLAP, + encoding_name=str(request.app.state.config.TIKTOKEN_ENCODING_NAME), + chunk_size=request.app.state.config.CHUNK_SIZE, + chunk_overlap=request.app.state.config.CHUNK_OVERLAP, add_start_index=True, ) else: @@ -812,8 +736,8 @@ def save_docs_to_vector_db( **(metadata if metadata else {}), "embedding_config": json.dumps( { - "engine": app.state.config.RAG_EMBEDDING_ENGINE, - "model": app.state.config.RAG_EMBEDDING_MODEL, + "engine": request.app.state.config.RAG_EMBEDDING_ENGINE, + "model": request.app.state.config.RAG_EMBEDDING_MODEL, } ), } @@ -842,20 +766,20 @@ def save_docs_to_vector_db( log.info(f"adding to collection {collection_name}") embedding_function = get_embedding_function( - app.state.config.RAG_EMBEDDING_ENGINE, - app.state.config.RAG_EMBEDDING_MODEL, - app.state.sentence_transformer_ef, + request.app.state.config.RAG_EMBEDDING_ENGINE, + request.app.state.config.RAG_EMBEDDING_MODEL, + request.app.state.sentence_transformer_ef, ( - app.state.config.OPENAI_API_BASE_URL - if app.state.config.RAG_EMBEDDING_ENGINE == "openai" - else app.state.config.OLLAMA_BASE_URL + request.app.state.config.OPENAI_API_BASE_URL + if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai" + else request.app.state.config.OLLAMA_BASE_URL ), ( - app.state.config.OPENAI_API_KEY - if app.state.config.RAG_EMBEDDING_ENGINE == "openai" - else app.state.config.OLLAMA_API_KEY + request.app.state.config.OPENAI_API_KEY + if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai" + else request.app.state.config.OLLAMA_API_KEY ), - app.state.config.RAG_EMBEDDING_BATCH_SIZE, + request.app.state.config.RAG_EMBEDDING_BATCH_SIZE, ) embeddings = embedding_function( @@ -889,8 +813,9 @@ class ProcessFileForm(BaseModel): collection_name: Optional[str] = None -@app.post("/process/file") +@router.post("/process/file") def process_file( + request: Request, form_data: ProcessFileForm, user=Depends(get_verified_user), ): @@ -960,9 +885,9 @@ def process_file( if file_path: file_path = Storage.get_file(file_path) loader = Loader( - engine=app.state.config.CONTENT_EXTRACTION_ENGINE, - TIKA_SERVER_URL=app.state.config.TIKA_SERVER_URL, - PDF_EXTRACT_IMAGES=app.state.config.PDF_EXTRACT_IMAGES, + engine=request.app.state.config.CONTENT_EXTRACTION_ENGINE, + TIKA_SERVER_URL=request.app.state.config.TIKA_SERVER_URL, + PDF_EXTRACT_IMAGES=request.app.state.config.PDF_EXTRACT_IMAGES, ) docs = loader.load( file.filename, file.meta.get("content_type"), file_path @@ -1007,6 +932,7 @@ def process_file( try: result = save_docs_to_vector_db( + request, docs=docs, collection_name=collection_name, metadata={ @@ -1053,8 +979,9 @@ class ProcessTextForm(BaseModel): collection_name: Optional[str] = None -@app.post("/process/text") +@router.post("/process/text") def process_text( + request: Request, form_data: ProcessTextForm, user=Depends(get_verified_user), ): @@ -1071,8 +998,7 @@ def process_text( text_content = form_data.content log.debug(f"text_content: {text_content}") - result = save_docs_to_vector_db(docs, collection_name) - + result = save_docs_to_vector_db(request, docs, collection_name) if result: return { "status": True, @@ -1086,8 +1012,10 @@ def process_text( ) -@app.post("/process/youtube") -def process_youtube_video(form_data: ProcessUrlForm, user=Depends(get_verified_user)): +@router.post("/process/youtube") +def process_youtube_video( + request: Request, form_data: ProcessUrlForm, user=Depends(get_verified_user) +): try: collection_name = form_data.collection_name if not collection_name: @@ -1095,14 +1023,15 @@ def process_youtube_video(form_data: ProcessUrlForm, user=Depends(get_verified_u loader = YoutubeLoader( form_data.url, - language=app.state.config.YOUTUBE_LOADER_LANGUAGE, - proxy_url=app.state.config.YOUTUBE_LOADER_PROXY_URL, + language=request.app.state.config.YOUTUBE_LOADER_LANGUAGE, + proxy_url=request.app.state.config.YOUTUBE_LOADER_PROXY_URL, ) docs = loader.load() content = " ".join([doc.page_content for doc in docs]) log.debug(f"text_content: {content}") - save_docs_to_vector_db(docs, collection_name, overwrite=True) + + save_docs_to_vector_db(request, docs, collection_name, overwrite=True) return { "status": True, @@ -1125,8 +1054,10 @@ def process_youtube_video(form_data: ProcessUrlForm, user=Depends(get_verified_u ) -@app.post("/process/web") -def process_web(form_data: ProcessUrlForm, user=Depends(get_verified_user)): +@router.post("/process/web") +def process_web( + request: Request, form_data: ProcessUrlForm, user=Depends(get_verified_user) +): try: collection_name = form_data.collection_name if not collection_name: @@ -1134,13 +1065,14 @@ def process_web(form_data: ProcessUrlForm, user=Depends(get_verified_user)): loader = get_web_loader( form_data.url, - verify_ssl=app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, - requests_per_second=app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS, + verify_ssl=request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, + requests_per_second=request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS, ) docs = loader.load() content = " ".join([doc.page_content for doc in docs]) + log.debug(f"text_content: {content}") - save_docs_to_vector_db(docs, collection_name, overwrite=True) + save_docs_to_vector_db(request, docs, collection_name, overwrite=True) return { "status": True, @@ -1163,7 +1095,7 @@ def process_web(form_data: ProcessUrlForm, user=Depends(get_verified_user)): ) -def search_web(engine: str, query: str) -> list[SearchResult]: +def search_web(request: Request, engine: str, query: str) -> list[SearchResult]: """Search the web using a search engine and return the results as a list of SearchResult objects. Will look for a search engine API key in environment variables in the following order: - SEARXNG_QUERY_URL @@ -1182,150 +1114,151 @@ def search_web(engine: str, query: str) -> list[SearchResult]: # TODO: add playwright to search the web if engine == "searxng": - if app.state.config.SEARXNG_QUERY_URL: + if request.app.state.config.SEARXNG_QUERY_URL: return search_searxng( - app.state.config.SEARXNG_QUERY_URL, + request.app.state.config.SEARXNG_QUERY_URL, query, - app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, - app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, + request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, + request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, ) else: raise Exception("No SEARXNG_QUERY_URL found in environment variables") elif engine == "google_pse": if ( - app.state.config.GOOGLE_PSE_API_KEY - and app.state.config.GOOGLE_PSE_ENGINE_ID + request.app.state.config.GOOGLE_PSE_API_KEY + and request.app.state.config.GOOGLE_PSE_ENGINE_ID ): return search_google_pse( - app.state.config.GOOGLE_PSE_API_KEY, - app.state.config.GOOGLE_PSE_ENGINE_ID, + request.app.state.config.GOOGLE_PSE_API_KEY, + request.app.state.config.GOOGLE_PSE_ENGINE_ID, query, - app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, - app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, + request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, + request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, ) else: raise Exception( "No GOOGLE_PSE_API_KEY or GOOGLE_PSE_ENGINE_ID found in environment variables" ) elif engine == "brave": - if app.state.config.BRAVE_SEARCH_API_KEY: + if request.app.state.config.BRAVE_SEARCH_API_KEY: return search_brave( - app.state.config.BRAVE_SEARCH_API_KEY, + request.app.state.config.BRAVE_SEARCH_API_KEY, query, - app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, - app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, + request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, + request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, ) else: raise Exception("No BRAVE_SEARCH_API_KEY found in environment variables") elif engine == "kagi": - if app.state.config.KAGI_SEARCH_API_KEY: + if request.app.state.config.KAGI_SEARCH_API_KEY: return search_kagi( - app.state.config.KAGI_SEARCH_API_KEY, + request.app.state.config.KAGI_SEARCH_API_KEY, query, - app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, - app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, + request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, + request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, ) else: raise Exception("No KAGI_SEARCH_API_KEY found in environment variables") elif engine == "mojeek": - if app.state.config.MOJEEK_SEARCH_API_KEY: + if request.app.state.config.MOJEEK_SEARCH_API_KEY: return search_mojeek( - app.state.config.MOJEEK_SEARCH_API_KEY, + request.app.state.config.MOJEEK_SEARCH_API_KEY, query, - app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, - app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, + request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, + request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, ) else: raise Exception("No MOJEEK_SEARCH_API_KEY found in environment variables") elif engine == "serpstack": - if app.state.config.SERPSTACK_API_KEY: + if request.app.state.config.SERPSTACK_API_KEY: return search_serpstack( - app.state.config.SERPSTACK_API_KEY, + request.app.state.config.SERPSTACK_API_KEY, query, - app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, - app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, - https_enabled=app.state.config.SERPSTACK_HTTPS, + request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, + request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, + https_enabled=request.app.state.config.SERPSTACK_HTTPS, ) else: raise Exception("No SERPSTACK_API_KEY found in environment variables") elif engine == "serper": - if app.state.config.SERPER_API_KEY: + if request.app.state.config.SERPER_API_KEY: return search_serper( - app.state.config.SERPER_API_KEY, + request.app.state.config.SERPER_API_KEY, query, - app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, - app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, + request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, + request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, ) else: raise Exception("No SERPER_API_KEY found in environment variables") elif engine == "serply": - if app.state.config.SERPLY_API_KEY: + if request.app.state.config.SERPLY_API_KEY: return search_serply( - app.state.config.SERPLY_API_KEY, + request.app.state.config.SERPLY_API_KEY, query, - app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, - app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, + request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, + request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, ) else: raise Exception("No SERPLY_API_KEY found in environment variables") elif engine == "duckduckgo": return search_duckduckgo( query, - app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, - app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, + request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, + request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, ) elif engine == "tavily": - if app.state.config.TAVILY_API_KEY: + if request.app.state.config.TAVILY_API_KEY: return search_tavily( - app.state.config.TAVILY_API_KEY, + request.app.state.config.TAVILY_API_KEY, query, - app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, + request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, ) else: raise Exception("No TAVILY_API_KEY found in environment variables") elif engine == "searchapi": - if app.state.config.SEARCHAPI_API_KEY: + if request.app.state.config.SEARCHAPI_API_KEY: return search_searchapi( - app.state.config.SEARCHAPI_API_KEY, - app.state.config.SEARCHAPI_ENGINE, + request.app.state.config.SEARCHAPI_API_KEY, + request.app.state.config.SEARCHAPI_ENGINE, query, - app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, - app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, + request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, + request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, ) else: raise Exception("No SEARCHAPI_API_KEY found in environment variables") elif engine == "jina": return search_jina( - app.state.config.JINA_API_KEY, + request.app.state.config.JINA_API_KEY, query, - app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, + request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, ) elif engine == "bing": return search_bing( - app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY, - app.state.config.BING_SEARCH_V7_ENDPOINT, + request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY, + request.app.state.config.BING_SEARCH_V7_ENDPOINT, str(DEFAULT_LOCALE), query, - app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, - app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, + request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, + request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, ) else: raise Exception("No search engine API key found in environment variables") -@app.post("/process/web/search") -def process_web_search(form_data: SearchForm, user=Depends(get_verified_user)): +@router.post("/process/web/search") +def process_web_search( + request: Request, form_data: SearchForm, user=Depends(get_verified_user) +): try: logging.info( - f"trying to web search with {app.state.config.RAG_WEB_SEARCH_ENGINE, form_data.query}" + f"trying to web search with {request.app.state.config.RAG_WEB_SEARCH_ENGINE, form_data.query}" ) web_results = search_web( - app.state.config.RAG_WEB_SEARCH_ENGINE, form_data.query + request, request.app.state.config.RAG_WEB_SEARCH_ENGINE, form_data.query ) except Exception as e: log.exception(e) - print(e) raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.WEB_SEARCH_ERROR(e), @@ -1334,18 +1267,19 @@ def process_web_search(form_data: SearchForm, user=Depends(get_verified_user)): try: collection_name = form_data.collection_name if collection_name == "": - collection_name = calculate_sha256_string(form_data.query)[:63] + collection_name = f"web-search-{calculate_sha256_string(form_data.query)}"[ + :63 + ] urls = [result.link for result in web_results] - loader = get_web_loader( - urls, - verify_ssl=app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, - requests_per_second=app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS, + urls=urls, + verify_ssl=request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, + requests_per_second=request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS, ) docs = loader.aload() - save_docs_to_vector_db(docs, collection_name, overwrite=True) + save_docs_to_vector_db(request, docs, collection_name, overwrite=True) return { "status": True, @@ -1368,28 +1302,31 @@ class QueryDocForm(BaseModel): hybrid: Optional[bool] = None -@app.post("/query/doc") +@router.post("/query/doc") def query_doc_handler( + request: Request, form_data: QueryDocForm, user=Depends(get_verified_user), ): try: - if app.state.config.ENABLE_RAG_HYBRID_SEARCH: + if request.app.state.config.ENABLE_RAG_HYBRID_SEARCH: return query_doc_with_hybrid_search( collection_name=form_data.collection_name, query=form_data.query, - embedding_function=app.state.EMBEDDING_FUNCTION, - k=form_data.k if form_data.k else app.state.config.TOP_K, - reranking_function=app.state.sentence_transformer_rf, + embedding_function=request.app.state.EMBEDDING_FUNCTION, + k=form_data.k if form_data.k else request.app.state.config.TOP_K, + reranking_function=request.app.state.sentence_transformer_rf, r=( - form_data.r if form_data.r else app.state.config.RELEVANCE_THRESHOLD + form_data.r + if form_data.r + else request.app.state.config.RELEVANCE_THRESHOLD ), ) else: return query_doc( collection_name=form_data.collection_name, - query_embedding=app.state.EMBEDDING_FUNCTION(form_data.query), - k=form_data.k if form_data.k else app.state.config.TOP_K, + query_embedding=request.app.state.EMBEDDING_FUNCTION(form_data.query), + k=form_data.k if form_data.k else request.app.state.config.TOP_K, ) except Exception as e: log.exception(e) @@ -1407,29 +1344,32 @@ class QueryCollectionsForm(BaseModel): hybrid: Optional[bool] = None -@app.post("/query/collection") +@router.post("/query/collection") def query_collection_handler( + request: Request, form_data: QueryCollectionsForm, user=Depends(get_verified_user), ): try: - if app.state.config.ENABLE_RAG_HYBRID_SEARCH: + if request.app.state.config.ENABLE_RAG_HYBRID_SEARCH: return query_collection_with_hybrid_search( collection_names=form_data.collection_names, queries=[form_data.query], - embedding_function=app.state.EMBEDDING_FUNCTION, - k=form_data.k if form_data.k else app.state.config.TOP_K, - reranking_function=app.state.sentence_transformer_rf, + embedding_function=request.app.state.EMBEDDING_FUNCTION, + k=form_data.k if form_data.k else request.app.state.config.TOP_K, + reranking_function=request.app.state.sentence_transformer_rf, r=( - form_data.r if form_data.r else app.state.config.RELEVANCE_THRESHOLD + form_data.r + if form_data.r + else request.app.state.config.RELEVANCE_THRESHOLD ), ) else: return query_collection( collection_names=form_data.collection_names, queries=[form_data.query], - embedding_function=app.state.EMBEDDING_FUNCTION, - k=form_data.k if form_data.k else app.state.config.TOP_K, + embedding_function=request.app.state.EMBEDDING_FUNCTION, + k=form_data.k if form_data.k else request.app.state.config.TOP_K, ) except Exception as e: @@ -1452,7 +1392,7 @@ class DeleteForm(BaseModel): file_id: str -@app.post("/delete") +@router.post("/delete") def delete_entries_from_collection(form_data: DeleteForm, user=Depends(get_admin_user)): try: if VECTOR_DB_CLIENT.has_collection(collection_name=form_data.collection_name): @@ -1471,13 +1411,13 @@ def delete_entries_from_collection(form_data: DeleteForm, user=Depends(get_admin return {"status": False} -@app.post("/reset/db") +@router.post("/reset/db") def reset_vector_db(user=Depends(get_admin_user)): VECTOR_DB_CLIENT.reset() Knowledges.delete_all_knowledge() -@app.post("/reset/uploads") +@router.post("/reset/uploads") def reset_upload_dir(user=Depends(get_admin_user)) -> bool: folder = f"{UPLOAD_DIR}" try: @@ -1502,10 +1442,6 @@ def reset_upload_dir(user=Depends(get_admin_user)) -> bool: if ENV == "dev": - @app.get("/ef") - async def get_embeddings(): - return {"result": app.state.EMBEDDING_FUNCTION("hello world")} - - @app.get("/ef/{text}") - async def get_embeddings_text(text: str): - return {"result": app.state.EMBEDDING_FUNCTION(text)} + @router.get("/ef/{text}") + async def get_embeddings(request: Request, text: Optional[str] = "Hello World!"): + return {"result": request.app.state.EMBEDDING_FUNCTION(text)} From b3987ad41e3d7c0f73dc41694d3bac2a37966f5a Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Wed, 11 Dec 2024 18:08:55 -0800 Subject: [PATCH 09/26] wip --- backend/open_webui/main.py | 39 +++++++++++++++++++++++-- backend/open_webui/routers/retrieval.py | 29 ------------------ 2 files changed, 36 insertions(+), 32 deletions(-) diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 8184b467b..5261d440f 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -62,8 +62,12 @@ from open_webui.routers import ( users, utils, ) - from open_webui.retrieval.utils import get_sources_from_files +from open_webui.routers.retrieval import ( + get_embedding_function, + update_embedding_model, + update_reranking_model, +) from open_webui.socket.main import ( @@ -73,15 +77,16 @@ from open_webui.socket.main import ( get_event_emitter, ) - from open_webui.internal.db import Session -from backend.open_webui.routers.webui import ( +from open_webui.routers.webui import ( app as webui_app, generate_function_chat_completion, get_all_models as get_open_webui_models, ) + + from open_webui.models.functions import Functions from open_webui.models.models import Models from open_webui.models.users import UserModel, Users @@ -523,6 +528,34 @@ app.state.sentence_transformer_rf = None app.state.YOUTUBE_LOADER_TRANSLATION = None +app.state.EMBEDDING_FUNCTION = get_embedding_function( + app.state.config.RAG_EMBEDDING_ENGINE, + app.state.config.RAG_EMBEDDING_MODEL, + app.state.sentence_transformer_ef, + ( + app.state.config.OPENAI_API_BASE_URL + if app.state.config.RAG_EMBEDDING_ENGINE == "openai" + else app.state.config.OLLAMA_BASE_URL + ), + ( + app.state.config.OPENAI_API_KEY + if app.state.config.RAG_EMBEDDING_ENGINE == "openai" + else app.state.config.OLLAMA_API_KEY + ), + app.state.config.RAG_EMBEDDING_BATCH_SIZE, +) + +update_embedding_model( + app.state.config.RAG_EMBEDDING_MODEL, + RAG_EMBEDDING_MODEL_AUTO_UPDATE, +) + +update_reranking_model( + app.state.config.RAG_RERANKING_MODEL, + RAG_RERANKING_MODEL_AUTO_UPDATE, +) + + ######################################## # # IMAGES diff --git a/backend/open_webui/routers/retrieval.py b/backend/open_webui/routers/retrieval.py index 7e0dc6018..5cd7209a8 100644 --- a/backend/open_webui/routers/retrieval.py +++ b/backend/open_webui/routers/retrieval.py @@ -155,17 +155,6 @@ def update_reranking_model( request.app.state.sentence_transformer_rf = None -update_embedding_model( - request.app.state.config.RAG_EMBEDDING_MODEL, - RAG_EMBEDDING_MODEL_AUTO_UPDATE, -) - -update_reranking_model( - request.app.state.config.RAG_RERANKING_MODEL, - RAG_RERANKING_MODEL_AUTO_UPDATE, -) - - ########################################## # # API routes @@ -176,24 +165,6 @@ update_reranking_model( router = APIRouter() -request.app.state.EMBEDDING_FUNCTION = get_embedding_function( - request.app.state.config.RAG_EMBEDDING_ENGINE, - request.app.state.config.RAG_EMBEDDING_MODEL, - request.app.state.sentence_transformer_ef, - ( - request.app.state.config.OPENAI_API_BASE_URL - if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai" - else request.app.state.config.OLLAMA_BASE_URL - ), - ( - request.app.state.config.OPENAI_API_KEY - if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai" - else request.app.state.config.OLLAMA_API_KEY - ), - request.app.state.config.RAG_EMBEDDING_BATCH_SIZE, -) - - class CollectionNameForm(BaseModel): collection_name: Optional[str] = None From 9e85ed861dd406d4d81e810cbba1e1eb7b7bd357 Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Wed, 11 Dec 2024 18:16:07 -0800 Subject: [PATCH 10/26] wip: pipelines --- backend/open_webui/routers/pipelines.py | 169 ++++++++++++++---------- 1 file changed, 100 insertions(+), 69 deletions(-) diff --git a/backend/open_webui/routers/pipelines.py b/backend/open_webui/routers/pipelines.py index 9450d520b..f1cdae140 100644 --- a/backend/open_webui/routers/pipelines.py +++ b/backend/open_webui/routers/pipelines.py @@ -1,17 +1,33 @@ -from fastapi import APIRouter, Depends, HTTPException, Response, status +from fastapi import ( + Depends, + FastAPI, + File, + Form, + HTTPException, + Request, + UploadFile, + status, + APIRouter, +) +import os +import logging +import shutil +import requests from pydantic import BaseModel from starlette.responses import FileResponse +from typing import Optional - -from open_webui.models.chats import ChatTitleMessagesForm -from open_webui.config import DATA_DIR, ENABLE_ADMIN_EXPORT +from open_webui.env import SRC_LOG_LEVELS +from open_webui.config import CACHE_DIR from open_webui.constants import ERROR_MESSAGES -from open_webui.utils.misc import get_gravatar_url -from open_webui.utils.pdf_generator import PDFGenerator + +from open_webui.routers.openai import get_all_models_responses + from open_webui.utils.auth import get_admin_user -router = APIRouter() +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["MAIN"]) ################################## @@ -20,15 +36,14 @@ router = APIRouter() # ################################## - -# TODO: Refactor pipelines API endpoints below into a separate file +router = APIRouter() -@app.get("/api/pipelines/list") -async def get_pipelines_list(user=Depends(get_admin_user)): - responses = await get_openai_models_responses() - +@router.get("/api/pipelines/list") +async def get_pipelines_list(request: Request, user=Depends(get_admin_user)): + responses = await get_all_models_responses(request) log.debug(f"get_pipelines_list: get_openai_models_responses returned {responses}") + urlIdxs = [ idx for idx, response in enumerate(responses) @@ -38,7 +53,7 @@ async def get_pipelines_list(user=Depends(get_admin_user)): return { "data": [ { - "url": openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx], + "url": request.app.state.config.OPENAI_API_BASE_URLS[urlIdx], "idx": urlIdx, } for urlIdx in urlIdxs @@ -46,9 +61,12 @@ async def get_pipelines_list(user=Depends(get_admin_user)): } -@app.post("/api/pipelines/upload") +@router.post("/api/pipelines/upload") async def upload_pipeline( - urlIdx: int = Form(...), file: UploadFile = File(...), user=Depends(get_admin_user) + request: Request, + urlIdx: int = Form(...), + file: UploadFile = File(...), + user=Depends(get_admin_user), ): print("upload_pipeline", urlIdx, file.filename) # Check if the uploaded file is a python file @@ -68,14 +86,16 @@ async def upload_pipeline( with open(file_path, "wb") as buffer: shutil.copyfileobj(file.file, buffer) - url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] - key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] - - headers = {"Authorization": f"Bearer {key}"} + url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx] + key = request.app.state.config.OPENAI_API_KEYS[urlIdx] with open(file_path, "rb") as f: files = {"file": f} - r = requests.post(f"{url}/pipelines/upload", headers=headers, files=files) + r = requests.post( + f"{url}/pipelines/upload", + headers={"Authorization": f"Bearer {key}"}, + files=files, + ) r.raise_for_status() data = r.json() @@ -85,7 +105,7 @@ async def upload_pipeline( # Handle connection error here print(f"Connection error: {e}") - detail = "Pipeline not found" + detail = None status_code = status.HTTP_404_NOT_FOUND if r is not None: status_code = r.status_code @@ -98,7 +118,7 @@ async def upload_pipeline( raise HTTPException( status_code=status_code, - detail=detail, + detail=detail if detail else "Pipeline not found", ) finally: # Ensure the file is deleted after the upload is completed or on failure @@ -111,18 +131,21 @@ class AddPipelineForm(BaseModel): urlIdx: int -@app.post("/api/pipelines/add") -async def add_pipeline(form_data: AddPipelineForm, user=Depends(get_admin_user)): +@router.post("/api/pipelines/add") +async def add_pipeline( + request: Request, form_data: AddPipelineForm, user=Depends(get_admin_user) +): r = None try: urlIdx = form_data.urlIdx - url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] - key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] + url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx] + key = request.app.state.config.OPENAI_API_KEYS[urlIdx] - headers = {"Authorization": f"Bearer {key}"} r = requests.post( - f"{url}/pipelines/add", headers=headers, json={"url": form_data.url} + f"{url}/pipelines/add", + headers={"Authorization": f"Bearer {key}"}, + json={"url": form_data.url}, ) r.raise_for_status() @@ -133,7 +156,7 @@ async def add_pipeline(form_data: AddPipelineForm, user=Depends(get_admin_user)) # Handle connection error here print(f"Connection error: {e}") - detail = "Pipeline not found" + detail = None if r is not None: try: res = r.json() @@ -144,7 +167,7 @@ async def add_pipeline(form_data: AddPipelineForm, user=Depends(get_admin_user)) raise HTTPException( status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND), - detail=detail, + detail=detail if detail else "Pipeline not found", ) @@ -153,18 +176,21 @@ class DeletePipelineForm(BaseModel): urlIdx: int -@app.delete("/api/pipelines/delete") -async def delete_pipeline(form_data: DeletePipelineForm, user=Depends(get_admin_user)): +@router.delete("/api/pipelines/delete") +async def delete_pipeline( + request: Request, form_data: DeletePipelineForm, user=Depends(get_admin_user) +): r = None try: urlIdx = form_data.urlIdx - url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] - key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] + url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx] + key = request.app.state.config.OPENAI_API_KEYS[urlIdx] - headers = {"Authorization": f"Bearer {key}"} r = requests.delete( - f"{url}/pipelines/delete", headers=headers, json={"id": form_data.id} + f"{url}/pipelines/delete", + headers={"Authorization": f"Bearer {key}"}, + json={"id": form_data.id}, ) r.raise_for_status() @@ -175,7 +201,7 @@ async def delete_pipeline(form_data: DeletePipelineForm, user=Depends(get_admin_ # Handle connection error here print(f"Connection error: {e}") - detail = "Pipeline not found" + detail = None if r is not None: try: res = r.json() @@ -186,19 +212,20 @@ async def delete_pipeline(form_data: DeletePipelineForm, user=Depends(get_admin_ raise HTTPException( status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND), - detail=detail, + detail=detail if detail else "Pipeline not found", ) -@app.get("/api/pipelines") -async def get_pipelines(urlIdx: Optional[int] = None, user=Depends(get_admin_user)): +@router.get("/api/pipelines") +async def get_pipelines( + request: Request, urlIdx: Optional[int] = None, user=Depends(get_admin_user) +): r = None try: - url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] - key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] + url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx] + key = request.app.state.config.OPENAI_API_KEYS[urlIdx] - headers = {"Authorization": f"Bearer {key}"} - r = requests.get(f"{url}/pipelines", headers=headers) + r = requests.get(f"{url}/pipelines", headers={"Authorization": f"Bearer {key}"}) r.raise_for_status() data = r.json() @@ -208,7 +235,7 @@ async def get_pipelines(urlIdx: Optional[int] = None, user=Depends(get_admin_use # Handle connection error here print(f"Connection error: {e}") - detail = "Pipeline not found" + detail = None if r is not None: try: res = r.json() @@ -219,23 +246,25 @@ async def get_pipelines(urlIdx: Optional[int] = None, user=Depends(get_admin_use raise HTTPException( status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND), - detail=detail, + detail=detail if detail else "Pipeline not found", ) -@app.get("/api/pipelines/{pipeline_id}/valves") +@router.get("/api/pipelines/{pipeline_id}/valves") async def get_pipeline_valves( + request: Request, urlIdx: Optional[int], pipeline_id: str, user=Depends(get_admin_user), ): r = None try: - url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] - key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] + url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx] + key = request.app.state.config.OPENAI_API_KEYS[urlIdx] - headers = {"Authorization": f"Bearer {key}"} - r = requests.get(f"{url}/{pipeline_id}/valves", headers=headers) + r = requests.get( + f"{url}/{pipeline_id}/valves", headers={"Authorization": f"Bearer {key}"} + ) r.raise_for_status() data = r.json() @@ -245,8 +274,7 @@ async def get_pipeline_valves( # Handle connection error here print(f"Connection error: {e}") - detail = "Pipeline not found" - + detail = None if r is not None: try: res = r.json() @@ -257,23 +285,26 @@ async def get_pipeline_valves( raise HTTPException( status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND), - detail=detail, + detail=detail if detail else "Pipeline not found", ) -@app.get("/api/pipelines/{pipeline_id}/valves/spec") +@router.get("/api/pipelines/{pipeline_id}/valves/spec") async def get_pipeline_valves_spec( + request: Request, urlIdx: Optional[int], pipeline_id: str, user=Depends(get_admin_user), ): r = None try: - url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] - key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] + url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx] + key = request.app.state.config.OPENAI_API_KEYS[urlIdx] - headers = {"Authorization": f"Bearer {key}"} - r = requests.get(f"{url}/{pipeline_id}/valves/spec", headers=headers) + r = requests.get( + f"{url}/{pipeline_id}/valves/spec", + headers={"Authorization": f"Bearer {key}"}, + ) r.raise_for_status() data = r.json() @@ -283,7 +314,7 @@ async def get_pipeline_valves_spec( # Handle connection error here print(f"Connection error: {e}") - detail = "Pipeline not found" + detail = None if r is not None: try: res = r.json() @@ -294,12 +325,13 @@ async def get_pipeline_valves_spec( raise HTTPException( status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND), - detail=detail, + detail=detail if detail else "Pipeline not found", ) -@app.post("/api/pipelines/{pipeline_id}/valves/update") +@router.post("/api/pipelines/{pipeline_id}/valves/update") async def update_pipeline_valves( + request: Request, urlIdx: Optional[int], pipeline_id: str, form_data: dict, @@ -307,13 +339,12 @@ async def update_pipeline_valves( ): r = None try: - url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] - key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] + url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx] + key = request.app.state.config.OPENAI_API_KEYS[urlIdx] - headers = {"Authorization": f"Bearer {key}"} r = requests.post( f"{url}/{pipeline_id}/valves/update", - headers=headers, + headers={"Authorization": f"Bearer {key}"}, json={**form_data}, ) @@ -325,7 +356,7 @@ async def update_pipeline_valves( # Handle connection error here print(f"Connection error: {e}") - detail = "Pipeline not found" + detail = None if r is not None: try: @@ -337,5 +368,5 @@ async def update_pipeline_valves( raise HTTPException( status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND), - detail=detail, + detail=detail if detail else "Pipeline not found", ) From 3bda1a8b887261d71004eea7ee180c62718c822e Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Wed, 11 Dec 2024 18:36:59 -0800 Subject: [PATCH 11/26] wip --- backend/open_webui/main.py | 345 ++++++++++++++++++++-- backend/open_webui/routers/files.py | 2 +- backend/open_webui/routers/functions.py | 2 +- backend/open_webui/routers/knowledge.py | 2 +- backend/open_webui/routers/tools.py | 2 +- backend/open_webui/routers/webui.py | 323 +------------------- backend/open_webui/socket/main.py | 2 +- backend/open_webui/test/util/mock_user.py | 2 +- backend/open_webui/utils/tools.py | 2 +- 9 files changed, 335 insertions(+), 347 deletions(-) diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 5261d440f..486311902 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -8,6 +8,8 @@ import shutil import sys import time import random +from typing import AsyncGenerator, Generator, Iterator + from contextlib import asynccontextmanager from urllib.parse import urlencode, parse_qs, urlparse from pydantic import BaseModel @@ -39,7 +41,6 @@ from starlette.responses import Response, StreamingResponse from open_webui.routers import ( audio, - chat, images, ollama, openai, @@ -90,7 +91,7 @@ from open_webui.routers.webui import ( from open_webui.models.functions import Functions from open_webui.models.models import Models from open_webui.models.users import UserModel, Users -from backend.open_webui.utils.plugin import load_function_module_by_id +from open_webui.utils.plugin import load_function_module_by_id from open_webui.constants import TASKS @@ -283,8 +284,13 @@ from open_webui.utils.misc import ( add_or_update_system_message, get_last_user_message, prepend_to_first_user_message_content, + openai_chat_chunk_message_template, + openai_chat_completion_message_template, +) +from open_webui.utils.payload import ( + apply_model_params_to_body_openai, + apply_model_system_prompt_to_body, ) - from open_webui.utils.payload import convert_payload_openai_to_ollama from open_webui.utils.response import ( @@ -1441,8 +1447,12 @@ app.add_middleware( app.mount("/ws", socket_app) -app.include_router(ollama.router, prefix="/ollama") -app.include_router(openai.router, prefix="/openai") +app.include_router(ollama.router, prefix="/ollama", tags=["ollama"]) +app.include_router(openai.router, prefix="/openai", tags=["openai"]) + + +app.include_router(pipelines.router, prefix="/pipelines", tags=["pipelines"]) +app.include_router(tasks.router, prefix="/tasks", tags=["tasks"]) app.include_router(images.router, prefix="/api/v1/images", tags=["images"]) @@ -1473,8 +1483,277 @@ app.include_router( app.include_router(utils.router, prefix="/api/v1/utils", tags=["utils"]) +################################## +# +# Chat Endpoints +# +################################## + + +def get_function_module(pipe_id: str): + # Check if function is already loaded + if pipe_id not in app.state.FUNCTIONS: + function_module, _, _ = load_function_module_by_id(pipe_id) + app.state.FUNCTIONS[pipe_id] = function_module + else: + function_module = app.state.FUNCTIONS[pipe_id] + + if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): + valves = Functions.get_function_valves_by_id(pipe_id) + function_module.valves = function_module.Valves(**(valves if valves else {})) + return function_module + + +async def get_function_models(): + pipes = Functions.get_functions_by_type("pipe", active_only=True) + pipe_models = [] + + for pipe in pipes: + function_module = get_function_module(pipe.id) + + # Check if function is a manifold + if hasattr(function_module, "pipes"): + sub_pipes = [] + + # Check if pipes is a function or a list + + try: + if callable(function_module.pipes): + sub_pipes = function_module.pipes() + else: + sub_pipes = function_module.pipes + except Exception as e: + log.exception(e) + sub_pipes = [] + + log.debug( + f"get_function_models: function '{pipe.id}' is a manifold of {sub_pipes}" + ) + + for p in sub_pipes: + sub_pipe_id = f'{pipe.id}.{p["id"]}' + sub_pipe_name = p["name"] + + if hasattr(function_module, "name"): + sub_pipe_name = f"{function_module.name}{sub_pipe_name}" + + pipe_flag = {"type": pipe.type} + + pipe_models.append( + { + "id": sub_pipe_id, + "name": sub_pipe_name, + "object": "model", + "created": pipe.created_at, + "owned_by": "openai", + "pipe": pipe_flag, + } + ) + else: + pipe_flag = {"type": "pipe"} + + log.debug( + f"get_function_models: function '{pipe.id}' is a single pipe {{ 'id': {pipe.id}, 'name': {pipe.name} }}" + ) + + pipe_models.append( + { + "id": pipe.id, + "name": pipe.name, + "object": "model", + "created": pipe.created_at, + "owned_by": "openai", + "pipe": pipe_flag, + } + ) + + return pipe_models + + +async def generate_function_chat_completion(form_data, user, models: dict = {}): + async def execute_pipe(pipe, params): + if inspect.iscoroutinefunction(pipe): + return await pipe(**params) + else: + return pipe(**params) + + async def get_message_content(res: str | Generator | AsyncGenerator) -> str: + if isinstance(res, str): + return res + if isinstance(res, Generator): + return "".join(map(str, res)) + if isinstance(res, AsyncGenerator): + return "".join([str(stream) async for stream in res]) + + def process_line(form_data: dict, line): + if isinstance(line, BaseModel): + line = line.model_dump_json() + line = f"data: {line}" + if isinstance(line, dict): + line = f"data: {json.dumps(line)}" + + try: + line = line.decode("utf-8") + except Exception: + pass + + if line.startswith("data:"): + return f"{line}\n\n" + else: + line = openai_chat_chunk_message_template(form_data["model"], line) + return f"data: {json.dumps(line)}\n\n" + + def get_pipe_id(form_data: dict) -> str: + pipe_id = form_data["model"] + if "." in pipe_id: + pipe_id, _ = pipe_id.split(".", 1) + return pipe_id + + def get_function_params(function_module, form_data, user, extra_params=None): + if extra_params is None: + extra_params = {} + + pipe_id = get_pipe_id(form_data) + + # Get the signature of the function + sig = inspect.signature(function_module.pipe) + params = {"body": form_data} | { + k: v for k, v in extra_params.items() if k in sig.parameters + } + + if "__user__" in params and hasattr(function_module, "UserValves"): + user_valves = Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id) + try: + params["__user__"]["valves"] = function_module.UserValves(**user_valves) + except Exception as e: + log.exception(e) + params["__user__"]["valves"] = function_module.UserValves() + + return params + + model_id = form_data.get("model") + model_info = Models.get_model_by_id(model_id) + + metadata = form_data.pop("metadata", {}) + + files = metadata.get("files", []) + tool_ids = metadata.get("tool_ids", []) + # Check if tool_ids is None + if tool_ids is None: + tool_ids = [] + + __event_emitter__ = None + __event_call__ = None + __task__ = None + __task_body__ = None + + if metadata: + if all(k in metadata for k in ("session_id", "chat_id", "message_id")): + __event_emitter__ = get_event_emitter(metadata) + __event_call__ = get_event_call(metadata) + __task__ = metadata.get("task", None) + __task_body__ = metadata.get("task_body", None) + + extra_params = { + "__event_emitter__": __event_emitter__, + "__event_call__": __event_call__, + "__task__": __task__, + "__task_body__": __task_body__, + "__files__": files, + "__user__": { + "id": user.id, + "email": user.email, + "name": user.name, + "role": user.role, + }, + "__metadata__": metadata, + } + extra_params["__tools__"] = get_tools( + app, + tool_ids, + user, + { + **extra_params, + "__model__": models.get(form_data["model"], None), + "__messages__": form_data["messages"], + "__files__": files, + }, + ) + + if model_info: + if model_info.base_model_id: + form_data["model"] = model_info.base_model_id + + params = model_info.params.model_dump() + form_data = apply_model_params_to_body_openai(params, form_data) + form_data = apply_model_system_prompt_to_body(params, form_data, user) + + pipe_id = get_pipe_id(form_data) + function_module = get_function_module(pipe_id) + + pipe = function_module.pipe + params = get_function_params(function_module, form_data, user, extra_params) + + if form_data.get("stream", False): + + async def stream_content(): + try: + res = await execute_pipe(pipe, params) + + # Directly return if the response is a StreamingResponse + if isinstance(res, StreamingResponse): + async for data in res.body_iterator: + yield data + return + if isinstance(res, dict): + yield f"data: {json.dumps(res)}\n\n" + return + + except Exception as e: + log.error(f"Error: {e}") + yield f"data: {json.dumps({'error': {'detail':str(e)}})}\n\n" + return + + if isinstance(res, str): + message = openai_chat_chunk_message_template(form_data["model"], res) + yield f"data: {json.dumps(message)}\n\n" + + if isinstance(res, Iterator): + for line in res: + yield process_line(form_data, line) + + if isinstance(res, AsyncGenerator): + async for line in res: + yield process_line(form_data, line) + + if isinstance(res, str) or isinstance(res, Generator): + finish_message = openai_chat_chunk_message_template( + form_data["model"], "" + ) + finish_message["choices"][0]["finish_reason"] = "stop" + yield f"data: {json.dumps(finish_message)}\n\n" + yield "data: [DONE]" + + return StreamingResponse(stream_content(), media_type="text/event-stream") + else: + try: + res = await execute_pipe(pipe, params) + + except Exception as e: + log.error(f"Error: {e}") + return {"error": {"detail": str(e)}} + + if isinstance(res, StreamingResponse) or isinstance(res, dict): + return res + if isinstance(res, BaseModel): + return res.model_dump() + + message = await get_message_content(res) + return openai_chat_completion_message_template(form_data["model"], message) + + async def get_all_base_models(): - open_webui_models = [] + function_models = [] openai_models = [] ollama_models = [] @@ -1496,9 +1775,44 @@ async def get_all_base_models(): for model in ollama_models["models"] ] - open_webui_models = await get_open_webui_models() + function_models = await get_function_models() + models = function_models + openai_models + ollama_models + + # Add arena models + if app.state.config.ENABLE_EVALUATION_ARENA_MODELS: + arena_models = [] + if len(app.state.config.EVALUATION_ARENA_MODELS) > 0: + arena_models = [ + { + "id": model["id"], + "name": model["name"], + "info": { + "meta": model["meta"], + }, + "object": "model", + "created": int(time.time()), + "owned_by": "arena", + "arena": True, + } + for model in app.state.config.EVALUATION_ARENA_MODELS + ] + else: + # Add default arena model + arena_models = [ + { + "id": DEFAULT_ARENA_MODEL["id"], + "name": DEFAULT_ARENA_MODEL["name"], + "info": { + "meta": DEFAULT_ARENA_MODEL["meta"], + }, + "object": "model", + "created": int(time.time()), + "owned_by": "arena", + "arena": True, + } + ] + models = models + arena_models - models = open_webui_models + openai_models + ollama_models return models @@ -1628,6 +1942,7 @@ async def get_all_models(): ) log.debug(f"get_all_models() returned {len(models)} models") + app.state.MODELS = {model["id"]: model for model in models} return models @@ -1689,16 +2004,8 @@ async def get_base_models(user=Depends(get_admin_user)): return {"data": models} -################################## -# -# Chat Endpoints -# -################################## - - @app.post("/api/chat/completions") async def generate_chat_completions( - request: Request, form_data: dict, user=Depends(get_verified_user), bypass_filter: bool = False, @@ -1706,7 +2013,7 @@ async def generate_chat_completions( if BYPASS_MODEL_ACCESS_CONTROL: bypass_filter = True - model_list = request.state.models + model_list = app.state.MODELS models = {model["id"]: model for model in model_list} model_id = form_data["model"] @@ -1843,8 +2150,8 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)): try: urlIdx = filter["urlIdx"] - url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] - key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] + url = app.state.config.OPENAI_API_BASE_URLS[urlIdx] + key = app.state.config.OPENAI_API_KEYS[urlIdx] if key != "": headers = {"Authorization": f"Bearer {key}"} diff --git a/backend/open_webui/routers/files.py b/backend/open_webui/routers/files.py index 49deb998f..bc553ab26 100644 --- a/backend/open_webui/routers/files.py +++ b/backend/open_webui/routers/files.py @@ -14,7 +14,7 @@ from open_webui.models.files import ( FileModelResponse, Files, ) -from backend.open_webui.routers.retrieval import process_file, ProcessFileForm +from open_webui.routers.retrieval import process_file, ProcessFileForm from open_webui.config import UPLOAD_DIR from open_webui.env import SRC_LOG_LEVELS diff --git a/backend/open_webui/routers/functions.py b/backend/open_webui/routers/functions.py index bb780f112..7f3305f25 100644 --- a/backend/open_webui/routers/functions.py +++ b/backend/open_webui/routers/functions.py @@ -8,7 +8,7 @@ from open_webui.models.functions import ( FunctionResponse, Functions, ) -from backend.open_webui.utils.plugin import load_function_module_by_id, replace_imports +from open_webui.utils.plugin import load_function_module_by_id, replace_imports from open_webui.config import CACHE_DIR from open_webui.constants import ERROR_MESSAGES from fastapi import APIRouter, Depends, HTTPException, Request, status diff --git a/backend/open_webui/routers/knowledge.py b/backend/open_webui/routers/knowledge.py index 85a4d30fd..0f4dd9283 100644 --- a/backend/open_webui/routers/knowledge.py +++ b/backend/open_webui/routers/knowledge.py @@ -12,7 +12,7 @@ from open_webui.models.knowledge import ( ) from open_webui.models.files import Files, FileModel from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT -from backend.open_webui.routers.retrieval import process_file, ProcessFileForm +from open_webui.routers.retrieval import process_file, ProcessFileForm from open_webui.constants import ERROR_MESSAGES diff --git a/backend/open_webui/routers/tools.py b/backend/open_webui/routers/tools.py index bf4a93f6c..9e95ebe5a 100644 --- a/backend/open_webui/routers/tools.py +++ b/backend/open_webui/routers/tools.py @@ -8,7 +8,7 @@ from open_webui.models.tools import ( ToolUserResponse, Tools, ) -from backend.open_webui.utils.plugin import load_tools_module_by_id, replace_imports +from open_webui.utils.plugin import load_tools_module_by_id, replace_imports from open_webui.config import CACHE_DIR from open_webui.constants import ERROR_MESSAGES from fastapi import APIRouter, Depends, HTTPException, Request, status diff --git a/backend/open_webui/routers/webui.py b/backend/open_webui/routers/webui.py index d3942db97..058d4c365 100644 --- a/backend/open_webui/routers/webui.py +++ b/backend/open_webui/routers/webui.py @@ -4,7 +4,7 @@ import logging import time from typing import AsyncGenerator, Generator, Iterator -from open_webui.apps.socket.main import get_event_call, get_event_emitter +from open_webui.socket.main import get_event_call, get_event_emitter from open_webui.models.functions import Functions from open_webui.models.models import Models from open_webui.routers import ( @@ -24,7 +24,7 @@ from open_webui.routers import ( users, utils, ) -from backend.open_webui.utils.plugin import load_function_module_by_id +from open_webui.utils.plugin import load_function_module_by_id from open_webui.config import ( ADMIN_EMAIL, CORS_ALLOW_ORIGIN, @@ -92,322 +92,3 @@ from open_webui.utils.tools import get_tools log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MAIN"]) - - -@app.get("/") -async def get_status(): - return { - "status": True, - "auth": WEBUI_AUTH, - "default_models": app.state.config.DEFAULT_MODELS, - "default_prompt_suggestions": app.state.config.DEFAULT_PROMPT_SUGGESTIONS, - } - - -async def get_all_models(): - models = [] - pipe_models = await get_pipe_models() - models = models + pipe_models - - if app.state.config.ENABLE_EVALUATION_ARENA_MODELS: - arena_models = [] - if len(app.state.config.EVALUATION_ARENA_MODELS) > 0: - arena_models = [ - { - "id": model["id"], - "name": model["name"], - "info": { - "meta": model["meta"], - }, - "object": "model", - "created": int(time.time()), - "owned_by": "arena", - "arena": True, - } - for model in app.state.config.EVALUATION_ARENA_MODELS - ] - else: - # Add default arena model - arena_models = [ - { - "id": DEFAULT_ARENA_MODEL["id"], - "name": DEFAULT_ARENA_MODEL["name"], - "info": { - "meta": DEFAULT_ARENA_MODEL["meta"], - }, - "object": "model", - "created": int(time.time()), - "owned_by": "arena", - "arena": True, - } - ] - models = models + arena_models - return models - - -def get_function_module(pipe_id: str): - # Check if function is already loaded - if pipe_id not in app.state.FUNCTIONS: - function_module, _, _ = load_function_module_by_id(pipe_id) - app.state.FUNCTIONS[pipe_id] = function_module - else: - function_module = app.state.FUNCTIONS[pipe_id] - - if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): - valves = Functions.get_function_valves_by_id(pipe_id) - function_module.valves = function_module.Valves(**(valves if valves else {})) - return function_module - - -async def get_pipe_models(): - pipes = Functions.get_functions_by_type("pipe", active_only=True) - pipe_models = [] - - for pipe in pipes: - function_module = get_function_module(pipe.id) - - # Check if function is a manifold - if hasattr(function_module, "pipes"): - sub_pipes = [] - - # Check if pipes is a function or a list - - try: - if callable(function_module.pipes): - sub_pipes = function_module.pipes() - else: - sub_pipes = function_module.pipes - except Exception as e: - log.exception(e) - sub_pipes = [] - - log.debug( - f"get_pipe_models: function '{pipe.id}' is a manifold of {sub_pipes}" - ) - - for p in sub_pipes: - sub_pipe_id = f'{pipe.id}.{p["id"]}' - sub_pipe_name = p["name"] - - if hasattr(function_module, "name"): - sub_pipe_name = f"{function_module.name}{sub_pipe_name}" - - pipe_flag = {"type": pipe.type} - - pipe_models.append( - { - "id": sub_pipe_id, - "name": sub_pipe_name, - "object": "model", - "created": pipe.created_at, - "owned_by": "openai", - "pipe": pipe_flag, - } - ) - else: - pipe_flag = {"type": "pipe"} - - log.debug( - f"get_pipe_models: function '{pipe.id}' is a single pipe {{ 'id': {pipe.id}, 'name': {pipe.name} }}" - ) - - pipe_models.append( - { - "id": pipe.id, - "name": pipe.name, - "object": "model", - "created": pipe.created_at, - "owned_by": "openai", - "pipe": pipe_flag, - } - ) - - return pipe_models - - -async def execute_pipe(pipe, params): - if inspect.iscoroutinefunction(pipe): - return await pipe(**params) - else: - return pipe(**params) - - -async def get_message_content(res: str | Generator | AsyncGenerator) -> str: - if isinstance(res, str): - return res - if isinstance(res, Generator): - return "".join(map(str, res)) - if isinstance(res, AsyncGenerator): - return "".join([str(stream) async for stream in res]) - - -def process_line(form_data: dict, line): - if isinstance(line, BaseModel): - line = line.model_dump_json() - line = f"data: {line}" - if isinstance(line, dict): - line = f"data: {json.dumps(line)}" - - try: - line = line.decode("utf-8") - except Exception: - pass - - if line.startswith("data:"): - return f"{line}\n\n" - else: - line = openai_chat_chunk_message_template(form_data["model"], line) - return f"data: {json.dumps(line)}\n\n" - - -def get_pipe_id(form_data: dict) -> str: - pipe_id = form_data["model"] - if "." in pipe_id: - pipe_id, _ = pipe_id.split(".", 1) - - return pipe_id - - -def get_function_params(function_module, form_data, user, extra_params=None): - if extra_params is None: - extra_params = {} - - pipe_id = get_pipe_id(form_data) - - # Get the signature of the function - sig = inspect.signature(function_module.pipe) - params = {"body": form_data} | { - k: v for k, v in extra_params.items() if k in sig.parameters - } - - if "__user__" in params and hasattr(function_module, "UserValves"): - user_valves = Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id) - try: - params["__user__"]["valves"] = function_module.UserValves(**user_valves) - except Exception as e: - log.exception(e) - params["__user__"]["valves"] = function_module.UserValves() - - return params - - -async def generate_function_chat_completion(form_data, user, models: dict = {}): - model_id = form_data.get("model") - model_info = Models.get_model_by_id(model_id) - - metadata = form_data.pop("metadata", {}) - - files = metadata.get("files", []) - tool_ids = metadata.get("tool_ids", []) - # Check if tool_ids is None - if tool_ids is None: - tool_ids = [] - - __event_emitter__ = None - __event_call__ = None - __task__ = None - __task_body__ = None - - if metadata: - if all(k in metadata for k in ("session_id", "chat_id", "message_id")): - __event_emitter__ = get_event_emitter(metadata) - __event_call__ = get_event_call(metadata) - __task__ = metadata.get("task", None) - __task_body__ = metadata.get("task_body", None) - - extra_params = { - "__event_emitter__": __event_emitter__, - "__event_call__": __event_call__, - "__task__": __task__, - "__task_body__": __task_body__, - "__files__": files, - "__user__": { - "id": user.id, - "email": user.email, - "name": user.name, - "role": user.role, - }, - "__metadata__": metadata, - } - extra_params["__tools__"] = get_tools( - app, - tool_ids, - user, - { - **extra_params, - "__model__": models.get(form_data["model"], None), - "__messages__": form_data["messages"], - "__files__": files, - }, - ) - - if model_info: - if model_info.base_model_id: - form_data["model"] = model_info.base_model_id - - params = model_info.params.model_dump() - form_data = apply_model_params_to_body_openai(params, form_data) - form_data = apply_model_system_prompt_to_body(params, form_data, user) - - pipe_id = get_pipe_id(form_data) - function_module = get_function_module(pipe_id) - - pipe = function_module.pipe - params = get_function_params(function_module, form_data, user, extra_params) - - if form_data.get("stream", False): - - async def stream_content(): - try: - res = await execute_pipe(pipe, params) - - # Directly return if the response is a StreamingResponse - if isinstance(res, StreamingResponse): - async for data in res.body_iterator: - yield data - return - if isinstance(res, dict): - yield f"data: {json.dumps(res)}\n\n" - return - - except Exception as e: - log.error(f"Error: {e}") - yield f"data: {json.dumps({'error': {'detail':str(e)}})}\n\n" - return - - if isinstance(res, str): - message = openai_chat_chunk_message_template(form_data["model"], res) - yield f"data: {json.dumps(message)}\n\n" - - if isinstance(res, Iterator): - for line in res: - yield process_line(form_data, line) - - if isinstance(res, AsyncGenerator): - async for line in res: - yield process_line(form_data, line) - - if isinstance(res, str) or isinstance(res, Generator): - finish_message = openai_chat_chunk_message_template( - form_data["model"], "" - ) - finish_message["choices"][0]["finish_reason"] = "stop" - yield f"data: {json.dumps(finish_message)}\n\n" - yield "data: [DONE]" - - return StreamingResponse(stream_content(), media_type="text/event-stream") - else: - try: - res = await execute_pipe(pipe, params) - - except Exception as e: - log.error(f"Error: {e}") - return {"error": {"detail": str(e)}} - - if isinstance(res, StreamingResponse) or isinstance(res, dict): - return res - if isinstance(res, BaseModel): - return res.model_dump() - - message = await get_message_content(res) - return openai_chat_completion_message_template(form_data["model"], message) diff --git a/backend/open_webui/socket/main.py b/backend/open_webui/socket/main.py index 7a673f098..ba5eeb6ae 100644 --- a/backend/open_webui/socket/main.py +++ b/backend/open_webui/socket/main.py @@ -13,7 +13,7 @@ from open_webui.env import ( WEBSOCKET_REDIS_URL, ) from open_webui.utils.auth import decode_token -from open_webui.apps.socket.utils import RedisDict +from open_webui.socket.utils import RedisDict from open_webui.env import ( GLOBAL_LOG_LEVEL, diff --git a/backend/open_webui/test/util/mock_user.py b/backend/open_webui/test/util/mock_user.py index e25256460..7ce64dffa 100644 --- a/backend/open_webui/test/util/mock_user.py +++ b/backend/open_webui/test/util/mock_user.py @@ -5,7 +5,7 @@ from fastapi import FastAPI @contextmanager def mock_webui_user(**kwargs): - from backend.open_webui.routers.webui import app + from open_webui.routers.webui import app with mock_user(app, **kwargs): yield diff --git a/backend/open_webui/utils/tools.py b/backend/open_webui/utils/tools.py index a88e71f20..0b9161f86 100644 --- a/backend/open_webui/utils/tools.py +++ b/backend/open_webui/utils/tools.py @@ -7,7 +7,7 @@ from functools import update_wrapper, partial from langchain_core.utils.function_calling import convert_to_openai_function from open_webui.models.tools import Tools from open_webui.models.users import UserModel -from backend.open_webui.utils.plugin import load_tools_module_by_id +from open_webui.utils.plugin import load_tools_module_by_id from pydantic import BaseModel, Field, create_model log = logging.getLogger(__name__) From ccdf51588eeecc0f8ccf36c87a360b0c36f83856 Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Wed, 11 Dec 2024 18:46:29 -0800 Subject: [PATCH 12/26] wip --- backend/open_webui/main.py | 76 ++++++++++++------------- backend/open_webui/routers/retrieval.py | 62 +++++++++++--------- 2 files changed, 70 insertions(+), 68 deletions(-) diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 486311902..09913ae01 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -39,6 +39,13 @@ from starlette.middleware.sessions import SessionMiddleware from starlette.responses import Response, StreamingResponse +from open_webui.socket.main import ( + app as socket_app, + periodic_usage_pool_cleanup, + get_event_call, + get_event_emitter, +) + from open_webui.routers import ( audio, images, @@ -63,35 +70,19 @@ from open_webui.routers import ( users, utils, ) -from open_webui.retrieval.utils import get_sources_from_files from open_webui.routers.retrieval import ( get_embedding_function, - update_embedding_model, - update_reranking_model, + get_ef, + get_rf, ) +from open_webui.retrieval.utils import get_sources_from_files -from open_webui.socket.main import ( - app as socket_app, - periodic_usage_pool_cleanup, - get_event_call, - get_event_emitter, -) - from open_webui.internal.db import Session - -from open_webui.routers.webui import ( - app as webui_app, - generate_function_chat_completion, - get_all_models as get_open_webui_models, -) - - from open_webui.models.functions import Functions from open_webui.models.models import Models from open_webui.models.users import UserModel, Users -from open_webui.utils.plugin import load_function_module_by_id from open_webui.constants import TASKS @@ -279,7 +270,7 @@ from open_webui.env import ( OFFLINE_MODE, ) - +from open_webui.utils.plugin import load_function_module_by_id from open_webui.utils.misc import ( add_or_update_system_message, get_last_user_message, @@ -528,8 +519,8 @@ app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = RAG_WEB_SEARCH_RESULT_COUNT app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS app.state.EMBEDDING_FUNCTION = None -app.state.sentence_transformer_ef = None -app.state.sentence_transformer_rf = None +app.state.ef = None +app.state.rf = None app.state.YOUTUBE_LOADER_TRANSLATION = None @@ -537,29 +528,34 @@ app.state.YOUTUBE_LOADER_TRANSLATION = None app.state.EMBEDDING_FUNCTION = get_embedding_function( app.state.config.RAG_EMBEDDING_ENGINE, app.state.config.RAG_EMBEDDING_MODEL, - app.state.sentence_transformer_ef, + app.state.ef, ( - app.state.config.OPENAI_API_BASE_URL + app.state.config.RAG_OPENAI_API_BASE_URL if app.state.config.RAG_EMBEDDING_ENGINE == "openai" - else app.state.config.OLLAMA_BASE_URL + else app.state.config.RAG_OLLAMA_BASE_URL ), ( - app.state.config.OPENAI_API_KEY + app.state.config.RAG_OPENAI_API_KEY if app.state.config.RAG_EMBEDDING_ENGINE == "openai" - else app.state.config.OLLAMA_API_KEY + else app.state.config.RAG_OLLAMA_API_KEY ), app.state.config.RAG_EMBEDDING_BATCH_SIZE, ) -update_embedding_model( - app.state.config.RAG_EMBEDDING_MODEL, - RAG_EMBEDDING_MODEL_AUTO_UPDATE, -) +try: + app.state.ef = get_ef( + app.state.config.RAG_EMBEDDING_ENGINE, + app.state.config.RAG_EMBEDDING_MODEL, + RAG_EMBEDDING_MODEL_AUTO_UPDATE, + ) -update_reranking_model( - app.state.config.RAG_RERANKING_MODEL, - RAG_RERANKING_MODEL_AUTO_UPDATE, -) + app.state.rf = get_rf( + app.state.config.RAG_RERANKING_MODEL, + RAG_RERANKING_MODEL_AUTO_UPDATE, + ) +except Exception as e: + log.error(f"Error updating models: {e}") + pass ######################################## @@ -990,11 +986,11 @@ async def chat_completion_files_handler( sources = get_sources_from_files( files=files, queries=queries, - embedding_function=retrieval_app.state.EMBEDDING_FUNCTION, - k=retrieval_app.state.config.TOP_K, - reranking_function=retrieval_app.state.sentence_transformer_rf, - r=retrieval_app.state.config.RELEVANCE_THRESHOLD, - hybrid_search=retrieval_app.state.config.ENABLE_RAG_HYBRID_SEARCH, + embedding_function=app.state.EMBEDDING_FUNCTION, + k=app.state.config.TOP_K, + reranking_function=app.state.rf, + r=app.state.config.RELEVANCE_THRESHOLD, + hybrid_search=app.state.config.ENABLE_RAG_HYBRID_SEARCH, ) log.debug(f"rag_contexts:sources: {sources}") diff --git a/backend/open_webui/routers/retrieval.py b/backend/open_webui/routers/retrieval.py index 5cd7209a8..c40208ac1 100644 --- a/backend/open_webui/routers/retrieval.py +++ b/backend/open_webui/routers/retrieval.py @@ -97,62 +97,58 @@ log.setLevel(SRC_LOG_LEVELS["RAG"]) ########################################## -def update_embedding_model( - request: Request, +def get_ef( + engine: str, embedding_model: str, auto_update: bool = False, ): - if embedding_model and request.app.state.config.RAG_EMBEDDING_ENGINE == "": + ef = None + if embedding_model and engine == "": from sentence_transformers import SentenceTransformer try: - request.app.state.sentence_transformer_ef = SentenceTransformer( + ef = SentenceTransformer( get_model_path(embedding_model, auto_update), device=DEVICE_TYPE, trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE, ) except Exception as e: log.debug(f"Error loading SentenceTransformer: {e}") - request.app.state.sentence_transformer_ef = None - else: - request.app.state.sentence_transformer_ef = None + + return ef -def update_reranking_model( - request: Request, +def get_rf( reranking_model: str, auto_update: bool = False, ): + rf = None if reranking_model: if any(model in reranking_model for model in ["jinaai/jina-colbert-v2"]): try: from open_webui.retrieval.models.colbert import ColBERT - request.app.state.sentence_transformer_rf = ColBERT( + rf = ColBERT( get_model_path(reranking_model, auto_update), env="docker" if DOCKER else None, ) + except Exception as e: log.error(f"ColBERT: {e}") - request.app.state.sentence_transformer_rf = None - request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = False + raise Exception(ERROR_MESSAGES.DEFAULT(e)) else: import sentence_transformers try: - request.app.state.sentence_transformer_rf = ( - sentence_transformers.CrossEncoder( - get_model_path(reranking_model, auto_update), - device=DEVICE_TYPE, - trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE, - ) + rf = sentence_transformers.CrossEncoder( + get_model_path(reranking_model, auto_update), + device=DEVICE_TYPE, + trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE, ) except: log.error("CrossEncoder error") - request.app.state.sentence_transformer_rf = None - request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = False - else: - request.app.state.sentence_transformer_rf = None + raise Exception(ERROR_MESSAGES.DEFAULT("CrossEncoder error")) + return rf ########################################## @@ -261,12 +257,15 @@ async def update_embedding_config( form_data.embedding_batch_size ) - update_embedding_model(request.app.state.config.RAG_EMBEDDING_MODEL) + request.app.state.ef = get_ef( + request.app.state.config.RAG_EMBEDDING_ENGINE, + request.app.state.config.RAG_EMBEDDING_MODEL, + ) request.app.state.EMBEDDING_FUNCTION = get_embedding_function( request.app.state.config.RAG_EMBEDDING_ENGINE, request.app.state.config.RAG_EMBEDDING_MODEL, - request.app.state.sentence_transformer_ef, + request.app.state.ef, ( request.app.state.config.OPENAI_API_BASE_URL if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai" @@ -316,7 +315,14 @@ async def update_reranking_config( try: request.app.state.config.RAG_RERANKING_MODEL = form_data.reranking_model - update_reranking_model(request.app.state.config.RAG_RERANKING_MODEL, True) + try: + request.app.state.rf = get_rf( + request.app.state.config.RAG_RERANKING_MODEL, + True, + ) + except Exception as e: + log.error(f"Error loading reranking model: {e}") + request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = False return { "status": True, @@ -739,7 +745,7 @@ def save_docs_to_vector_db( embedding_function = get_embedding_function( request.app.state.config.RAG_EMBEDDING_ENGINE, request.app.state.config.RAG_EMBEDDING_MODEL, - request.app.state.sentence_transformer_ef, + request.app.state.ef, ( request.app.state.config.OPENAI_API_BASE_URL if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai" @@ -1286,7 +1292,7 @@ def query_doc_handler( query=form_data.query, embedding_function=request.app.state.EMBEDDING_FUNCTION, k=form_data.k if form_data.k else request.app.state.config.TOP_K, - reranking_function=request.app.state.sentence_transformer_rf, + reranking_function=request.app.state.rf, r=( form_data.r if form_data.r @@ -1328,7 +1334,7 @@ def query_collection_handler( queries=[form_data.query], embedding_function=request.app.state.EMBEDDING_FUNCTION, k=form_data.k if form_data.k else request.app.state.config.TOP_K, - reranking_function=request.app.state.sentence_transformer_rf, + reranking_function=request.app.state.rf, r=( form_data.r if form_data.r From 772f5ccd60f1967bbb62afe16bbdacc481b3fec1 Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Wed, 11 Dec 2024 18:53:38 -0800 Subject: [PATCH 13/26] wip --- backend/open_webui/main.py | 52 ++--- backend/open_webui/routers/images.py | 280 +++++++++++++++------------ backend/open_webui/routers/webui.py | 94 --------- 3 files changed, 180 insertions(+), 246 deletions(-) delete mode 100644 backend/open_webui/routers/webui.py diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 09913ae01..84af1685a 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -697,11 +697,11 @@ async def chat_completion_filter_functions_handler(body, model, extra_params): if not filter: continue - if filter_id in webui_app.state.FUNCTIONS: - function_module = webui_app.state.FUNCTIONS[filter_id] + if filter_id in app.state.FUNCTIONS: + function_module = app.state.FUNCTIONS[filter_id] else: function_module, _, _ = load_function_module_by_id(filter_id) - webui_app.state.FUNCTIONS[filter_id] = function_module + app.state.FUNCTIONS[filter_id] = function_module # Check if the function has a file_handler variable if hasattr(function_module, "file_handler"): @@ -828,7 +828,7 @@ async def chat_completion_tools_handler( models, ) tools = get_tools( - webui_app, + app, tool_ids, user, { @@ -1406,7 +1406,7 @@ async def commit_session_after_request(request: Request, call_next): @app.middleware("http") async def check_url(request: Request, call_next): start_time = int(time.time()) - request.state.enable_api_key = webui_app.state.config.ENABLE_API_KEY + request.state.enable_api_key = app.state.config.ENABLE_API_KEY response = await call_next(request) process_time = int(time.time()) - start_time response.headers["X-Process-Time"] = str(process_time) @@ -1913,11 +1913,11 @@ async def get_all_models(): ] def get_function_module_by_id(function_id): - if function_id in webui_app.state.FUNCTIONS: - function_module = webui_app.state.FUNCTIONS[function_id] + if function_id in app.state.FUNCTIONS: + function_module = app.state.FUNCTIONS[function_id] else: function_module, _, _ = load_function_module_by_id(function_id) - webui_app.state.FUNCTIONS[function_id] = function_module + app.state.FUNCTIONS[function_id] = function_module for model in models: action_ids = [ @@ -1953,7 +1953,7 @@ async def get_models(user=Depends(get_verified_user)): if "pipeline" not in model or model["pipeline"].get("type", None) != "filter" ] - model_order_list = webui_app.state.config.MODEL_ORDER_LIST + model_order_list = app.state.config.MODEL_ORDER_LIST if model_order_list: model_order_dict = {model_id: i for i, model_id in enumerate(model_order_list)} # Sort models by order list priority, with fallback for those not in the list @@ -2229,11 +2229,11 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)): if not filter: continue - if filter_id in webui_app.state.FUNCTIONS: - function_module = webui_app.state.FUNCTIONS[filter_id] + if filter_id in app.state.FUNCTIONS: + function_module = app.state.FUNCTIONS[filter_id] else: function_module, _, _ = load_function_module_by_id(filter_id) - webui_app.state.FUNCTIONS[filter_id] = function_module + app.state.FUNCTIONS[filter_id] = function_module if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): valves = Functions.get_function_valves_by_id(filter_id) @@ -2340,11 +2340,11 @@ async def chat_action(action_id: str, form_data: dict, user=Depends(get_verified } ) - if action_id in webui_app.state.FUNCTIONS: - function_module = webui_app.state.FUNCTIONS[action_id] + if action_id in app.state.FUNCTIONS: + function_module = app.state.FUNCTIONS[action_id] else: function_module, _, _ = load_function_module_by_id(action_id) - webui_app.state.FUNCTIONS[action_id] = function_module + app.state.FUNCTIONS[action_id] = function_module if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): valves = Functions.get_function_valves_by_id(action_id) @@ -2448,17 +2448,17 @@ async def get_app_config(request: Request): }, "features": { "auth": WEBUI_AUTH, - "auth_trusted_header": bool(webui_app.state.AUTH_TRUSTED_EMAIL_HEADER), - "enable_ldap": webui_app.state.config.ENABLE_LDAP, - "enable_api_key": webui_app.state.config.ENABLE_API_KEY, - "enable_signup": webui_app.state.config.ENABLE_SIGNUP, - "enable_login_form": webui_app.state.config.ENABLE_LOGIN_FORM, + "auth_trusted_header": bool(app.state.AUTH_TRUSTED_EMAIL_HEADER), + "enable_ldap": app.state.config.ENABLE_LDAP, + "enable_api_key": app.state.config.ENABLE_API_KEY, + "enable_signup": app.state.config.ENABLE_SIGNUP, + "enable_login_form": app.state.config.ENABLE_LOGIN_FORM, **( { "enable_web_search": retrieval_app.state.config.ENABLE_RAG_WEB_SEARCH, "enable_image_generation": images_app.state.config.ENABLED, - "enable_community_sharing": webui_app.state.config.ENABLE_COMMUNITY_SHARING, - "enable_message_rating": webui_app.state.config.ENABLE_MESSAGE_RATING, + "enable_community_sharing": app.state.config.ENABLE_COMMUNITY_SHARING, + "enable_message_rating": app.state.config.ENABLE_MESSAGE_RATING, "enable_admin_export": ENABLE_ADMIN_EXPORT, "enable_admin_chat_access": ENABLE_ADMIN_CHAT_ACCESS, } @@ -2468,8 +2468,8 @@ async def get_app_config(request: Request): }, **( { - "default_models": webui_app.state.config.DEFAULT_MODELS, - "default_prompt_suggestions": webui_app.state.config.DEFAULT_PROMPT_SUGGESTIONS, + "default_models": app.state.config.DEFAULT_MODELS, + "default_prompt_suggestions": app.state.config.DEFAULT_PROMPT_SUGGESTIONS, "audio": { "tts": { "engine": audio_app.state.config.TTS_ENGINE, @@ -2484,7 +2484,7 @@ async def get_app_config(request: Request): "max_size": retrieval_app.state.config.FILE_MAX_SIZE, "max_count": retrieval_app.state.config.FILE_MAX_COUNT, }, - "permissions": {**webui_app.state.config.USER_PERMISSIONS}, + "permissions": {**app.state.config.USER_PERMISSIONS}, } if user is not None else {} @@ -2506,7 +2506,7 @@ async def get_webhook_url(user=Depends(get_admin_user)): @app.post("/api/webhook") async def update_webhook_url(form_data: UrlForm, user=Depends(get_admin_user)): app.state.config.WEBHOOK_URL = form_data.url - webui_app.state.WEBHOOK_URL = app.state.config.WEBHOOK_URL + app.state.WEBHOOK_URL = app.state.config.WEBHOOK_URL return {"url": app.state.config.WEBHOOK_URL} diff --git a/backend/open_webui/routers/images.py b/backend/open_webui/routers/images.py index f4c12ab64..0deded03e 100644 --- a/backend/open_webui/routers/images.py +++ b/backend/open_webui/routers/images.py @@ -9,6 +9,18 @@ from pathlib import Path from typing import Optional import requests + + +from fastapi import Depends, FastAPI, HTTPException, Request, APIRouter +from fastapi.middleware.cors import CORSMiddleware +from pydantic import BaseModel + + +from open_webui.config import CACHE_DIR +from open_webui.constants import ERROR_MESSAGES +from open_webui.env import ENV, SRC_LOG_LEVELS, ENABLE_FORWARD_USER_INFO_HEADERS + +from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.utils.images.comfyui import ( ComfyUIGenerateImageForm, ComfyUIWorkflow, @@ -16,48 +28,36 @@ from open_webui.utils.images.comfyui import ( ) -from open_webui.config import CACHE_DIR -from open_webui.constants import ERROR_MESSAGES -from open_webui.env import ENV, SRC_LOG_LEVELS, ENABLE_FORWARD_USER_INFO_HEADERS - -from fastapi import Depends, FastAPI, HTTPException, Request -from fastapi.middleware.cors import CORSMiddleware -from pydantic import BaseModel -from open_webui.utils.auth import get_admin_user, get_verified_user - log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["IMAGES"]) IMAGE_CACHE_DIR = Path(CACHE_DIR).joinpath("./image/generations/") IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True) -app = FastAPI( - docs_url="/docs" if ENV == "dev" else None, - openapi_url="/openapi.json" if ENV == "dev" else None, - redoc_url=None, -) + +router = APIRouter() -@app.get("/config") +@router.get("/config") async def get_config(request: Request, user=Depends(get_admin_user)): return { - "enabled": app.state.config.ENABLED, - "engine": app.state.config.ENGINE, + "enabled": request.app.state.config.ENABLED, + "engine": request.app.state.config.ENGINE, "openai": { - "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL, - "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY, + "OPENAI_API_BASE_URL": request.app.state.config.OPENAI_API_BASE_URL, + "OPENAI_API_KEY": request.app.state.config.OPENAI_API_KEY, }, "automatic1111": { - "AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL, - "AUTOMATIC1111_API_AUTH": app.state.config.AUTOMATIC1111_API_AUTH, - "AUTOMATIC1111_CFG_SCALE": app.state.config.AUTOMATIC1111_CFG_SCALE, - "AUTOMATIC1111_SAMPLER": app.state.config.AUTOMATIC1111_SAMPLER, - "AUTOMATIC1111_SCHEDULER": app.state.config.AUTOMATIC1111_SCHEDULER, + "AUTOMATIC1111_BASE_URL": request.app.state.config.AUTOMATIC1111_BASE_URL, + "AUTOMATIC1111_API_AUTH": request.app.state.config.AUTOMATIC1111_API_AUTH, + "AUTOMATIC1111_CFG_SCALE": request.app.state.config.AUTOMATIC1111_CFG_SCALE, + "AUTOMATIC1111_SAMPLER": request.app.state.config.AUTOMATIC1111_SAMPLER, + "AUTOMATIC1111_SCHEDULER": request.app.state.config.AUTOMATIC1111_SCHEDULER, }, "comfyui": { - "COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL, - "COMFYUI_WORKFLOW": app.state.config.COMFYUI_WORKFLOW, - "COMFYUI_WORKFLOW_NODES": app.state.config.COMFYUI_WORKFLOW_NODES, + "COMFYUI_BASE_URL": request.app.state.config.COMFYUI_BASE_URL, + "COMFYUI_WORKFLOW": request.app.state.config.COMFYUI_WORKFLOW, + "COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES, }, } @@ -89,133 +89,150 @@ class ConfigForm(BaseModel): comfyui: ComfyUIConfigForm -@app.post("/config/update") -async def update_config(form_data: ConfigForm, user=Depends(get_admin_user)): - app.state.config.ENGINE = form_data.engine - app.state.config.ENABLED = form_data.enabled +@router.post("/config/update") +async def update_config( + request: Request, form_data: ConfigForm, user=Depends(get_admin_user) +): + request.app.state.config.ENGINE = form_data.engine + request.app.state.config.ENABLED = form_data.enabled - app.state.config.OPENAI_API_BASE_URL = form_data.openai.OPENAI_API_BASE_URL - app.state.config.OPENAI_API_KEY = form_data.openai.OPENAI_API_KEY + request.app.state.config.OPENAI_API_BASE_URL = form_data.openai.OPENAI_API_BASE_URL + request.app.state.config.OPENAI_API_KEY = form_data.openai.OPENAI_API_KEY - app.state.config.AUTOMATIC1111_BASE_URL = ( + request.app.state.config.AUTOMATIC1111_BASE_URL = ( form_data.automatic1111.AUTOMATIC1111_BASE_URL ) - app.state.config.AUTOMATIC1111_API_AUTH = ( + request.app.state.config.AUTOMATIC1111_API_AUTH = ( form_data.automatic1111.AUTOMATIC1111_API_AUTH ) - app.state.config.AUTOMATIC1111_CFG_SCALE = ( + request.app.state.config.AUTOMATIC1111_CFG_SCALE = ( float(form_data.automatic1111.AUTOMATIC1111_CFG_SCALE) if form_data.automatic1111.AUTOMATIC1111_CFG_SCALE else None ) - app.state.config.AUTOMATIC1111_SAMPLER = ( + request.app.state.config.AUTOMATIC1111_SAMPLER = ( form_data.automatic1111.AUTOMATIC1111_SAMPLER if form_data.automatic1111.AUTOMATIC1111_SAMPLER else None ) - app.state.config.AUTOMATIC1111_SCHEDULER = ( + request.app.state.config.AUTOMATIC1111_SCHEDULER = ( form_data.automatic1111.AUTOMATIC1111_SCHEDULER if form_data.automatic1111.AUTOMATIC1111_SCHEDULER else None ) - app.state.config.COMFYUI_BASE_URL = form_data.comfyui.COMFYUI_BASE_URL.strip("/") - app.state.config.COMFYUI_WORKFLOW = form_data.comfyui.COMFYUI_WORKFLOW - app.state.config.COMFYUI_WORKFLOW_NODES = form_data.comfyui.COMFYUI_WORKFLOW_NODES + request.app.state.config.COMFYUI_BASE_URL = ( + form_data.comfyui.COMFYUI_BASE_URL.strip("/") + ) + request.app.state.config.COMFYUI_WORKFLOW = form_data.comfyui.COMFYUI_WORKFLOW + request.app.state.config.COMFYUI_WORKFLOW_NODES = ( + form_data.comfyui.COMFYUI_WORKFLOW_NODES + ) return { - "enabled": app.state.config.ENABLED, - "engine": app.state.config.ENGINE, + "enabled": request.app.state.config.ENABLED, + "engine": request.app.state.config.ENGINE, "openai": { - "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL, - "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY, + "OPENAI_API_BASE_URL": request.app.state.config.OPENAI_API_BASE_URL, + "OPENAI_API_KEY": request.app.state.config.OPENAI_API_KEY, }, "automatic1111": { - "AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL, - "AUTOMATIC1111_API_AUTH": app.state.config.AUTOMATIC1111_API_AUTH, - "AUTOMATIC1111_CFG_SCALE": app.state.config.AUTOMATIC1111_CFG_SCALE, - "AUTOMATIC1111_SAMPLER": app.state.config.AUTOMATIC1111_SAMPLER, - "AUTOMATIC1111_SCHEDULER": app.state.config.AUTOMATIC1111_SCHEDULER, + "AUTOMATIC1111_BASE_URL": request.app.state.config.AUTOMATIC1111_BASE_URL, + "AUTOMATIC1111_API_AUTH": request.app.state.config.AUTOMATIC1111_API_AUTH, + "AUTOMATIC1111_CFG_SCALE": request.app.state.config.AUTOMATIC1111_CFG_SCALE, + "AUTOMATIC1111_SAMPLER": request.app.state.config.AUTOMATIC1111_SAMPLER, + "AUTOMATIC1111_SCHEDULER": request.app.state.config.AUTOMATIC1111_SCHEDULER, }, "comfyui": { - "COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL, - "COMFYUI_WORKFLOW": app.state.config.COMFYUI_WORKFLOW, - "COMFYUI_WORKFLOW_NODES": app.state.config.COMFYUI_WORKFLOW_NODES, + "COMFYUI_BASE_URL": request.app.state.config.COMFYUI_BASE_URL, + "COMFYUI_WORKFLOW": request.app.state.config.COMFYUI_WORKFLOW, + "COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES, }, } -def get_automatic1111_api_auth(): - if app.state.config.AUTOMATIC1111_API_AUTH is None: +def get_automatic1111_api_auth(request: Request): + if request.app.state.config.AUTOMATIC1111_API_AUTH is None: return "" else: - auth1111_byte_string = app.state.config.AUTOMATIC1111_API_AUTH.encode("utf-8") + auth1111_byte_string = request.app.state.config.AUTOMATIC1111_API_AUTH.encode( + "utf-8" + ) auth1111_base64_encoded_bytes = base64.b64encode(auth1111_byte_string) auth1111_base64_encoded_string = auth1111_base64_encoded_bytes.decode("utf-8") return f"Basic {auth1111_base64_encoded_string}" -@app.get("/config/url/verify") -async def verify_url(user=Depends(get_admin_user)): - if app.state.config.ENGINE == "automatic1111": +@router.get("/config/url/verify") +async def verify_url(request: Request, user=Depends(get_admin_user)): + if request.app.state.config.ENGINE == "automatic1111": try: r = requests.get( - url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", - headers={"authorization": get_automatic1111_api_auth()}, + url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", + headers={"authorization": get_automatic1111_api_auth(request)}, ) r.raise_for_status() return True except Exception: - app.state.config.ENABLED = False + request.app.state.config.ENABLED = False raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL) - elif app.state.config.ENGINE == "comfyui": + elif request.app.state.config.ENGINE == "comfyui": try: - r = requests.get(url=f"{app.state.config.COMFYUI_BASE_URL}/object_info") + r = requests.get( + url=f"{request.app.state.config.COMFYUI_BASE_URL}/object_info" + ) r.raise_for_status() return True except Exception: - app.state.config.ENABLED = False + request.app.state.config.ENABLED = False raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL) else: return True -def set_image_model(model: str): +def set_image_model(request: Request, model: str): log.info(f"Setting image model to {model}") - app.state.config.MODEL = model - if app.state.config.ENGINE in ["", "automatic1111"]: + request.app.state.config.MODEL = model + if request.app.state.config.ENGINE in ["", "automatic1111"]: api_auth = get_automatic1111_api_auth() r = requests.get( - url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", + url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", headers={"authorization": api_auth}, ) options = r.json() if model != options["sd_model_checkpoint"]: options["sd_model_checkpoint"] = model r = requests.post( - url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", + url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", json=options, headers={"authorization": api_auth}, ) - return app.state.config.MODEL + return request.app.state.config.MODEL def get_image_model(): - if app.state.config.ENGINE == "openai": - return app.state.config.MODEL if app.state.config.MODEL else "dall-e-2" - elif app.state.config.ENGINE == "comfyui": - return app.state.config.MODEL if app.state.config.MODEL else "" - elif app.state.config.ENGINE == "automatic1111" or app.state.config.ENGINE == "": + if request.app.state.config.ENGINE == "openai": + return ( + request.app.state.config.MODEL + if request.app.state.config.MODEL + else "dall-e-2" + ) + elif request.app.state.config.ENGINE == "comfyui": + return request.app.state.config.MODEL if request.app.state.config.MODEL else "" + elif ( + request.app.state.config.ENGINE == "automatic1111" + or request.app.state.config.ENGINE == "" + ): try: r = requests.get( - url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", + url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", headers={"authorization": get_automatic1111_api_auth()}, ) options = r.json() return options["sd_model_checkpoint"] except Exception as e: - app.state.config.ENABLED = False + request.app.state.config.ENABLED = False raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) @@ -225,23 +242,25 @@ class ImageConfigForm(BaseModel): IMAGE_STEPS: int -@app.get("/image/config") -async def get_image_config(user=Depends(get_admin_user)): +@router.get("/image/config") +async def get_image_config(request: Request, user=Depends(get_admin_user)): return { - "MODEL": app.state.config.MODEL, - "IMAGE_SIZE": app.state.config.IMAGE_SIZE, - "IMAGE_STEPS": app.state.config.IMAGE_STEPS, + "MODEL": request.app.state.config.MODEL, + "IMAGE_SIZE": request.app.state.config.IMAGE_SIZE, + "IMAGE_STEPS": request.app.state.config.IMAGE_STEPS, } -@app.post("/image/config/update") -async def update_image_config(form_data: ImageConfigForm, user=Depends(get_admin_user)): +@router.post("/image/config/update") +async def update_image_config( + request: Request, form_data: ImageConfigForm, user=Depends(get_admin_user) +): - set_image_model(form_data.MODEL) + set_image_model(request, form_data.MODEL) pattern = r"^\d+x\d+$" if re.match(pattern, form_data.IMAGE_SIZE): - app.state.config.IMAGE_SIZE = form_data.IMAGE_SIZE + request.app.state.config.IMAGE_SIZE = form_data.IMAGE_SIZE else: raise HTTPException( status_code=400, @@ -249,7 +268,7 @@ async def update_image_config(form_data: ImageConfigForm, user=Depends(get_admin ) if form_data.IMAGE_STEPS >= 0: - app.state.config.IMAGE_STEPS = form_data.IMAGE_STEPS + request.app.state.config.IMAGE_STEPS = form_data.IMAGE_STEPS else: raise HTTPException( status_code=400, @@ -257,29 +276,31 @@ async def update_image_config(form_data: ImageConfigForm, user=Depends(get_admin ) return { - "MODEL": app.state.config.MODEL, - "IMAGE_SIZE": app.state.config.IMAGE_SIZE, - "IMAGE_STEPS": app.state.config.IMAGE_STEPS, + "MODEL": request.app.state.config.MODEL, + "IMAGE_SIZE": request.app.state.config.IMAGE_SIZE, + "IMAGE_STEPS": request.app.state.config.IMAGE_STEPS, } -@app.get("/models") -def get_models(user=Depends(get_verified_user)): +@router.get("/models") +def get_models(request: Request, user=Depends(get_verified_user)): try: - if app.state.config.ENGINE == "openai": + if request.app.state.config.ENGINE == "openai": return [ {"id": "dall-e-2", "name": "DALL·E 2"}, {"id": "dall-e-3", "name": "DALL·E 3"}, ] - elif app.state.config.ENGINE == "comfyui": + elif request.app.state.config.ENGINE == "comfyui": # TODO - get models from comfyui - r = requests.get(url=f"{app.state.config.COMFYUI_BASE_URL}/object_info") + r = requests.get( + url=f"{request.app.state.config.COMFYUI_BASE_URL}/object_info" + ) info = r.json() - workflow = json.loads(app.state.config.COMFYUI_WORKFLOW) + workflow = json.loads(request.app.state.config.COMFYUI_WORKFLOW) model_node_id = None - for node in app.state.config.COMFYUI_WORKFLOW_NODES: + for node in request.app.state.config.COMFYUI_WORKFLOW_NODES: if node["type"] == "model": if node["node_ids"]: model_node_id = node["node_ids"][0] @@ -315,10 +336,11 @@ def get_models(user=Depends(get_verified_user)): ) ) elif ( - app.state.config.ENGINE == "automatic1111" or app.state.config.ENGINE == "" + request.app.state.config.ENGINE == "automatic1111" + or request.app.state.config.ENGINE == "" ): r = requests.get( - url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models", + url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models", headers={"authorization": get_automatic1111_api_auth()}, ) models = r.json() @@ -329,7 +351,7 @@ def get_models(user=Depends(get_verified_user)): ) ) except Exception as e: - app.state.config.ENABLED = False + request.app.state.config.ENABLED = False raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) @@ -401,18 +423,21 @@ def save_url_image(url): return None -@app.post("/generations") +@router.post("/generations") async def image_generations( + request: Request, form_data: GenerateImageForm, user=Depends(get_verified_user), ): - width, height = tuple(map(int, app.state.config.IMAGE_SIZE.split("x"))) + width, height = tuple(map(int, request.app.state.config.IMAGE_SIZE.split("x"))) r = None try: - if app.state.config.ENGINE == "openai": + if request.app.state.config.ENGINE == "openai": headers = {} - headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEY}" + headers["Authorization"] = ( + f"Bearer {request.app.state.config.OPENAI_API_KEY}" + ) headers["Content-Type"] = "application/json" if ENABLE_FORWARD_USER_INFO_HEADERS: @@ -423,14 +448,16 @@ async def image_generations( data = { "model": ( - app.state.config.MODEL - if app.state.config.MODEL != "" + request.app.state.config.MODEL + if request.app.state.config.MODEL != "" else "dall-e-2" ), "prompt": form_data.prompt, "n": form_data.n, "size": ( - form_data.size if form_data.size else app.state.config.IMAGE_SIZE + form_data.size + if form_data.size + else request.app.state.config.IMAGE_SIZE ), "response_format": "b64_json", } @@ -438,7 +465,7 @@ async def image_generations( # Use asyncio.to_thread for the requests.post call r = await asyncio.to_thread( requests.post, - url=f"{app.state.config.OPENAI_API_BASE_URL}/images/generations", + url=f"{request.app.state.config.OPENAI_API_BASE_URL}/images/generations", json=data, headers=headers, ) @@ -458,7 +485,7 @@ async def image_generations( return images - elif app.state.config.ENGINE == "comfyui": + elif request.app.state.config.ENGINE == "comfyui": data = { "prompt": form_data.prompt, "width": width, @@ -466,8 +493,8 @@ async def image_generations( "n": form_data.n, } - if app.state.config.IMAGE_STEPS is not None: - data["steps"] = app.state.config.IMAGE_STEPS + if request.app.state.config.IMAGE_STEPS is not None: + data["steps"] = request.app.state.config.IMAGE_STEPS if form_data.negative_prompt is not None: data["negative_prompt"] = form_data.negative_prompt @@ -476,18 +503,18 @@ async def image_generations( **{ "workflow": ComfyUIWorkflow( **{ - "workflow": app.state.config.COMFYUI_WORKFLOW, - "nodes": app.state.config.COMFYUI_WORKFLOW_NODES, + "workflow": request.app.state.config.COMFYUI_WORKFLOW, + "nodes": request.app.state.config.COMFYUI_WORKFLOW_NODES, } ), **data, } ) res = await comfyui_generate_image( - app.state.config.MODEL, + request.app.state.config.MODEL, form_data, user.id, - app.state.config.COMFYUI_BASE_URL, + request.app.state.config.COMFYUI_BASE_URL, ) log.debug(f"res: {res}") @@ -504,7 +531,8 @@ async def image_generations( log.debug(f"images: {images}") return images elif ( - app.state.config.ENGINE == "automatic1111" or app.state.config.ENGINE == "" + request.app.state.config.ENGINE == "automatic1111" + or request.app.state.config.ENGINE == "" ): if form_data.model: set_image_model(form_data.model) @@ -516,25 +544,25 @@ async def image_generations( "height": height, } - if app.state.config.IMAGE_STEPS is not None: - data["steps"] = app.state.config.IMAGE_STEPS + if request.app.state.config.IMAGE_STEPS is not None: + data["steps"] = request.app.state.config.IMAGE_STEPS if form_data.negative_prompt is not None: data["negative_prompt"] = form_data.negative_prompt - if app.state.config.AUTOMATIC1111_CFG_SCALE: - data["cfg_scale"] = app.state.config.AUTOMATIC1111_CFG_SCALE + if request.app.state.config.AUTOMATIC1111_CFG_SCALE: + data["cfg_scale"] = request.app.state.config.AUTOMATIC1111_CFG_SCALE - if app.state.config.AUTOMATIC1111_SAMPLER: - data["sampler_name"] = app.state.config.AUTOMATIC1111_SAMPLER + if request.app.state.config.AUTOMATIC1111_SAMPLER: + data["sampler_name"] = request.app.state.config.AUTOMATIC1111_SAMPLER - if app.state.config.AUTOMATIC1111_SCHEDULER: - data["scheduler"] = app.state.config.AUTOMATIC1111_SCHEDULER + if request.app.state.config.AUTOMATIC1111_SCHEDULER: + data["scheduler"] = request.app.state.config.AUTOMATIC1111_SCHEDULER # Use asyncio.to_thread for the requests.post call r = await asyncio.to_thread( requests.post, - url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img", + url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img", json=data, headers={"authorization": get_automatic1111_api_auth()}, ) diff --git a/backend/open_webui/routers/webui.py b/backend/open_webui/routers/webui.py deleted file mode 100644 index 058d4c365..000000000 --- a/backend/open_webui/routers/webui.py +++ /dev/null @@ -1,94 +0,0 @@ -import inspect -import json -import logging -import time -from typing import AsyncGenerator, Generator, Iterator - -from open_webui.socket.main import get_event_call, get_event_emitter -from open_webui.models.functions import Functions -from open_webui.models.models import Models -from open_webui.routers import ( - auths, - chats, - folders, - configs, - groups, - files, - functions, - memories, - models, - knowledge, - prompts, - evaluations, - tools, - users, - utils, -) -from open_webui.utils.plugin import load_function_module_by_id -from open_webui.config import ( - ADMIN_EMAIL, - CORS_ALLOW_ORIGIN, - DEFAULT_MODELS, - DEFAULT_PROMPT_SUGGESTIONS, - DEFAULT_USER_ROLE, - MODEL_ORDER_LIST, - ENABLE_COMMUNITY_SHARING, - ENABLE_LOGIN_FORM, - ENABLE_MESSAGE_RATING, - ENABLE_SIGNUP, - ENABLE_API_KEY, - ENABLE_EVALUATION_ARENA_MODELS, - EVALUATION_ARENA_MODELS, - DEFAULT_ARENA_MODEL, - JWT_EXPIRES_IN, - ENABLE_OAUTH_ROLE_MANAGEMENT, - OAUTH_ROLES_CLAIM, - OAUTH_EMAIL_CLAIM, - OAUTH_PICTURE_CLAIM, - OAUTH_USERNAME_CLAIM, - OAUTH_ALLOWED_ROLES, - OAUTH_ADMIN_ROLES, - SHOW_ADMIN_DETAILS, - USER_PERMISSIONS, - WEBHOOK_URL, - WEBUI_AUTH, - WEBUI_BANNERS, - ENABLE_LDAP, - LDAP_SERVER_LABEL, - LDAP_SERVER_HOST, - LDAP_SERVER_PORT, - LDAP_ATTRIBUTE_FOR_USERNAME, - LDAP_SEARCH_FILTERS, - LDAP_SEARCH_BASE, - LDAP_APP_DN, - LDAP_APP_PASSWORD, - LDAP_USE_TLS, - LDAP_CA_CERT_FILE, - LDAP_CIPHERS, - AppConfig, -) -from open_webui.env import ( - ENV, - SRC_LOG_LEVELS, - WEBUI_AUTH_TRUSTED_EMAIL_HEADER, - WEBUI_AUTH_TRUSTED_NAME_HEADER, -) -from fastapi import FastAPI -from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import StreamingResponse -from pydantic import BaseModel -from open_webui.utils.misc import ( - openai_chat_chunk_message_template, - openai_chat_completion_message_template, -) -from open_webui.utils.payload import ( - apply_model_params_to_body_openai, - apply_model_system_prompt_to_body, -) - - -from open_webui.utils.tools import get_tools - - -log = logging.getLogger(__name__) -log.setLevel(SRC_LOG_LEVELS["MAIN"]) From fe5519e0a2ab884a685e9d02d9f2695bae8c8a9e Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Wed, 11 Dec 2024 19:52:46 -0800 Subject: [PATCH 14/26] wip --- backend/open_webui/main.py | 242 ++++++------------------ backend/open_webui/routers/ollama.py | 11 +- backend/open_webui/routers/pipelines.py | 140 +++++++++++++- backend/open_webui/routers/tasks.py | 39 ++-- backend/open_webui/utils/task.py | 16 ++ 5 files changed, 236 insertions(+), 212 deletions(-) diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 84af1685a..dbb9518af 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -75,6 +75,11 @@ from open_webui.routers.retrieval import ( get_ef, get_rf, ) +from open_webui.routers.pipelines import ( + process_pipeline_inlet_filter, + process_pipeline_outlet_filter, +) + from open_webui.retrieval.utils import get_sources_from_files @@ -290,6 +295,7 @@ from open_webui.utils.response import ( ) from open_webui.utils.task import ( + get_task_model_id, rag_template, tools_function_calling_generation_template, ) @@ -662,35 +668,36 @@ app.state.MODELS = {} ################################## -def get_filter_function_ids(model): - def get_priority(function_id): - function = Functions.get_function_by_id(function_id) - if function is not None and hasattr(function, "valves"): - # TODO: Fix FunctionModel - return (function.valves if function.valves else {}).get("priority", 0) - return 0 - - filter_ids = [function.id for function in Functions.get_global_filter_functions()] - if "info" in model and "meta" in model["info"]: - filter_ids.extend(model["info"]["meta"].get("filterIds", [])) - filter_ids = list(set(filter_ids)) - - enabled_filter_ids = [ - function.id - for function in Functions.get_functions_by_type("filter", active_only=True) - ] - - filter_ids = [ - filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids - ] - - filter_ids.sort(key=get_priority) - return filter_ids - - async def chat_completion_filter_functions_handler(body, model, extra_params): skip_files = None + def get_filter_function_ids(model): + def get_priority(function_id): + function = Functions.get_function_by_id(function_id) + if function is not None and hasattr(function, "valves"): + # TODO: Fix FunctionModel + return (function.valves if function.valves else {}).get("priority", 0) + return 0 + + filter_ids = [ + function.id for function in Functions.get_global_filter_functions() + ] + if "info" in model and "meta" in model["info"]: + filter_ids.extend(model["info"]["meta"].get("filterIds", [])) + filter_ids = list(set(filter_ids)) + + enabled_filter_ids = [ + function.id + for function in Functions.get_functions_by_type("filter", active_only=True) + ] + + filter_ids = [ + filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids + ] + + filter_ids.sort(key=get_priority) + return filter_ids + filter_ids = get_filter_function_ids(model) for filter_id in filter_ids: filter = Functions.get_function_by_id(filter_id) @@ -791,22 +798,6 @@ async def get_content_from_response(response) -> Optional[str]: return content -def get_task_model_id( - default_model_id: str, task_model: str, task_model_external: str, models -) -> str: - # Set the task model - task_model_id = default_model_id - # Check if the user has a custom task model and use that model - if models[task_model_id]["owned_by"] == "ollama": - if task_model and task_model in models: - task_model_id = task_model - else: - if task_model_external and task_model_external in models: - task_model_id = task_model_external - - return task_model_id - - async def chat_completion_tools_handler( body: dict, user: UserModel, models, extra_params: dict ) -> tuple[dict, dict]: @@ -857,7 +848,7 @@ async def chat_completion_tools_handler( ) try: - payload = filter_pipeline(payload, user, models) + payload = process_pipeline_inlet_filter(request, payload, user, models) except Exception as e: raise e @@ -1153,7 +1144,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): if prompt is None: raise Exception("No user message found") if ( - retrieval_app.state.config.RELEVANCE_THRESHOLD == 0 + app.state.config.RELEVANCE_THRESHOLD == 0 and context_string.strip() == "" ): log.debug( @@ -1164,16 +1155,12 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): # TODO: replace with add_or_update_system_message if model["owned_by"] == "ollama": body["messages"] = prepend_to_first_user_message_content( - rag_template( - retrieval_app.state.config.RAG_TEMPLATE, context_string, prompt - ), + rag_template(app.state.config.RAG_TEMPLATE, context_string, prompt), body["messages"], ) else: body["messages"] = add_or_update_system_message( - rag_template( - retrieval_app.state.config.RAG_TEMPLATE, context_string, prompt - ), + rag_template(app.state.config.RAG_TEMPLATE, context_string, prompt), body["messages"], ) @@ -1225,77 +1212,6 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): app.add_middleware(ChatCompletionMiddleware) -################################## -# -# Pipeline Middleware -# -################################## - - -def get_sorted_filters(model_id, models): - filters = [ - model - for model in models.values() - if "pipeline" in model - and "type" in model["pipeline"] - and model["pipeline"]["type"] == "filter" - and ( - model["pipeline"]["pipelines"] == ["*"] - or any( - model_id == target_model_id - for target_model_id in model["pipeline"]["pipelines"] - ) - ) - ] - sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"]) - return sorted_filters - - -def filter_pipeline(payload, user, models): - user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role} - model_id = payload["model"] - - sorted_filters = get_sorted_filters(model_id, models) - model = models[model_id] - - if "pipeline" in model: - sorted_filters.append(model) - - for filter in sorted_filters: - r = None - try: - urlIdx = filter["urlIdx"] - - url = app.state.config.OPENAI_API_BASE_URLS[urlIdx] - key = app.state.config.OPENAI_API_KEYS[urlIdx] - - if key == "": - continue - - headers = {"Authorization": f"Bearer {key}"} - r = requests.post( - f"{url}/{filter['id']}/filter/inlet", - headers=headers, - json={ - "user": user, - "body": payload, - }, - ) - - r.raise_for_status() - payload = r.json() - except Exception as e: - # Handle connection error here - print(f"Connection error: {e}") - - if r is not None: - res = r.json() - if "detail" in res: - raise Exception(r.status_code, res["detail"]) - - return payload - - class PipelineMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): if not request.method == "POST" and any( @@ -1335,11 +1251,11 @@ class PipelineMiddleware(BaseHTTPMiddleware): content={"detail": e.detail}, ) - model_list = await get_all_models() - models = {model["id"]: model for model in model_list} + await get_all_models() + models = app.state.MODELS try: - data = filter_pipeline(data, user, models) + data = process_pipeline_inlet_filter(request, data, user, models) except Exception as e: if len(e.args) > 1: return JSONResponse( @@ -1447,8 +1363,8 @@ app.include_router(ollama.router, prefix="/ollama", tags=["ollama"]) app.include_router(openai.router, prefix="/openai", tags=["openai"]) -app.include_router(pipelines.router, prefix="/pipelines", tags=["pipelines"]) -app.include_router(tasks.router, prefix="/tasks", tags=["tasks"]) +app.include_router(pipelines.router, prefix="/api/pipelines", tags=["pipelines"]) +app.include_router(tasks.router, prefix="/api/tasks", tags=["tasks"]) app.include_router(images.router, prefix="/api/v1/images", tags=["images"]) @@ -2105,7 +2021,6 @@ async def generate_chat_completions( if model["owned_by"] == "ollama": # Using /ollama/api/chat endpoint form_data = convert_payload_openai_to_ollama(form_data) - form_data = GenerateChatCompletionForm(**form_data) response = await generate_ollama_chat_completion( form_data=form_data, user=user, bypass_filter=bypass_filter ) @@ -2124,7 +2039,9 @@ async def generate_chat_completions( @app.post("/api/chat/completed") -async def chat_completed(form_data: dict, user=Depends(get_verified_user)): +async def chat_completed( + request: Request, form_data: dict, user=Depends(get_verified_user) +): model_list = await get_all_models() models = {model["id"]: model for model in model_list} @@ -2137,53 +2054,14 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)): ) model = models[model_id] - sorted_filters = get_sorted_filters(model_id, models) - if "pipeline" in model: - sorted_filters = [model] + sorted_filters - for filter in sorted_filters: - r = None - try: - urlIdx = filter["urlIdx"] - - url = app.state.config.OPENAI_API_BASE_URLS[urlIdx] - key = app.state.config.OPENAI_API_KEYS[urlIdx] - - if key != "": - headers = {"Authorization": f"Bearer {key}"} - r = requests.post( - f"{url}/{filter['id']}/filter/outlet", - headers=headers, - json={ - "user": { - "id": user.id, - "name": user.name, - "email": user.email, - "role": user.role, - }, - "body": data, - }, - ) - - r.raise_for_status() - data = r.json() - except Exception as e: - # Handle connection error here - print(f"Connection error: {e}") - - if r is not None: - try: - res = r.json() - if "detail" in res: - return JSONResponse( - status_code=r.status_code, - content=res, - ) - except Exception: - pass - - else: - pass + try: + data = process_pipeline_outlet_filter(request, data, user, models) + except Exception as e: + return HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=str(e), + ) __event_emitter__ = get_event_emitter( { @@ -2455,8 +2333,8 @@ async def get_app_config(request: Request): "enable_login_form": app.state.config.ENABLE_LOGIN_FORM, **( { - "enable_web_search": retrieval_app.state.config.ENABLE_RAG_WEB_SEARCH, - "enable_image_generation": images_app.state.config.ENABLED, + "enable_web_search": app.state.config.ENABLE_RAG_WEB_SEARCH, + "enable_image_generation": app.state.config.ENABLE_IMAGE_GENERATION, "enable_community_sharing": app.state.config.ENABLE_COMMUNITY_SHARING, "enable_message_rating": app.state.config.ENABLE_MESSAGE_RATING, "enable_admin_export": ENABLE_ADMIN_EXPORT, @@ -2472,17 +2350,17 @@ async def get_app_config(request: Request): "default_prompt_suggestions": app.state.config.DEFAULT_PROMPT_SUGGESTIONS, "audio": { "tts": { - "engine": audio_app.state.config.TTS_ENGINE, - "voice": audio_app.state.config.TTS_VOICE, - "split_on": audio_app.state.config.TTS_SPLIT_ON, + "engine": app.state.config.TTS_ENGINE, + "voice": app.state.config.TTS_VOICE, + "split_on": app.state.config.TTS_SPLIT_ON, }, "stt": { - "engine": audio_app.state.config.STT_ENGINE, + "engine": app.state.config.STT_ENGINE, }, }, "file": { - "max_size": retrieval_app.state.config.FILE_MAX_SIZE, - "max_count": retrieval_app.state.config.FILE_MAX_COUNT, + "max_size": app.state.config.FILE_MAX_SIZE, + "max_count": app.state.config.FILE_MAX_COUNT, }, "permissions": {**app.state.config.USER_PERMISSIONS}, } diff --git a/backend/open_webui/routers/ollama.py b/backend/open_webui/routers/ollama.py index b217b8f45..c36c2d730 100644 --- a/backend/open_webui/routers/ollama.py +++ b/backend/open_webui/routers/ollama.py @@ -941,7 +941,7 @@ async def get_ollama_url(request: Request, model: str, url_idx: Optional[int] = @router.post("/api/chat/{url_idx}") async def generate_chat_completion( request: Request, - form_data: GenerateChatCompletionForm, + form_data: dict, url_idx: Optional[int] = None, user=Depends(get_verified_user), bypass_filter: Optional[bool] = False, @@ -949,6 +949,15 @@ async def generate_chat_completion( if BYPASS_MODEL_ACCESS_CONTROL: bypass_filter = True + try: + form_data = GenerateChatCompletionForm(**form_data) + except Exception as e: + log.exception(e) + raise HTTPException( + status_code=400, + detail=str(e), + ) + payload = {**form_data.model_dump(exclude_none=True)} if "metadata" in payload: del payload["metadata"] diff --git a/backend/open_webui/routers/pipelines.py b/backend/open_webui/routers/pipelines.py index f1cdae140..258c10ee6 100644 --- a/backend/open_webui/routers/pipelines.py +++ b/backend/open_webui/routers/pipelines.py @@ -30,6 +30,130 @@ log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MAIN"]) +################################## +# +# Pipeline Middleware +# +################################## + + +def get_sorted_filters(model_id, models): + filters = [ + model + for model in models.values() + if "pipeline" in model + and "type" in model["pipeline"] + and model["pipeline"]["type"] == "filter" + and ( + model["pipeline"]["pipelines"] == ["*"] + or any( + model_id == target_model_id + for target_model_id in model["pipeline"]["pipelines"] + ) + ) + ] + sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"]) + return sorted_filters + + +def process_pipeline_inlet_filter(request, payload, user, models): + user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role} + model_id = payload["model"] + + sorted_filters = get_sorted_filters(model_id, models) + model = models[model_id] + + if "pipeline" in model: + sorted_filters.append(model) + + for filter in sorted_filters: + r = None + try: + urlIdx = filter["urlIdx"] + + url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx] + key = request.app.state.config.OPENAI_API_KEYS[urlIdx] + + if key == "": + continue + + headers = {"Authorization": f"Bearer {key}"} + r = requests.post( + f"{url}/{filter['id']}/filter/inlet", + headers=headers, + json={ + "user": user, + "body": payload, + }, + ) + + r.raise_for_status() + payload = r.json() + except Exception as e: + # Handle connection error here + print(f"Connection error: {e}") + + if r is not None: + res = r.json() + if "detail" in res: + raise Exception(r.status_code, res["detail"]) + + return payload + + +def process_pipeline_outlet_filter(request, payload, user, models): + user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role} + model_id = payload["model"] + + sorted_filters = get_sorted_filters(model_id, models) + model = models[model_id] + + if "pipeline" in model: + sorted_filters = [model] + sorted_filters + + for filter in sorted_filters: + r = None + try: + urlIdx = filter["urlIdx"] + + url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx] + key = request.app.state.config.OPENAI_API_KEYS[urlIdx] + + if key != "": + r = requests.post( + f"{url}/{filter['id']}/filter/outlet", + headers={"Authorization": f"Bearer {key}"}, + json={ + "user": { + "id": user.id, + "name": user.name, + "email": user.email, + "role": user.role, + }, + "body": data, + }, + ) + + r.raise_for_status() + data = r.json() + except Exception as e: + # Handle connection error here + print(f"Connection error: {e}") + + if r is not None: + try: + res = r.json() + if "detail" in res: + return Exception(r.status_code, res) + except Exception: + pass + + else: + pass + + return payload + + ################################## # # Pipelines Endpoints @@ -39,7 +163,7 @@ log.setLevel(SRC_LOG_LEVELS["MAIN"]) router = APIRouter() -@router.get("/api/pipelines/list") +@router.get("/list") async def get_pipelines_list(request: Request, user=Depends(get_admin_user)): responses = await get_all_models_responses(request) log.debug(f"get_pipelines_list: get_openai_models_responses returned {responses}") @@ -61,7 +185,7 @@ async def get_pipelines_list(request: Request, user=Depends(get_admin_user)): } -@router.post("/api/pipelines/upload") +@router.post("/upload") async def upload_pipeline( request: Request, urlIdx: int = Form(...), @@ -131,7 +255,7 @@ class AddPipelineForm(BaseModel): urlIdx: int -@router.post("/api/pipelines/add") +@router.post("/add") async def add_pipeline( request: Request, form_data: AddPipelineForm, user=Depends(get_admin_user) ): @@ -176,7 +300,7 @@ class DeletePipelineForm(BaseModel): urlIdx: int -@router.delete("/api/pipelines/delete") +@router.delete("/delete") async def delete_pipeline( request: Request, form_data: DeletePipelineForm, user=Depends(get_admin_user) ): @@ -216,7 +340,7 @@ async def delete_pipeline( ) -@router.get("/api/pipelines") +@router.get("/") async def get_pipelines( request: Request, urlIdx: Optional[int] = None, user=Depends(get_admin_user) ): @@ -250,7 +374,7 @@ async def get_pipelines( ) -@router.get("/api/pipelines/{pipeline_id}/valves") +@router.get("/{pipeline_id}/valves") async def get_pipeline_valves( request: Request, urlIdx: Optional[int], @@ -289,7 +413,7 @@ async def get_pipeline_valves( ) -@router.get("/api/pipelines/{pipeline_id}/valves/spec") +@router.get("/{pipeline_id}/valves/spec") async def get_pipeline_valves_spec( request: Request, urlIdx: Optional[int], @@ -329,7 +453,7 @@ async def get_pipeline_valves_spec( ) -@router.post("/api/pipelines/{pipeline_id}/valves/update") +@router.post("/{pipeline_id}/valves/update") async def update_pipeline_valves( request: Request, urlIdx: Optional[int], diff --git a/backend/open_webui/routers/tasks.py b/backend/open_webui/routers/tasks.py index 4af25d4d3..13e5a95a3 100644 --- a/backend/open_webui/routers/tasks.py +++ b/backend/open_webui/routers/tasks.py @@ -1,6 +1,7 @@ from fastapi import APIRouter, Depends, HTTPException, Response, status, Request +from fastapi.responses import JSONResponse, RedirectResponse + from pydantic import BaseModel -from starlette.responses import FileResponse from typing import Optional import logging @@ -16,6 +17,9 @@ from open_webui.utils.task import ( from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.constants import TASKS +from open_webui.routers.pipelines import process_pipeline_inlet_filter +from open_webui.utils.task import get_task_model_id + from open_webui.config import ( DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE, DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE, @@ -121,9 +125,7 @@ async def update_task_config( async def generate_title( request: Request, form_data: dict, user=Depends(get_verified_user) ): - - model_list = await get_all_models() - models = {model["id"]: model for model in model_list} + models = request.app.state.MODELS model_id = form_data["model"] if model_id not in models: @@ -191,7 +193,7 @@ Artificial Intelligence in Healthcare # Handle pipeline filters try: - payload = filter_pipeline(payload, user, models) + payload = process_pipeline_inlet_filter(payload, user, models) except Exception as e: if len(e.args) > 1: return JSONResponse( @@ -220,8 +222,7 @@ async def generate_chat_tags( content={"detail": "Tags generation is disabled"}, ) - model_list = await get_all_models() - models = {model["id"]: model for model in model_list} + models = request.app.state.MODELS model_id = form_data["model"] if model_id not in models: @@ -281,7 +282,7 @@ JSON format: { "tags": ["tag1", "tag2", "tag3"] } # Handle pipeline filters try: - payload = filter_pipeline(payload, user, models) + payload = process_pipeline_inlet_filter(payload, user, models) except Exception as e: if len(e.args) > 1: return JSONResponse( @@ -318,8 +319,7 @@ async def generate_queries( detail=f"Query generation is disabled", ) - model_list = await get_all_models() - models = {model["id"]: model for model in model_list} + models = request.app.state.MODELS model_id = form_data["model"] if model_id not in models: @@ -363,7 +363,7 @@ async def generate_queries( # Handle pipeline filters try: - payload = filter_pipeline(payload, user, models) + payload = process_pipeline_inlet_filter(payload, user, models) except Exception as e: if len(e.args) > 1: return JSONResponse( @@ -405,8 +405,7 @@ async def generate_autocompletion( detail=f"Input prompt exceeds maximum length of {request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH}", ) - model_list = await get_all_models() - models = {model["id"]: model for model in model_list} + models = request.app.state.MODELS model_id = form_data["model"] if model_id not in models: @@ -450,7 +449,7 @@ async def generate_autocompletion( # Handle pipeline filters try: - payload = filter_pipeline(payload, user, models) + payload = process_pipeline_inlet_filter(payload, user, models) except Exception as e: if len(e.args) > 1: return JSONResponse( @@ -473,8 +472,7 @@ async def generate_emoji( request: Request, form_data: dict, user=Depends(get_verified_user) ): - model_list = await get_all_models() - models = {model["id"]: model for model in model_list} + models = request.app.state.MODELS model_id = form_data["model"] if model_id not in models: @@ -525,7 +523,7 @@ Message: """{{prompt}}""" # Handle pipeline filters try: - payload = filter_pipeline(payload, user, models) + payload = process_pipeline_inlet_filter(payload, user, models) except Exception as e: if len(e.args) > 1: return JSONResponse( @@ -548,10 +546,9 @@ async def generate_moa_response( request: Request, form_data: dict, user=Depends(get_verified_user) ): - model_list = await get_all_models() - models = {model["id"]: model for model in model_list} - + models = request.app.state.MODELS model_id = form_data["model"] + if model_id not in models: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, @@ -593,7 +590,7 @@ Responses from models: {{responses}}""" } try: - payload = filter_pipeline(payload, user, models) + payload = process_pipeline_inlet_filter(payload, user, models) except Exception as e: if len(e.args) > 1: return JSONResponse( diff --git a/backend/open_webui/utils/task.py b/backend/open_webui/utils/task.py index 604161a31..ebb7483ba 100644 --- a/backend/open_webui/utils/task.py +++ b/backend/open_webui/utils/task.py @@ -16,6 +16,22 @@ log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["RAG"]) +def get_task_model_id( + default_model_id: str, task_model: str, task_model_external: str, models +) -> str: + # Set the task model + task_model_id = default_model_id + # Check if the user has a custom task model and use that model + if models[task_model_id]["owned_by"] == "ollama": + if task_model and task_model in models: + task_model_id = task_model + else: + if task_model_external and task_model_external in models: + task_model_id = task_model_external + + return task_model_id + + def prompt_template( template: str, user_name: Optional[str] = None, user_location: Optional[str] = None ) -> str: From a07ff56c5010b127fc8b64f82d1f1e14d293e27a Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Wed, 11 Dec 2024 20:15:23 -0800 Subject: [PATCH 15/26] wip --- backend/open_webui/main.py | 34 +++++++------ backend/open_webui/routers/openai.py | 71 ++++++++++++++-------------- 2 files changed, 55 insertions(+), 50 deletions(-) diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index dbb9518af..a632a3874 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -1009,9 +1009,12 @@ async def get_body_and_model_and_user(request, models): class ChatCompletionMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): - if not request.method == "POST" and any( - endpoint in request.url.path - for endpoint in ["/ollama/api/chat", "/chat/completions"] + if not ( + request.method == "POST" + and any( + endpoint in request.url.path + for endpoint in ["/ollama/api/chat", "/chat/completions"] + ) ): return await call_next(request) log.debug(f"request.url.path: {request.url.path}") @@ -1214,9 +1217,12 @@ app.add_middleware(ChatCompletionMiddleware) class PipelineMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): - if not request.method == "POST" and any( - endpoint in request.url.path - for endpoint in ["/ollama/api/chat", "/chat/completions"] + if not ( + request.method == "POST" + and any( + endpoint in request.url.path + for endpoint in ["/ollama/api/chat", "/chat/completions"] + ) ): return await call_next(request) @@ -1664,17 +1670,17 @@ async def generate_function_chat_completion(form_data, user, models: dict = {}): return openai_chat_completion_message_template(form_data["model"], message) -async def get_all_base_models(): +async def get_all_base_models(request): function_models = [] openai_models = [] ollama_models = [] if app.state.config.ENABLE_OPENAI_API: - openai_models = await openai.get_all_models() + openai_models = await openai.get_all_models(request) openai_models = openai_models["data"] if app.state.config.ENABLE_OLLAMA_API: - ollama_models = await ollama.get_all_models() + ollama_models = await ollama.get_all_models(request) ollama_models = [ { "id": model["model"], @@ -1729,8 +1735,8 @@ async def get_all_base_models(): @cached(ttl=3) -async def get_all_models(): - models = await get_all_base_models() +async def get_all_models(request): + models = await get_all_base_models(request) # If there are no models, return an empty list if len([model for model in models if not model.get("arena", False)]) == 0: @@ -1859,8 +1865,8 @@ async def get_all_models(): @app.get("/api/models") -async def get_models(user=Depends(get_verified_user)): - models = await get_all_models() +async def get_models(request: Request, user=Depends(get_verified_user)): + models = await get_all_models(request) # Filter out filter pipelines models = [ @@ -2042,7 +2048,7 @@ async def generate_chat_completions( async def chat_completed( request: Request, form_data: dict, user=Depends(get_verified_user) ): - model_list = await get_all_models() + model_list = await get_all_models(request) models = {model["id"]: model for model in model_list} data = form_data diff --git a/backend/open_webui/routers/openai.py b/backend/open_webui/routers/openai.py index 34c5683a8..657f3662a 100644 --- a/backend/open_webui/routers/openai.py +++ b/backend/open_webui/routers/openai.py @@ -245,41 +245,6 @@ async def speech(request: Request, user=Depends(get_verified_user)): raise HTTPException(status_code=401, detail=ERROR_MESSAGES.OPENAI_NOT_FOUND) -def merge_models_lists(model_lists): - log.debug(f"merge_models_lists {model_lists}") - merged_list = [] - - for idx, models in enumerate(model_lists): - if models is not None and "error" not in models: - merged_list.extend( - [ - { - **model, - "name": model.get("name", model["id"]), - "owned_by": "openai", - "openai": model, - "urlIdx": idx, - } - for model in models - if "api.openai.com" - not in request.app.state.config.OPENAI_API_BASE_URLS[idx] - or not any( - name in model["id"] - for name in [ - "babbage", - "dall-e", - "davinci", - "embedding", - "tts", - "whisper", - ] - ) - ] - ) - - return merged_list - - async def get_all_models_responses(request: Request) -> list: if not request.app.state.config.ENABLE_OPENAI_API: return [] @@ -379,7 +344,7 @@ async def get_all_models(request: Request) -> dict[str, list]: if not request.app.state.config.ENABLE_OPENAI_API: return {"data": []} - responses = await get_all_models_responses() + responses = await get_all_models_responses(request) def extract_data(response): if response and "data" in response: @@ -388,6 +353,40 @@ async def get_all_models(request: Request) -> dict[str, list]: return response return None + def merge_models_lists(model_lists): + log.debug(f"merge_models_lists {model_lists}") + merged_list = [] + + for idx, models in enumerate(model_lists): + if models is not None and "error" not in models: + merged_list.extend( + [ + { + **model, + "name": model.get("name", model["id"]), + "owned_by": "openai", + "openai": model, + "urlIdx": idx, + } + for model in models + if "api.openai.com" + not in request.app.state.config.OPENAI_API_BASE_URLS[idx] + or not any( + name in model["id"] + for name in [ + "babbage", + "dall-e", + "davinci", + "embedding", + "tts", + "whisper", + ] + ) + ] + ) + + return merged_list + models = {"data": merge_models_lists(map(extract_data, responses))} log.debug(f"models: {models}") From eb9733e99ffb5c6f4eafa8d1d083ddfe40af7c02 Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Wed, 11 Dec 2024 20:25:46 -0800 Subject: [PATCH 16/26] wip --- src/lib/constants.ts | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/lib/constants.ts b/src/lib/constants.ts index 700bd3c42..d92f33671 100644 --- a/src/lib/constants.ts +++ b/src/lib/constants.ts @@ -9,9 +9,9 @@ export const WEBUI_API_BASE_URL = `${WEBUI_BASE_URL}/api/v1`; export const OLLAMA_API_BASE_URL = `${WEBUI_BASE_URL}/ollama`; export const OPENAI_API_BASE_URL = `${WEBUI_BASE_URL}/openai`; -export const AUDIO_API_BASE_URL = `${WEBUI_BASE_URL}/audio/api/v1`; -export const IMAGES_API_BASE_URL = `${WEBUI_BASE_URL}/images/api/v1`; -export const RETRIEVAL_API_BASE_URL = `${WEBUI_BASE_URL}/retrieval/api/v1`; +export const AUDIO_API_BASE_URL = `${WEBUI_BASE_URL}/api/v1/audio`; +export const IMAGES_API_BASE_URL = `${WEBUI_BASE_URL}/api/v1/images`; +export const RETRIEVAL_API_BASE_URL = `${WEBUI_BASE_URL}/api/v1/retrieval`; export const WEBUI_VERSION = APP_VERSION; export const WEBUI_BUILD_HASH = APP_BUILD_HASH; From d9ffcea764c3eeefaf2dff1b0e47a6e51c27fae7 Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Wed, 11 Dec 2024 20:26:24 -0800 Subject: [PATCH 17/26] wip --- backend/open_webui/main.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index a632a3874..c1fab6b9c 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -1914,8 +1914,8 @@ async def get_models(request: Request, user=Depends(get_verified_user)): @app.get("/api/models/base") -async def get_base_models(user=Depends(get_admin_user)): - models = await get_all_base_models() +async def get_base_models(request: Request, user=Depends(get_admin_user)): + models = await get_all_base_models(request) # Filter out arena models models = [model for model in models if not model.get("arena", False)] From 866c3dff116bd96707f8ed3f11f6331298d6857f Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Wed, 11 Dec 2024 20:39:55 -0800 Subject: [PATCH 18/26] fix --- backend/open_webui/main.py | 38 +++++++++++++++++++--------- backend/open_webui/routers/ollama.py | 18 ++++++------- backend/open_webui/routers/openai.py | 2 +- 3 files changed, 36 insertions(+), 22 deletions(-) diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index c1fab6b9c..a49c225b3 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -70,6 +70,15 @@ from open_webui.routers import ( users, utils, ) + +from open_webui.routers.openai import ( + generate_chat_completion as generate_openai_chat_completion, +) + +from open_webui.routers.ollama import ( + generate_chat_completion as generate_ollama_chat_completion, +) + from open_webui.routers.retrieval import ( get_embedding_function, get_ef, @@ -1019,8 +1028,8 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): return await call_next(request) log.debug(f"request.url.path: {request.url.path}") - model_list = await get_all_models() - models = {model["id"]: model for model in model_list} + await get_all_models(request) + models = app.state.MODELS try: body, model, user = await get_body_and_model_and_user(request, models) @@ -1257,7 +1266,7 @@ class PipelineMiddleware(BaseHTTPMiddleware): content={"detail": e.detail}, ) - await get_all_models() + await get_all_models(request) models = app.state.MODELS try: @@ -1924,6 +1933,7 @@ async def get_base_models(request: Request, user=Depends(get_admin_user)): @app.post("/api/chat/completions") async def generate_chat_completions( + request: Request, form_data: dict, user=Depends(get_verified_user), bypass_filter: bool = False, @@ -1931,8 +1941,7 @@ async def generate_chat_completions( if BYPASS_MODEL_ACCESS_CONTROL: bypass_filter = True - model_list = app.state.MODELS - models = {model["id"]: model for model in model_list} + models = app.state.MODELS model_id = form_data["model"] if model_id not in models: @@ -1981,7 +1990,7 @@ async def generate_chat_completions( if model_ids and filter_mode == "exclude": model_ids = [ model["id"] - for model in await get_all_models() + for model in await get_all_models(request) if model.get("owned_by") != "arena" and model["id"] not in model_ids ] @@ -1991,7 +2000,7 @@ async def generate_chat_completions( else: model_ids = [ model["id"] - for model in await get_all_models() + for model in await get_all_models(request) if model.get("owned_by") != "arena" ] selected_model_id = random.choice(model_ids) @@ -2028,6 +2037,7 @@ async def generate_chat_completions( # Using /ollama/api/chat endpoint form_data = convert_payload_openai_to_ollama(form_data) response = await generate_ollama_chat_completion( + request=request, form_data=form_data, user=user, bypass_filter=bypass_filter ) if form_data.stream: @@ -2040,6 +2050,8 @@ async def generate_chat_completions( return convert_response_ollama_to_openai(response) else: return await generate_openai_chat_completion( + request=request, + form_data, user=user, bypass_filter=bypass_filter ) @@ -2048,8 +2060,8 @@ async def generate_chat_completions( async def chat_completed( request: Request, form_data: dict, user=Depends(get_verified_user) ): - model_list = await get_all_models(request) - models = {model["id"]: model for model in model_list} + await get_all_models(request) + models = app.state.MODELS data = form_data model_id = data["model"] @@ -2183,7 +2195,9 @@ async def chat_completed( @app.post("/api/chat/actions/{action_id}") -async def chat_action(action_id: str, form_data: dict, user=Depends(get_verified_user)): +async def chat_action( + request: Request, action_id: str, form_data: dict, user=Depends(get_verified_user) +): if "." in action_id: action_id, sub_action_id = action_id.split(".") else: @@ -2196,8 +2210,8 @@ async def chat_action(action_id: str, form_data: dict, user=Depends(get_verified detail="Action not found", ) - model_list = await get_all_models() - models = {model["id"]: model for model in model_list} + await get_all_models(request) + models = app.state.MODELS data = form_data model_id = data["model"] diff --git a/backend/open_webui/routers/ollama.py b/backend/open_webui/routers/ollama.py index c36c2d730..233e30ce5 100644 --- a/backend/open_webui/routers/ollama.py +++ b/backend/open_webui/routers/ollama.py @@ -344,7 +344,7 @@ async def get_ollama_tags( models = [] if url_idx is None: - models = await get_all_models() + models = await get_all_models(request) else: url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] key = get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS) @@ -565,7 +565,7 @@ async def copy_model( user=Depends(get_admin_user), ): if url_idx is None: - await get_all_models() + await get_all_models(request) models = request.app.state.OLLAMA_MODELS if form_data.source in models: @@ -620,7 +620,7 @@ async def delete_model( user=Depends(get_admin_user), ): if url_idx is None: - await get_all_models() + await get_all_models(request) models = request.app.state.OLLAMA_MODELS if form_data.name in models: @@ -670,7 +670,7 @@ async def delete_model( async def show_model_info( request: Request, form_data: ModelNameForm, user=Depends(get_verified_user) ): - await get_all_models() + await get_all_models(request) models = request.app.state.OLLAMA_MODELS if form_data.name not in models: @@ -734,7 +734,7 @@ async def embed( log.info(f"generate_ollama_batch_embeddings {form_data}") if url_idx is None: - await get_all_models() + await get_all_models(request) models = request.app.state.OLLAMA_MODELS model = form_data.model @@ -803,7 +803,7 @@ async def embeddings( log.info(f"generate_ollama_embeddings {form_data}") if url_idx is None: - await get_all_models() + await get_all_models(request) models = request.app.state.OLLAMA_MODELS model = form_data.model @@ -878,8 +878,8 @@ async def generate_completion( user=Depends(get_verified_user), ): if url_idx is None: - model_list = await get_all_models() - models = {model["model"]: model for model in model_list["models"]} + await get_all_models(request) + models = request.app.state.OLLAMA_MODELS model = form_data.model @@ -1200,7 +1200,7 @@ async def get_openai_models( models = [] if url_idx is None: - model_list = await get_all_models() + model_list = await get_all_models(request) models = [ { "id": model["model"], diff --git a/backend/open_webui/routers/openai.py b/backend/open_webui/routers/openai.py index 657f3662a..f7f78be85 100644 --- a/backend/open_webui/routers/openai.py +++ b/backend/open_webui/routers/openai.py @@ -404,7 +404,7 @@ async def get_models( } if url_idx is None: - models = await get_all_models() + models = await get_all_models(request) else: url = request.app.state.config.OPENAI_API_BASE_URLS[url_idx] key = request.app.state.config.OPENAI_API_KEYS[url_idx] From 403262d7640a1418bd1d0b601f22cecd7988e477 Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Wed, 11 Dec 2024 20:40:20 -0800 Subject: [PATCH 19/26] fix --- backend/open_webui/main.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index a49c225b3..eba8e51e8 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -2037,8 +2037,7 @@ async def generate_chat_completions( # Using /ollama/api/chat endpoint form_data = convert_payload_openai_to_ollama(form_data) response = await generate_ollama_chat_completion( - request=request, - form_data=form_data, user=user, bypass_filter=bypass_filter + request=request, form_data=form_data, user=user, bypass_filter=bypass_filter ) if form_data.stream: response.headers["content-type"] = "text/event-stream" @@ -2050,9 +2049,7 @@ async def generate_chat_completions( return convert_response_ollama_to_openai(response) else: return await generate_openai_chat_completion( - request=request, - - form_data, user=user, bypass_filter=bypass_filter + request=request, form_data=form_data, user=user, bypass_filter=bypass_filter ) From 4311bb7b99a32963e0a0462e6273d1f67efddbe3 Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Thu, 12 Dec 2024 20:22:17 -0800 Subject: [PATCH 20/26] wip --- backend/open_webui/env.py | 2 +- backend/open_webui/functions.py | 315 +++++++ backend/open_webui/main.py | 1019 +++-------------------- backend/open_webui/routers/images.py | 14 +- backend/open_webui/routers/retrieval.py | 46 +- backend/open_webui/routers/tasks.py | 26 +- backend/open_webui/utils/chat.py | 380 +++++++++ backend/open_webui/utils/models.py | 222 +++++ backend/open_webui/utils/tools.py | 12 +- src/lib/apis/index.ts | 32 +- 10 files changed, 1102 insertions(+), 966 deletions(-) create mode 100644 backend/open_webui/functions.py create mode 100644 backend/open_webui/utils/chat.py create mode 100644 backend/open_webui/utils/models.py diff --git a/backend/open_webui/env.py b/backend/open_webui/env.py index e1b350ead..0fd6080de 100644 --- a/backend/open_webui/env.py +++ b/backend/open_webui/env.py @@ -376,7 +376,7 @@ else: AIOHTTP_CLIENT_TIMEOUT = 300 AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST = os.environ.get( - "AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST", "5" + "AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST", "" ) if AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST == "": diff --git a/backend/open_webui/functions.py b/backend/open_webui/functions.py new file mode 100644 index 000000000..d424d4663 --- /dev/null +++ b/backend/open_webui/functions.py @@ -0,0 +1,315 @@ +import logging +import sys +import inspect +import json + +from pydantic import BaseModel +from typing import AsyncGenerator, Generator, Iterator +from fastapi import ( + Depends, + FastAPI, + File, + Form, + HTTPException, + Request, + UploadFile, + status, +) +from starlette.responses import Response, StreamingResponse + + +from open_webui.socket.main import ( + get_event_call, + get_event_emitter, +) + + +from open_webui.models.functions import Functions +from open_webui.models.models import Models + +from open_webui.utils.plugin import load_function_module_by_id +from open_webui.utils.tools import get_tools +from open_webui.utils.access_control import has_access + +from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL + +from open_webui.utils.misc import ( + add_or_update_system_message, + get_last_user_message, + prepend_to_first_user_message_content, + openai_chat_chunk_message_template, + openai_chat_completion_message_template, +) +from open_webui.utils.payload import ( + apply_model_params_to_body_openai, + apply_model_system_prompt_to_body, +) + + +logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL) +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["MAIN"]) + + +def get_function_module_by_id(request: Request, pipe_id: str): + # Check if function is already loaded + if pipe_id not in request.app.state.FUNCTIONS: + function_module, _, _ = load_function_module_by_id(pipe_id) + request.app.state.FUNCTIONS[pipe_id] = function_module + else: + function_module = request.app.state.FUNCTIONS[pipe_id] + + if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): + valves = Functions.get_function_valves_by_id(pipe_id) + function_module.valves = function_module.Valves(**(valves if valves else {})) + return function_module + + +async def get_function_models(): + pipes = Functions.get_functions_by_type("pipe", active_only=True) + pipe_models = [] + + for pipe in pipes: + function_module = get_function_module_by_id(pipe.id) + + # Check if function is a manifold + if hasattr(function_module, "pipes"): + sub_pipes = [] + + # Check if pipes is a function or a list + + try: + if callable(function_module.pipes): + sub_pipes = function_module.pipes() + else: + sub_pipes = function_module.pipes + except Exception as e: + log.exception(e) + sub_pipes = [] + + log.debug( + f"get_function_models: function '{pipe.id}' is a manifold of {sub_pipes}" + ) + + for p in sub_pipes: + sub_pipe_id = f'{pipe.id}.{p["id"]}' + sub_pipe_name = p["name"] + + if hasattr(function_module, "name"): + sub_pipe_name = f"{function_module.name}{sub_pipe_name}" + + pipe_flag = {"type": pipe.type} + + pipe_models.append( + { + "id": sub_pipe_id, + "name": sub_pipe_name, + "object": "model", + "created": pipe.created_at, + "owned_by": "openai", + "pipe": pipe_flag, + } + ) + else: + pipe_flag = {"type": "pipe"} + + log.debug( + f"get_function_models: function '{pipe.id}' is a single pipe {{ 'id': {pipe.id}, 'name': {pipe.name} }}" + ) + + pipe_models.append( + { + "id": pipe.id, + "name": pipe.name, + "object": "model", + "created": pipe.created_at, + "owned_by": "openai", + "pipe": pipe_flag, + } + ) + + return pipe_models + + +async def generate_function_chat_completion( + request, form_data, user, models: dict = {} +): + async def execute_pipe(pipe, params): + if inspect.iscoroutinefunction(pipe): + return await pipe(**params) + else: + return pipe(**params) + + async def get_message_content(res: str | Generator | AsyncGenerator) -> str: + if isinstance(res, str): + return res + if isinstance(res, Generator): + return "".join(map(str, res)) + if isinstance(res, AsyncGenerator): + return "".join([str(stream) async for stream in res]) + + def process_line(form_data: dict, line): + if isinstance(line, BaseModel): + line = line.model_dump_json() + line = f"data: {line}" + if isinstance(line, dict): + line = f"data: {json.dumps(line)}" + + try: + line = line.decode("utf-8") + except Exception: + pass + + if line.startswith("data:"): + return f"{line}\n\n" + else: + line = openai_chat_chunk_message_template(form_data["model"], line) + return f"data: {json.dumps(line)}\n\n" + + def get_pipe_id(form_data: dict) -> str: + pipe_id = form_data["model"] + if "." in pipe_id: + pipe_id, _ = pipe_id.split(".", 1) + return pipe_id + + def get_function_params(function_module, form_data, user, extra_params=None): + if extra_params is None: + extra_params = {} + + pipe_id = get_pipe_id(form_data) + + # Get the signature of the function + sig = inspect.signature(function_module.pipe) + params = {"body": form_data} | { + k: v for k, v in extra_params.items() if k in sig.parameters + } + + if "__user__" in params and hasattr(function_module, "UserValves"): + user_valves = Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id) + try: + params["__user__"]["valves"] = function_module.UserValves(**user_valves) + except Exception as e: + log.exception(e) + params["__user__"]["valves"] = function_module.UserValves() + + return params + + model_id = form_data.get("model") + model_info = Models.get_model_by_id(model_id) + + metadata = form_data.pop("metadata", {}) + + files = metadata.get("files", []) + tool_ids = metadata.get("tool_ids", []) + # Check if tool_ids is None + if tool_ids is None: + tool_ids = [] + + __event_emitter__ = None + __event_call__ = None + __task__ = None + __task_body__ = None + + if metadata: + if all(k in metadata for k in ("session_id", "chat_id", "message_id")): + __event_emitter__ = get_event_emitter(metadata) + __event_call__ = get_event_call(metadata) + __task__ = metadata.get("task", None) + __task_body__ = metadata.get("task_body", None) + + extra_params = { + "__event_emitter__": __event_emitter__, + "__event_call__": __event_call__, + "__task__": __task__, + "__task_body__": __task_body__, + "__files__": files, + "__user__": { + "id": user.id, + "email": user.email, + "name": user.name, + "role": user.role, + }, + "__metadata__": metadata, + } + extra_params["__tools__"] = get_tools( + request, + tool_ids, + user, + { + **extra_params, + "__model__": models.get(form_data["model"], None), + "__messages__": form_data["messages"], + "__files__": files, + }, + ) + + if model_info: + if model_info.base_model_id: + form_data["model"] = model_info.base_model_id + + params = model_info.params.model_dump() + form_data = apply_model_params_to_body_openai(params, form_data) + form_data = apply_model_system_prompt_to_body(params, form_data, user) + + pipe_id = get_pipe_id(form_data) + function_module = get_function_module_by_id(pipe_id) + + pipe = function_module.pipe + params = get_function_params(function_module, form_data, user, extra_params) + + if form_data.get("stream", False): + + async def stream_content(): + try: + res = await execute_pipe(pipe, params) + + # Directly return if the response is a StreamingResponse + if isinstance(res, StreamingResponse): + async for data in res.body_iterator: + yield data + return + if isinstance(res, dict): + yield f"data: {json.dumps(res)}\n\n" + return + + except Exception as e: + log.error(f"Error: {e}") + yield f"data: {json.dumps({'error': {'detail':str(e)}})}\n\n" + return + + if isinstance(res, str): + message = openai_chat_chunk_message_template(form_data["model"], res) + yield f"data: {json.dumps(message)}\n\n" + + if isinstance(res, Iterator): + for line in res: + yield process_line(form_data, line) + + if isinstance(res, AsyncGenerator): + async for line in res: + yield process_line(form_data, line) + + if isinstance(res, str) or isinstance(res, Generator): + finish_message = openai_chat_chunk_message_template( + form_data["model"], "" + ) + finish_message["choices"][0]["finish_reason"] = "stop" + yield f"data: {json.dumps(finish_message)}\n\n" + yield "data: [DONE]" + + return StreamingResponse(stream_content(), media_type="text/event-stream") + else: + try: + res = await execute_pipe(pipe, params) + + except Exception as e: + log.error(f"Error: {e}") + return {"error": {"detail": str(e)}} + + if isinstance(res, StreamingResponse) or isinstance(res, dict): + return res + if isinstance(res, BaseModel): + return res.model_dump() + + message = await get_message_content(res) + return openai_chat_completion_message_template(form_data["model"], message) diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index eba8e51e8..ca0522289 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -8,7 +8,6 @@ import shutil import sys import time import random -from typing import AsyncGenerator, Generator, Iterator from contextlib import asynccontextmanager from urllib.parse import urlencode, parse_qs, urlparse @@ -45,7 +44,6 @@ from open_webui.socket.main import ( get_event_call, get_event_emitter, ) - from open_webui.routers import ( audio, images, @@ -71,14 +69,6 @@ from open_webui.routers import ( utils, ) -from open_webui.routers.openai import ( - generate_chat_completion as generate_openai_chat_completion, -) - -from open_webui.routers.ollama import ( - generate_chat_completion as generate_ollama_chat_completion, -) - from open_webui.routers.retrieval import ( get_embedding_function, get_ef, @@ -86,7 +76,6 @@ from open_webui.routers.retrieval import ( ) from open_webui.routers.pipelines import ( process_pipeline_inlet_filter, - process_pipeline_outlet_filter, ) from open_webui.retrieval.utils import get_sources_from_files @@ -284,6 +273,15 @@ from open_webui.env import ( OFFLINE_MODE, ) + +from open_webui.utils.models import get_all_models, get_all_base_models +from open_webui.utils.chat import ( + generate_chat_completion as chat_completion_handler, + chat_completed as chat_completed_handler, + chat_action as chat_action_handler, +) + + from open_webui.utils.plugin import load_function_module_by_id from open_webui.utils.misc import ( add_or_update_system_message, @@ -292,10 +290,7 @@ from open_webui.utils.misc import ( openai_chat_chunk_message_template, openai_chat_completion_message_template, ) -from open_webui.utils.payload import ( - apply_model_params_to_body_openai, - apply_model_system_prompt_to_body, -) + from open_webui.utils.payload import convert_payload_openai_to_ollama from open_webui.utils.response import ( @@ -772,44 +767,42 @@ async def chat_completion_filter_functions_handler(body, model, extra_params): return body, {} -def get_tools_function_calling_payload(messages, task_model_id, content): - user_message = get_last_user_message(messages) - history = "\n".join( - f"{message['role'].upper()}: \"\"\"{message['content']}\"\"\"" - for message in messages[::-1][:4] - ) - - prompt = f"History:\n{history}\nQuery: {user_message}" - - return { - "model": task_model_id, - "messages": [ - {"role": "system", "content": content}, - {"role": "user", "content": f"Query: {prompt}"}, - ], - "stream": False, - "metadata": {"task": str(TASKS.FUNCTION_CALLING)}, - } - - -async def get_content_from_response(response) -> Optional[str]: - content = None - if hasattr(response, "body_iterator"): - async for chunk in response.body_iterator: - data = json.loads(chunk.decode("utf-8")) - content = data["choices"][0]["message"]["content"] - - # Cleanup any remaining background tasks if necessary - if response.background is not None: - await response.background() - else: - content = response["choices"][0]["message"]["content"] - return content - - async def chat_completion_tools_handler( - body: dict, user: UserModel, models, extra_params: dict + request: Request, body: dict, user: UserModel, models, extra_params: dict ) -> tuple[dict, dict]: + async def get_content_from_response(response) -> Optional[str]: + content = None + if hasattr(response, "body_iterator"): + async for chunk in response.body_iterator: + data = json.loads(chunk.decode("utf-8")) + content = data["choices"][0]["message"]["content"] + + # Cleanup any remaining background tasks if necessary + if response.background is not None: + await response.background() + else: + content = response["choices"][0]["message"]["content"] + return content + + def get_tools_function_calling_payload(messages, task_model_id, content): + user_message = get_last_user_message(messages) + history = "\n".join( + f"{message['role'].upper()}: \"\"\"{message['content']}\"\"\"" + for message in messages[::-1][:4] + ) + + prompt = f"History:\n{history}\nQuery: {user_message}" + + return { + "model": task_model_id, + "messages": [ + {"role": "system", "content": content}, + {"role": "user", "content": f"Query: {prompt}"}, + ], + "stream": False, + "metadata": {"task": str(TASKS.FUNCTION_CALLING)}, + } + # If tool_ids field is present, call the functions metadata = body.get("metadata", {}) @@ -823,12 +816,12 @@ async def chat_completion_tools_handler( task_model_id = get_task_model_id( body["model"], - app.state.config.TASK_MODEL, - app.state.config.TASK_MODEL_EXTERNAL, + request.app.state.config.TASK_MODEL, + request.app.state.config.TASK_MODEL_EXTERNAL, models, ) tools = get_tools( - app, + request, tool_ids, user, { @@ -948,7 +941,7 @@ async def chat_completion_tools_handler( async def chat_completion_files_handler( - body: dict, user: UserModel + request: Request, body: dict, user: UserModel ) -> tuple[dict, dict[str, list]]: sources = [] @@ -986,36 +979,17 @@ async def chat_completion_files_handler( sources = get_sources_from_files( files=files, queries=queries, - embedding_function=app.state.EMBEDDING_FUNCTION, - k=app.state.config.TOP_K, - reranking_function=app.state.rf, - r=app.state.config.RELEVANCE_THRESHOLD, - hybrid_search=app.state.config.ENABLE_RAG_HYBRID_SEARCH, + embedding_function=request.app.state.EMBEDDING_FUNCTION, + k=request.app.state.config.TOP_K, + reranking_function=request.app.state.rf, + r=request.app.state.config.RELEVANCE_THRESHOLD, + hybrid_search=request.app.state.config.ENABLE_RAG_HYBRID_SEARCH, ) log.debug(f"rag_contexts:sources: {sources}") return body, {"sources": sources} -async def get_body_and_model_and_user(request, models): - # Read the original request body - body = await request.body() - body_str = body.decode("utf-8") - body = json.loads(body_str) if body_str else {} - - model_id = body["model"] - if model_id not in models: - raise Exception("Model not found") - model = models[model_id] - - user = get_current_user( - request, - get_http_authorization_cred(request.headers.get("Authorization")), - ) - - return body, model, user - - class ChatCompletionMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): if not ( @@ -1031,6 +1005,24 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): await get_all_models(request) models = app.state.MODELS + async def get_body_and_model_and_user(request, models): + # Read the original request body + body = await request.body() + body_str = body.decode("utf-8") + body = json.loads(body_str) if body_str else {} + + model_id = body["model"] + if model_id not in models: + raise Exception("Model not found") + model = models[model_id] + + user = get_current_user( + request, + get_http_authorization_cred(request.headers.get("Authorization")), + ) + + return body, model, user + try: body, model, user = await get_body_and_model_and_user(request, models) except Exception as e: @@ -1118,14 +1110,14 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): try: body, flags = await chat_completion_tools_handler( - body, user, models, extra_params + request, body, user, models, extra_params ) sources.extend(flags.get("sources", [])) except Exception as e: log.exception(e) try: - body, flags = await chat_completion_files_handler(body, user) + body, flags = await chat_completion_files_handler(request, body, user) sources.extend(flags.get("sources", [])) except Exception as e: log.exception(e) @@ -1378,15 +1370,12 @@ app.include_router(ollama.router, prefix="/ollama", tags=["ollama"]) app.include_router(openai.router, prefix="/openai", tags=["openai"]) -app.include_router(pipelines.router, prefix="/api/pipelines", tags=["pipelines"]) -app.include_router(tasks.router, prefix="/api/tasks", tags=["tasks"]) - - +app.include_router(pipelines.router, prefix="/api/v1/pipelines", tags=["pipelines"]) +app.include_router(tasks.router, prefix="/api/v1/tasks", tags=["tasks"]) app.include_router(images.router, prefix="/api/v1/images", tags=["images"]) app.include_router(audio.router, prefix="/api/v1/audio", tags=["audio"]) app.include_router(retrieval.router, prefix="/api/v1/retrieval", tags=["retrieval"]) - app.include_router(configs.router, prefix="/api/v1/configs", tags=["configs"]) app.include_router(auths.router, prefix="/api/v1/auths", tags=["auths"]) @@ -1417,483 +1406,9 @@ app.include_router(utils.router, prefix="/api/v1/utils", tags=["utils"]) ################################## -def get_function_module(pipe_id: str): - # Check if function is already loaded - if pipe_id not in app.state.FUNCTIONS: - function_module, _, _ = load_function_module_by_id(pipe_id) - app.state.FUNCTIONS[pipe_id] = function_module - else: - function_module = app.state.FUNCTIONS[pipe_id] - - if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): - valves = Functions.get_function_valves_by_id(pipe_id) - function_module.valves = function_module.Valves(**(valves if valves else {})) - return function_module - - -async def get_function_models(): - pipes = Functions.get_functions_by_type("pipe", active_only=True) - pipe_models = [] - - for pipe in pipes: - function_module = get_function_module(pipe.id) - - # Check if function is a manifold - if hasattr(function_module, "pipes"): - sub_pipes = [] - - # Check if pipes is a function or a list - - try: - if callable(function_module.pipes): - sub_pipes = function_module.pipes() - else: - sub_pipes = function_module.pipes - except Exception as e: - log.exception(e) - sub_pipes = [] - - log.debug( - f"get_function_models: function '{pipe.id}' is a manifold of {sub_pipes}" - ) - - for p in sub_pipes: - sub_pipe_id = f'{pipe.id}.{p["id"]}' - sub_pipe_name = p["name"] - - if hasattr(function_module, "name"): - sub_pipe_name = f"{function_module.name}{sub_pipe_name}" - - pipe_flag = {"type": pipe.type} - - pipe_models.append( - { - "id": sub_pipe_id, - "name": sub_pipe_name, - "object": "model", - "created": pipe.created_at, - "owned_by": "openai", - "pipe": pipe_flag, - } - ) - else: - pipe_flag = {"type": "pipe"} - - log.debug( - f"get_function_models: function '{pipe.id}' is a single pipe {{ 'id': {pipe.id}, 'name': {pipe.name} }}" - ) - - pipe_models.append( - { - "id": pipe.id, - "name": pipe.name, - "object": "model", - "created": pipe.created_at, - "owned_by": "openai", - "pipe": pipe_flag, - } - ) - - return pipe_models - - -async def generate_function_chat_completion(form_data, user, models: dict = {}): - async def execute_pipe(pipe, params): - if inspect.iscoroutinefunction(pipe): - return await pipe(**params) - else: - return pipe(**params) - - async def get_message_content(res: str | Generator | AsyncGenerator) -> str: - if isinstance(res, str): - return res - if isinstance(res, Generator): - return "".join(map(str, res)) - if isinstance(res, AsyncGenerator): - return "".join([str(stream) async for stream in res]) - - def process_line(form_data: dict, line): - if isinstance(line, BaseModel): - line = line.model_dump_json() - line = f"data: {line}" - if isinstance(line, dict): - line = f"data: {json.dumps(line)}" - - try: - line = line.decode("utf-8") - except Exception: - pass - - if line.startswith("data:"): - return f"{line}\n\n" - else: - line = openai_chat_chunk_message_template(form_data["model"], line) - return f"data: {json.dumps(line)}\n\n" - - def get_pipe_id(form_data: dict) -> str: - pipe_id = form_data["model"] - if "." in pipe_id: - pipe_id, _ = pipe_id.split(".", 1) - return pipe_id - - def get_function_params(function_module, form_data, user, extra_params=None): - if extra_params is None: - extra_params = {} - - pipe_id = get_pipe_id(form_data) - - # Get the signature of the function - sig = inspect.signature(function_module.pipe) - params = {"body": form_data} | { - k: v for k, v in extra_params.items() if k in sig.parameters - } - - if "__user__" in params and hasattr(function_module, "UserValves"): - user_valves = Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id) - try: - params["__user__"]["valves"] = function_module.UserValves(**user_valves) - except Exception as e: - log.exception(e) - params["__user__"]["valves"] = function_module.UserValves() - - return params - - model_id = form_data.get("model") - model_info = Models.get_model_by_id(model_id) - - metadata = form_data.pop("metadata", {}) - - files = metadata.get("files", []) - tool_ids = metadata.get("tool_ids", []) - # Check if tool_ids is None - if tool_ids is None: - tool_ids = [] - - __event_emitter__ = None - __event_call__ = None - __task__ = None - __task_body__ = None - - if metadata: - if all(k in metadata for k in ("session_id", "chat_id", "message_id")): - __event_emitter__ = get_event_emitter(metadata) - __event_call__ = get_event_call(metadata) - __task__ = metadata.get("task", None) - __task_body__ = metadata.get("task_body", None) - - extra_params = { - "__event_emitter__": __event_emitter__, - "__event_call__": __event_call__, - "__task__": __task__, - "__task_body__": __task_body__, - "__files__": files, - "__user__": { - "id": user.id, - "email": user.email, - "name": user.name, - "role": user.role, - }, - "__metadata__": metadata, - } - extra_params["__tools__"] = get_tools( - app, - tool_ids, - user, - { - **extra_params, - "__model__": models.get(form_data["model"], None), - "__messages__": form_data["messages"], - "__files__": files, - }, - ) - - if model_info: - if model_info.base_model_id: - form_data["model"] = model_info.base_model_id - - params = model_info.params.model_dump() - form_data = apply_model_params_to_body_openai(params, form_data) - form_data = apply_model_system_prompt_to_body(params, form_data, user) - - pipe_id = get_pipe_id(form_data) - function_module = get_function_module(pipe_id) - - pipe = function_module.pipe - params = get_function_params(function_module, form_data, user, extra_params) - - if form_data.get("stream", False): - - async def stream_content(): - try: - res = await execute_pipe(pipe, params) - - # Directly return if the response is a StreamingResponse - if isinstance(res, StreamingResponse): - async for data in res.body_iterator: - yield data - return - if isinstance(res, dict): - yield f"data: {json.dumps(res)}\n\n" - return - - except Exception as e: - log.error(f"Error: {e}") - yield f"data: {json.dumps({'error': {'detail':str(e)}})}\n\n" - return - - if isinstance(res, str): - message = openai_chat_chunk_message_template(form_data["model"], res) - yield f"data: {json.dumps(message)}\n\n" - - if isinstance(res, Iterator): - for line in res: - yield process_line(form_data, line) - - if isinstance(res, AsyncGenerator): - async for line in res: - yield process_line(form_data, line) - - if isinstance(res, str) or isinstance(res, Generator): - finish_message = openai_chat_chunk_message_template( - form_data["model"], "" - ) - finish_message["choices"][0]["finish_reason"] = "stop" - yield f"data: {json.dumps(finish_message)}\n\n" - yield "data: [DONE]" - - return StreamingResponse(stream_content(), media_type="text/event-stream") - else: - try: - res = await execute_pipe(pipe, params) - - except Exception as e: - log.error(f"Error: {e}") - return {"error": {"detail": str(e)}} - - if isinstance(res, StreamingResponse) or isinstance(res, dict): - return res - if isinstance(res, BaseModel): - return res.model_dump() - - message = await get_message_content(res) - return openai_chat_completion_message_template(form_data["model"], message) - - -async def get_all_base_models(request): - function_models = [] - openai_models = [] - ollama_models = [] - - if app.state.config.ENABLE_OPENAI_API: - openai_models = await openai.get_all_models(request) - openai_models = openai_models["data"] - - if app.state.config.ENABLE_OLLAMA_API: - ollama_models = await ollama.get_all_models(request) - ollama_models = [ - { - "id": model["model"], - "name": model["name"], - "object": "model", - "created": int(time.time()), - "owned_by": "ollama", - "ollama": model, - } - for model in ollama_models["models"] - ] - - function_models = await get_function_models() - models = function_models + openai_models + ollama_models - - # Add arena models - if app.state.config.ENABLE_EVALUATION_ARENA_MODELS: - arena_models = [] - if len(app.state.config.EVALUATION_ARENA_MODELS) > 0: - arena_models = [ - { - "id": model["id"], - "name": model["name"], - "info": { - "meta": model["meta"], - }, - "object": "model", - "created": int(time.time()), - "owned_by": "arena", - "arena": True, - } - for model in app.state.config.EVALUATION_ARENA_MODELS - ] - else: - # Add default arena model - arena_models = [ - { - "id": DEFAULT_ARENA_MODEL["id"], - "name": DEFAULT_ARENA_MODEL["name"], - "info": { - "meta": DEFAULT_ARENA_MODEL["meta"], - }, - "object": "model", - "created": int(time.time()), - "owned_by": "arena", - "arena": True, - } - ] - models = models + arena_models - - return models - - -@cached(ttl=3) -async def get_all_models(request): - models = await get_all_base_models(request) - - # If there are no models, return an empty list - if len([model for model in models if not model.get("arena", False)]) == 0: - return [] - - global_action_ids = [ - function.id for function in Functions.get_global_action_functions() - ] - enabled_action_ids = [ - function.id - for function in Functions.get_functions_by_type("action", active_only=True) - ] - - custom_models = Models.get_all_models() - for custom_model in custom_models: - if custom_model.base_model_id is None: - for model in models: - if ( - custom_model.id == model["id"] - or custom_model.id == model["id"].split(":")[0] - ): - if custom_model.is_active: - model["name"] = custom_model.name - model["info"] = custom_model.model_dump() - - action_ids = [] - if "info" in model and "meta" in model["info"]: - action_ids.extend( - model["info"]["meta"].get("actionIds", []) - ) - - model["action_ids"] = action_ids - else: - models.remove(model) - - elif custom_model.is_active and ( - custom_model.id not in [model["id"] for model in models] - ): - owned_by = "openai" - pipe = None - action_ids = [] - - for model in models: - if ( - custom_model.base_model_id == model["id"] - or custom_model.base_model_id == model["id"].split(":")[0] - ): - owned_by = model["owned_by"] - if "pipe" in model: - pipe = model["pipe"] - break - - if custom_model.meta: - meta = custom_model.meta.model_dump() - if "actionIds" in meta: - action_ids.extend(meta["actionIds"]) - - models.append( - { - "id": f"{custom_model.id}", - "name": custom_model.name, - "object": "model", - "created": custom_model.created_at, - "owned_by": owned_by, - "info": custom_model.model_dump(), - "preset": True, - **({"pipe": pipe} if pipe is not None else {}), - "action_ids": action_ids, - } - ) - - # Process action_ids to get the actions - def get_action_items_from_module(function, module): - actions = [] - if hasattr(module, "actions"): - actions = module.actions - return [ - { - "id": f"{function.id}.{action['id']}", - "name": action.get("name", f"{function.name} ({action['id']})"), - "description": function.meta.description, - "icon_url": action.get( - "icon_url", function.meta.manifest.get("icon_url", None) - ), - } - for action in actions - ] - else: - return [ - { - "id": function.id, - "name": function.name, - "description": function.meta.description, - "icon_url": function.meta.manifest.get("icon_url", None), - } - ] - - def get_function_module_by_id(function_id): - if function_id in app.state.FUNCTIONS: - function_module = app.state.FUNCTIONS[function_id] - else: - function_module, _, _ = load_function_module_by_id(function_id) - app.state.FUNCTIONS[function_id] = function_module - - for model in models: - action_ids = [ - action_id - for action_id in list(set(model.pop("action_ids", []) + global_action_ids)) - if action_id in enabled_action_ids - ] - - model["actions"] = [] - for action_id in action_ids: - action_function = Functions.get_function_by_id(action_id) - if action_function is None: - raise Exception(f"Action not found: {action_id}") - - function_module = get_function_module_by_id(action_id) - model["actions"].extend( - get_action_items_from_module(action_function, function_module) - ) - log.debug(f"get_all_models() returned {len(models)} models") - - app.state.MODELS = {model["id"]: model for model in models} - return models - - @app.get("/api/models") async def get_models(request: Request, user=Depends(get_verified_user)): - models = await get_all_models(request) - - # Filter out filter pipelines - models = [ - model - for model in models - if "pipeline" not in model or model["pipeline"].get("type", None) != "filter" - ] - - model_order_list = app.state.config.MODEL_ORDER_LIST - if model_order_list: - model_order_dict = {model_id: i for i, model_id in enumerate(model_order_list)} - # Sort models by order list priority, with fallback for those not in the list - models.sort( - key=lambda x: (model_order_dict.get(x["id"], float("inf")), x["name"]) - ) - - # Filter out models that the user does not have access to - if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL: + def get_filtered_models(models, user): filtered_models = [] for model in models: if model.get("arena"): @@ -1913,393 +1428,87 @@ async def get_models(request: Request, user=Depends(get_verified_user)): user.id, type="read", access_control=model_info.access_control ): filtered_models.append(model) - models = filtered_models + + return filtered_models + + models = await get_all_models(request) + + # Filter out filter pipelines + models = [ + model + for model in models + if "pipeline" not in model or model["pipeline"].get("type", None) != "filter" + ] + + model_order_list = request.app.state.config.MODEL_ORDER_LIST + if model_order_list: + model_order_dict = {model_id: i for i, model_id in enumerate(model_order_list)} + # Sort models by order list priority, with fallback for those not in the list + models.sort( + key=lambda x: (model_order_dict.get(x["id"], float("inf")), x["name"]) + ) + + # Filter out models that the user does not have access to + if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL: + models = get_filtered_models(models, user) log.debug( f"/api/models returned filtered models accessible to the user: {json.dumps([model['id'] for model in models])}" ) - return {"data": models} @app.get("/api/models/base") async def get_base_models(request: Request, user=Depends(get_admin_user)): models = await get_all_base_models(request) - - # Filter out arena models - models = [model for model in models if not model.get("arena", False)] return {"data": models} @app.post("/api/chat/completions") -async def generate_chat_completions( +async def chat_completion( request: Request, form_data: dict, user=Depends(get_verified_user), bypass_filter: bool = False, ): - if BYPASS_MODEL_ACCESS_CONTROL: - bypass_filter = True - - models = app.state.MODELS - - model_id = form_data["model"] - if model_id not in models: + try: + return await chat_completion_handler(request, form_data, user, bypass_filter) + except Exception as e: raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Model not found", + status_code=status.HTTP_400_BAD_REQUEST, + detail=str(e), ) - model = models[model_id] - # Check if user has access to the model - if not bypass_filter and user.role == "user": - if model.get("arena"): - if not has_access( - user.id, - type="read", - access_control=model.get("info", {}) - .get("meta", {}) - .get("access_control", {}), - ): - raise HTTPException( - status_code=403, - detail="Model not found", - ) - else: - model_info = Models.get_model_by_id(model_id) - if not model_info: - raise HTTPException( - status_code=404, - detail="Model not found", - ) - elif not ( - user.id == model_info.user_id - or has_access( - user.id, type="read", access_control=model_info.access_control - ) - ): - raise HTTPException( - status_code=403, - detail="Model not found", - ) - - if model["owned_by"] == "arena": - model_ids = model.get("info", {}).get("meta", {}).get("model_ids") - filter_mode = model.get("info", {}).get("meta", {}).get("filter_mode") - if model_ids and filter_mode == "exclude": - model_ids = [ - model["id"] - for model in await get_all_models(request) - if model.get("owned_by") != "arena" and model["id"] not in model_ids - ] - - selected_model_id = None - if isinstance(model_ids, list) and model_ids: - selected_model_id = random.choice(model_ids) - else: - model_ids = [ - model["id"] - for model in await get_all_models(request) - if model.get("owned_by") != "arena" - ] - selected_model_id = random.choice(model_ids) - - form_data["model"] = selected_model_id - - if form_data.get("stream") == True: - - async def stream_wrapper(stream): - yield f"data: {json.dumps({'selected_model_id': selected_model_id})}\n\n" - async for chunk in stream: - yield chunk - - response = await generate_chat_completions( - form_data, user, bypass_filter=True - ) - return StreamingResponse( - stream_wrapper(response.body_iterator), media_type="text/event-stream" - ) - else: - return { - **( - await generate_chat_completions(form_data, user, bypass_filter=True) - ), - "selected_model_id": selected_model_id, - } - - if model.get("pipe"): - # Below does not require bypass_filter because this is the only route the uses this function and it is already bypassing the filter - return await generate_function_chat_completion( - form_data, user=user, models=models - ) - if model["owned_by"] == "ollama": - # Using /ollama/api/chat endpoint - form_data = convert_payload_openai_to_ollama(form_data) - response = await generate_ollama_chat_completion( - request=request, form_data=form_data, user=user, bypass_filter=bypass_filter - ) - if form_data.stream: - response.headers["content-type"] = "text/event-stream" - return StreamingResponse( - convert_streaming_response_ollama_to_openai(response), - headers=dict(response.headers), - ) - else: - return convert_response_ollama_to_openai(response) - else: - return await generate_openai_chat_completion( - request=request, form_data=form_data, user=user, bypass_filter=bypass_filter - ) +generate_chat_completions = chat_completion +generate_chat_completion = chat_completion @app.post("/api/chat/completed") async def chat_completed( request: Request, form_data: dict, user=Depends(get_verified_user) ): - await get_all_models(request) - models = app.state.MODELS - - data = form_data - model_id = data["model"] - if model_id not in models: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Model not found", - ) - - model = models[model_id] - try: - data = process_pipeline_outlet_filter(request, data, user, models) + return await chat_completed_handler(request, form_data, user) except Exception as e: - return HTTPException( + raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=str(e), ) - __event_emitter__ = get_event_emitter( - { - "chat_id": data["chat_id"], - "message_id": data["id"], - "session_id": data["session_id"], - } - ) - - __event_call__ = get_event_call( - { - "chat_id": data["chat_id"], - "message_id": data["id"], - "session_id": data["session_id"], - } - ) - - def get_priority(function_id): - function = Functions.get_function_by_id(function_id) - if function is not None and hasattr(function, "valves"): - # TODO: Fix FunctionModel to include vavles - return (function.valves if function.valves else {}).get("priority", 0) - return 0 - - filter_ids = [function.id for function in Functions.get_global_filter_functions()] - if "info" in model and "meta" in model["info"]: - filter_ids.extend(model["info"]["meta"].get("filterIds", [])) - filter_ids = list(set(filter_ids)) - - enabled_filter_ids = [ - function.id - for function in Functions.get_functions_by_type("filter", active_only=True) - ] - filter_ids = [ - filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids - ] - - # Sort filter_ids by priority, using the get_priority function - filter_ids.sort(key=get_priority) - - for filter_id in filter_ids: - filter = Functions.get_function_by_id(filter_id) - if not filter: - continue - - if filter_id in app.state.FUNCTIONS: - function_module = app.state.FUNCTIONS[filter_id] - else: - function_module, _, _ = load_function_module_by_id(filter_id) - app.state.FUNCTIONS[filter_id] = function_module - - if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): - valves = Functions.get_function_valves_by_id(filter_id) - function_module.valves = function_module.Valves( - **(valves if valves else {}) - ) - - if not hasattr(function_module, "outlet"): - continue - try: - outlet = function_module.outlet - - # Get the signature of the function - sig = inspect.signature(outlet) - params = {"body": data} - - # Extra parameters to be passed to the function - extra_params = { - "__model__": model, - "__id__": filter_id, - "__event_emitter__": __event_emitter__, - "__event_call__": __event_call__, - } - - # Add extra params in contained in function signature - for key, value in extra_params.items(): - if key in sig.parameters: - params[key] = value - - if "__user__" in sig.parameters: - __user__ = { - "id": user.id, - "email": user.email, - "name": user.name, - "role": user.role, - } - - try: - if hasattr(function_module, "UserValves"): - __user__["valves"] = function_module.UserValves( - **Functions.get_user_valves_by_id_and_user_id( - filter_id, user.id - ) - ) - except Exception as e: - print(e) - - params = {**params, "__user__": __user__} - - if inspect.iscoroutinefunction(outlet): - data = await outlet(**params) - else: - data = outlet(**params) - - except Exception as e: - print(f"Error: {e}") - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content={"detail": str(e)}, - ) - - return data - @app.post("/api/chat/actions/{action_id}") async def chat_action( request: Request, action_id: str, form_data: dict, user=Depends(get_verified_user) ): - if "." in action_id: - action_id, sub_action_id = action_id.split(".") - else: - sub_action_id = None - - action = Functions.get_function_by_id(action_id) - if not action: + try: + return await chat_action_handler(request, action_id, form_data, user) + except Exception as e: raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Action not found", + status_code=status.HTTP_400_BAD_REQUEST, + detail=str(e), ) - await get_all_models(request) - models = app.state.MODELS - - data = form_data - model_id = data["model"] - - if model_id not in models: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Model not found", - ) - model = models[model_id] - - __event_emitter__ = get_event_emitter( - { - "chat_id": data["chat_id"], - "message_id": data["id"], - "session_id": data["session_id"], - } - ) - __event_call__ = get_event_call( - { - "chat_id": data["chat_id"], - "message_id": data["id"], - "session_id": data["session_id"], - } - ) - - if action_id in app.state.FUNCTIONS: - function_module = app.state.FUNCTIONS[action_id] - else: - function_module, _, _ = load_function_module_by_id(action_id) - app.state.FUNCTIONS[action_id] = function_module - - if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): - valves = Functions.get_function_valves_by_id(action_id) - function_module.valves = function_module.Valves(**(valves if valves else {})) - - if hasattr(function_module, "action"): - try: - action = function_module.action - - # Get the signature of the function - sig = inspect.signature(action) - params = {"body": data} - - # Extra parameters to be passed to the function - extra_params = { - "__model__": model, - "__id__": sub_action_id if sub_action_id is not None else action_id, - "__event_emitter__": __event_emitter__, - "__event_call__": __event_call__, - } - - # Add extra params in contained in function signature - for key, value in extra_params.items(): - if key in sig.parameters: - params[key] = value - - if "__user__" in sig.parameters: - __user__ = { - "id": user.id, - "email": user.email, - "name": user.name, - "role": user.role, - } - - try: - if hasattr(function_module, "UserValves"): - __user__["valves"] = function_module.UserValves( - **Functions.get_user_valves_by_id_and_user_id( - action_id, user.id - ) - ) - except Exception as e: - print(e) - - params = {**params, "__user__": __user__} - - if inspect.iscoroutinefunction(action): - data = await action(**params) - else: - data = action(**params) - - except Exception as e: - print(f"Error: {e}") - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content={"detail": str(e)}, - ) - - return data - ################################## # diff --git a/backend/open_webui/routers/images.py b/backend/open_webui/routers/images.py index 0deded03e..2beec59f7 100644 --- a/backend/open_webui/routers/images.py +++ b/backend/open_webui/routers/images.py @@ -41,7 +41,7 @@ router = APIRouter() @router.get("/config") async def get_config(request: Request, user=Depends(get_admin_user)): return { - "enabled": request.app.state.config.ENABLED, + "enabled": request.app.state.config.ENABLE_IMAGE_GENERATION, "engine": request.app.state.config.ENGINE, "openai": { "OPENAI_API_BASE_URL": request.app.state.config.OPENAI_API_BASE_URL, @@ -94,7 +94,7 @@ async def update_config( request: Request, form_data: ConfigForm, user=Depends(get_admin_user) ): request.app.state.config.ENGINE = form_data.engine - request.app.state.config.ENABLED = form_data.enabled + request.app.state.config.ENABLE_IMAGE_GENERATION = form_data.enabled request.app.state.config.OPENAI_API_BASE_URL = form_data.openai.OPENAI_API_BASE_URL request.app.state.config.OPENAI_API_KEY = form_data.openai.OPENAI_API_KEY @@ -131,7 +131,7 @@ async def update_config( ) return { - "enabled": request.app.state.config.ENABLED, + "enabled": request.app.state.config.ENABLE_IMAGE_GENERATION, "engine": request.app.state.config.ENGINE, "openai": { "OPENAI_API_BASE_URL": request.app.state.config.OPENAI_API_BASE_URL, @@ -175,7 +175,7 @@ async def verify_url(request: Request, user=Depends(get_admin_user)): r.raise_for_status() return True except Exception: - request.app.state.config.ENABLED = False + request.app.state.config.ENABLE_IMAGE_GENERATION = False raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL) elif request.app.state.config.ENGINE == "comfyui": try: @@ -185,7 +185,7 @@ async def verify_url(request: Request, user=Depends(get_admin_user)): r.raise_for_status() return True except Exception: - request.app.state.config.ENABLED = False + request.app.state.config.ENABLE_IMAGE_GENERATION = False raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL) else: return True @@ -232,7 +232,7 @@ def get_image_model(): options = r.json() return options["sd_model_checkpoint"] except Exception as e: - request.app.state.config.ENABLED = False + request.app.state.config.ENABLE_IMAGE_GENERATION = False raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) @@ -351,7 +351,7 @@ def get_models(request: Request, user=Depends(get_verified_user)): ) ) except Exception as e: - request.app.state.config.ENABLED = False + request.app.state.config.ENABLE_IMAGE_GENERATION = False raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) diff --git a/backend/open_webui/routers/retrieval.py b/backend/open_webui/routers/retrieval.py index c40208ac1..e577f70f1 100644 --- a/backend/open_webui/routers/retrieval.py +++ b/backend/open_webui/routers/retrieval.py @@ -195,12 +195,12 @@ async def get_embedding_config(request: Request, user=Depends(get_admin_user)): "embedding_model": request.app.state.config.RAG_EMBEDDING_MODEL, "embedding_batch_size": request.app.state.config.RAG_EMBEDDING_BATCH_SIZE, "openai_config": { - "url": request.app.state.config.OPENAI_API_BASE_URL, - "key": request.app.state.config.OPENAI_API_KEY, + "url": request.app.state.config.RAG_OPENAI_API_BASE_URL, + "key": request.app.state.config.RAG_OPENAI_API_KEY, }, "ollama_config": { - "url": request.app.state.config.OLLAMA_BASE_URL, - "key": request.app.state.config.OLLAMA_API_KEY, + "url": request.app.state.config.RAG_OLLAMA_BASE_URL, + "key": request.app.state.config.RAG_OLLAMA_API_KEY, }, } @@ -244,14 +244,20 @@ async def update_embedding_config( if request.app.state.config.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]: if form_data.openai_config is not None: - request.app.state.config.OPENAI_API_BASE_URL = ( + request.app.state.config.RAG_OPENAI_API_BASE_URL = ( form_data.openai_config.url ) - request.app.state.config.OPENAI_API_KEY = form_data.openai_config.key + request.app.state.config.RAG_OPENAI_API_KEY = ( + form_data.openai_config.key + ) if form_data.ollama_config is not None: - request.app.state.config.OLLAMA_BASE_URL = form_data.ollama_config.url - request.app.state.config.OLLAMA_API_KEY = form_data.ollama_config.key + request.app.state.config.RAG_OLLAMA_BASE_URL = ( + form_data.ollama_config.url + ) + request.app.state.config.RAG_OLLAMA_API_KEY = ( + form_data.ollama_config.key + ) request.app.state.config.RAG_EMBEDDING_BATCH_SIZE = ( form_data.embedding_batch_size @@ -267,14 +273,14 @@ async def update_embedding_config( request.app.state.config.RAG_EMBEDDING_MODEL, request.app.state.ef, ( - request.app.state.config.OPENAI_API_BASE_URL + request.app.state.config.RAG_OPENAI_API_BASE_URL if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai" - else request.app.state.config.OLLAMA_BASE_URL + else request.app.state.config.RAG_OLLAMA_BASE_URL ), ( - request.app.state.config.OPENAI_API_KEY + request.app.state.config.RAG_OPENAI_API_KEY if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai" - else request.app.state.config.OLLAMA_API_KEY + else request.app.state.config.RAG_OLLAMA_API_KEY ), request.app.state.config.RAG_EMBEDDING_BATCH_SIZE, ) @@ -285,12 +291,12 @@ async def update_embedding_config( "embedding_model": request.app.state.config.RAG_EMBEDDING_MODEL, "embedding_batch_size": request.app.state.config.RAG_EMBEDDING_BATCH_SIZE, "openai_config": { - "url": request.app.state.config.OPENAI_API_BASE_URL, - "key": request.app.state.config.OPENAI_API_KEY, + "url": request.app.state.config.RAG_OPENAI_API_BASE_URL, + "key": request.app.state.config.RAG_OPENAI_API_KEY, }, "ollama_config": { - "url": request.app.state.config.OLLAMA_BASE_URL, - "key": request.app.state.config.OLLAMA_API_KEY, + "url": request.app.state.config.RAG_OLLAMA_BASE_URL, + "key": request.app.state.config.RAG_OLLAMA_API_KEY, }, } except Exception as e: @@ -747,14 +753,14 @@ def save_docs_to_vector_db( request.app.state.config.RAG_EMBEDDING_MODEL, request.app.state.ef, ( - request.app.state.config.OPENAI_API_BASE_URL + request.app.state.config.RAG_OPENAI_API_BASE_URL if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai" - else request.app.state.config.OLLAMA_BASE_URL + else request.app.state.config.RAG_OLLAMA_BASE_URL ), ( - request.app.state.config.OPENAI_API_KEY + request.app.state.config.RAG_OPENAI_API_KEY if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai" - else request.app.state.config.OLLAMA_API_KEY + else request.app.state.config.RAG_OLLAMA_API_KEY ), request.app.state.config.RAG_EMBEDDING_BATCH_SIZE, ) diff --git a/backend/open_webui/routers/tasks.py b/backend/open_webui/routers/tasks.py index 13e5a95a3..425a3ef02 100644 --- a/backend/open_webui/routers/tasks.py +++ b/backend/open_webui/routers/tasks.py @@ -5,7 +5,7 @@ from pydantic import BaseModel from typing import Optional import logging - +from open_webui.utils.chat import generate_chat_completion from open_webui.utils.task import ( title_generation_template, query_generation_template, @@ -193,7 +193,7 @@ Artificial Intelligence in Healthcare # Handle pipeline filters try: - payload = process_pipeline_inlet_filter(payload, user, models) + payload = process_pipeline_inlet_filter(request, payload, user, models) except Exception as e: if len(e.args) > 1: return JSONResponse( @@ -208,7 +208,7 @@ Artificial Intelligence in Healthcare if "chat_id" in payload: del payload["chat_id"] - return await generate_chat_completions(form_data=payload, user=user) + return await generate_chat_completion(request, form_data=payload, user=user) @router.post("/tags/completions") @@ -282,7 +282,7 @@ JSON format: { "tags": ["tag1", "tag2", "tag3"] } # Handle pipeline filters try: - payload = process_pipeline_inlet_filter(payload, user, models) + payload = process_pipeline_inlet_filter(request, payload, user, models) except Exception as e: if len(e.args) > 1: return JSONResponse( @@ -297,7 +297,7 @@ JSON format: { "tags": ["tag1", "tag2", "tag3"] } if "chat_id" in payload: del payload["chat_id"] - return await generate_chat_completions(form_data=payload, user=user) + return await generate_chat_completion(request, form_data=payload, user=user) @router.post("/queries/completions") @@ -363,7 +363,7 @@ async def generate_queries( # Handle pipeline filters try: - payload = process_pipeline_inlet_filter(payload, user, models) + payload = process_pipeline_inlet_filter(request, payload, user, models) except Exception as e: if len(e.args) > 1: return JSONResponse( @@ -378,7 +378,7 @@ async def generate_queries( if "chat_id" in payload: del payload["chat_id"] - return await generate_chat_completions(form_data=payload, user=user) + return await generate_chat_completion(request, form_data=payload, user=user) @router.post("/auto/completions") @@ -449,7 +449,7 @@ async def generate_autocompletion( # Handle pipeline filters try: - payload = process_pipeline_inlet_filter(payload, user, models) + payload = process_pipeline_inlet_filter(request, payload, user, models) except Exception as e: if len(e.args) > 1: return JSONResponse( @@ -464,7 +464,7 @@ async def generate_autocompletion( if "chat_id" in payload: del payload["chat_id"] - return await generate_chat_completions(form_data=payload, user=user) + return await generate_chat_completion(request, form_data=payload, user=user) @router.post("/emoji/completions") @@ -523,7 +523,7 @@ Message: """{{prompt}}""" # Handle pipeline filters try: - payload = process_pipeline_inlet_filter(payload, user, models) + payload = process_pipeline_inlet_filter(request, payload, user, models) except Exception as e: if len(e.args) > 1: return JSONResponse( @@ -538,7 +538,7 @@ Message: """{{prompt}}""" if "chat_id" in payload: del payload["chat_id"] - return await generate_chat_completions(form_data=payload, user=user) + return await generate_chat_completion(request, form_data=payload, user=user) @router.post("/moa/completions") @@ -590,7 +590,7 @@ Responses from models: {{responses}}""" } try: - payload = process_pipeline_inlet_filter(payload, user, models) + payload = process_pipeline_inlet_filter(request, payload, user, models) except Exception as e: if len(e.args) > 1: return JSONResponse( @@ -605,4 +605,4 @@ Responses from models: {{responses}}""" if "chat_id" in payload: del payload["chat_id"] - return await generate_chat_completions(form_data=payload, user=user) + return await generate_chat_completion(request, form_data=payload, user=user) diff --git a/backend/open_webui/utils/chat.py b/backend/open_webui/utils/chat.py new file mode 100644 index 000000000..f127e6eb4 --- /dev/null +++ b/backend/open_webui/utils/chat.py @@ -0,0 +1,380 @@ +import time +import logging +import sys + +from aiocache import cached +from typing import Any +import random +import json +import inspect + +from fastapi import Request +from starlette.responses import Response, StreamingResponse + +from open_webui.socket.main import ( + get_event_call, + get_event_emitter, +) + +from open_webui.functions import generate_function_chat_completion + +from open_webui.routers.openai import ( + generate_chat_completion as generate_openai_chat_completion, +) + +from open_webui.routers.ollama import ( + generate_chat_completion as generate_ollama_chat_completion, +) + +from open_webui.routers.pipelines import ( + process_pipeline_outlet_filter, +) + + +from open_webui.models.functions import Functions +from open_webui.models.models import Models + + +from open_webui.utils.plugin import load_function_module_by_id +from open_webui.utils.access_control import has_access +from open_webui.utils.models import get_all_models +from open_webui.utils.payload import convert_payload_openai_to_ollama +from open_webui.utils.response import ( + convert_response_ollama_to_openai, + convert_streaming_response_ollama_to_openai, +) +from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL, BYPASS_MODEL_ACCESS_CONTROL + + +logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL) +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["MAIN"]) + + +async def generate_chat_completion( + request: Request, + form_data: dict, + user: Any, + bypass_filter: bool = False, +): + if BYPASS_MODEL_ACCESS_CONTROL: + bypass_filter = True + + models = request.app.state.MODELS + + model_id = form_data["model"] + if model_id not in models: + raise Exception("Model not found") + + model = models[model_id] + + # Check if user has access to the model + if not bypass_filter and user.role == "user": + if model.get("arena"): + if not has_access( + user.id, + type="read", + access_control=model.get("info", {}) + .get("meta", {}) + .get("access_control", {}), + ): + raise Exception("Model not found") + else: + model_info = Models.get_model_by_id(model_id) + if not model_info: + raise Exception("Model not found") + elif not ( + user.id == model_info.user_id + or has_access( + user.id, type="read", access_control=model_info.access_control + ) + ): + raise Exception("Model not found") + + if model["owned_by"] == "arena": + model_ids = model.get("info", {}).get("meta", {}).get("model_ids") + filter_mode = model.get("info", {}).get("meta", {}).get("filter_mode") + if model_ids and filter_mode == "exclude": + model_ids = [ + model["id"] + for model in await get_all_models(request) + if model.get("owned_by") != "arena" and model["id"] not in model_ids + ] + + selected_model_id = None + if isinstance(model_ids, list) and model_ids: + selected_model_id = random.choice(model_ids) + else: + model_ids = [ + model["id"] + for model in await get_all_models(request) + if model.get("owned_by") != "arena" + ] + selected_model_id = random.choice(model_ids) + + form_data["model"] = selected_model_id + + if form_data.get("stream") == True: + + async def stream_wrapper(stream): + yield f"data: {json.dumps({'selected_model_id': selected_model_id})}\n\n" + async for chunk in stream: + yield chunk + + response = await generate_chat_completion( + form_data, user, bypass_filter=True + ) + return StreamingResponse( + stream_wrapper(response.body_iterator), media_type="text/event-stream" + ) + else: + return { + **(await generate_chat_completion(form_data, user, bypass_filter=True)), + "selected_model_id": selected_model_id, + } + + if model.get("pipe"): + # Below does not require bypass_filter because this is the only route the uses this function and it is already bypassing the filter + return await generate_function_chat_completion( + form_data, user=user, models=models + ) + if model["owned_by"] == "ollama": + # Using /ollama/api/chat endpoint + form_data = convert_payload_openai_to_ollama(form_data) + response = await generate_ollama_chat_completion( + request=request, form_data=form_data, user=user, bypass_filter=bypass_filter + ) + if form_data.stream: + response.headers["content-type"] = "text/event-stream" + return StreamingResponse( + convert_streaming_response_ollama_to_openai(response), + headers=dict(response.headers), + ) + else: + return convert_response_ollama_to_openai(response) + else: + return await generate_openai_chat_completion( + request=request, form_data=form_data, user=user, bypass_filter=bypass_filter + ) + + +async def chat_completed(request: Request, form_data: dict, user: Any): + await get_all_models(request) + models = request.app.state.MODELS + + data = form_data + model_id = data["model"] + if model_id not in models: + raise Exception("Model not found") + + model = models[model_id] + + try: + data = process_pipeline_outlet_filter(request, data, user, models) + except Exception as e: + return Exception(f"Error: {e}") + + __event_emitter__ = get_event_emitter( + { + "chat_id": data["chat_id"], + "message_id": data["id"], + "session_id": data["session_id"], + } + ) + + __event_call__ = get_event_call( + { + "chat_id": data["chat_id"], + "message_id": data["id"], + "session_id": data["session_id"], + } + ) + + def get_priority(function_id): + function = Functions.get_function_by_id(function_id) + if function is not None and hasattr(function, "valves"): + # TODO: Fix FunctionModel to include vavles + return (function.valves if function.valves else {}).get("priority", 0) + return 0 + + filter_ids = [function.id for function in Functions.get_global_filter_functions()] + if "info" in model and "meta" in model["info"]: + filter_ids.extend(model["info"]["meta"].get("filterIds", [])) + filter_ids = list(set(filter_ids)) + + enabled_filter_ids = [ + function.id + for function in Functions.get_functions_by_type("filter", active_only=True) + ] + filter_ids = [ + filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids + ] + + # Sort filter_ids by priority, using the get_priority function + filter_ids.sort(key=get_priority) + + for filter_id in filter_ids: + filter = Functions.get_function_by_id(filter_id) + if not filter: + continue + + if filter_id in request.app.state.FUNCTIONS: + function_module = request.app.state.FUNCTIONS[filter_id] + else: + function_module, _, _ = load_function_module_by_id(filter_id) + request.app.state.FUNCTIONS[filter_id] = function_module + + if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): + valves = Functions.get_function_valves_by_id(filter_id) + function_module.valves = function_module.Valves( + **(valves if valves else {}) + ) + + if not hasattr(function_module, "outlet"): + continue + try: + outlet = function_module.outlet + + # Get the signature of the function + sig = inspect.signature(outlet) + params = {"body": data} + + # Extra parameters to be passed to the function + extra_params = { + "__model__": model, + "__id__": filter_id, + "__event_emitter__": __event_emitter__, + "__event_call__": __event_call__, + } + + # Add extra params in contained in function signature + for key, value in extra_params.items(): + if key in sig.parameters: + params[key] = value + + if "__user__" in sig.parameters: + __user__ = { + "id": user.id, + "email": user.email, + "name": user.name, + "role": user.role, + } + + try: + if hasattr(function_module, "UserValves"): + __user__["valves"] = function_module.UserValves( + **Functions.get_user_valves_by_id_and_user_id( + filter_id, user.id + ) + ) + except Exception as e: + print(e) + + params = {**params, "__user__": __user__} + + if inspect.iscoroutinefunction(outlet): + data = await outlet(**params) + else: + data = outlet(**params) + + except Exception as e: + return Exception(f"Error: {e}") + + return data + + +async def chat_action(request: Request, action_id: str, form_data: dict, user: Any): + if "." in action_id: + action_id, sub_action_id = action_id.split(".") + else: + sub_action_id = None + + action = Functions.get_function_by_id(action_id) + if not action: + raise Exception(f"Action not found: {action_id}") + + await get_all_models(request) + models = request.app.state.MODELS + + data = form_data + model_id = data["model"] + + if model_id not in models: + raise Exception("Model not found") + model = models[model_id] + + __event_emitter__ = get_event_emitter( + { + "chat_id": data["chat_id"], + "message_id": data["id"], + "session_id": data["session_id"], + } + ) + __event_call__ = get_event_call( + { + "chat_id": data["chat_id"], + "message_id": data["id"], + "session_id": data["session_id"], + } + ) + + if action_id in request.app.state.FUNCTIONS: + function_module = request.app.state.FUNCTIONS[action_id] + else: + function_module, _, _ = load_function_module_by_id(action_id) + request.app.state.FUNCTIONS[action_id] = function_module + + if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): + valves = Functions.get_function_valves_by_id(action_id) + function_module.valves = function_module.Valves(**(valves if valves else {})) + + if hasattr(function_module, "action"): + try: + action = function_module.action + + # Get the signature of the function + sig = inspect.signature(action) + params = {"body": data} + + # Extra parameters to be passed to the function + extra_params = { + "__model__": model, + "__id__": sub_action_id if sub_action_id is not None else action_id, + "__event_emitter__": __event_emitter__, + "__event_call__": __event_call__, + } + + # Add extra params in contained in function signature + for key, value in extra_params.items(): + if key in sig.parameters: + params[key] = value + + if "__user__" in sig.parameters: + __user__ = { + "id": user.id, + "email": user.email, + "name": user.name, + "role": user.role, + } + + try: + if hasattr(function_module, "UserValves"): + __user__["valves"] = function_module.UserValves( + **Functions.get_user_valves_by_id_and_user_id( + action_id, user.id + ) + ) + except Exception as e: + print(e) + + params = {**params, "__user__": __user__} + + if inspect.iscoroutinefunction(action): + data = await action(**params) + else: + data = action(**params) + + except Exception as e: + return Exception(f"Error: {e}") + + return data diff --git a/backend/open_webui/utils/models.py b/backend/open_webui/utils/models.py new file mode 100644 index 000000000..7fdd8b605 --- /dev/null +++ b/backend/open_webui/utils/models.py @@ -0,0 +1,222 @@ +import time +import logging +import sys + +from aiocache import cached +from fastapi import Request + +from open_webui.routers import openai, ollama +from open_webui.functions import get_function_models + + +from open_webui.models.functions import Functions +from open_webui.models.models import Models + + +from open_webui.utils.plugin import load_function_module_by_id + + +from open_webui.config import ( + DEFAULT_ARENA_MODEL, +) + +from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL + + +logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL) +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["MAIN"]) + + +async def get_all_base_models(request: Request): + function_models = [] + openai_models = [] + ollama_models = [] + + if request.app.state.config.ENABLE_OPENAI_API: + openai_models = await openai.get_all_models(request) + openai_models = openai_models["data"] + + if request.app.state.config.ENABLE_OLLAMA_API: + ollama_models = await ollama.get_all_models(request) + ollama_models = [ + { + "id": model["model"], + "name": model["name"], + "object": "model", + "created": int(time.time()), + "owned_by": "ollama", + "ollama": model, + } + for model in ollama_models["models"] + ] + + function_models = await get_function_models() + models = function_models + openai_models + ollama_models + + return models + + +@cached(ttl=3) +async def get_all_models(request): + models = await get_all_base_models(request) + + # If there are no models, return an empty list + if len(models) == 0: + return [] + + # Add arena models + if request.app.state.config.ENABLE_EVALUATION_ARENA_MODELS: + arena_models = [] + if len(request.app.state.config.EVALUATION_ARENA_MODELS) > 0: + arena_models = [ + { + "id": model["id"], + "name": model["name"], + "info": { + "meta": model["meta"], + }, + "object": "model", + "created": int(time.time()), + "owned_by": "arena", + "arena": True, + } + for model in request.app.state.config.EVALUATION_ARENA_MODELS + ] + else: + # Add default arena model + arena_models = [ + { + "id": DEFAULT_ARENA_MODEL["id"], + "name": DEFAULT_ARENA_MODEL["name"], + "info": { + "meta": DEFAULT_ARENA_MODEL["meta"], + }, + "object": "model", + "created": int(time.time()), + "owned_by": "arena", + "arena": True, + } + ] + models = models + arena_models + + global_action_ids = [ + function.id for function in Functions.get_global_action_functions() + ] + enabled_action_ids = [ + function.id + for function in Functions.get_functions_by_type("action", active_only=True) + ] + + custom_models = Models.get_all_models() + for custom_model in custom_models: + if custom_model.base_model_id is None: + for model in models: + if ( + custom_model.id == model["id"] + or custom_model.id == model["id"].split(":")[0] + ): + if custom_model.is_active: + model["name"] = custom_model.name + model["info"] = custom_model.model_dump() + + action_ids = [] + if "info" in model and "meta" in model["info"]: + action_ids.extend( + model["info"]["meta"].get("actionIds", []) + ) + + model["action_ids"] = action_ids + else: + models.remove(model) + + elif custom_model.is_active and ( + custom_model.id not in [model["id"] for model in models] + ): + owned_by = "openai" + pipe = None + action_ids = [] + + for model in models: + if ( + custom_model.base_model_id == model["id"] + or custom_model.base_model_id == model["id"].split(":")[0] + ): + owned_by = model["owned_by"] + if "pipe" in model: + pipe = model["pipe"] + break + + if custom_model.meta: + meta = custom_model.meta.model_dump() + if "actionIds" in meta: + action_ids.extend(meta["actionIds"]) + + models.append( + { + "id": f"{custom_model.id}", + "name": custom_model.name, + "object": "model", + "created": custom_model.created_at, + "owned_by": owned_by, + "info": custom_model.model_dump(), + "preset": True, + **({"pipe": pipe} if pipe is not None else {}), + "action_ids": action_ids, + } + ) + + # Process action_ids to get the actions + def get_action_items_from_module(function, module): + actions = [] + if hasattr(module, "actions"): + actions = module.actions + return [ + { + "id": f"{function.id}.{action['id']}", + "name": action.get("name", f"{function.name} ({action['id']})"), + "description": function.meta.description, + "icon_url": action.get( + "icon_url", function.meta.manifest.get("icon_url", None) + ), + } + for action in actions + ] + else: + return [ + { + "id": function.id, + "name": function.name, + "description": function.meta.description, + "icon_url": function.meta.manifest.get("icon_url", None), + } + ] + + def get_function_module_by_id(function_id): + if function_id in request.app.state.FUNCTIONS: + function_module = request.app.state.FUNCTIONS[function_id] + else: + function_module, _, _ = load_function_module_by_id(function_id) + request.app.state.FUNCTIONS[function_id] = function_module + + for model in models: + action_ids = [ + action_id + for action_id in list(set(model.pop("action_ids", []) + global_action_ids)) + if action_id in enabled_action_ids + ] + + model["actions"] = [] + for action_id in action_ids: + action_function = Functions.get_function_by_id(action_id) + if action_function is None: + raise Exception(f"Action not found: {action_id}") + + function_module = get_function_module_by_id(action_id) + model["actions"].extend( + get_action_items_from_module(action_function, function_module) + ) + log.debug(f"get_all_models() returned {len(models)} models") + + request.app.state.MODELS = {model["id"]: model for model in models} + return models diff --git a/backend/open_webui/utils/tools.py b/backend/open_webui/utils/tools.py index 0b9161f86..b6e13011d 100644 --- a/backend/open_webui/utils/tools.py +++ b/backend/open_webui/utils/tools.py @@ -4,11 +4,15 @@ import re from typing import Any, Awaitable, Callable, get_type_hints from functools import update_wrapper, partial + +from fastapi import Request +from pydantic import BaseModel, Field, create_model from langchain_core.utils.function_calling import convert_to_openai_function + + from open_webui.models.tools import Tools from open_webui.models.users import UserModel from open_webui.utils.plugin import load_tools_module_by_id -from pydantic import BaseModel, Field, create_model log = logging.getLogger(__name__) @@ -32,7 +36,7 @@ def apply_extra_params_to_tool_function( # Mutation on extra_params def get_tools( - webui_app, tool_ids: list[str], user: UserModel, extra_params: dict + request: Request, tool_ids: list[str], user: UserModel, extra_params: dict ) -> dict[str, dict]: tools_dict = {} @@ -41,10 +45,10 @@ def get_tools( if tools is None: continue - module = webui_app.state.TOOLS.get(tool_id, None) + module = request.app.state.TOOLS.get(tool_id, None) if module is None: module, _ = load_tools_module_by_id(tool_id) - webui_app.state.TOOLS[tool_id] = module + request.app.state.TOOLS[tool_id] = module extra_params["__id__"] = tool_id if hasattr(module, "valves") and hasattr(module, "Valves"): diff --git a/src/lib/apis/index.ts b/src/lib/apis/index.ts index e76aa3c99..d06fbf3d7 100644 --- a/src/lib/apis/index.ts +++ b/src/lib/apis/index.ts @@ -110,7 +110,7 @@ export const chatAction = async (token: string, action_id: string, body: ChatAct export const getTaskConfig = async (token: string = '') => { let error = null; - const res = await fetch(`${WEBUI_BASE_URL}/api/task/config`, { + const res = await fetch(`${WEBUI_BASE_URL}/api/v1/tasks/config`, { method: 'GET', headers: { Accept: 'application/json', @@ -138,7 +138,7 @@ export const getTaskConfig = async (token: string = '') => { export const updateTaskConfig = async (token: string, config: object) => { let error = null; - const res = await fetch(`${WEBUI_BASE_URL}/api/task/config/update`, { + const res = await fetch(`${WEBUI_BASE_URL}/api/v1/tasks/config/update`, { method: 'POST', headers: { Accept: 'application/json', @@ -176,7 +176,7 @@ export const generateTitle = async ( ) => { let error = null; - const res = await fetch(`${WEBUI_BASE_URL}/api/task/title/completions`, { + const res = await fetch(`${WEBUI_BASE_URL}/api/v1/tasks/title/completions`, { method: 'POST', headers: { Accept: 'application/json', @@ -216,7 +216,7 @@ export const generateTags = async ( ) => { let error = null; - const res = await fetch(`${WEBUI_BASE_URL}/api/task/tags/completions`, { + const res = await fetch(`${WEBUI_BASE_URL}/api/v1/tasks/tags/completions`, { method: 'POST', headers: { Accept: 'application/json', @@ -288,7 +288,7 @@ export const generateEmoji = async ( ) => { let error = null; - const res = await fetch(`${WEBUI_BASE_URL}/api/task/emoji/completions`, { + const res = await fetch(`${WEBUI_BASE_URL}/api/v1/tasks/emoji/completions`, { method: 'POST', headers: { Accept: 'application/json', @@ -337,7 +337,7 @@ export const generateQueries = async ( ) => { let error = null; - const res = await fetch(`${WEBUI_BASE_URL}/api/task/queries/completions`, { + const res = await fetch(`${WEBUI_BASE_URL}/api/v1/tasks/queries/completions`, { method: 'POST', headers: { Accept: 'application/json', @@ -407,7 +407,7 @@ export const generateAutoCompletion = async ( const controller = new AbortController(); let error = null; - const res = await fetch(`${WEBUI_BASE_URL}/api/task/auto/completions`, { + const res = await fetch(`${WEBUI_BASE_URL}/api/v1/tasks/auto/completions`, { signal: controller.signal, method: 'POST', headers: { @@ -477,7 +477,7 @@ export const generateMoACompletion = async ( const controller = new AbortController(); let error = null; - const res = await fetch(`${WEBUI_BASE_URL}/api/task/moa/completions`, { + const res = await fetch(`${WEBUI_BASE_URL}/api/v1/tasks/moa/completions`, { signal: controller.signal, method: 'POST', headers: { @@ -507,7 +507,7 @@ export const generateMoACompletion = async ( export const getPipelinesList = async (token: string = '') => { let error = null; - const res = await fetch(`${WEBUI_BASE_URL}/api/pipelines/list`, { + const res = await fetch(`${WEBUI_BASE_URL}/api/v1/pipelines/list`, { method: 'GET', headers: { Accept: 'application/json', @@ -541,7 +541,7 @@ export const uploadPipeline = async (token: string, file: File, urlIdx: string) formData.append('file', file); formData.append('urlIdx', urlIdx); - const res = await fetch(`${WEBUI_BASE_URL}/api/pipelines/upload`, { + const res = await fetch(`${WEBUI_BASE_URL}/api/v1/pipelines/upload`, { method: 'POST', headers: { ...(token && { authorization: `Bearer ${token}` }) @@ -573,7 +573,7 @@ export const uploadPipeline = async (token: string, file: File, urlIdx: string) export const downloadPipeline = async (token: string, url: string, urlIdx: string) => { let error = null; - const res = await fetch(`${WEBUI_BASE_URL}/api/pipelines/add`, { + const res = await fetch(`${WEBUI_BASE_URL}/api/v1/pipelines/add`, { method: 'POST', headers: { Accept: 'application/json', @@ -609,7 +609,7 @@ export const downloadPipeline = async (token: string, url: string, urlIdx: strin export const deletePipeline = async (token: string, id: string, urlIdx: string) => { let error = null; - const res = await fetch(`${WEBUI_BASE_URL}/api/pipelines/delete`, { + const res = await fetch(`${WEBUI_BASE_URL}/api/v1/pipelines/delete`, { method: 'DELETE', headers: { Accept: 'application/json', @@ -650,7 +650,7 @@ export const getPipelines = async (token: string, urlIdx?: string) => { searchParams.append('urlIdx', urlIdx); } - const res = await fetch(`${WEBUI_BASE_URL}/api/pipelines?${searchParams.toString()}`, { + const res = await fetch(`${WEBUI_BASE_URL}/api/v1/pipelines?${searchParams.toString()}`, { method: 'GET', headers: { Accept: 'application/json', @@ -685,7 +685,7 @@ export const getPipelineValves = async (token: string, pipeline_id: string, urlI } const res = await fetch( - `${WEBUI_BASE_URL}/api/pipelines/${pipeline_id}/valves?${searchParams.toString()}`, + `${WEBUI_BASE_URL}/api/v1/pipelines/${pipeline_id}/valves?${searchParams.toString()}`, { method: 'GET', headers: { @@ -721,7 +721,7 @@ export const getPipelineValvesSpec = async (token: string, pipeline_id: string, } const res = await fetch( - `${WEBUI_BASE_URL}/api/pipelines/${pipeline_id}/valves/spec?${searchParams.toString()}`, + `${WEBUI_BASE_URL}/api/v1/pipelines/${pipeline_id}/valves/spec?${searchParams.toString()}`, { method: 'GET', headers: { @@ -762,7 +762,7 @@ export const updatePipelineValves = async ( } const res = await fetch( - `${WEBUI_BASE_URL}/api/pipelines/${pipeline_id}/valves/update?${searchParams.toString()}`, + `${WEBUI_BASE_URL}/api/v1/pipelines/${pipeline_id}/valves/update?${searchParams.toString()}`, { method: 'POST', headers: { From d8a01cb9116c39154f7cf05ee0091b73ff28b3da Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Thu, 12 Dec 2024 20:24:36 -0800 Subject: [PATCH 21/26] wip --- backend/open_webui/routers/images.py | 54 ++++++++++++++-------------- 1 file changed, 28 insertions(+), 26 deletions(-) diff --git a/backend/open_webui/routers/images.py b/backend/open_webui/routers/images.py index 2beec59f7..03e115cc0 100644 --- a/backend/open_webui/routers/images.py +++ b/backend/open_webui/routers/images.py @@ -42,10 +42,10 @@ router = APIRouter() async def get_config(request: Request, user=Depends(get_admin_user)): return { "enabled": request.app.state.config.ENABLE_IMAGE_GENERATION, - "engine": request.app.state.config.ENGINE, + "engine": request.app.state.config.IMAGE_GENERATION_ENGINE, "openai": { - "OPENAI_API_BASE_URL": request.app.state.config.OPENAI_API_BASE_URL, - "OPENAI_API_KEY": request.app.state.config.OPENAI_API_KEY, + "OPENAI_API_BASE_URL": request.app.state.config.IMAGES_OPENAI_API_BASE_URL, + "OPENAI_API_KEY": request.app.state.config.IMAGES_OPENAI_API_KEY, }, "automatic1111": { "AUTOMATIC1111_BASE_URL": request.app.state.config.AUTOMATIC1111_BASE_URL, @@ -93,11 +93,13 @@ class ConfigForm(BaseModel): async def update_config( request: Request, form_data: ConfigForm, user=Depends(get_admin_user) ): - request.app.state.config.ENGINE = form_data.engine + request.app.state.config.IMAGE_GENERATION_ENGINE = form_data.engine request.app.state.config.ENABLE_IMAGE_GENERATION = form_data.enabled - request.app.state.config.OPENAI_API_BASE_URL = form_data.openai.OPENAI_API_BASE_URL - request.app.state.config.OPENAI_API_KEY = form_data.openai.OPENAI_API_KEY + request.app.state.config.IMAGES_OPENAI_API_BASE_URL = ( + form_data.openai.OPENAI_API_BASE_URL + ) + request.app.state.config.IMAGES_OPENAI_API_KEY = form_data.openai.OPENAI_API_KEY request.app.state.config.AUTOMATIC1111_BASE_URL = ( form_data.automatic1111.AUTOMATIC1111_BASE_URL @@ -132,10 +134,10 @@ async def update_config( return { "enabled": request.app.state.config.ENABLE_IMAGE_GENERATION, - "engine": request.app.state.config.ENGINE, + "engine": request.app.state.config.IMAGE_GENERATION_ENGINE, "openai": { - "OPENAI_API_BASE_URL": request.app.state.config.OPENAI_API_BASE_URL, - "OPENAI_API_KEY": request.app.state.config.OPENAI_API_KEY, + "OPENAI_API_BASE_URL": request.app.state.config.IMAGES_OPENAI_API_BASE_URL, + "OPENAI_API_KEY": request.app.state.config.IMAGES_OPENAI_API_KEY, }, "automatic1111": { "AUTOMATIC1111_BASE_URL": request.app.state.config.AUTOMATIC1111_BASE_URL, @@ -166,7 +168,7 @@ def get_automatic1111_api_auth(request: Request): @router.get("/config/url/verify") async def verify_url(request: Request, user=Depends(get_admin_user)): - if request.app.state.config.ENGINE == "automatic1111": + if request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111": try: r = requests.get( url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", @@ -177,7 +179,7 @@ async def verify_url(request: Request, user=Depends(get_admin_user)): except Exception: request.app.state.config.ENABLE_IMAGE_GENERATION = False raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL) - elif request.app.state.config.ENGINE == "comfyui": + elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui": try: r = requests.get( url=f"{request.app.state.config.COMFYUI_BASE_URL}/object_info" @@ -194,7 +196,7 @@ async def verify_url(request: Request, user=Depends(get_admin_user)): def set_image_model(request: Request, model: str): log.info(f"Setting image model to {model}") request.app.state.config.MODEL = model - if request.app.state.config.ENGINE in ["", "automatic1111"]: + if request.app.state.config.IMAGE_GENERATION_ENGINE in ["", "automatic1111"]: api_auth = get_automatic1111_api_auth() r = requests.get( url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", @@ -212,17 +214,17 @@ def set_image_model(request: Request, model: str): def get_image_model(): - if request.app.state.config.ENGINE == "openai": + if request.app.state.config.IMAGE_GENERATION_ENGINE == "openai": return ( request.app.state.config.MODEL if request.app.state.config.MODEL else "dall-e-2" ) - elif request.app.state.config.ENGINE == "comfyui": + elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui": return request.app.state.config.MODEL if request.app.state.config.MODEL else "" elif ( - request.app.state.config.ENGINE == "automatic1111" - or request.app.state.config.ENGINE == "" + request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111" + or request.app.state.config.IMAGE_GENERATION_ENGINE == "" ): try: r = requests.get( @@ -285,12 +287,12 @@ async def update_image_config( @router.get("/models") def get_models(request: Request, user=Depends(get_verified_user)): try: - if request.app.state.config.ENGINE == "openai": + if request.app.state.config.IMAGE_GENERATION_ENGINE == "openai": return [ {"id": "dall-e-2", "name": "DALL·E 2"}, {"id": "dall-e-3", "name": "DALL·E 3"}, ] - elif request.app.state.config.ENGINE == "comfyui": + elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui": # TODO - get models from comfyui r = requests.get( url=f"{request.app.state.config.COMFYUI_BASE_URL}/object_info" @@ -336,8 +338,8 @@ def get_models(request: Request, user=Depends(get_verified_user)): ) ) elif ( - request.app.state.config.ENGINE == "automatic1111" - or request.app.state.config.ENGINE == "" + request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111" + or request.app.state.config.IMAGE_GENERATION_ENGINE == "" ): r = requests.get( url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models", @@ -433,10 +435,10 @@ async def image_generations( r = None try: - if request.app.state.config.ENGINE == "openai": + if request.app.state.config.IMAGE_GENERATION_ENGINE == "openai": headers = {} headers["Authorization"] = ( - f"Bearer {request.app.state.config.OPENAI_API_KEY}" + f"Bearer {request.app.state.config.IMAGES_OPENAI_API_KEY}" ) headers["Content-Type"] = "application/json" @@ -465,7 +467,7 @@ async def image_generations( # Use asyncio.to_thread for the requests.post call r = await asyncio.to_thread( requests.post, - url=f"{request.app.state.config.OPENAI_API_BASE_URL}/images/generations", + url=f"{request.app.state.config.IMAGES_OPENAI_API_BASE_URL}/images/generations", json=data, headers=headers, ) @@ -485,7 +487,7 @@ async def image_generations( return images - elif request.app.state.config.ENGINE == "comfyui": + elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui": data = { "prompt": form_data.prompt, "width": width, @@ -531,8 +533,8 @@ async def image_generations( log.debug(f"images: {images}") return images elif ( - request.app.state.config.ENGINE == "automatic1111" - or request.app.state.config.ENGINE == "" + request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111" + or request.app.state.config.IMAGE_GENERATION_ENGINE == "" ): if form_data.model: set_image_model(form_data.model) From 8c38708827ea990d9c8323784213daa60bd83042 Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Thu, 12 Dec 2024 20:26:28 -0800 Subject: [PATCH 22/26] wip --- backend/open_webui/routers/images.py | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/backend/open_webui/routers/images.py b/backend/open_webui/routers/images.py index 03e115cc0..3f51fbdb4 100644 --- a/backend/open_webui/routers/images.py +++ b/backend/open_webui/routers/images.py @@ -195,7 +195,7 @@ async def verify_url(request: Request, user=Depends(get_admin_user)): def set_image_model(request: Request, model: str): log.info(f"Setting image model to {model}") - request.app.state.config.MODEL = model + request.app.state.config.IMAGE_GENERATION_MODEL = model if request.app.state.config.IMAGE_GENERATION_ENGINE in ["", "automatic1111"]: api_auth = get_automatic1111_api_auth() r = requests.get( @@ -210,18 +210,22 @@ def set_image_model(request: Request, model: str): json=options, headers={"authorization": api_auth}, ) - return request.app.state.config.MODEL + return request.app.state.config.IMAGE_GENERATION_MODEL -def get_image_model(): +def get_image_model(request): if request.app.state.config.IMAGE_GENERATION_ENGINE == "openai": return ( - request.app.state.config.MODEL - if request.app.state.config.MODEL + request.app.state.config.IMAGE_GENERATION_MODEL + if request.app.state.config.IMAGE_GENERATION_MODEL else "dall-e-2" ) elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui": - return request.app.state.config.MODEL if request.app.state.config.MODEL else "" + return ( + request.app.state.config.IMAGE_GENERATION_MODEL + if request.app.state.config.IMAGE_GENERATION_MODEL + else "" + ) elif ( request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111" or request.app.state.config.IMAGE_GENERATION_ENGINE == "" @@ -247,7 +251,7 @@ class ImageConfigForm(BaseModel): @router.get("/image/config") async def get_image_config(request: Request, user=Depends(get_admin_user)): return { - "MODEL": request.app.state.config.MODEL, + "MODEL": request.app.state.config.IMAGE_GENERATION_MODEL, "IMAGE_SIZE": request.app.state.config.IMAGE_SIZE, "IMAGE_STEPS": request.app.state.config.IMAGE_STEPS, } @@ -278,7 +282,7 @@ async def update_image_config( ) return { - "MODEL": request.app.state.config.MODEL, + "MODEL": request.app.state.config.IMAGE_GENERATION_MODEL, "IMAGE_SIZE": request.app.state.config.IMAGE_SIZE, "IMAGE_STEPS": request.app.state.config.IMAGE_STEPS, } @@ -450,8 +454,8 @@ async def image_generations( data = { "model": ( - request.app.state.config.MODEL - if request.app.state.config.MODEL != "" + request.app.state.config.IMAGE_GENERATION_MODEL + if request.app.state.config.IMAGE_GENERATION_MODEL != "" else "dall-e-2" ), "prompt": form_data.prompt, @@ -513,7 +517,7 @@ async def image_generations( } ) res = await comfyui_generate_image( - request.app.state.config.MODEL, + request.app.state.config.IMAGE_GENERATION_MODEL, form_data, user.id, request.app.state.config.COMFYUI_BASE_URL, From 1197c640c43a6b72ef3df6ff5470b805781d6e04 Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Thu, 12 Dec 2024 22:28:42 -0800 Subject: [PATCH 23/26] refac --- backend/open_webui/config.py | 46 ++ backend/open_webui/main.py | 698 ++----------------------- backend/open_webui/routers/tasks.py | 172 ++---- backend/open_webui/utils/chat.py | 42 +- backend/open_webui/utils/middleware.py | 507 ++++++++++++++++++ backend/open_webui/utils/models.py | 24 + 6 files changed, 664 insertions(+), 825 deletions(-) create mode 100644 backend/open_webui/utils/middleware.py diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index 955b3423e..018e21b30 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -957,12 +957,45 @@ TITLE_GENERATION_PROMPT_TEMPLATE = PersistentConfig( os.environ.get("TITLE_GENERATION_PROMPT_TEMPLATE", ""), ) +DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE = """Create a concise, 3-5 word title with an emoji as a title for the chat history, in the given language. Suitable Emojis for the summary can be used to enhance understanding but avoid quotation marks or special formatting. RESPOND ONLY WITH THE TITLE TEXT. + +Examples of titles: +📉 Stock Market Trends +🍪 Perfect Chocolate Chip Recipe +Evolution of Music Streaming +Remote Work Productivity Tips +Artificial Intelligence in Healthcare +🎮 Video Game Development Insights + + +{{MESSAGES:END:2}} +""" + + TAGS_GENERATION_PROMPT_TEMPLATE = PersistentConfig( "TAGS_GENERATION_PROMPT_TEMPLATE", "task.tags.prompt_template", os.environ.get("TAGS_GENERATION_PROMPT_TEMPLATE", ""), ) +DEFAULT_TAGS_GENERATION_PROMPT_TEMPLATE = """### Task: +Generate 1-3 broad tags categorizing the main themes of the chat history, along with 1-3 more specific subtopic tags. + +### Guidelines: +- Start with high-level domains (e.g. Science, Technology, Philosophy, Arts, Politics, Business, Health, Sports, Entertainment, Education) +- Consider including relevant subfields/subdomains if they are strongly represented throughout the conversation +- If content is too short (less than 3 messages) or too diverse, use only ["General"] +- Use the chat's primary language; default to English if multilingual +- Prioritize accuracy over specificity + +### Output: +JSON format: { "tags": ["tag1", "tag2", "tag3"] } + +### Chat History: + +{{MESSAGES:END:6}} +""" + ENABLE_TAGS_GENERATION = PersistentConfig( "ENABLE_TAGS_GENERATION", "task.tags.enable", @@ -1081,6 +1114,19 @@ TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = PersistentConfig( ) +DEFAULT_TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = """Available Tools: {{TOOLS}}\nReturn an empty string if no tools match the query. If a function tool matches, construct and return a JSON object in the format {\"name\": \"functionName\", \"parameters\": {\"requiredFunctionParamKey\": \"requiredFunctionParamValue\"}} using the appropriate tool and its parameters. Only return the object and limit the response to the JSON object without additional text.""" + + +DEFAULT_EMOJI_GENERATION_PROMPT_TEMPLATE = """Your task is to reflect the speaker's likely facial expression through a fitting emoji. Interpret emotions from the message and reflect their facial expression using fitting, diverse emojis (e.g., 😊, 😢, 😡, 😱). + +Message: ```{{prompt}}```""" + +DEFAULT_MOA_GENERATION_PROMPT_TEMPLATE = """You have been provided with a set of responses from various models to the latest user query: "{{prompt}}" + +Your task is to synthesize these responses into a single, high-quality response. It is crucial to critically evaluate the information provided in these responses, recognizing that some of it may be biased or incorrect. Your response should not simply replicate the given answers but should offer a refined, accurate, and comprehensive reply to the instruction. Ensure your response is well-structured, coherent, and adheres to the highest standards of accuracy and reliability. + +Responses from models: {{responses}}""" + #################################### # Vector Database #################################### diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index ca0522289..ce9d47959 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -41,8 +41,6 @@ from starlette.responses import Response, StreamingResponse from open_webui.socket.main import ( app as socket_app, periodic_usage_pool_cleanup, - get_event_call, - get_event_emitter, ) from open_webui.routers import ( audio, @@ -74,12 +72,6 @@ from open_webui.routers.retrieval import ( get_ef, get_rf, ) -from open_webui.routers.pipelines import ( - process_pipeline_inlet_filter, -) - -from open_webui.retrieval.utils import get_sources_from_files - from open_webui.internal.db import Session @@ -87,8 +79,6 @@ from open_webui.models.functions import Functions from open_webui.models.models import Models from open_webui.models.users import UserModel, Users - -from open_webui.constants import TASKS from open_webui.config import ( # Ollama ENABLE_OLLAMA_API, @@ -274,43 +264,22 @@ from open_webui.env import ( ) -from open_webui.utils.models import get_all_models, get_all_base_models +from open_webui.utils.models import ( + get_all_models, + get_all_base_models, + check_model_access, +) from open_webui.utils.chat import ( generate_chat_completion as chat_completion_handler, chat_completed as chat_completed_handler, chat_action as chat_action_handler, ) - - -from open_webui.utils.plugin import load_function_module_by_id -from open_webui.utils.misc import ( - add_or_update_system_message, - get_last_user_message, - prepend_to_first_user_message_content, - openai_chat_chunk_message_template, - openai_chat_completion_message_template, -) - - -from open_webui.utils.payload import convert_payload_openai_to_ollama -from open_webui.utils.response import ( - convert_response_ollama_to_openai, - convert_streaming_response_ollama_to_openai, -) - -from open_webui.utils.task import ( - get_task_model_id, - rag_template, - tools_function_calling_generation_template, -) -from open_webui.utils.tools import get_tools +from open_webui.utils.middleware import process_chat_payload, process_chat_response from open_webui.utils.access_control import has_access from open_webui.utils.auth import ( decode_token, get_admin_user, - get_current_user, - get_http_authorization_cred, get_verified_user, ) from open_webui.utils.oauth import oauth_manager @@ -665,634 +634,6 @@ app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH = ( app.state.MODELS = {} -################################## -# -# ChatCompletion Middleware -# -################################## - - -async def chat_completion_filter_functions_handler(body, model, extra_params): - skip_files = None - - def get_filter_function_ids(model): - def get_priority(function_id): - function = Functions.get_function_by_id(function_id) - if function is not None and hasattr(function, "valves"): - # TODO: Fix FunctionModel - return (function.valves if function.valves else {}).get("priority", 0) - return 0 - - filter_ids = [ - function.id for function in Functions.get_global_filter_functions() - ] - if "info" in model and "meta" in model["info"]: - filter_ids.extend(model["info"]["meta"].get("filterIds", [])) - filter_ids = list(set(filter_ids)) - - enabled_filter_ids = [ - function.id - for function in Functions.get_functions_by_type("filter", active_only=True) - ] - - filter_ids = [ - filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids - ] - - filter_ids.sort(key=get_priority) - return filter_ids - - filter_ids = get_filter_function_ids(model) - for filter_id in filter_ids: - filter = Functions.get_function_by_id(filter_id) - if not filter: - continue - - if filter_id in app.state.FUNCTIONS: - function_module = app.state.FUNCTIONS[filter_id] - else: - function_module, _, _ = load_function_module_by_id(filter_id) - app.state.FUNCTIONS[filter_id] = function_module - - # Check if the function has a file_handler variable - if hasattr(function_module, "file_handler"): - skip_files = function_module.file_handler - - if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): - valves = Functions.get_function_valves_by_id(filter_id) - function_module.valves = function_module.Valves( - **(valves if valves else {}) - ) - - if not hasattr(function_module, "inlet"): - continue - - try: - inlet = function_module.inlet - - # Get the signature of the function - sig = inspect.signature(inlet) - params = {"body": body} | { - k: v - for k, v in { - **extra_params, - "__model__": model, - "__id__": filter_id, - }.items() - if k in sig.parameters - } - - if "__user__" in params and hasattr(function_module, "UserValves"): - try: - params["__user__"]["valves"] = function_module.UserValves( - **Functions.get_user_valves_by_id_and_user_id( - filter_id, params["__user__"]["id"] - ) - ) - except Exception as e: - print(e) - - if inspect.iscoroutinefunction(inlet): - body = await inlet(**params) - else: - body = inlet(**params) - - except Exception as e: - print(f"Error: {e}") - raise e - - if skip_files and "files" in body.get("metadata", {}): - del body["metadata"]["files"] - - return body, {} - - -async def chat_completion_tools_handler( - request: Request, body: dict, user: UserModel, models, extra_params: dict -) -> tuple[dict, dict]: - async def get_content_from_response(response) -> Optional[str]: - content = None - if hasattr(response, "body_iterator"): - async for chunk in response.body_iterator: - data = json.loads(chunk.decode("utf-8")) - content = data["choices"][0]["message"]["content"] - - # Cleanup any remaining background tasks if necessary - if response.background is not None: - await response.background() - else: - content = response["choices"][0]["message"]["content"] - return content - - def get_tools_function_calling_payload(messages, task_model_id, content): - user_message = get_last_user_message(messages) - history = "\n".join( - f"{message['role'].upper()}: \"\"\"{message['content']}\"\"\"" - for message in messages[::-1][:4] - ) - - prompt = f"History:\n{history}\nQuery: {user_message}" - - return { - "model": task_model_id, - "messages": [ - {"role": "system", "content": content}, - {"role": "user", "content": f"Query: {prompt}"}, - ], - "stream": False, - "metadata": {"task": str(TASKS.FUNCTION_CALLING)}, - } - - # If tool_ids field is present, call the functions - metadata = body.get("metadata", {}) - - tool_ids = metadata.get("tool_ids", None) - log.debug(f"{tool_ids=}") - if not tool_ids: - return body, {} - - skip_files = False - sources = [] - - task_model_id = get_task_model_id( - body["model"], - request.app.state.config.TASK_MODEL, - request.app.state.config.TASK_MODEL_EXTERNAL, - models, - ) - tools = get_tools( - request, - tool_ids, - user, - { - **extra_params, - "__model__": models[task_model_id], - "__messages__": body["messages"], - "__files__": metadata.get("files", []), - }, - ) - log.info(f"{tools=}") - - specs = [tool["spec"] for tool in tools.values()] - tools_specs = json.dumps(specs) - - if app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE != "": - template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE - else: - template = """Available Tools: {{TOOLS}}\nReturn an empty string if no tools match the query. If a function tool matches, construct and return a JSON object in the format {\"name\": \"functionName\", \"parameters\": {\"requiredFunctionParamKey\": \"requiredFunctionParamValue\"}} using the appropriate tool and its parameters. Only return the object and limit the response to the JSON object without additional text.""" - - tools_function_calling_prompt = tools_function_calling_generation_template( - template, tools_specs - ) - log.info(f"{tools_function_calling_prompt=}") - payload = get_tools_function_calling_payload( - body["messages"], task_model_id, tools_function_calling_prompt - ) - - try: - payload = process_pipeline_inlet_filter(request, payload, user, models) - except Exception as e: - raise e - - try: - response = await generate_chat_completions(form_data=payload, user=user) - log.debug(f"{response=}") - content = await get_content_from_response(response) - log.debug(f"{content=}") - - if not content: - return body, {} - - try: - content = content[content.find("{") : content.rfind("}") + 1] - if not content: - raise Exception("No JSON object found in the response") - - result = json.loads(content) - - tool_function_name = result.get("name", None) - if tool_function_name not in tools: - return body, {} - - tool_function_params = result.get("parameters", {}) - - try: - required_params = ( - tools[tool_function_name] - .get("spec", {}) - .get("parameters", {}) - .get("required", []) - ) - tool_function = tools[tool_function_name]["callable"] - tool_function_params = { - k: v - for k, v in tool_function_params.items() - if k in required_params - } - tool_output = await tool_function(**tool_function_params) - - except Exception as e: - tool_output = str(e) - - if isinstance(tool_output, str): - if tools[tool_function_name]["citation"]: - sources.append( - { - "source": { - "name": f"TOOL:{tools[tool_function_name]['toolkit_id']}/{tool_function_name}" - }, - "document": [tool_output], - "metadata": [ - { - "source": f"TOOL:{tools[tool_function_name]['toolkit_id']}/{tool_function_name}" - } - ], - } - ) - else: - sources.append( - { - "source": {}, - "document": [tool_output], - "metadata": [ - { - "source": f"TOOL:{tools[tool_function_name]['toolkit_id']}/{tool_function_name}" - } - ], - } - ) - - if tools[tool_function_name]["file_handler"]: - skip_files = True - - except Exception as e: - log.exception(f"Error: {e}") - content = None - except Exception as e: - log.exception(f"Error: {e}") - content = None - - log.debug(f"tool_contexts: {sources}") - - if skip_files and "files" in body.get("metadata", {}): - del body["metadata"]["files"] - - return body, {"sources": sources} - - -async def chat_completion_files_handler( - request: Request, body: dict, user: UserModel -) -> tuple[dict, dict[str, list]]: - sources = [] - - if files := body.get("metadata", {}).get("files", None): - try: - queries_response = await generate_queries( - { - "model": body["model"], - "messages": body["messages"], - "type": "retrieval", - }, - user, - ) - queries_response = queries_response["choices"][0]["message"]["content"] - - try: - bracket_start = queries_response.find("{") - bracket_end = queries_response.rfind("}") + 1 - - if bracket_start == -1 or bracket_end == -1: - raise Exception("No JSON object found in the response") - - queries_response = queries_response[bracket_start:bracket_end] - queries_response = json.loads(queries_response) - except Exception as e: - queries_response = {"queries": [queries_response]} - - queries = queries_response.get("queries", []) - except Exception as e: - queries = [] - - if len(queries) == 0: - queries = [get_last_user_message(body["messages"])] - - sources = get_sources_from_files( - files=files, - queries=queries, - embedding_function=request.app.state.EMBEDDING_FUNCTION, - k=request.app.state.config.TOP_K, - reranking_function=request.app.state.rf, - r=request.app.state.config.RELEVANCE_THRESHOLD, - hybrid_search=request.app.state.config.ENABLE_RAG_HYBRID_SEARCH, - ) - - log.debug(f"rag_contexts:sources: {sources}") - return body, {"sources": sources} - - -class ChatCompletionMiddleware(BaseHTTPMiddleware): - async def dispatch(self, request: Request, call_next): - if not ( - request.method == "POST" - and any( - endpoint in request.url.path - for endpoint in ["/ollama/api/chat", "/chat/completions"] - ) - ): - return await call_next(request) - log.debug(f"request.url.path: {request.url.path}") - - await get_all_models(request) - models = app.state.MODELS - - async def get_body_and_model_and_user(request, models): - # Read the original request body - body = await request.body() - body_str = body.decode("utf-8") - body = json.loads(body_str) if body_str else {} - - model_id = body["model"] - if model_id not in models: - raise Exception("Model not found") - model = models[model_id] - - user = get_current_user( - request, - get_http_authorization_cred(request.headers.get("Authorization")), - ) - - return body, model, user - - try: - body, model, user = await get_body_and_model_and_user(request, models) - except Exception as e: - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content={"detail": str(e)}, - ) - - model_info = Models.get_model_by_id(model["id"]) - if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL: - if model.get("arena"): - if not has_access( - user.id, - type="read", - access_control=model.get("info", {}) - .get("meta", {}) - .get("access_control", {}), - ): - raise HTTPException( - status_code=403, - detail="Model not found", - ) - else: - if not model_info: - return JSONResponse( - status_code=status.HTTP_404_NOT_FOUND, - content={"detail": "Model not found"}, - ) - elif not ( - user.id == model_info.user_id - or has_access( - user.id, type="read", access_control=model_info.access_control - ) - ): - return JSONResponse( - status_code=status.HTTP_403_FORBIDDEN, - content={"detail": "User does not have access to the model"}, - ) - - metadata = { - "chat_id": body.pop("chat_id", None), - "message_id": body.pop("id", None), - "session_id": body.pop("session_id", None), - "tool_ids": body.get("tool_ids", None), - "files": body.get("files", None), - } - body["metadata"] = metadata - - extra_params = { - "__event_emitter__": get_event_emitter(metadata), - "__event_call__": get_event_call(metadata), - "__user__": { - "id": user.id, - "email": user.email, - "name": user.name, - "role": user.role, - }, - "__metadata__": metadata, - } - - # Initialize data_items to store additional data to be sent to the client - # Initialize contexts and citation - data_items = [] - sources = [] - - try: - body, flags = await chat_completion_filter_functions_handler( - body, model, extra_params - ) - except Exception as e: - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content={"detail": str(e)}, - ) - - tool_ids = body.pop("tool_ids", None) - files = body.pop("files", None) - - metadata = { - **metadata, - "tool_ids": tool_ids, - "files": files, - } - body["metadata"] = metadata - - try: - body, flags = await chat_completion_tools_handler( - request, body, user, models, extra_params - ) - sources.extend(flags.get("sources", [])) - except Exception as e: - log.exception(e) - - try: - body, flags = await chat_completion_files_handler(request, body, user) - sources.extend(flags.get("sources", [])) - except Exception as e: - log.exception(e) - - # If context is not empty, insert it into the messages - if len(sources) > 0: - context_string = "" - for source_idx, source in enumerate(sources): - source_id = source.get("source", {}).get("name", "") - - if "document" in source: - for doc_idx, doc_context in enumerate(source["document"]): - metadata = source.get("metadata") - doc_source_id = None - - if metadata: - doc_source_id = metadata[doc_idx].get("source", source_id) - - if source_id: - context_string += f"{doc_source_id if doc_source_id is not None else source_id}{doc_context}\n" - else: - # If there is no source_id, then do not include the source_id tag - context_string += f"{doc_context}\n" - - context_string = context_string.strip() - prompt = get_last_user_message(body["messages"]) - - if prompt is None: - raise Exception("No user message found") - if ( - app.state.config.RELEVANCE_THRESHOLD == 0 - and context_string.strip() == "" - ): - log.debug( - f"With a 0 relevancy threshold for RAG, the context cannot be empty" - ) - - # Workaround for Ollama 2.0+ system prompt issue - # TODO: replace with add_or_update_system_message - if model["owned_by"] == "ollama": - body["messages"] = prepend_to_first_user_message_content( - rag_template(app.state.config.RAG_TEMPLATE, context_string, prompt), - body["messages"], - ) - else: - body["messages"] = add_or_update_system_message( - rag_template(app.state.config.RAG_TEMPLATE, context_string, prompt), - body["messages"], - ) - - # If there are citations, add them to the data_items - sources = [ - source for source in sources if source.get("source", {}).get("name", "") - ] - if len(sources) > 0: - data_items.append({"sources": sources}) - - modified_body_bytes = json.dumps(body).encode("utf-8") - # Replace the request body with the modified one - request._body = modified_body_bytes - # Set custom header to ensure content-length matches new body length - request.headers.__dict__["_list"] = [ - (b"content-length", str(len(modified_body_bytes)).encode("utf-8")), - *[(k, v) for k, v in request.headers.raw if k.lower() != b"content-length"], - ] - - response = await call_next(request) - if not isinstance(response, StreamingResponse): - return response - - content_type = response.headers["Content-Type"] - is_openai = "text/event-stream" in content_type - is_ollama = "application/x-ndjson" in content_type - if not is_openai and not is_ollama: - return response - - def wrap_item(item): - return f"data: {item}\n\n" if is_openai else f"{item}\n" - - async def stream_wrapper(original_generator, data_items): - for item in data_items: - yield wrap_item(json.dumps(item)) - - async for data in original_generator: - yield data - - return StreamingResponse( - stream_wrapper(response.body_iterator, data_items), - headers=dict(response.headers), - ) - - async def _receive(self, body: bytes): - return {"type": "http.request", "body": body, "more_body": False} - - -app.add_middleware(ChatCompletionMiddleware) - - -class PipelineMiddleware(BaseHTTPMiddleware): - async def dispatch(self, request: Request, call_next): - if not ( - request.method == "POST" - and any( - endpoint in request.url.path - for endpoint in ["/ollama/api/chat", "/chat/completions"] - ) - ): - return await call_next(request) - - log.debug(f"request.url.path: {request.url.path}") - - # Read the original request body - body = await request.body() - # Decode body to string - body_str = body.decode("utf-8") - # Parse string to JSON - data = json.loads(body_str) if body_str else {} - - try: - user = get_current_user( - request, - get_http_authorization_cred(request.headers["Authorization"]), - ) - except KeyError as e: - if len(e.args) > 1: - return JSONResponse( - status_code=e.args[0], - content={"detail": e.args[1]}, - ) - else: - return JSONResponse( - status_code=status.HTTP_401_UNAUTHORIZED, - content={"detail": "Not authenticated"}, - ) - except HTTPException as e: - return JSONResponse( - status_code=e.status_code, - content={"detail": e.detail}, - ) - - await get_all_models(request) - models = app.state.MODELS - - try: - data = process_pipeline_inlet_filter(request, data, user, models) - except Exception as e: - if len(e.args) > 1: - return JSONResponse( - status_code=e.args[0], - content={"detail": e.args[1]}, - ) - else: - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content={"detail": str(e)}, - ) - - modified_body_bytes = json.dumps(data).encode("utf-8") - # Replace the request body with the modified one - request._body = modified_body_bytes - # Set custom header to ensure content-length matches new body length - request.headers.__dict__["_list"] = [ - (b"content-length", str(len(modified_body_bytes)).encode("utf-8")), - *[(k, v) for k, v in request.headers.raw if k.lower() != b"content-length"], - ] - - response = await call_next(request) - return response - - async def _receive(self, body: bytes): - return {"type": "http.request", "body": body, "more_body": False} - - -app.add_middleware(PipelineMiddleware) - class RedirectMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): @@ -1471,8 +812,32 @@ async def chat_completion( user=Depends(get_verified_user), bypass_filter: bool = False, ): + try: - return await chat_completion_handler(request, form_data, user, bypass_filter) + model_id = form_data.get("model", None) + if model_id not in request.app.state.MODELS: + raise Exception("Model not found") + model = request.app.state.MODELS[model_id] + + # Check if user has access to the model + if not bypass_filter and user.role == "user": + try: + check_model_access(user, model) + except Exception as e: + raise e + + form_data, events = await process_chat_payload(request, form_data, user, model) + except Exception as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=str(e), + ) + + try: + response = await chat_completion_handler( + request, form_data, user, bypass_filter + ) + return await process_chat_response(response, events) except Exception as e: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -1480,6 +845,7 @@ async def chat_completion( ) +# Alias for chat_completion (Legacy) generate_chat_completions = chat_completion generate_chat_completion = chat_completion diff --git a/backend/open_webui/routers/tasks.py b/backend/open_webui/routers/tasks.py index 425a3ef02..a2a6cdc92 100644 --- a/backend/open_webui/routers/tasks.py +++ b/backend/open_webui/routers/tasks.py @@ -21,8 +21,12 @@ from open_webui.routers.pipelines import process_pipeline_inlet_filter from open_webui.utils.task import get_task_model_id from open_webui.config import ( + DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE, + DEFAULT_TAGS_GENERATION_PROMPT_TEMPLATE, DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE, DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE, + DEFAULT_EMOJI_GENERATION_PROMPT_TEMPLATE, + DEFAULT_MOA_GENERATION_PROMPT_TEMPLATE, ) from open_webui.env import SRC_LOG_LEVELS @@ -150,19 +154,7 @@ async def generate_title( if request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE != "": template = request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE else: - template = """Create a concise, 3-5 word title with an emoji as a title for the chat history, in the given language. Suitable Emojis for the summary can be used to enhance understanding but avoid quotation marks or special formatting. RESPOND ONLY WITH THE TITLE TEXT. - -Examples of titles: -📉 Stock Market Trends -🍪 Perfect Chocolate Chip Recipe -Evolution of Music Streaming -Remote Work Productivity Tips -Artificial Intelligence in Healthcare -🎮 Video Game Development Insights - - -{{MESSAGES:END:2}} -""" + template = DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE content = title_generation_template( template, @@ -191,24 +183,13 @@ Artificial Intelligence in Healthcare }, } - # Handle pipeline filters try: - payload = process_pipeline_inlet_filter(request, payload, user, models) + return await generate_chat_completion(request, form_data=payload, user=user) except Exception as e: - if len(e.args) > 1: - return JSONResponse( - status_code=e.args[0], - content={"detail": e.args[1]}, - ) - else: - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content={"detail": str(e)}, - ) - if "chat_id" in payload: - del payload["chat_id"] - - return await generate_chat_completion(request, form_data=payload, user=user) + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"detail": str(e)}, + ) @router.post("/tags/completions") @@ -247,23 +228,7 @@ async def generate_chat_tags( if request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE != "": template = request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE else: - template = """### Task: -Generate 1-3 broad tags categorizing the main themes of the chat history, along with 1-3 more specific subtopic tags. - -### Guidelines: -- Start with high-level domains (e.g. Science, Technology, Philosophy, Arts, Politics, Business, Health, Sports, Entertainment, Education) -- Consider including relevant subfields/subdomains if they are strongly represented throughout the conversation -- If content is too short (less than 3 messages) or too diverse, use only ["General"] -- Use the chat's primary language; default to English if multilingual -- Prioritize accuracy over specificity - -### Output: -JSON format: { "tags": ["tag1", "tag2", "tag3"] } - -### Chat History: - -{{MESSAGES:END:6}} -""" + template = DEFAULT_TAGS_GENERATION_PROMPT_TEMPLATE content = tags_generation_template( template, form_data["messages"], {"name": user.name} @@ -280,24 +245,13 @@ JSON format: { "tags": ["tag1", "tag2", "tag3"] } }, } - # Handle pipeline filters try: - payload = process_pipeline_inlet_filter(request, payload, user, models) + return await generate_chat_completion(request, form_data=payload, user=user) except Exception as e: - if len(e.args) > 1: - return JSONResponse( - status_code=e.args[0], - content={"detail": e.args[1]}, - ) - else: - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content={"detail": str(e)}, - ) - if "chat_id" in payload: - del payload["chat_id"] - - return await generate_chat_completion(request, form_data=payload, user=user) + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"detail": str(e)}, + ) @router.post("/queries/completions") @@ -361,24 +315,13 @@ async def generate_queries( }, } - # Handle pipeline filters try: - payload = process_pipeline_inlet_filter(request, payload, user, models) + return await generate_chat_completion(request, form_data=payload, user=user) except Exception as e: - if len(e.args) > 1: - return JSONResponse( - status_code=e.args[0], - content={"detail": e.args[1]}, - ) - else: - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content={"detail": str(e)}, - ) - if "chat_id" in payload: - del payload["chat_id"] - - return await generate_chat_completion(request, form_data=payload, user=user) + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"detail": str(e)}, + ) @router.post("/auto/completions") @@ -447,24 +390,13 @@ async def generate_autocompletion( }, } - # Handle pipeline filters try: - payload = process_pipeline_inlet_filter(request, payload, user, models) + return await generate_chat_completion(request, form_data=payload, user=user) except Exception as e: - if len(e.args) > 1: - return JSONResponse( - status_code=e.args[0], - content={"detail": e.args[1]}, - ) - else: - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content={"detail": str(e)}, - ) - if "chat_id" in payload: - del payload["chat_id"] - - return await generate_chat_completion(request, form_data=payload, user=user) + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"detail": str(e)}, + ) @router.post("/emoji/completions") @@ -492,11 +424,8 @@ async def generate_emoji( log.debug(f"generating emoji using model {task_model_id} for user {user.email} ") - template = ''' -Your task is to reflect the speaker's likely facial expression through a fitting emoji. Interpret emotions from the message and reflect their facial expression using fitting, diverse emojis (e.g., 😊, 😢, 😡, 😱). + template = DEFAULT_EMOJI_GENERATION_PROMPT_TEMPLATE -Message: """{{prompt}}""" -''' content = emoji_generation_template( template, form_data["prompt"], @@ -521,24 +450,13 @@ Message: """{{prompt}}""" "metadata": {"task": str(TASKS.EMOJI_GENERATION), "task_body": form_data}, } - # Handle pipeline filters try: - payload = process_pipeline_inlet_filter(request, payload, user, models) + return await generate_chat_completion(request, form_data=payload, user=user) except Exception as e: - if len(e.args) > 1: - return JSONResponse( - status_code=e.args[0], - content={"detail": e.args[1]}, - ) - else: - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content={"detail": str(e)}, - ) - if "chat_id" in payload: - del payload["chat_id"] - - return await generate_chat_completion(request, form_data=payload, user=user) + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"detail": str(e)}, + ) @router.post("/moa/completions") @@ -566,11 +484,7 @@ async def generate_moa_response( log.debug(f"generating MOA model {task_model_id} for user {user.email} ") - template = """You have been provided with a set of responses from various models to the latest user query: "{{prompt}}" - -Your task is to synthesize these responses into a single, high-quality response. It is crucial to critically evaluate the information provided in these responses, recognizing that some of it may be biased or incorrect. Your response should not simply replicate the given answers but should offer a refined, accurate, and comprehensive reply to the instruction. Ensure your response is well-structured, coherent, and adheres to the highest standards of accuracy and reliability. - -Responses from models: {{responses}}""" + template = DEFAULT_MOA_GENERATION_PROMPT_TEMPLATE content = moa_response_generation_template( template, @@ -590,19 +504,9 @@ Responses from models: {{responses}}""" } try: - payload = process_pipeline_inlet_filter(request, payload, user, models) + return await generate_chat_completion(request, form_data=payload, user=user) except Exception as e: - if len(e.args) > 1: - return JSONResponse( - status_code=e.args[0], - content={"detail": e.args[1]}, - ) - else: - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content={"detail": str(e)}, - ) - if "chat_id" in payload: - del payload["chat_id"] - - return await generate_chat_completion(request, form_data=payload, user=user) + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"detail": str(e)}, + ) diff --git a/backend/open_webui/utils/chat.py b/backend/open_webui/utils/chat.py index f127e6eb4..96d7693b5 100644 --- a/backend/open_webui/utils/chat.py +++ b/backend/open_webui/utils/chat.py @@ -3,7 +3,7 @@ import logging import sys from aiocache import cached -from typing import Any +from typing import Any, Optional import random import json import inspect @@ -11,11 +11,13 @@ import inspect from fastapi import Request from starlette.responses import Response, StreamingResponse + +from open_webui.models.users import UserModel + from open_webui.socket.main import ( get_event_call, get_event_emitter, ) - from open_webui.functions import generate_function_chat_completion from open_webui.routers.openai import ( @@ -27,22 +29,22 @@ from open_webui.routers.ollama import ( ) from open_webui.routers.pipelines import ( + process_pipeline_inlet_filter, process_pipeline_outlet_filter, ) - from open_webui.models.functions import Functions from open_webui.models.models import Models from open_webui.utils.plugin import load_function_module_by_id -from open_webui.utils.access_control import has_access -from open_webui.utils.models import get_all_models +from open_webui.utils.models import get_all_models, check_model_access from open_webui.utils.payload import convert_payload_openai_to_ollama from open_webui.utils.response import ( convert_response_ollama_to_openai, convert_streaming_response_ollama_to_openai, ) + from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL, BYPASS_MODEL_ACCESS_CONTROL @@ -66,30 +68,20 @@ async def generate_chat_completion( if model_id not in models: raise Exception("Model not found") + # Process the form_data through the pipeline + try: + form_data = process_pipeline_inlet_filter(request, form_data, user, models) + except Exception as e: + raise e + model = models[model_id] # Check if user has access to the model if not bypass_filter and user.role == "user": - if model.get("arena"): - if not has_access( - user.id, - type="read", - access_control=model.get("info", {}) - .get("meta", {}) - .get("access_control", {}), - ): - raise Exception("Model not found") - else: - model_info = Models.get_model_by_id(model_id) - if not model_info: - raise Exception("Model not found") - elif not ( - user.id == model_info.user_id - or has_access( - user.id, type="read", access_control=model_info.access_control - ) - ): - raise Exception("Model not found") + try: + check_model_access(user, model) + except Exception as e: + raise e if model["owned_by"] == "arena": model_ids = model.get("info", {}).get("meta", {}).get("model_ids") diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py new file mode 100644 index 000000000..9c5186e00 --- /dev/null +++ b/backend/open_webui/utils/middleware.py @@ -0,0 +1,507 @@ +import time +import logging +import sys + +from aiocache import cached +from typing import Any, Optional +import random +import json +import inspect + +from fastapi import Request +from starlette.responses import Response, StreamingResponse + + +from open_webui.socket.main import ( + get_event_call, + get_event_emitter, +) +from open_webui.routers.tasks import generate_queries + + +from open_webui.models.users import UserModel +from open_webui.models.functions import Functions +from open_webui.models.models import Models + +from open_webui.retrieval.utils import get_sources_from_files + + +from open_webui.utils.chat import generate_chat_completion +from open_webui.utils.task import ( + get_task_model_id, + rag_template, + tools_function_calling_generation_template, +) +from open_webui.utils.misc import ( + add_or_update_system_message, + get_last_user_message, + prepend_to_first_user_message_content, +) +from open_webui.utils.tools import get_tools +from open_webui.utils.plugin import load_function_module_by_id + + +from open_webui.config import DEFAULT_TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE +from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL, BYPASS_MODEL_ACCESS_CONTROL +from open_webui.constants import TASKS + + +logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL) +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["MAIN"]) + + +async def chat_completion_filter_functions_handler(request, body, model, extra_params): + skip_files = None + + def get_filter_function_ids(model): + def get_priority(function_id): + function = Functions.get_function_by_id(function_id) + if function is not None and hasattr(function, "valves"): + # TODO: Fix FunctionModel + return (function.valves if function.valves else {}).get("priority", 0) + return 0 + + filter_ids = [ + function.id for function in Functions.get_global_filter_functions() + ] + if "info" in model and "meta" in model["info"]: + filter_ids.extend(model["info"]["meta"].get("filterIds", [])) + filter_ids = list(set(filter_ids)) + + enabled_filter_ids = [ + function.id + for function in Functions.get_functions_by_type("filter", active_only=True) + ] + + filter_ids = [ + filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids + ] + + filter_ids.sort(key=get_priority) + return filter_ids + + filter_ids = get_filter_function_ids(model) + for filter_id in filter_ids: + filter = Functions.get_function_by_id(filter_id) + if not filter: + continue + + if filter_id in request.app.state.FUNCTIONS: + function_module = request.app.state.FUNCTIONS[filter_id] + else: + function_module, _, _ = load_function_module_by_id(filter_id) + request.app.state.FUNCTIONS[filter_id] = function_module + + # Check if the function has a file_handler variable + if hasattr(function_module, "file_handler"): + skip_files = function_module.file_handler + + # Apply valves to the function + if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): + valves = Functions.get_function_valves_by_id(filter_id) + function_module.valves = function_module.Valves( + **(valves if valves else {}) + ) + + if hasattr(function_module, "inlet"): + try: + inlet = function_module.inlet + + # Create a dictionary of parameters to be passed to the function + params = {"body": body} | { + k: v + for k, v in { + **extra_params, + "__model__": model, + "__id__": filter_id, + }.items() + if k in inspect.signature(inlet).parameters + } + + if "__user__" in params and hasattr(function_module, "UserValves"): + try: + params["__user__"]["valves"] = function_module.UserValves( + **Functions.get_user_valves_by_id_and_user_id( + filter_id, params["__user__"]["id"] + ) + ) + except Exception as e: + print(e) + + if inspect.iscoroutinefunction(inlet): + body = await inlet(**params) + else: + body = inlet(**params) + + except Exception as e: + print(f"Error: {e}") + raise e + + if skip_files and "files" in body.get("metadata", {}): + del body["metadata"]["files"] + + return body, {} + + +async def chat_completion_tools_handler( + request: Request, body: dict, user: UserModel, models, extra_params: dict +) -> tuple[dict, dict]: + async def get_content_from_response(response) -> Optional[str]: + content = None + if hasattr(response, "body_iterator"): + async for chunk in response.body_iterator: + data = json.loads(chunk.decode("utf-8")) + content = data["choices"][0]["message"]["content"] + + # Cleanup any remaining background tasks if necessary + if response.background is not None: + await response.background() + else: + content = response["choices"][0]["message"]["content"] + return content + + def get_tools_function_calling_payload(messages, task_model_id, content): + user_message = get_last_user_message(messages) + history = "\n".join( + f"{message['role'].upper()}: \"\"\"{message['content']}\"\"\"" + for message in messages[::-1][:4] + ) + + prompt = f"History:\n{history}\nQuery: {user_message}" + + return { + "model": task_model_id, + "messages": [ + {"role": "system", "content": content}, + {"role": "user", "content": f"Query: {prompt}"}, + ], + "stream": False, + "metadata": {"task": str(TASKS.FUNCTION_CALLING)}, + } + + # If tool_ids field is present, call the functions + metadata = body.get("metadata", {}) + + tool_ids = metadata.get("tool_ids", None) + log.debug(f"{tool_ids=}") + if not tool_ids: + return body, {} + + skip_files = False + sources = [] + + task_model_id = get_task_model_id( + body["model"], + request.app.state.config.TASK_MODEL, + request.app.state.config.TASK_MODEL_EXTERNAL, + models, + ) + tools = get_tools( + request, + tool_ids, + user, + { + **extra_params, + "__model__": models[task_model_id], + "__messages__": body["messages"], + "__files__": metadata.get("files", []), + }, + ) + log.info(f"{tools=}") + + specs = [tool["spec"] for tool in tools.values()] + tools_specs = json.dumps(specs) + + if request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE != "": + template = request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE + else: + template = DEFAULT_TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE + + tools_function_calling_prompt = tools_function_calling_generation_template( + template, tools_specs + ) + log.info(f"{tools_function_calling_prompt=}") + payload = get_tools_function_calling_payload( + body["messages"], task_model_id, tools_function_calling_prompt + ) + + try: + response = await generate_chat_completion(request, form_data=payload, user=user) + log.debug(f"{response=}") + content = await get_content_from_response(response) + log.debug(f"{content=}") + + if not content: + return body, {} + + try: + content = content[content.find("{") : content.rfind("}") + 1] + if not content: + raise Exception("No JSON object found in the response") + + result = json.loads(content) + + tool_function_name = result.get("name", None) + if tool_function_name not in tools: + return body, {} + + tool_function_params = result.get("parameters", {}) + + try: + required_params = ( + tools[tool_function_name] + .get("spec", {}) + .get("parameters", {}) + .get("required", []) + ) + tool_function = tools[tool_function_name]["callable"] + tool_function_params = { + k: v + for k, v in tool_function_params.items() + if k in required_params + } + tool_output = await tool_function(**tool_function_params) + + except Exception as e: + tool_output = str(e) + + if isinstance(tool_output, str): + if tools[tool_function_name]["citation"]: + sources.append( + { + "source": { + "name": f"TOOL:{tools[tool_function_name]['toolkit_id']}/{tool_function_name}" + }, + "document": [tool_output], + "metadata": [ + { + "source": f"TOOL:{tools[tool_function_name]['toolkit_id']}/{tool_function_name}" + } + ], + } + ) + else: + sources.append( + { + "source": {}, + "document": [tool_output], + "metadata": [ + { + "source": f"TOOL:{tools[tool_function_name]['toolkit_id']}/{tool_function_name}" + } + ], + } + ) + + if tools[tool_function_name]["file_handler"]: + skip_files = True + + except Exception as e: + log.exception(f"Error: {e}") + content = None + except Exception as e: + log.exception(f"Error: {e}") + content = None + + log.debug(f"tool_contexts: {sources}") + + if skip_files and "files" in body.get("metadata", {}): + del body["metadata"]["files"] + + return body, {"sources": sources} + + +async def chat_completion_files_handler( + request: Request, body: dict, user: UserModel +) -> tuple[dict, dict[str, list]]: + sources = [] + + if files := body.get("metadata", {}).get("files", None): + try: + queries_response = await generate_queries( + { + "model": body["model"], + "messages": body["messages"], + "type": "retrieval", + }, + user, + ) + queries_response = queries_response["choices"][0]["message"]["content"] + + try: + bracket_start = queries_response.find("{") + bracket_end = queries_response.rfind("}") + 1 + + if bracket_start == -1 or bracket_end == -1: + raise Exception("No JSON object found in the response") + + queries_response = queries_response[bracket_start:bracket_end] + queries_response = json.loads(queries_response) + except Exception as e: + queries_response = {"queries": [queries_response]} + + queries = queries_response.get("queries", []) + except Exception as e: + queries = [] + + if len(queries) == 0: + queries = [get_last_user_message(body["messages"])] + + sources = get_sources_from_files( + files=files, + queries=queries, + embedding_function=request.app.state.EMBEDDING_FUNCTION, + k=request.app.state.config.TOP_K, + reranking_function=request.app.state.rf, + r=request.app.state.config.RELEVANCE_THRESHOLD, + hybrid_search=request.app.state.config.ENABLE_RAG_HYBRID_SEARCH, + ) + + log.debug(f"rag_contexts:sources: {sources}") + return body, {"sources": sources} + + +async def process_chat_payload(request, form_data, user, model): + metadata = { + "chat_id": form_data.pop("chat_id", None), + "message_id": form_data.pop("id", None), + "session_id": form_data.pop("session_id", None), + "tool_ids": form_data.get("tool_ids", None), + "files": form_data.get("files", None), + } + form_data["metadata"] = metadata + + extra_params = { + "__event_emitter__": get_event_emitter(metadata), + "__event_call__": get_event_call(metadata), + "__user__": { + "id": user.id, + "email": user.email, + "name": user.name, + "role": user.role, + }, + "__metadata__": metadata, + } + + # Initialize events to store additional event to be sent to the client + # Initialize contexts and citation + models = request.app.state.MODELS + events = [] + sources = [] + + try: + form_data, flags = await chat_completion_filter_functions_handler( + request, form_data, model, extra_params + ) + except Exception as e: + return Exception(f"Error: {e}") + + tool_ids = form_data.pop("tool_ids", None) + files = form_data.pop("files", None) + + metadata = { + **metadata, + "tool_ids": tool_ids, + "files": files, + } + form_data["metadata"] = metadata + + try: + form_data, flags = await chat_completion_tools_handler( + request, form_data, user, models, extra_params + ) + sources.extend(flags.get("sources", [])) + except Exception as e: + log.exception(e) + + try: + form_data, flags = await chat_completion_files_handler(request, form_data, user) + sources.extend(flags.get("sources", [])) + except Exception as e: + log.exception(e) + + # If context is not empty, insert it into the messages + if len(sources) > 0: + context_string = "" + for source_idx, source in enumerate(sources): + source_id = source.get("source", {}).get("name", "") + + if "document" in source: + for doc_idx, doc_context in enumerate(source["document"]): + metadata = source.get("metadata") + doc_source_id = None + + if metadata: + doc_source_id = metadata[doc_idx].get("source", source_id) + + if source_id: + context_string += f"{doc_source_id if doc_source_id is not None else source_id}{doc_context}\n" + else: + # If there is no source_id, then do not include the source_id tag + context_string += f"{doc_context}\n" + + context_string = context_string.strip() + prompt = get_last_user_message(form_data["messages"]) + + if prompt is None: + raise Exception("No user message found") + if ( + request.app.state.config.RELEVANCE_THRESHOLD == 0 + and context_string.strip() == "" + ): + log.debug( + f"With a 0 relevancy threshold for RAG, the context cannot be empty" + ) + + # Workaround for Ollama 2.0+ system prompt issue + # TODO: replace with add_or_update_system_message + if model["owned_by"] == "ollama": + form_data["messages"] = prepend_to_first_user_message_content( + rag_template( + request.app.state.config.RAG_TEMPLATE, context_string, prompt + ), + form_data["messages"], + ) + else: + form_data["messages"] = add_or_update_system_message( + rag_template( + request.app.state.config.RAG_TEMPLATE, context_string, prompt + ), + form_data["messages"], + ) + + # If there are citations, add them to the data_items + sources = [source for source in sources if source.get("source", {}).get("name", "")] + + if len(sources) > 0: + events.append({"sources": sources}) + + return form_data, events + + +async def process_chat_response(response, events): + if not isinstance(response, StreamingResponse): + return response + + content_type = response.headers["Content-Type"] + is_openai = "text/event-stream" in content_type + is_ollama = "application/x-ndjson" in content_type + + if not is_openai and not is_ollama: + return response + + async def stream_wrapper(original_generator, events): + def wrap_item(item): + return f"data: {item}\n\n" if is_openai else f"{item}\n" + + for event in events: + yield wrap_item(json.dumps(event)) + + async for data in original_generator: + yield data + + return StreamingResponse( + stream_wrapper(response.body_iterator, events), + headers=dict(response.headers), + ) diff --git a/backend/open_webui/utils/models.py b/backend/open_webui/utils/models.py index 7fdd8b605..b135f8997 100644 --- a/backend/open_webui/utils/models.py +++ b/backend/open_webui/utils/models.py @@ -14,6 +14,7 @@ from open_webui.models.models import Models from open_webui.utils.plugin import load_function_module_by_id +from open_webui.utils.access_control import has_access from open_webui.config import ( @@ -220,3 +221,26 @@ async def get_all_models(request): request.app.state.MODELS = {model["id"]: model for model in models} return models + + +def check_model_access(user, model): + if model.get("arena"): + if not has_access( + user.id, + type="read", + access_control=model.get("info", {}) + .get("meta", {}) + .get("access_control", {}), + ): + raise Exception("Model not found") + else: + model_info = Models.get_model_by_id(model.get("id")) + if not model_info: + raise Exception("Model not found") + elif not ( + user.id == model_info.user_id + or has_access( + user.id, type="read", access_control=model_info.access_control + ) + ): + raise Exception("Model not found") From 9a081c8593533fccce50600378a68f433019fc7a Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Thu, 12 Dec 2024 22:32:28 -0800 Subject: [PATCH 24/26] refac --- backend/open_webui/main.py | 2 ++ backend/open_webui/routers/files.py | 14 +++++++++----- backend/open_webui/routers/knowledge.py | 10 ++++++++-- 3 files changed, 19 insertions(+), 7 deletions(-) diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index ce9d47959..31604984f 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -812,6 +812,8 @@ async def chat_completion( user=Depends(get_verified_user), bypass_filter: bool = False, ): + if not request.app.state.MODELS: + await get_all_models(request) try: model_id = form_data.get("model", None) diff --git a/backend/open_webui/routers/files.py b/backend/open_webui/routers/files.py index bc553ab26..e56eef273 100644 --- a/backend/open_webui/routers/files.py +++ b/backend/open_webui/routers/files.py @@ -21,7 +21,7 @@ from open_webui.env import SRC_LOG_LEVELS from open_webui.constants import ERROR_MESSAGES -from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status +from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status, Request from fastapi.responses import FileResponse, StreamingResponse @@ -39,7 +39,9 @@ router = APIRouter() @router.post("/", response_model=FileModelResponse) -def upload_file(file: UploadFile = File(...), user=Depends(get_verified_user)): +def upload_file( + request: Request, file: UploadFile = File(...), user=Depends(get_verified_user) +): log.info(f"file.content_type: {file.content_type}") try: unsanitized_filename = file.filename @@ -68,7 +70,7 @@ def upload_file(file: UploadFile = File(...), user=Depends(get_verified_user)): ) try: - process_file(ProcessFileForm(file_id=id)) + process_file(request, ProcessFileForm(file_id=id)) file_item = Files.get_file_by_id(id=id) except Exception as e: log.exception(e) @@ -183,13 +185,15 @@ class ContentForm(BaseModel): @router.post("/{id}/data/content/update") async def update_file_data_content_by_id( - id: str, form_data: ContentForm, user=Depends(get_verified_user) + request: Request, id: str, form_data: ContentForm, user=Depends(get_verified_user) ): file = Files.get_file_by_id(id) if file and (file.user_id == user.id or user.role == "admin"): try: - process_file(ProcessFileForm(file_id=id, content=form_data.content)) + process_file( + request, ProcessFileForm(file_id=id, content=form_data.content) + ) file = Files.get_file_by_id(id=id) except Exception as e: log.exception(e) diff --git a/backend/open_webui/routers/knowledge.py b/backend/open_webui/routers/knowledge.py index 0f4dd9283..7f9947d7a 100644 --- a/backend/open_webui/routers/knowledge.py +++ b/backend/open_webui/routers/knowledge.py @@ -242,6 +242,7 @@ class KnowledgeFileIdForm(BaseModel): @router.post("/{id}/file/add", response_model=Optional[KnowledgeFilesResponse]) def add_file_to_knowledge_by_id( + request: Request, id: str, form_data: KnowledgeFileIdForm, user=Depends(get_verified_user), @@ -274,7 +275,9 @@ def add_file_to_knowledge_by_id( # Add content to the vector database try: - process_file(ProcessFileForm(file_id=form_data.file_id, collection_name=id)) + process_file( + request, ProcessFileForm(file_id=form_data.file_id, collection_name=id) + ) except Exception as e: log.debug(e) raise HTTPException( @@ -318,6 +321,7 @@ def add_file_to_knowledge_by_id( @router.post("/{id}/file/update", response_model=Optional[KnowledgeFilesResponse]) def update_file_from_knowledge_by_id( + request: Request, id: str, form_data: KnowledgeFileIdForm, user=Depends(get_verified_user), @@ -349,7 +353,9 @@ def update_file_from_knowledge_by_id( # Add content to the vector database try: - process_file(ProcessFileForm(file_id=form_data.file_id, collection_name=id)) + process_file( + request, ProcessFileForm(file_id=form_data.file_id, collection_name=id) + ) except Exception as e: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, From f9a05dd1e10716b55905a0e9598dbf25890406aa Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Thu, 12 Dec 2024 23:31:08 -0800 Subject: [PATCH 25/26] refac --- backend/open_webui/utils/chat.py | 2 +- backend/open_webui/utils/misc.py | 10 +- backend/open_webui/utils/response.py | 57 +++++++- src/lib/apis/streaming/index.ts | 22 ++- src/lib/components/chat/Chat.svelte | 134 +++++++++--------- .../chat/Messages/ResponseMessage.svelte | 109 +++++--------- 6 files changed, 190 insertions(+), 144 deletions(-) diff --git a/backend/open_webui/utils/chat.py b/backend/open_webui/utils/chat.py index 96d7693b5..676a4a203 100644 --- a/backend/open_webui/utils/chat.py +++ b/backend/open_webui/utils/chat.py @@ -136,7 +136,7 @@ async def generate_chat_completion( response = await generate_ollama_chat_completion( request=request, form_data=form_data, user=user, bypass_filter=bypass_filter ) - if form_data.stream: + if form_data.get("stream"): response.headers["content-type"] = "text/event-stream" return StreamingResponse( convert_streaming_response_ollama_to_openai(response), diff --git a/backend/open_webui/utils/misc.py b/backend/open_webui/utils/misc.py index a5af492ba..aba696f60 100644 --- a/backend/open_webui/utils/misc.py +++ b/backend/open_webui/utils/misc.py @@ -106,7 +106,7 @@ def openai_chat_message_template(model: str): def openai_chat_chunk_message_template( - model: str, message: Optional[str] = None + model: str, message: Optional[str] = None, usage: Optional[dict] = None ) -> dict: template = openai_chat_message_template(model) template["object"] = "chat.completion.chunk" @@ -114,17 +114,23 @@ def openai_chat_chunk_message_template( template["choices"][0]["delta"] = {"content": message} else: template["choices"][0]["finish_reason"] = "stop" + + if usage: + template["usage"] = usage return template def openai_chat_completion_message_template( - model: str, message: Optional[str] = None + model: str, message: Optional[str] = None, usage: Optional[dict] = None ) -> dict: template = openai_chat_message_template(model) template["object"] = "chat.completion" if message is not None: template["choices"][0]["message"] = {"content": message, "role": "assistant"} template["choices"][0]["finish_reason"] = "stop" + + if usage: + template["usage"] = usage return template diff --git a/backend/open_webui/utils/response.py b/backend/open_webui/utils/response.py index b8501e92c..891016e43 100644 --- a/backend/open_webui/utils/response.py +++ b/backend/open_webui/utils/response.py @@ -21,8 +21,63 @@ async def convert_streaming_response_ollama_to_openai(ollama_streaming_response) message_content = data.get("message", {}).get("content", "") done = data.get("done", False) + usage = None + if done: + usage = { + "response_token/s": ( + round( + ( + ( + data.get("eval_count", 0) + / ((data.get("eval_duration", 0) / 1_000_000_000)) + ) + * 100 + ), + 2, + ) + if data.get("eval_duration", 0) > 0 + else "N/A" + ), + "prompt_token/s": ( + round( + ( + ( + data.get("prompt_eval_count", 0) + / ( + ( + data.get("prompt_eval_duration", 0) + / 1_000_000_000 + ) + ) + ) + * 100 + ), + 2, + ) + if data.get("prompt_eval_duration", 0) > 0 + else "N/A" + ), + "total_duration": round( + ((data.get("total_duration", 0) / 1_000_000) * 100), 2 + ), + "load_duration": round( + ((data.get("load_duration", 0) / 1_000_000) * 100), 2 + ), + "prompt_eval_count": data.get("prompt_eval_count", 0), + "prompt_eval_duration": round( + ((data.get("prompt_eval_duration", 0) / 1_000_000) * 100), 2 + ), + "eval_count": data.get("eval_count", 0), + "eval_duration": round( + ((data.get("eval_duration", 0) / 1_000_000) * 100), 2 + ), + "approximate_total": ( + lambda s: f"{s // 3600}h{(s % 3600) // 60}m{s % 60}s" + )((data.get("total_duration", 0) or 0) // 1_000_000_000), + } + data = openai_chat_chunk_message_template( - model, message_content if not done else None + model, message_content if not done else None, usage ) line = f"data: {json.dumps(data)}\n\n" diff --git a/src/lib/apis/streaming/index.ts b/src/lib/apis/streaming/index.ts index 54804385d..5617ce36c 100644 --- a/src/lib/apis/streaming/index.ts +++ b/src/lib/apis/streaming/index.ts @@ -77,10 +77,14 @@ async function* openAIStreamToIterator( continue; } + if (parsedData.usage) { + yield { done: false, value: '', usage: parsedData.usage }; + continue; + } + yield { done: false, value: parsedData.choices?.[0]?.delta?.content ?? '', - usage: parsedData.usage }; } catch (e) { console.error('Error extracting delta from SSE event:', e); @@ -98,10 +102,26 @@ async function* streamLargeDeltasAsRandomChunks( yield textStreamUpdate; return; } + + if (textStreamUpdate.error) { + yield textStreamUpdate; + continue; + } if (textStreamUpdate.sources) { yield textStreamUpdate; continue; } + if (textStreamUpdate.selectedModelId) { + yield textStreamUpdate; + continue; + } + if (textStreamUpdate.usage) { + yield textStreamUpdate; + continue; + } + + + let content = textStreamUpdate.value; if (content.length < 5) { yield { done: false, value: content }; diff --git a/src/lib/components/chat/Chat.svelte b/src/lib/components/chat/Chat.svelte index e6a653420..a55cbc87b 100644 --- a/src/lib/components/chat/Chat.svelte +++ b/src/lib/components/chat/Chat.svelte @@ -455,41 +455,43 @@ ////////////////////////// const initNewChat = async () => { - if (sessionStorage.selectedModels) { - selectedModels = JSON.parse(sessionStorage.selectedModels); - sessionStorage.removeItem('selectedModels'); - } else { - if ($page.url.searchParams.get('models')) { - selectedModels = $page.url.searchParams.get('models')?.split(','); - } else if ($page.url.searchParams.get('model')) { - const urlModels = $page.url.searchParams.get('model')?.split(','); + if ($page.url.searchParams.get('models')) { + selectedModels = $page.url.searchParams.get('models')?.split(','); + } else if ($page.url.searchParams.get('model')) { + const urlModels = $page.url.searchParams.get('model')?.split(','); - if (urlModels.length === 1) { - const m = $models.find((m) => m.id === urlModels[0]); - if (!m) { - const modelSelectorButton = document.getElementById('model-selector-0-button'); - if (modelSelectorButton) { - modelSelectorButton.click(); - await tick(); + if (urlModels.length === 1) { + const m = $models.find((m) => m.id === urlModels[0]); + if (!m) { + const modelSelectorButton = document.getElementById('model-selector-0-button'); + if (modelSelectorButton) { + modelSelectorButton.click(); + await tick(); - const modelSelectorInput = document.getElementById('model-search-input'); - if (modelSelectorInput) { - modelSelectorInput.focus(); - modelSelectorInput.value = urlModels[0]; - modelSelectorInput.dispatchEvent(new Event('input')); - } + const modelSelectorInput = document.getElementById('model-search-input'); + if (modelSelectorInput) { + modelSelectorInput.focus(); + modelSelectorInput.value = urlModels[0]; + modelSelectorInput.dispatchEvent(new Event('input')); } - } else { - selectedModels = urlModels; } } else { selectedModels = urlModels; } - } else if ($settings?.models) { - selectedModels = $settings?.models; - } else if ($config?.default_models) { - console.log($config?.default_models.split(',') ?? ''); - selectedModels = $config?.default_models.split(','); + } else { + selectedModels = urlModels; + } + } else { + if (sessionStorage.selectedModels) { + selectedModels = JSON.parse(sessionStorage.selectedModels); + sessionStorage.removeItem('selectedModels'); + } else { + if ($settings?.models) { + selectedModels = $settings?.models; + } else if ($config?.default_models) { + console.log($config?.default_models.split(',') ?? ''); + selectedModels = $config?.default_models.split(','); + } } } @@ -1056,11 +1058,14 @@ } let _response = null; - if (model?.owned_by === 'ollama') { - _response = await sendPromptOllama(model, prompt, responseMessageId, _chatId); - } else if (model) { - _response = await sendPromptOpenAI(model, prompt, responseMessageId, _chatId); - } + + // if (model?.owned_by === 'ollama') { + // _response = await sendPromptOllama(model, prompt, responseMessageId, _chatId); + // } else if (model) { + // } + + _response = await sendPromptOpenAI(model, prompt, responseMessageId, _chatId); + _responses.push(_response); if (chatEventEmitter) clearInterval(chatEventEmitter); @@ -1207,24 +1212,14 @@ $settings?.params?.stream_response ?? params?.stream_response ?? true; + const [res, controller] = await generateChatCompletion(localStorage.token, { stream: stream, model: model.id, messages: messagesBody, - options: { - ...{ ...($settings?.params ?? {}), ...params }, - stop: - (params?.stop ?? $settings?.params?.stop ?? undefined) - ? (params?.stop.split(',').map((token) => token.trim()) ?? $settings.params.stop).map( - (str) => decodeURIComponent(JSON.parse('"' + str.replace(/\"/g, '\\"') + '"')) - ) - : undefined, - num_predict: params?.max_tokens ?? $settings?.params?.max_tokens ?? undefined, - repeat_penalty: - params?.frequency_penalty ?? $settings?.params?.frequency_penalty ?? undefined - }, format: $settings.requestFormat ?? undefined, keep_alive: $settings.keepAlive ?? undefined, + tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined, files: files.length > 0 ? files : undefined, session_id: $socket?.id, @@ -1542,13 +1537,6 @@ { stream: stream, model: model.id, - ...(stream && (model.info?.meta?.capabilities?.usage ?? false) - ? { - stream_options: { - include_usage: true - } - } - : {}), messages: [ params?.system || $settings.system || (responseMessage?.userContext ?? null) ? { @@ -1593,23 +1581,36 @@ content: message?.merged?.content ?? message.content }) })), - seed: params?.seed ?? $settings?.params?.seed ?? undefined, - stop: - (params?.stop ?? $settings?.params?.stop ?? undefined) - ? (params?.stop.split(',').map((token) => token.trim()) ?? $settings.params.stop).map( - (str) => decodeURIComponent(JSON.parse('"' + str.replace(/\"/g, '\\"') + '"')) - ) - : undefined, - temperature: params?.temperature ?? $settings?.params?.temperature ?? undefined, - top_p: params?.top_p ?? $settings?.params?.top_p ?? undefined, - frequency_penalty: - params?.frequency_penalty ?? $settings?.params?.frequency_penalty ?? undefined, - max_tokens: params?.max_tokens ?? $settings?.params?.max_tokens ?? undefined, + + // params: { + // ...$settings?.params, + // ...params, + + // format: $settings.requestFormat ?? undefined, + // keep_alive: $settings.keepAlive ?? undefined, + // stop: + // (params?.stop ?? $settings?.params?.stop ?? undefined) + // ? ( + // params?.stop.split(',').map((token) => token.trim()) ?? $settings.params.stop + // ).map((str) => + // decodeURIComponent(JSON.parse('"' + str.replace(/\"/g, '\\"') + '"')) + // ) + // : undefined + // }, + tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined, files: files.length > 0 ? files : undefined, session_id: $socket?.id, chat_id: $chatId, - id: responseMessageId + id: responseMessageId, + + ...(stream && (model.info?.meta?.capabilities?.usage ?? false) + ? { + stream_options: { + include_usage: true + } + } + : {}) }, `${WEBUI_BASE_URL}/api` ); @@ -1636,6 +1637,7 @@ await handleOpenAIError(error, null, model, responseMessage); break; } + if (done || stopResponseFlag || _chatId !== $chatId) { responseMessage.done = true; history.messages[responseMessageId] = responseMessage; @@ -1648,7 +1650,7 @@ } if (usage) { - responseMessage.info = { ...usage, openai: true, usage }; + responseMessage.usage = usage; } if (selectedModelId) { diff --git a/src/lib/components/chat/Messages/ResponseMessage.svelte b/src/lib/components/chat/Messages/ResponseMessage.svelte index 2e883df93..76210f68c 100644 --- a/src/lib/components/chat/Messages/ResponseMessage.svelte +++ b/src/lib/components/chat/Messages/ResponseMessage.svelte @@ -932,82 +932,45 @@ {/if} - {#if message.info} + {#if message.usage} ${sanitizeResponseContent( - JSON.stringify(message.info.usage, null, 2) - .replace(/"([^(")"]+)":/g, '$1:') - .slice(1, -1) - .split('\n') - .map((line) => line.slice(2)) - .map((line) => (line.endsWith(',') ? line.slice(0, -1) : line)) - .join('\n') - )}` - : `prompt_tokens: ${message.info.prompt_tokens ?? 'N/A'}
- completion_tokens: ${message.info.completion_tokens ?? 'N/A'}
- total_tokens: ${message.info.total_tokens ?? 'N/A'}` - : `response_token/s: ${ - `${ - Math.round( - ((message.info.eval_count ?? 0) / - ((message.info.eval_duration ?? 0) / 1000000000)) * - 100 - ) / 100 - } tokens` ?? 'N/A' - }
- prompt_token/s: ${ - Math.round( - ((message.info.prompt_eval_count ?? 0) / - ((message.info.prompt_eval_duration ?? 0) / 1000000000)) * - 100 - ) / 100 ?? 'N/A' - } tokens
- total_duration: ${ - Math.round(((message.info.total_duration ?? 0) / 1000000) * 100) / 100 ?? 'N/A' - }ms
- load_duration: ${ - Math.round(((message.info.load_duration ?? 0) / 1000000) * 100) / 100 ?? 'N/A' - }ms
- prompt_eval_count: ${message.info.prompt_eval_count ?? 'N/A'}
- prompt_eval_duration: ${ - Math.round(((message.info.prompt_eval_duration ?? 0) / 1000000) * 100) / 100 ?? - 'N/A' - }ms
- eval_count: ${message.info.eval_count ?? 'N/A'}
- eval_duration: ${ - Math.round(((message.info.eval_duration ?? 0) / 1000000) * 100) / 100 ?? 'N/A' - }ms
- approximate_total: ${approximateToHumanReadable(message.info.total_duration ?? 0)}`} - placement="top" + content={message.usage + ? `
${sanitizeResponseContent(
+													JSON.stringify(message.usage, null, 2)
+														.replace(/"([^(")"]+)":/g, '$1:')
+														.slice(1, -1)
+														.split('\n')
+														.map((line) => line.slice(2))
+														.map((line) => (line.endsWith(',') ? line.slice(0, -1) : line))
+														.join('\n')
+												)}
` + : ''} + placement="bottom" > - - - + + +
{/if} From 6efca03a8f1613a1494b63fa16378f7f1c68ef85 Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Fri, 13 Dec 2024 22:51:43 -0800 Subject: [PATCH 26/26] refac --- backend/open_webui/functions.py | 1 + backend/open_webui/socket/main.py | 5 +++++ backend/open_webui/utils/chat.py | 2 ++ backend/open_webui/utils/middleware.py | 1 + 4 files changed, 9 insertions(+) diff --git a/backend/open_webui/functions.py b/backend/open_webui/functions.py index d424d4663..9c241432a 100644 --- a/backend/open_webui/functions.py +++ b/backend/open_webui/functions.py @@ -230,6 +230,7 @@ async def generate_function_chat_completion( "role": user.role, }, "__metadata__": metadata, + "__request__": request, } extra_params["__tools__"] = get_tools( request, diff --git a/backend/open_webui/socket/main.py b/backend/open_webui/socket/main.py index ba5eeb6ae..c0f45c9a0 100644 --- a/backend/open_webui/socket/main.py +++ b/backend/open_webui/socket/main.py @@ -173,6 +173,11 @@ async def user_count(sid): await sio.emit("user-count", {"count": len(USER_POOL.items())}) +@sio.on("chat") +async def chat(sid, data): + print("chat", sid, SESSION_POOL[sid], data) + + @sio.event async def disconnect(sid): if sid in SESSION_POOL: diff --git a/backend/open_webui/utils/chat.py b/backend/open_webui/utils/chat.py index 676a4a203..56904d1d8 100644 --- a/backend/open_webui/utils/chat.py +++ b/backend/open_webui/utils/chat.py @@ -237,6 +237,7 @@ async def chat_completed(request: Request, form_data: dict, user: Any): "__id__": filter_id, "__event_emitter__": __event_emitter__, "__event_call__": __event_call__, + "__request__": request, } # Add extra params in contained in function signature @@ -334,6 +335,7 @@ async def chat_action(request: Request, action_id: str, form_data: dict, user: A "__id__": sub_action_id if sub_action_id is not None else action_id, "__event_emitter__": __event_emitter__, "__event_call__": __event_call__, + "__request__": request, } # Add extra params in contained in function signature diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index 9c5186e00..1d2bc2b99 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -382,6 +382,7 @@ async def process_chat_payload(request, form_data, user, model): "role": user.role, }, "__metadata__": metadata, + "__request__": request, } # Initialize events to store additional event to be sent to the client