diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index c62fb3ac9..d3e323a4a 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -683,6 +683,17 @@ Path(UPLOAD_DIR).mkdir(parents=True, exist_ok=True) CACHE_DIR = f"{DATA_DIR}/cache" Path(CACHE_DIR).mkdir(parents=True, exist_ok=True) + +#################################### +# DIRECT CONNECTIONS +#################################### + +ENABLE_DIRECT_CONNECTIONS = PersistentConfig( + "ENABLE_DIRECT_CONNECTIONS", + "direct.enable", + os.environ.get("ENABLE_DIRECT_CONNECTIONS", "True").lower() == "true", +) + #################################### # OLLAMA_BASE_URL #################################### @@ -1326,6 +1337,54 @@ Your task is to synthesize these responses into a single, high-quality response. Responses from models: {{responses}}""" +#################################### +# Code Interpreter +#################################### + +ENABLE_CODE_INTERPRETER = PersistentConfig( + "ENABLE_CODE_INTERPRETER", + "code_interpreter.enable", + os.environ.get("ENABLE_CODE_INTERPRETER", "True").lower() == "true", +) + +CODE_INTERPRETER_ENGINE = PersistentConfig( + "CODE_INTERPRETER_ENGINE", + "code_interpreter.engine", + os.environ.get("CODE_INTERPRETER_ENGINE", "pyodide"), +) + +CODE_INTERPRETER_PROMPT_TEMPLATE = PersistentConfig( + "CODE_INTERPRETER_PROMPT_TEMPLATE", + "code_interpreter.prompt_template", + os.environ.get("CODE_INTERPRETER_PROMPT_TEMPLATE", ""), +) + +CODE_INTERPRETER_JUPYTER_URL = PersistentConfig( + "CODE_INTERPRETER_JUPYTER_URL", + "code_interpreter.jupyter.url", + os.environ.get("CODE_INTERPRETER_JUPYTER_URL", ""), +) + +CODE_INTERPRETER_JUPYTER_AUTH = PersistentConfig( + "CODE_INTERPRETER_JUPYTER_AUTH", + "code_interpreter.jupyter.auth", + os.environ.get("CODE_INTERPRETER_JUPYTER_AUTH", ""), +) + +CODE_INTERPRETER_JUPYTER_AUTH_TOKEN = PersistentConfig( + "CODE_INTERPRETER_JUPYTER_AUTH_TOKEN", + "code_interpreter.jupyter.auth_token", + os.environ.get("CODE_INTERPRETER_JUPYTER_AUTH_TOKEN", ""), +) + + +CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD = PersistentConfig( + "CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD", + "code_interpreter.jupyter.auth_password", + os.environ.get("CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD", ""), +) + + DEFAULT_CODE_INTERPRETER_PROMPT = """ #### Tools Available @@ -1336,9 +1395,8 @@ DEFAULT_CODE_INTERPRETER_PROMPT = """ - When coding, **always aim to print meaningful outputs** (e.g., results, tables, summaries, or visuals) to better interpret and verify the findings. Avoid relying on implicit outputs; prioritize explicit and clear print statements so the results are effectively communicated to the user. - After obtaining the printed output, **always provide a concise analysis, interpretation, or next steps to help the user understand the findings or refine the outcome further.** - If the results are unclear, unexpected, or require validation, refine the code and execute it again as needed. Always aim to deliver meaningful insights from the results, iterating if necessary. - - If a link is provided for an image, audio, or any file, include it in the response exactly as given to ensure the user has access to the original resource. + - **If a link to an image, audio, or any file is provided in markdown format in the output, ALWAYS regurgitate word for word, explicitly display it as part of the response to ensure the user can access it easily, do NOT change the link.** - All responses should be communicated in the chat's primary language, ensuring seamless understanding. If the chat is multilingual, default to English for clarity. - - **If a link to an image, audio, or any file is provided in markdown format, ALWAYS regurgitate explicitly display it as part of the response to ensure the user can access it easily, do NOT change the link.** Ensure that the tools are effectively utilized to achieve the highest-quality analysis for the user.""" @@ -1691,6 +1749,12 @@ MOJEEK_SEARCH_API_KEY = PersistentConfig( os.getenv("MOJEEK_SEARCH_API_KEY", ""), ) +BOCHA_SEARCH_API_KEY = PersistentConfig( + "BOCHA_SEARCH_API_KEY", + "rag.web.search.bocha_search_api_key", + os.getenv("BOCHA_SEARCH_API_KEY", ""), +) + SERPSTACK_API_KEY = PersistentConfig( "SERPSTACK_API_KEY", "rag.web.search.serpstack_api_key", diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index dfd5a254b..d6e2832f1 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -97,6 +97,16 @@ from open_webui.config import ( OPENAI_API_BASE_URLS, OPENAI_API_KEYS, OPENAI_API_CONFIGS, + # Direct Connections + ENABLE_DIRECT_CONNECTIONS, + # Code Interpreter + ENABLE_CODE_INTERPRETER, + CODE_INTERPRETER_ENGINE, + CODE_INTERPRETER_PROMPT_TEMPLATE, + CODE_INTERPRETER_JUPYTER_URL, + CODE_INTERPRETER_JUPYTER_AUTH, + CODE_INTERPRETER_JUPYTER_AUTH_TOKEN, + CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD, # Image AUTOMATIC1111_API_AUTH, AUTOMATIC1111_BASE_URL, @@ -183,6 +193,7 @@ from open_webui.config import ( EXA_API_KEY, KAGI_SEARCH_API_KEY, MOJEEK_SEARCH_API_KEY, + BOCHA_SEARCH_API_KEY, GOOGLE_PSE_API_KEY, GOOGLE_PSE_ENGINE_ID, GOOGLE_DRIVE_CLIENT_ID, @@ -325,7 +336,11 @@ class SPAStaticFiles(StaticFiles): return await super().get_response(path, scope) except (HTTPException, StarletteHTTPException) as ex: if ex.status_code == 404: - return await super().get_response("index.html", scope) + if path.endswith(".js"): + # Return 404 for javascript files + raise ex + else: + return await super().get_response("index.html", scope) else: raise ex @@ -392,6 +407,14 @@ app.state.config.OPENAI_API_CONFIGS = OPENAI_API_CONFIGS app.state.OPENAI_MODELS = {} +######################################## +# +# DIRECT CONNECTIONS +# +######################################## + +app.state.config.ENABLE_DIRECT_CONNECTIONS = ENABLE_DIRECT_CONNECTIONS + ######################################## # # WEBUI @@ -517,6 +540,7 @@ app.state.config.GOOGLE_PSE_ENGINE_ID = GOOGLE_PSE_ENGINE_ID app.state.config.BRAVE_SEARCH_API_KEY = BRAVE_SEARCH_API_KEY app.state.config.KAGI_SEARCH_API_KEY = KAGI_SEARCH_API_KEY app.state.config.MOJEEK_SEARCH_API_KEY = MOJEEK_SEARCH_API_KEY +app.state.config.BOCHA_SEARCH_API_KEY = BOCHA_SEARCH_API_KEY app.state.config.SERPSTACK_API_KEY = SERPSTACK_API_KEY app.state.config.SERPSTACK_HTTPS = SERPSTACK_HTTPS app.state.config.SERPER_API_KEY = SERPER_API_KEY @@ -574,6 +598,24 @@ app.state.EMBEDDING_FUNCTION = get_embedding_function( app.state.config.RAG_EMBEDDING_BATCH_SIZE, ) +######################################## +# +# CODE INTERPRETER +# +######################################## + +app.state.config.ENABLE_CODE_INTERPRETER = ENABLE_CODE_INTERPRETER +app.state.config.CODE_INTERPRETER_ENGINE = CODE_INTERPRETER_ENGINE +app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE = CODE_INTERPRETER_PROMPT_TEMPLATE + +app.state.config.CODE_INTERPRETER_JUPYTER_URL = CODE_INTERPRETER_JUPYTER_URL +app.state.config.CODE_INTERPRETER_JUPYTER_AUTH = CODE_INTERPRETER_JUPYTER_AUTH +app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN = ( + CODE_INTERPRETER_JUPYTER_AUTH_TOKEN +) +app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD = ( + CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD +) ######################################## # @@ -759,6 +801,7 @@ app.include_router(openai.router, prefix="/openai", tags=["openai"]) app.include_router(pipelines.router, prefix="/api/v1/pipelines", tags=["pipelines"]) app.include_router(tasks.router, prefix="/api/v1/tasks", tags=["tasks"]) app.include_router(images.router, prefix="/api/v1/images", tags=["images"]) + app.include_router(audio.router, prefix="/api/v1/audio", tags=["audio"]) app.include_router(retrieval.router, prefix="/api/v1/retrieval", tags=["retrieval"]) @@ -1017,15 +1060,17 @@ async def get_app_config(request: Request): "enable_websocket": ENABLE_WEBSOCKET_SUPPORT, **( { + "enable_direct_connections": app.state.config.ENABLE_DIRECT_CONNECTIONS, "enable_channels": app.state.config.ENABLE_CHANNELS, "enable_web_search": app.state.config.ENABLE_RAG_WEB_SEARCH, - "enable_google_drive_integration": app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION, + "enable_code_interpreter": app.state.config.ENABLE_CODE_INTERPRETER, "enable_image_generation": app.state.config.ENABLE_IMAGE_GENERATION, + "enable_autocomplete_generation": app.state.config.ENABLE_AUTOCOMPLETE_GENERATION, "enable_community_sharing": app.state.config.ENABLE_COMMUNITY_SHARING, "enable_message_rating": app.state.config.ENABLE_MESSAGE_RATING, - "enable_autocomplete_generation": app.state.config.ENABLE_AUTOCOMPLETE_GENERATION, "enable_admin_export": ENABLE_ADMIN_EXPORT, "enable_admin_chat_access": ENABLE_ADMIN_CHAT_ACCESS, + "enable_google_drive_integration": app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION, } if user is not None else {} diff --git a/backend/open_webui/models/users.py b/backend/open_webui/models/users.py index 5c196281f..605299528 100644 --- a/backend/open_webui/models/users.py +++ b/backend/open_webui/models/users.py @@ -271,6 +271,24 @@ class UsersTable: except Exception: return None + def update_user_settings_by_id(self, id: str, updated: dict) -> Optional[UserModel]: + try: + with get_db() as db: + user_settings = db.query(User).filter_by(id=id).first().settings + + if user_settings is None: + user_settings = {} + + user_settings.update(updated) + + db.query(User).filter_by(id=id).update({"settings": user_settings}) + db.commit() + + user = db.query(User).filter_by(id=id).first() + return UserModel.model_validate(user) + except Exception: + return None + def delete_user_by_id(self, id: str) -> bool: try: # Remove User from Groups diff --git a/backend/open_webui/retrieval/vector/dbs/opensearch.py b/backend/open_webui/retrieval/vector/dbs/opensearch.py index 41d634391..b8186b3f9 100644 --- a/backend/open_webui/retrieval/vector/dbs/opensearch.py +++ b/backend/open_webui/retrieval/vector/dbs/opensearch.py @@ -120,18 +120,12 @@ class OpenSearchClient: return None query_body = { - "query": { - "bool": { - "filter": [] - } - }, + "query": {"bool": {"filter": []}}, "_source": ["text", "metadata"], } for field, value in filter.items(): - query_body["query"]["bool"]["filter"].append({ - "term": {field: value} - }) + query_body["query"]["bool"]["filter"].append({"term": {field: value}}) size = limit if limit else 10 @@ -139,7 +133,7 @@ class OpenSearchClient: result = self.client.search( index=f"{self.index_prefix}_{collection_name}", body=query_body, - size=size + size=size, ) return self._result_to_get_result(result) diff --git a/backend/open_webui/retrieval/web/bocha.py b/backend/open_webui/retrieval/web/bocha.py new file mode 100644 index 000000000..98bdae704 --- /dev/null +++ b/backend/open_webui/retrieval/web/bocha.py @@ -0,0 +1,72 @@ +import logging +from typing import Optional + +import requests +import json +from open_webui.retrieval.web.main import SearchResult, get_filtered_results +from open_webui.env import SRC_LOG_LEVELS + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["RAG"]) + +def _parse_response(response): + result = {} + if "data" in response: + data = response["data"] + if "webPages" in data: + webPages = data["webPages"] + if "value" in webPages: + result["webpage"] = [ + { + "id": item.get("id", ""), + "name": item.get("name", ""), + "url": item.get("url", ""), + "snippet": item.get("snippet", ""), + "summary": item.get("summary", ""), + "siteName": item.get("siteName", ""), + "siteIcon": item.get("siteIcon", ""), + "datePublished": item.get("datePublished", "") or item.get("dateLastCrawled", ""), + } + for item in webPages["value"] + ] + return result + + +def search_bocha( + api_key: str, query: str, count: int, filter_list: Optional[list[str]] = None +) -> list[SearchResult]: + """Search using Bocha's Search API and return the results as a list of SearchResult objects. + + Args: + api_key (str): A Bocha Search API key + query (str): The query to search for + """ + url = "https://api.bochaai.com/v1/web-search?utm_source=ollama" + headers = { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json" + } + + payload = json.dumps({ + "query": query, + "summary": True, + "freshness": "noLimit", + "count": count + }) + + response = requests.post(url, headers=headers, data=payload, timeout=5) + response.raise_for_status() + results = _parse_response(response.json()) + print(results) + if filter_list: + results = get_filtered_results(results, filter_list) + + return [ + SearchResult( + link=result["url"], + title=result.get("name"), + snippet=result.get("summary") + ) + for result in results.get("webpage", [])[:count] + ] + diff --git a/backend/open_webui/retrieval/web/google_pse.py b/backend/open_webui/retrieval/web/google_pse.py index 2c51dd3c9..ea8225262 100644 --- a/backend/open_webui/retrieval/web/google_pse.py +++ b/backend/open_webui/retrieval/web/google_pse.py @@ -8,7 +8,6 @@ from open_webui.env import SRC_LOG_LEVELS log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["RAG"]) - def search_google_pse( api_key: str, search_engine_id: str, @@ -17,34 +16,51 @@ def search_google_pse( filter_list: Optional[list[str]] = None, ) -> list[SearchResult]: """Search using Google's Programmable Search Engine API and return the results as a list of SearchResult objects. + Handles pagination for counts greater than 10. Args: api_key (str): A Programmable Search Engine API key search_engine_id (str): A Programmable Search Engine ID query (str): The query to search for + count (int): The number of results to return (max 100, as PSE max results per query is 10 and max page is 10) + filter_list (Optional[list[str]], optional): A list of keywords to filter out from results. Defaults to None. + + Returns: + list[SearchResult]: A list of SearchResult objects. """ url = "https://www.googleapis.com/customsearch/v1" - headers = {"Content-Type": "application/json"} - params = { - "cx": search_engine_id, - "q": query, - "key": api_key, - "num": count, - } + all_results = [] + start_index = 1 # Google PSE start parameter is 1-based - response = requests.request("GET", url, headers=headers, params=params) - response.raise_for_status() + while count > 0: + num_results_this_page = min(count, 10) # Google PSE max results per page is 10 + params = { + "cx": search_engine_id, + "q": query, + "key": api_key, + "num": num_results_this_page, + "start": start_index, + } + response = requests.request("GET", url, headers=headers, params=params) + response.raise_for_status() + json_response = response.json() + results = json_response.get("items", []) + if results: # check if results are returned. If not, no more pages to fetch. + all_results.extend(results) + count -= len(results) # Decrement count by the number of results fetched in this page. + start_index += 10 # Increment start index for the next page + else: + break # No more results from Google PSE, break the loop - json_response = response.json() - results = json_response.get("items", []) if filter_list: - results = get_filtered_results(results, filter_list) + all_results = get_filtered_results(all_results, filter_list) + return [ SearchResult( link=result["link"], title=result.get("title"), snippet=result.get("snippet"), ) - for result in results + for result in all_results ] diff --git a/backend/open_webui/retrieval/web/jina_search.py b/backend/open_webui/retrieval/web/jina_search.py index 02af42ea6..a87293db5 100644 --- a/backend/open_webui/retrieval/web/jina_search.py +++ b/backend/open_webui/retrieval/web/jina_search.py @@ -25,13 +25,10 @@ def search_jina(api_key: str, query: str, count: int) -> list[SearchResult]: "Accept": "application/json", "Content-Type": "application/json", "Authorization": api_key, - "X-Retain-Images": "none" + "X-Retain-Images": "none", } - payload = { - "q": query, - "count": count if count <= 10 else 10 - } + payload = {"q": query, "count": count if count <= 10 else 10} url = str(URL(jina_search_endpoint)) response = requests.post(url, headers=headers, json=payload) diff --git a/backend/open_webui/routers/audio.py b/backend/open_webui/routers/audio.py index 7242042e2..e2d05ba90 100644 --- a/backend/open_webui/routers/audio.py +++ b/backend/open_webui/routers/audio.py @@ -560,10 +560,14 @@ def transcribe(request: Request, file_path): # Extract transcript from Deepgram response try: - transcript = response_data["results"]["channels"][0]["alternatives"][0].get("transcript", "") + transcript = response_data["results"]["channels"][0]["alternatives"][ + 0 + ].get("transcript", "") except (KeyError, IndexError) as e: log.error(f"Malformed response from Deepgram: {str(e)}") - raise Exception("Failed to parse Deepgram response - unexpected response format") + raise Exception( + "Failed to parse Deepgram response - unexpected response format" + ) data = {"text": transcript.strip()} # Save transcript diff --git a/backend/open_webui/routers/configs.py b/backend/open_webui/routers/configs.py index ef6c4d8c1..016075234 100644 --- a/backend/open_webui/routers/configs.py +++ b/backend/open_webui/routers/configs.py @@ -36,6 +36,98 @@ async def export_config(user=Depends(get_admin_user)): return get_config() +############################ +# Direct Connections Config +############################ + + +class DirectConnectionsConfigForm(BaseModel): + ENABLE_DIRECT_CONNECTIONS: bool + + +@router.get("/direct_connections", response_model=DirectConnectionsConfigForm) +async def get_direct_connections_config(request: Request, user=Depends(get_admin_user)): + return { + "ENABLE_DIRECT_CONNECTIONS": request.app.state.config.ENABLE_DIRECT_CONNECTIONS, + } + + +@router.post("/direct_connections", response_model=DirectConnectionsConfigForm) +async def set_direct_connections_config( + request: Request, + form_data: DirectConnectionsConfigForm, + user=Depends(get_admin_user), +): + request.app.state.config.ENABLE_DIRECT_CONNECTIONS = ( + form_data.ENABLE_DIRECT_CONNECTIONS + ) + return { + "ENABLE_DIRECT_CONNECTIONS": request.app.state.config.ENABLE_DIRECT_CONNECTIONS, + } + + +############################ +# CodeInterpreterConfig +############################ +class CodeInterpreterConfigForm(BaseModel): + ENABLE_CODE_INTERPRETER: bool + CODE_INTERPRETER_ENGINE: str + CODE_INTERPRETER_PROMPT_TEMPLATE: Optional[str] + CODE_INTERPRETER_JUPYTER_URL: Optional[str] + CODE_INTERPRETER_JUPYTER_AUTH: Optional[str] + CODE_INTERPRETER_JUPYTER_AUTH_TOKEN: Optional[str] + CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD: Optional[str] + + +@router.get("/code_interpreter", response_model=CodeInterpreterConfigForm) +async def get_code_interpreter_config(request: Request, user=Depends(get_admin_user)): + return { + "ENABLE_CODE_INTERPRETER": request.app.state.config.ENABLE_CODE_INTERPRETER, + "CODE_INTERPRETER_ENGINE": request.app.state.config.CODE_INTERPRETER_ENGINE, + "CODE_INTERPRETER_PROMPT_TEMPLATE": request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE, + "CODE_INTERPRETER_JUPYTER_URL": request.app.state.config.CODE_INTERPRETER_JUPYTER_URL, + "CODE_INTERPRETER_JUPYTER_AUTH": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH, + "CODE_INTERPRETER_JUPYTER_AUTH_TOKEN": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN, + "CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD, + } + + +@router.post("/code_interpreter", response_model=CodeInterpreterConfigForm) +async def set_code_interpreter_config( + request: Request, form_data: CodeInterpreterConfigForm, user=Depends(get_admin_user) +): + request.app.state.config.ENABLE_CODE_INTERPRETER = form_data.ENABLE_CODE_INTERPRETER + request.app.state.config.CODE_INTERPRETER_ENGINE = form_data.CODE_INTERPRETER_ENGINE + request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE = ( + form_data.CODE_INTERPRETER_PROMPT_TEMPLATE + ) + + request.app.state.config.CODE_INTERPRETER_JUPYTER_URL = ( + form_data.CODE_INTERPRETER_JUPYTER_URL + ) + + request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH = ( + form_data.CODE_INTERPRETER_JUPYTER_AUTH + ) + + request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN = ( + form_data.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN + ) + request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD = ( + form_data.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD + ) + + return { + "ENABLE_CODE_INTERPRETER": request.app.state.config.ENABLE_CODE_INTERPRETER, + "CODE_INTERPRETER_ENGINE": request.app.state.config.CODE_INTERPRETER_ENGINE, + "CODE_INTERPRETER_PROMPT_TEMPLATE": request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE, + "CODE_INTERPRETER_JUPYTER_URL": request.app.state.config.CODE_INTERPRETER_JUPYTER_URL, + "CODE_INTERPRETER_JUPYTER_AUTH": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH, + "CODE_INTERPRETER_JUPYTER_AUTH_TOKEN": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN, + "CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD, + } + + ############################ # SetDefaultModels ############################ diff --git a/backend/open_webui/routers/files.py b/backend/open_webui/routers/files.py index 7160c2e86..051321257 100644 --- a/backend/open_webui/routers/files.py +++ b/backend/open_webui/routers/files.py @@ -3,30 +3,22 @@ import os import uuid from pathlib import Path from typing import Optional -from pydantic import BaseModel -import mimetypes from urllib.parse import quote -from open_webui.storage.provider import Storage - +from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile, status +from fastapi.responses import FileResponse, StreamingResponse +from open_webui.constants import ERROR_MESSAGES +from open_webui.env import SRC_LOG_LEVELS from open_webui.models.files import ( FileForm, FileModel, FileModelResponse, Files, ) -from open_webui.routers.retrieval import process_file, ProcessFileForm - -from open_webui.config import UPLOAD_DIR -from open_webui.env import SRC_LOG_LEVELS -from open_webui.constants import ERROR_MESSAGES - - -from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status, Request -from fastapi.responses import FileResponse, StreamingResponse - - +from open_webui.routers.retrieval import ProcessFileForm, process_file +from open_webui.storage.provider import Storage from open_webui.utils.auth import get_admin_user, get_verified_user +from pydantic import BaseModel log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MODELS"]) @@ -41,7 +33,10 @@ router = APIRouter() @router.post("/", response_model=FileModelResponse) def upload_file( - request: Request, file: UploadFile = File(...), user=Depends(get_verified_user) + request: Request, + file: UploadFile = File(...), + user=Depends(get_verified_user), + file_metadata: dict = {}, ): log.info(f"file.content_type: {file.content_type}") try: @@ -65,6 +60,7 @@ def upload_file( "name": name, "content_type": file.content_type, "size": len(contents), + "data": file_metadata, }, } ), @@ -126,7 +122,7 @@ async def delete_all_files(user=Depends(get_admin_user)): Storage.delete_all_files() except Exception as e: log.exception(e) - log.error(f"Error deleting files") + log.error("Error deleting files") raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT("Error deleting files"), @@ -248,7 +244,7 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)): ) except Exception as e: log.exception(e) - log.error(f"Error getting file content") + log.error("Error getting file content") raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT("Error getting file content"), @@ -279,7 +275,7 @@ async def get_html_file_content_by_id(id: str, user=Depends(get_verified_user)): ) except Exception as e: log.exception(e) - log.error(f"Error getting file content") + log.error("Error getting file content") raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT("Error getting file content"), @@ -355,7 +351,7 @@ async def delete_file_by_id(id: str, user=Depends(get_verified_user)): Storage.delete_file(file.path) except Exception as e: log.exception(e) - log.error(f"Error deleting files") + log.error("Error deleting files") raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT("Error deleting files"), diff --git a/backend/open_webui/routers/images.py b/backend/open_webui/routers/images.py index 7afd9d106..4046773de 100644 --- a/backend/open_webui/routers/images.py +++ b/backend/open_webui/routers/images.py @@ -1,32 +1,26 @@ import asyncio import base64 +import io import json import logging import mimetypes import re -import uuid from pathlib import Path from typing import Optional import requests - - -from fastapi import Depends, FastAPI, HTTPException, Request, APIRouter -from fastapi.middleware.cors import CORSMiddleware -from pydantic import BaseModel - - +from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile from open_webui.config import CACHE_DIR from open_webui.constants import ERROR_MESSAGES -from open_webui.env import ENV, SRC_LOG_LEVELS, ENABLE_FORWARD_USER_INFO_HEADERS - +from open_webui.env import ENABLE_FORWARD_USER_INFO_HEADERS, SRC_LOG_LEVELS +from open_webui.routers.files import upload_file from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.utils.images.comfyui import ( ComfyUIGenerateImageForm, ComfyUIWorkflow, comfyui_generate_image, ) - +from pydantic import BaseModel log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["IMAGES"]) @@ -271,7 +265,6 @@ async def get_image_config(request: Request, user=Depends(get_admin_user)): async def update_image_config( request: Request, form_data: ImageConfigForm, user=Depends(get_admin_user) ): - set_image_model(request, form_data.MODEL) pattern = r"^\d+x\d+$" @@ -383,40 +376,22 @@ class GenerateImageForm(BaseModel): negative_prompt: Optional[str] = None -def save_b64_image(b64_str): +def load_b64_image_data(b64_str): try: - image_id = str(uuid.uuid4()) - if "," in b64_str: header, encoded = b64_str.split(",", 1) mime_type = header.split(";")[0] - img_data = base64.b64decode(encoded) - image_format = mimetypes.guess_extension(mime_type) - - image_filename = f"{image_id}{image_format}" - file_path = IMAGE_CACHE_DIR / f"{image_filename}" - with open(file_path, "wb") as f: - f.write(img_data) - return image_filename else: - image_filename = f"{image_id}.png" - file_path = IMAGE_CACHE_DIR.joinpath(image_filename) - + mime_type = "image/png" img_data = base64.b64decode(b64_str) - - # Write the image data to a file - with open(file_path, "wb") as f: - f.write(img_data) - return image_filename - + return img_data, mime_type except Exception as e: - log.exception(f"Error saving image: {e}") + log.exception(f"Error loading image data: {e}") return None -def save_url_image(url, headers=None): - image_id = str(uuid.uuid4()) +def load_url_image_data(url, headers=None): try: if headers: r = requests.get(url, headers=headers) @@ -426,18 +401,7 @@ def save_url_image(url, headers=None): r.raise_for_status() if r.headers["content-type"].split("/")[0] == "image": mime_type = r.headers["content-type"] - image_format = mimetypes.guess_extension(mime_type) - - if not image_format: - raise ValueError("Could not determine image type from MIME type") - - image_filename = f"{image_id}{image_format}" - - file_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}") - with open(file_path, "wb") as image_file: - for chunk in r.iter_content(chunk_size=8192): - image_file.write(chunk) - return image_filename + return r.content, mime_type else: log.error("Url does not point to an image.") return None @@ -447,6 +411,20 @@ def save_url_image(url, headers=None): return None +def upload_image(request, image_metadata, image_data, content_type, user): + image_format = mimetypes.guess_extension(content_type) + file = UploadFile( + file=io.BytesIO(image_data), + filename=f"generated-image{image_format}", # will be converted to a unique ID on upload_file + headers={ + "content-type": content_type, + }, + ) + file_item = upload_file(request, file, user, file_metadata=image_metadata) + url = request.app.url_path_for("get_file_content_by_id", id=file_item.id) + return url + + @router.post("/generations") async def image_generations( request: Request, @@ -500,13 +478,9 @@ async def image_generations( images = [] for image in res["data"]: - image_filename = save_b64_image(image["b64_json"]) - images.append({"url": f"/cache/image/generations/{image_filename}"}) - file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json") - - with open(file_body_path, "w") as f: - json.dump(data, f) - + image_data, content_type = load_b64_image_data(image["b64_json"]) + url = upload_image(request, data, image_data, content_type, user) + images.append({"url": url}) return images elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui": @@ -552,14 +526,15 @@ async def image_generations( "Authorization": f"Bearer {request.app.state.config.COMFYUI_API_KEY}" } - image_filename = save_url_image(image["url"], headers) - images.append({"url": f"/cache/image/generations/{image_filename}"}) - file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json") - - with open(file_body_path, "w") as f: - json.dump(form_data.model_dump(exclude_none=True), f) - - log.debug(f"images: {images}") + image_data, content_type = load_url_image_data(image["url"], headers) + url = upload_image( + request, + form_data.model_dump(exclude_none=True), + image_data, + content_type, + user, + ) + images.append({"url": url}) return images elif ( request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111" @@ -604,13 +579,15 @@ async def image_generations( images = [] for image in res["images"]: - image_filename = save_b64_image(image) - images.append({"url": f"/cache/image/generations/{image_filename}"}) - file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json") - - with open(file_body_path, "w") as f: - json.dump({**data, "info": res["info"]}, f) - + image_data, content_type = load_b64_image_data(image) + url = upload_image( + request, + {**data, "info": res["info"]}, + image_data, + content_type, + user, + ) + images.append({"url": url}) return images except Exception as e: error = e diff --git a/backend/open_webui/routers/ollama.py b/backend/open_webui/routers/ollama.py index 1c6365683..64373c616 100644 --- a/backend/open_webui/routers/ollama.py +++ b/backend/open_webui/routers/ollama.py @@ -1424,11 +1424,11 @@ async def upload_model( os.makedirs(UPLOAD_DIR, exist_ok=True) # --- P1: save file locally --- - chunk_size = 1024 * 1024 * 2 # 2 MB chunks + chunk_size = 1024 * 1024 * 2 # 2 MB chunks with open(file_path, "wb") as out_f: while True: chunk = file.file.read(chunk_size) - #log.info(f"Chunk: {str(chunk)}") # DEBUG + # log.info(f"Chunk: {str(chunk)}") # DEBUG if not chunk: break out_f.write(chunk) @@ -1436,15 +1436,15 @@ async def upload_model( async def file_process_stream(): nonlocal ollama_url total_size = os.path.getsize(file_path) - log.info(f"Total Model Size: {str(total_size)}") # DEBUG + log.info(f"Total Model Size: {str(total_size)}") # DEBUG # --- P2: SSE progress + calculate sha256 hash --- file_hash = calculate_sha256(file_path, chunk_size) - log.info(f"Model Hash: {str(file_hash)}") # DEBUG + log.info(f"Model Hash: {str(file_hash)}") # DEBUG try: with open(file_path, "rb") as f: bytes_read = 0 - while chunk := f.read(chunk_size): + while chunk := f.read(chunk_size): bytes_read += len(chunk) progress = round(bytes_read / total_size * 100, 2) data_msg = { @@ -1460,25 +1460,23 @@ async def upload_model( response = requests.post(url, data=f) if response.ok: - log.info(f"Uploaded to /api/blobs") # DEBUG + log.info(f"Uploaded to /api/blobs") # DEBUG # Remove local file os.remove(file_path) # Create model in ollama model_name, ext = os.path.splitext(file.filename) - log.info(f"Created Model: {model_name}") # DEBUG + log.info(f"Created Model: {model_name}") # DEBUG create_payload = { "model": model_name, # Reference the file by its original name => the uploaded blob's digest - "files": { - file.filename: f"sha256:{file_hash}" - }, + "files": {file.filename: f"sha256:{file_hash}"}, } - log.info(f"Model Payload: {create_payload}") # DEBUG + log.info(f"Model Payload: {create_payload}") # DEBUG # Call ollama /api/create - #https://github.com/ollama/ollama/blob/main/docs/api.md#create-a-model + # https://github.com/ollama/ollama/blob/main/docs/api.md#create-a-model create_resp = requests.post( url=f"{ollama_url}/api/create", headers={"Content-Type": "application/json"}, @@ -1486,7 +1484,7 @@ async def upload_model( ) if create_resp.ok: - log.info(f"API SUCCESS!") # DEBUG + log.info(f"API SUCCESS!") # DEBUG done_msg = { "done": True, "blob": f"sha256:{file_hash}", @@ -1506,4 +1504,4 @@ async def upload_model( res = {"error": str(e)} yield f"data: {json.dumps(res)}\n\n" - return StreamingResponse(file_process_stream(), media_type="text/event-stream") \ No newline at end of file + return StreamingResponse(file_process_stream(), media_type="text/event-stream") diff --git a/backend/open_webui/routers/retrieval.py b/backend/open_webui/routers/retrieval.py index 1fb58ae3d..a624b8048 100644 --- a/backend/open_webui/routers/retrieval.py +++ b/backend/open_webui/routers/retrieval.py @@ -45,6 +45,7 @@ from open_webui.retrieval.web.utils import get_web_loader from open_webui.retrieval.web.brave import search_brave from open_webui.retrieval.web.kagi import search_kagi from open_webui.retrieval.web.mojeek import search_mojeek +from open_webui.retrieval.web.bocha import search_bocha from open_webui.retrieval.web.duckduckgo import search_duckduckgo from open_webui.retrieval.web.google_pse import search_google_pse from open_webui.retrieval.web.jina_search import search_jina @@ -379,6 +380,7 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)): "brave_search_api_key": request.app.state.config.BRAVE_SEARCH_API_KEY, "kagi_search_api_key": request.app.state.config.KAGI_SEARCH_API_KEY, "mojeek_search_api_key": request.app.state.config.MOJEEK_SEARCH_API_KEY, + "bocha_search_api_key": request.app.state.config.BOCHA_SEARCH_API_KEY, "serpstack_api_key": request.app.state.config.SERPSTACK_API_KEY, "serpstack_https": request.app.state.config.SERPSTACK_HTTPS, "serper_api_key": request.app.state.config.SERPER_API_KEY, @@ -429,6 +431,7 @@ class WebSearchConfig(BaseModel): brave_search_api_key: Optional[str] = None kagi_search_api_key: Optional[str] = None mojeek_search_api_key: Optional[str] = None + bocha_search_api_key: Optional[str] = None serpstack_api_key: Optional[str] = None serpstack_https: Optional[bool] = None serper_api_key: Optional[str] = None @@ -525,6 +528,9 @@ async def update_rag_config( request.app.state.config.MOJEEK_SEARCH_API_KEY = ( form_data.web.search.mojeek_search_api_key ) + request.app.state.config.BOCHA_SEARCH_API_KEY = ( + form_data.web.search.bocha_search_api_key + ) request.app.state.config.SERPSTACK_API_KEY = ( form_data.web.search.serpstack_api_key ) @@ -591,6 +597,7 @@ async def update_rag_config( "brave_search_api_key": request.app.state.config.BRAVE_SEARCH_API_KEY, "kagi_search_api_key": request.app.state.config.KAGI_SEARCH_API_KEY, "mojeek_search_api_key": request.app.state.config.MOJEEK_SEARCH_API_KEY, + "bocha_search_api_key": request.app.state.config.BOCHA_SEARCH_API_KEY, "serpstack_api_key": request.app.state.config.SERPSTACK_API_KEY, "serpstack_https": request.app.state.config.SERPSTACK_HTTPS, "serper_api_key": request.app.state.config.SERPER_API_KEY, @@ -1113,6 +1120,7 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]: - BRAVE_SEARCH_API_KEY - KAGI_SEARCH_API_KEY - MOJEEK_SEARCH_API_KEY + - BOCHA_SEARCH_API_KEY - SERPSTACK_API_KEY - SERPER_API_KEY - SERPLY_API_KEY @@ -1180,6 +1188,16 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]: ) else: raise Exception("No MOJEEK_SEARCH_API_KEY found in environment variables") + elif engine == "bocha": + if request.app.state.config.BOCHA_SEARCH_API_KEY: + return search_bocha( + request.app.state.config.BOCHA_SEARCH_API_KEY, + query, + request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, + request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, + ) + else: + raise Exception("No BOCHA_SEARCH_API_KEY found in environment variables") elif engine == "serpstack": if request.app.state.config.SERPSTACK_API_KEY: return search_serpstack( diff --git a/backend/open_webui/routers/users.py b/backend/open_webui/routers/users.py index ddcaef767..872212d3c 100644 --- a/backend/open_webui/routers/users.py +++ b/backend/open_webui/routers/users.py @@ -153,7 +153,7 @@ async def get_user_settings_by_session_user(user=Depends(get_verified_user)): async def update_user_settings_by_session_user( form_data: UserSettings, user=Depends(get_verified_user) ): - user = Users.update_user_by_id(user.id, {"settings": form_data.model_dump()}) + user = Users.update_user_settings_by_id(user.id, form_data.model_dump()) if user: return user.settings else: diff --git a/backend/open_webui/storage/provider.py b/backend/open_webui/storage/provider.py index afc50b397..b03cf0a7e 100644 --- a/backend/open_webui/storage/provider.py +++ b/backend/open_webui/storage/provider.py @@ -94,7 +94,7 @@ class S3StorageProvider(StorageProvider): aws_secret_access_key=S3_SECRET_ACCESS_KEY, ) self.bucket_name = S3_BUCKET_NAME - self.key_prefix = S3_KEY_PREFIX if S3_KEY_PREFIX else "" + self.key_prefix = S3_KEY_PREFIX if S3_KEY_PREFIX else "" def upload_file(self, file: BinaryIO, filename: str) -> Tuple[bytes, str]: """Handles uploading of the file to S3 storage.""" @@ -108,7 +108,7 @@ class S3StorageProvider(StorageProvider): ) except ClientError as e: raise RuntimeError(f"Error uploading file to S3: {e}") - + def get_file(self, file_path: str) -> str: """Handles downloading of the file from S3 storage.""" try: @@ -137,7 +137,8 @@ class S3StorageProvider(StorageProvider): if "Contents" in response: for content in response["Contents"]: # Skip objects that were not uploaded from open-webui in the first place - if not content["Key"].startswith(self.key_prefix): continue + if not content["Key"].startswith(self.key_prefix): + continue self.s3_client.delete_object( Bucket=self.bucket_name, Key=content["Key"] @@ -150,11 +151,12 @@ class S3StorageProvider(StorageProvider): # The s3 key is the name assigned to an object. It excludes the bucket name, but includes the internal path and the file name. def _extract_s3_key(self, full_file_path: str) -> str: - return '/'.join(full_file_path.split("//")[1].split("/")[1:]) - + return "/".join(full_file_path.split("//")[1].split("/")[1:]) + def _get_local_file_path(self, s3_key: str) -> str: return f"{UPLOAD_DIR}/{s3_key.split('/')[-1]}" + class GCSStorageProvider(StorageProvider): def __init__(self): self.bucket_name = GCS_BUCKET_NAME diff --git a/backend/open_webui/utils/code_interpreter.py b/backend/open_webui/utils/code_interpreter.py new file mode 100644 index 000000000..0a74da9c7 --- /dev/null +++ b/backend/open_webui/utils/code_interpreter.py @@ -0,0 +1,148 @@ +import asyncio +import json +import uuid +import websockets +import requests +from urllib.parse import urljoin + + +async def execute_code_jupyter( + jupyter_url, code, token=None, password=None, timeout=10 +): + """ + Executes Python code in a Jupyter kernel. + Supports authentication with a token or password. + :param jupyter_url: Jupyter server URL (e.g., "http://localhost:8888") + :param code: Code to execute + :param token: Jupyter authentication token (optional) + :param password: Jupyter password (optional) + :param timeout: WebSocket timeout in seconds (default: 10s) + :return: Dictionary with stdout, stderr, and result + - Images are prefixed with "base64:image/png," and separated by newlines if multiple. + """ + session = requests.Session() # Maintain cookies + headers = {} # Headers for requests + + # Authenticate using password + if password and not token: + try: + login_url = urljoin(jupyter_url, "/login") + response = session.get(login_url) + response.raise_for_status() + xsrf_token = session.cookies.get("_xsrf") + if not xsrf_token: + raise ValueError("Failed to fetch _xsrf token") + + login_data = {"_xsrf": xsrf_token, "password": password} + login_response = session.post( + login_url, data=login_data, cookies=session.cookies + ) + login_response.raise_for_status() + headers["X-XSRFToken"] = xsrf_token + except Exception as e: + return { + "stdout": "", + "stderr": f"Authentication Error: {str(e)}", + "result": "", + } + + # Construct API URLs with authentication token if provided + params = f"?token={token}" if token else "" + kernel_url = urljoin(jupyter_url, f"/api/kernels{params}") + + try: + response = session.post(kernel_url, headers=headers, cookies=session.cookies) + response.raise_for_status() + kernel_id = response.json()["id"] + + websocket_url = urljoin( + jupyter_url.replace("http", "ws"), + f"/api/kernels/{kernel_id}/channels{params}", + ) + + ws_headers = {} + if password and not token: + ws_headers["X-XSRFToken"] = session.cookies.get("_xsrf") + cookies = {name: value for name, value in session.cookies.items()} + ws_headers["Cookie"] = "; ".join( + [f"{name}={value}" for name, value in cookies.items()] + ) + + async with websockets.connect( + websocket_url, additional_headers=ws_headers + ) as ws: + msg_id = str(uuid.uuid4()) + execute_request = { + "header": { + "msg_id": msg_id, + "msg_type": "execute_request", + "username": "user", + "session": str(uuid.uuid4()), + "date": "", + "version": "5.3", + }, + "parent_header": {}, + "metadata": {}, + "content": { + "code": code, + "silent": False, + "store_history": True, + "user_expressions": {}, + "allow_stdin": False, + "stop_on_error": True, + }, + "channel": "shell", + } + await ws.send(json.dumps(execute_request)) + + stdout, stderr, result = "", "", [] + + while True: + try: + message = await asyncio.wait_for(ws.recv(), timeout) + message_data = json.loads(message) + if message_data.get("parent_header", {}).get("msg_id") == msg_id: + msg_type = message_data.get("msg_type") + + if msg_type == "stream": + if message_data["content"]["name"] == "stdout": + stdout += message_data["content"]["text"] + elif message_data["content"]["name"] == "stderr": + stderr += message_data["content"]["text"] + + elif msg_type in ("execute_result", "display_data"): + data = message_data["content"]["data"] + if "image/png" in data: + result.append( + f"data:image/png;base64,{data['image/png']}" + ) + elif "text/plain" in data: + result.append(data["text/plain"]) + + elif msg_type == "error": + stderr += "\n".join(message_data["content"]["traceback"]) + + elif ( + msg_type == "status" + and message_data["content"]["execution_state"] == "idle" + ): + break + + except asyncio.TimeoutError: + stderr += "\nExecution timed out." + break + + except Exception as e: + return {"stdout": "", "stderr": f"Error: {str(e)}", "result": ""} + + finally: + if kernel_id: + requests.delete( + f"{kernel_url}/{kernel_id}", headers=headers, cookies=session.cookies + ) + + return { + "stdout": stdout.strip(), + "stderr": stderr.strip(), + "result": "\n".join(result).strip() if result else "", + } diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index 1d2768087..acd429118 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -72,7 +72,7 @@ from open_webui.utils.filter import ( get_sorted_filter_ids, process_filter_functions, ) - +from open_webui.utils.code_interpreter import execute_code_jupyter from open_webui.tasks import create_task @@ -684,7 +684,12 @@ async def process_chat_payload(request, form_data, metadata, user, model): if "code_interpreter" in features and features["code_interpreter"]: form_data["messages"] = add_or_update_user_message( - DEFAULT_CODE_INTERPRETER_PROMPT, form_data["messages"] + ( + request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE + if request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE != "" + else DEFAULT_CODE_INTERPRETER_PROMPT + ), + form_data["messages"], ) try: @@ -1639,21 +1644,60 @@ async def process_chat_response( content_blocks[-1]["type"] == "code_interpreter" and retries < MAX_RETRIES ): + await event_emitter( + { + "type": "chat:completion", + "data": { + "content": serialize_content_blocks(content_blocks), + }, + } + ) + retries += 1 log.debug(f"Attempt count: {retries}") output = "" try: if content_blocks[-1]["attributes"].get("type") == "code": - output = await event_caller( - { - "type": "execute:python", - "data": { - "id": str(uuid4()), - "code": content_blocks[-1]["content"], - }, + code = content_blocks[-1]["content"] + + if ( + request.app.state.config.CODE_INTERPRETER_ENGINE + == "pyodide" + ): + output = await event_caller( + { + "type": "execute:python", + "data": { + "id": str(uuid4()), + "code": code, + }, + } + ) + elif ( + request.app.state.config.CODE_INTERPRETER_ENGINE + == "jupyter" + ): + output = await execute_code_jupyter( + request.app.state.config.CODE_INTERPRETER_JUPYTER_URL, + code, + ( + request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN + if request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH + == "token" + else None + ), + ( + request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD + if request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH + == "password" + else None + ), + ) + else: + output = { + "stdout": "Code interpreter engine not configured." } - ) if isinstance(output, dict): stdout = output.get("stdout", "") @@ -1687,6 +1731,38 @@ async def process_chat_response( ) output["stdout"] = "\n".join(stdoutLines) + + result = output.get("result", "") + + if result: + resultLines = result.split("\n") + for idx, line in enumerate(resultLines): + if "data:image/png;base64" in line: + id = str(uuid4()) + + # ensure the path exists + os.makedirs( + os.path.join(CACHE_DIR, "images"), + exist_ok=True, + ) + + image_path = os.path.join( + CACHE_DIR, + f"images/{id}.png", + ) + + with open(image_path, "wb") as f: + f.write( + base64.b64decode( + line.split(",")[1] + ) + ) + + resultLines[idx] = ( + f"![Output Image {idx}](/cache/images/{id}.png)" + ) + + output["result"] = "\n".join(resultLines) except Exception as e: output = str(e) diff --git a/backend/open_webui/utils/misc.py b/backend/open_webui/utils/misc.py index eb90ea5ea..99e6d9c39 100644 --- a/backend/open_webui/utils/misc.py +++ b/backend/open_webui/utils/misc.py @@ -245,7 +245,7 @@ def get_gravatar_url(email): def calculate_sha256(file_path, chunk_size): - #Compute SHA-256 hash of a file efficiently in chunks + # Compute SHA-256 hash of a file efficiently in chunks sha256 = hashlib.sha256() with open(file_path, "rb") as f: while chunk := f.read(chunk_size): diff --git a/backend/open_webui/utils/oauth.py b/backend/open_webui/utils/oauth.py index 5182a1b17..463f67adc 100644 --- a/backend/open_webui/utils/oauth.py +++ b/backend/open_webui/utils/oauth.py @@ -142,13 +142,17 @@ class OAuthManager: log.debug(f"Oauth Groups claim: {oauth_claim}") log.debug(f"User oauth groups: {user_oauth_groups}") log.debug(f"User's current groups: {[g.name for g in user_current_groups]}") - log.debug(f"All groups available in OpenWebUI: {[g.name for g in all_available_groups]}") + log.debug( + f"All groups available in OpenWebUI: {[g.name for g in all_available_groups]}" + ) # Remove groups that user is no longer a part of for group_model in user_current_groups: if group_model.name not in user_oauth_groups: # Remove group from user - log.debug(f"Removing user from group {group_model.name} as it is no longer in their oauth groups") + log.debug( + f"Removing user from group {group_model.name} as it is no longer in their oauth groups" + ) user_ids = group_model.user_ids user_ids = [i for i in user_ids if i != user.id] @@ -174,7 +178,9 @@ class OAuthManager: gm.name == group_model.name for gm in user_current_groups ): # Add user to group - log.debug(f"Adding user to group {group_model.name} as it was found in their oauth groups") + log.debug( + f"Adding user to group {group_model.name} as it was found in their oauth groups" + ) user_ids = group_model.user_ids user_ids.append(user.id) @@ -289,7 +295,9 @@ class OAuthManager: base64_encoded_picture = base64.b64encode( picture ).decode("utf-8") - guessed_mime_type = mimetypes.guess_type(picture_url)[0] + guessed_mime_type = mimetypes.guess_type( + picture_url + )[0] if guessed_mime_type is None: # assume JPG, browsers are tolerant enough of image formats guessed_mime_type = "image/jpeg" @@ -307,7 +315,8 @@ class OAuthManager: username_claim = auth_manager_config.OAUTH_USERNAME_CLAIM name = user_data.get(username_claim) - if not isinstance(name, str): + if not name: + log.warning("Username claim is missing, using email as name") name = email role = self.get_user_role(None, user_data) diff --git a/backend/open_webui/utils/payload.py b/backend/open_webui/utils/payload.py index b68b313de..5eb040434 100644 --- a/backend/open_webui/utils/payload.py +++ b/backend/open_webui/utils/payload.py @@ -14,6 +14,12 @@ def apply_model_system_prompt_to_body( if not system: return form_data + # Metadata (WebUI Usage) + if metadata: + variables = metadata.get("variables", {}) + if variables: + system = prompt_variables_template(system, variables) + # Legacy (API Usage) if user: template_params = { @@ -25,12 +31,6 @@ def apply_model_system_prompt_to_body( system = prompt_template(system, **template_params) - # Metadata (WebUI Usage) - if metadata: - variables = metadata.get("variables", {}) - if variables: - system = prompt_variables_template(system, variables) - form_data["messages"] = add_or_update_system_message( system, form_data.get("messages", []) ) diff --git a/backend/open_webui/utils/pdf_generator.py b/backend/open_webui/utils/pdf_generator.py index 1bb9f76b3..8b04dd81b 100644 --- a/backend/open_webui/utils/pdf_generator.py +++ b/backend/open_webui/utils/pdf_generator.py @@ -2,6 +2,7 @@ from datetime import datetime from io import BytesIO from pathlib import Path from typing import Dict, Any, List +from html import escape from markdown import markdown @@ -41,13 +42,13 @@ class PDFGenerator: def _build_html_message(self, message: Dict[str, Any]) -> str: """Build HTML for a single message.""" - role = message.get("role", "user") - content = message.get("content", "") + role = escape(message.get("role", "user")) + content = escape(message.get("content", "")) timestamp = message.get("timestamp") - model = message.get("model") if role == "assistant" else "" + model = escape(message.get("model") if role == "assistant" else "") - date_str = self.format_timestamp(timestamp) if timestamp else "" + date_str = escape(self.format_timestamp(timestamp) if timestamp else "") # extends pymdownx extension to convert markdown to html. # - https://facelessuser.github.io/pymdown-extensions/usage_notes/ @@ -76,6 +77,7 @@ class PDFGenerator: def _generate_html_body(self) -> str: """Generate the full HTML body for the PDF.""" + escaped_title = escape(self.form_data.title) return f""" @@ -84,7 +86,7 @@ class PDFGenerator:
-

{self.form_data.title}

+

{escaped_title}

{self.messages_html}
diff --git a/backend/open_webui/utils/response.py b/backend/open_webui/utils/response.py index b16805bf3..4917d3852 100644 --- a/backend/open_webui/utils/response.py +++ b/backend/open_webui/utils/response.py @@ -73,7 +73,9 @@ async def convert_streaming_response_ollama_to_openai(ollama_streaming_response) "type": "function", "function": { "name": tool_call.get("function", {}).get("name", ""), - "arguments": f"{tool_call.get('function', {}).get('arguments', {})}", + "arguments": json.dumps( + tool_call.get("function", {}).get("arguments", {}) + ), }, } openai_tool_calls.append(openai_tool_call) diff --git a/backend/requirements.txt b/backend/requirements.txt index c150e468e..86755f50f 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -91,7 +91,7 @@ pytube==15.0.0 extract_msg pydub -duckduckgo-search~=7.3.0 +duckduckgo-search~=7.3.2 ## Google Drive google-api-python-client diff --git a/pyproject.toml b/pyproject.toml index 39f0bf004..544105e11 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -98,7 +98,7 @@ dependencies = [ "extract_msg", "pydub", - "duckduckgo-search~=7.3.0", + "duckduckgo-search~=7.3.2", "google-api-python-client", "google-auth-httplib2", diff --git a/src/lib/apis/configs/index.ts b/src/lib/apis/configs/index.ts index e9faf346b..d7f02564c 100644 --- a/src/lib/apis/configs/index.ts +++ b/src/lib/apis/configs/index.ts @@ -58,6 +58,120 @@ export const exportConfig = async (token: string) => { return res; }; +export const getDirectConnectionsConfig = async (token: string) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/configs/direct_connections`, { + method: 'GET', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + error = err.detail; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const setDirectConnectionsConfig = async (token: string, config: object) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/configs/direct_connections`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + ...config + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + error = err.detail; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const getCodeInterpreterConfig = async (token: string) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/configs/code_interpreter`, { + method: 'GET', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + error = err.detail; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const setCodeInterpreterConfig = async (token: string, config: object) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/configs/code_interpreter`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + ...config + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + error = err.detail; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + export const getModelsConfig = async (token: string) => { let error = null; diff --git a/src/lib/apis/index.ts b/src/lib/apis/index.ts index c7fd78819..53c577a45 100644 --- a/src/lib/apis/index.ts +++ b/src/lib/apis/index.ts @@ -1,6 +1,11 @@ import { WEBUI_API_BASE_URL, WEBUI_BASE_URL } from '$lib/constants'; +import { getOpenAIModelsDirect } from './openai'; -export const getModels = async (token: string = '', base: boolean = false) => { +export const getModels = async ( + token: string = '', + connections: object | null = null, + base: boolean = false +) => { let error = null; const res = await fetch(`${WEBUI_BASE_URL}/api/models${base ? '/base' : ''}`, { method: 'GET', @@ -25,6 +30,97 @@ export const getModels = async (token: string = '', base: boolean = false) => { } let models = res?.data ?? []; + + if (connections && !base) { + let localModels = []; + + if (connections) { + const OPENAI_API_BASE_URLS = connections.OPENAI_API_BASE_URLS; + const OPENAI_API_KEYS = connections.OPENAI_API_KEYS; + const OPENAI_API_CONFIGS = connections.OPENAI_API_CONFIGS; + + const requests = []; + for (const idx in OPENAI_API_BASE_URLS) { + const url = OPENAI_API_BASE_URLS[idx]; + + if (idx.toString() in OPENAI_API_CONFIGS) { + const apiConfig = OPENAI_API_CONFIGS[idx.toString()] ?? {}; + + const enable = apiConfig?.enable ?? true; + const modelIds = apiConfig?.model_ids ?? []; + + if (enable) { + if (modelIds.length > 0) { + const modelList = { + object: 'list', + data: modelIds.map((modelId) => ({ + id: modelId, + name: modelId, + owned_by: 'openai', + openai: { id: modelId }, + urlIdx: idx + })) + }; + + requests.push( + (async () => { + return modelList; + })() + ); + } else { + requests.push(getOpenAIModelsDirect(url, OPENAI_API_KEYS[idx])); + } + } else { + requests.push( + (async () => { + return { + object: 'list', + data: [], + urlIdx: idx + }; + })() + ); + } + } + } + + const responses = await Promise.all(requests); + + for (const idx in responses) { + const response = responses[idx]; + const apiConfig = OPENAI_API_CONFIGS[idx.toString()] ?? {}; + + let models = Array.isArray(response) ? response : (response?.data ?? []); + models = models.map((model) => ({ ...model, openai: { id: model.id }, urlIdx: idx })); + + const prefixId = apiConfig.prefix_id; + if (prefixId) { + for (const model of models) { + model.id = `${prefixId}.${model.id}`; + } + } + + localModels = localModels.concat(models); + } + } + + models = models.concat( + localModels.map((model) => ({ + ...model, + name: model?.name ?? model?.id, + direct: true + })) + ); + + // Remove duplicates + const modelsMap = {}; + for (const model of models) { + modelsMap[model.id] = model; + } + + models = Object.values(modelsMap); + } + return models; }; diff --git a/src/lib/apis/openai/index.ts b/src/lib/apis/openai/index.ts index a801bcdbb..bab2d6e36 100644 --- a/src/lib/apis/openai/index.ts +++ b/src/lib/apis/openai/index.ts @@ -208,6 +208,33 @@ export const updateOpenAIKeys = async (token: string = '', keys: string[]) => { return res.OPENAI_API_KEYS; }; +export const getOpenAIModelsDirect = async (url: string, key: string) => { + let error = null; + + const res = await fetch(`${url}/models`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + ...(key && { authorization: `Bearer ${key}` }) + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + error = `OpenAI: ${err?.error?.message ?? 'Network Problem'}`; + return []; + }); + + if (error) { + throw error; + } + + return res; +}; + export const getOpenAIModels = async (token: string, urlIdx?: number) => { let error = null; @@ -241,33 +268,62 @@ export const getOpenAIModels = async (token: string, urlIdx?: number) => { export const verifyOpenAIConnection = async ( token: string = '', url: string = 'https://api.openai.com/v1', - key: string = '' + key: string = '', + direct: boolean = false ) => { + if (!url) { + throw 'OpenAI: URL is required'; + } + let error = null; + let res = null; - const res = await fetch(`${OPENAI_API_BASE_URL}/verify`, { - method: 'POST', - headers: { - Accept: 'application/json', - Authorization: `Bearer ${token}`, - 'Content-Type': 'application/json' - }, - body: JSON.stringify({ - url, - key + if (direct) { + res = await fetch(`${url}/models`, { + method: 'GET', + headers: { + Accept: 'application/json', + Authorization: `Bearer ${key}`, + 'Content-Type': 'application/json' + } }) - }) - .then(async (res) => { - if (!res.ok) throw await res.json(); - return res.json(); - }) - .catch((err) => { - error = `OpenAI: ${err?.error?.message ?? 'Network Problem'}`; - return []; - }); + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + error = `OpenAI: ${err?.error?.message ?? 'Network Problem'}`; + return []; + }); - if (error) { - throw error; + if (error) { + throw error; + } + } else { + res = await fetch(`${OPENAI_API_BASE_URL}/verify`, { + method: 'POST', + headers: { + Accept: 'application/json', + Authorization: `Bearer ${token}`, + 'Content-Type': 'application/json' + }, + body: JSON.stringify({ + url, + key + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + error = `OpenAI: ${err?.error?.message ?? 'Network Problem'}`; + return []; + }); + + if (error) { + throw error; + } } return res; diff --git a/src/lib/components/admin/Settings/Connections/AddConnectionModal.svelte b/src/lib/components/AddConnectionModal.svelte similarity index 98% rename from src/lib/components/admin/Settings/Connections/AddConnectionModal.svelte rename to src/lib/components/AddConnectionModal.svelte index a8726a546..95074c258 100644 --- a/src/lib/components/admin/Settings/Connections/AddConnectionModal.svelte +++ b/src/lib/components/AddConnectionModal.svelte @@ -20,7 +20,9 @@ export let show = false; export let edit = false; + export let ollama = false; + export let direct = false; export let connection = null; @@ -46,9 +48,11 @@ }; const verifyOpenAIHandler = async () => { - const res = await verifyOpenAIConnection(localStorage.token, url, key).catch((error) => { - toast.error(`${error}`); - }); + const res = await verifyOpenAIConnection(localStorage.token, url, key, direct).catch( + (error) => { + toast.error(`${error}`); + } + ); if (res) { toast.success($i18n.t('Server connection verified')); diff --git a/src/lib/components/admin/Evaluations/Feedbacks.svelte b/src/lib/components/admin/Evaluations/Feedbacks.svelte index e43081302..e73adb027 100644 --- a/src/lib/components/admin/Evaluations/Feedbacks.svelte +++ b/src/lib/components/admin/Evaluations/Feedbacks.svelte @@ -65,7 +65,7 @@ }; const shareHandler = async () => { - toast.success($i18n.t('Redirecting you to OpenWebUI Community')); + toast.success($i18n.t('Redirecting you to Open WebUI Community')); // remove snapshot from feedbacks const feedbacksToShare = feedbacks.map((f) => { @@ -266,7 +266,7 @@ }} >
- {$i18n.t('Share to OpenWebUI Community')} + {$i18n.t('Share to Open WebUI Community')}
diff --git a/src/lib/components/admin/Functions.svelte b/src/lib/components/admin/Functions.svelte index 80c7d11cd..9134ace55 100644 --- a/src/lib/components/admin/Functions.svelte +++ b/src/lib/components/admin/Functions.svelte @@ -3,7 +3,7 @@ import fileSaver from 'file-saver'; const { saveAs } = fileSaver; - import { WEBUI_NAME, config, functions, models } from '$lib/stores'; + import { WEBUI_NAME, config, functions, models, settings } from '$lib/stores'; import { onMount, getContext, tick } from 'svelte'; import { goto } from '$app/navigation'; @@ -65,7 +65,7 @@ return null; }); - toast.success($i18n.t('Redirecting you to OpenWebUI Community')); + toast.success($i18n.t('Redirecting you to Open WebUI Community')); const url = 'https://openwebui.com'; @@ -126,7 +126,12 @@ toast.success($i18n.t('Function deleted successfully')); functions.set(await getFunctions(localStorage.token)); - models.set(await getModels(localStorage.token)); + models.set( + await getModels( + localStorage.token, + $config?.features?.enable_direct_connections && ($settings?.directConnections ?? null) + ) + ); } }; @@ -147,7 +152,12 @@ } functions.set(await getFunctions(localStorage.token)); - models.set(await getModels(localStorage.token)); + models.set( + await getModels( + localStorage.token, + $config?.features?.enable_direct_connections && ($settings?.directConnections ?? null) + ) + ); } }; @@ -359,7 +369,13 @@ bind:state={func.is_active} on:change={async (e) => { toggleFunctionById(localStorage.token, func.id); - models.set(await getModels(localStorage.token)); + models.set( + await getModels( + localStorage.token, + $config?.features?.enable_direct_connections && + ($settings?.directConnections ?? null) + ) + ); }} /> @@ -453,7 +469,7 @@ {#if $config?.features.enable_community_sharing}
- {$i18n.t('Made by OpenWebUI Community')} + {$i18n.t('Made by Open WebUI Community')}
{ await tick(); - models.set(await getModels(localStorage.token)); + models.set( + await getModels( + localStorage.token, + $config?.features?.enable_direct_connections && ($settings?.directConnections ?? null) + ) + ); }} /> @@ -517,7 +538,12 @@ toast.success($i18n.t('Functions imported successfully')); functions.set(await getFunctions(localStorage.token)); - models.set(await getModels(localStorage.token)); + models.set( + await getModels( + localStorage.token, + $config?.features?.enable_direct_connections && ($settings?.directConnections ?? null) + ) + ); }; reader.readAsText(importFiles[0]); diff --git a/src/lib/components/admin/Settings.svelte b/src/lib/components/admin/Settings.svelte index f0886ea5c..415e4377a 100644 --- a/src/lib/components/admin/Settings.svelte +++ b/src/lib/components/admin/Settings.svelte @@ -19,6 +19,7 @@ import ChartBar from '../icons/ChartBar.svelte'; import DocumentChartBar from '../icons/DocumentChartBar.svelte'; import Evaluations from './Settings/Evaluations.svelte'; + import CodeInterpreter from './Settings/CodeInterpreter.svelte'; const i18n = getContext('i18n'); @@ -188,6 +189,32 @@
{$i18n.t('Web Search')}
+ +