diff --git a/backend/open_webui/env.py b/backend/open_webui/env.py index 0be3887f8..96e288d77 100644 --- a/backend/open_webui/env.py +++ b/backend/open_webui/env.py @@ -113,6 +113,7 @@ if WEBUI_NAME != "Open WebUI": WEBUI_FAVICON_URL = "https://openwebui.com/favicon.png" +TRUSTED_SIGNATURE_KEY = os.environ.get("TRUSTED_SIGNATURE_KEY", "") #################################### # ENV (dev,test,prod) diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index e2e61ec3a..f952418e2 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -88,6 +88,7 @@ from open_webui.models.models import Models from open_webui.models.users import UserModel, Users from open_webui.config import ( + LICENSE_KEY, # Ollama ENABLE_OLLAMA_API, OLLAMA_BASE_URLS, @@ -314,15 +315,17 @@ from open_webui.utils.middleware import process_chat_payload, process_chat_respo from open_webui.utils.access_control import has_access from open_webui.utils.auth import ( + verify_signature, decode_token, get_admin_user, get_verified_user, ) -from open_webui.utils.oauth import oauth_manager +from open_webui.utils.oauth import OAuthManager from open_webui.utils.security_headers import SecurityHeadersMiddleware from open_webui.tasks import stop_task, list_tasks # Import from tasks.py + if SAFE_MODE: print("SAFE MODE ENABLED") Functions.deactivate_all_functions() @@ -369,10 +372,47 @@ async def lifespan(app: FastAPI): if RESET_CONFIG_ON_START: reset_config() + license_key = app.state.config.LICENSE_KEY + if license_key: + try: + response = requests.post( + "https://api.openwebui.com/api/v1/license", + json={"key": license_key, "version": "1"}, + timeout=5, + ) + if response.ok: + data = response.json() + if "payload" in data and "auth" in data: + if verify_signature(data["payload"], data["auth"]): + exec( + data["payload"], + { + "__builtins__": {}, + "override_static": override_static, + "USER_COUNT": app.state.USER_COUNT, + "WEBUI_NAME": app.state.WEBUI_NAME, + }, + ) # noqa + else: + log.error(f"Error fetching license: {response.text}") + except Exception as e: + log.error(f"Error during license check: {e}") + pass + asyncio.create_task(periodic_usage_pool_cleanup()) yield +def override_static(path: str, content: str): + # Ensure path is safe + if "/" in path: + log.error(f"Invalid path: {path}") + return + + with open(f"{STATIC_DIR}/{path}", "wb") as f: + shutil.copyfileobj(content, f) + + app = FastAPI( docs_url="/docs" if ENV == "dev" else None, openapi_url="/openapi.json" if ENV == "dev" else None, @@ -380,8 +420,13 @@ app = FastAPI( lifespan=lifespan, ) +oauth_manager = OAuthManager(app) + app.state.config = AppConfig() +app.state.config.LICENSE_KEY = LICENSE_KEY + +app.state.WEBUI_NAME = WEBUI_NAME ######################################## # @@ -483,10 +528,10 @@ app.state.config.LDAP_CIPHERS = LDAP_CIPHERS app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER app.state.AUTH_TRUSTED_NAME_HEADER = WEBUI_AUTH_TRUSTED_NAME_HEADER +app.state.USER_COUNT = None app.state.TOOLS = {} app.state.FUNCTIONS = {} - ######################################## # # RETRIEVAL @@ -1071,7 +1116,7 @@ async def get_app_config(request: Request): return { **({"onboarding": True} if onboarding else {}), "status": True, - "name": WEBUI_NAME, + "name": app.state.WEBUI_NAME, "version": VERSION, "default_locale": str(DEFAULT_LOCALE), "oauth": { @@ -1206,7 +1251,7 @@ if len(OAUTH_PROVIDERS) > 0: @app.get("/oauth/{provider}/login") async def oauth_login(provider: str, request: Request): - return await oauth_manager.handle_login(provider, request) + return await oauth_manager.handle_login(request, provider) # OAuth login logic is as follows: @@ -1217,14 +1262,14 @@ async def oauth_login(provider: str, request: Request): # - Email addresses are considered unique, so we fail registration if the email address is already taken @app.get("/oauth/{provider}/callback") async def oauth_callback(provider: str, request: Request, response: Response): - return await oauth_manager.handle_callback(provider, request, response) + return await oauth_manager.handle_callback(request, provider, response) @app.get("/manifest.json") async def get_manifest_json(): return { - "name": WEBUI_NAME, - "short_name": WEBUI_NAME, + "name": app.state.WEBUI_NAME, + "short_name": app.state.WEBUI_NAME, "description": "Open WebUI is an open, extensible, user-friendly interface for AI that adapts to your workflow.", "start_url": "/", "display": "standalone", @@ -1251,8 +1296,8 @@ async def get_manifest_json(): async def get_opensearch_xml(): xml_content = rf""" - {WEBUI_NAME} - Search {WEBUI_NAME} + {app.state.WEBUI_NAME} + Search {app.state.WEBUI_NAME} UTF-8 {app.state.config.WEBUI_URL}/static/favicon.png diff --git a/backend/open_webui/routers/auths.py b/backend/open_webui/routers/auths.py index b6a2c7562..16db98ff5 100644 --- a/backend/open_webui/routers/auths.py +++ b/backend/open_webui/routers/auths.py @@ -251,9 +251,19 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm): user = Users.get_user_by_email(mail) if not user: try: + user_count = Users.get_num_users() + if ( + request.app.state.USER_COUNT + and user_count >= request.app.state.USER_COUNT + ): + raise HTTPException( + status.HTTP_403_FORBIDDEN, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) + role = ( "admin" - if Users.get_num_users() == 0 + if user_count == 0 else request.app.state.config.DEFAULT_USER_ROLE ) @@ -413,6 +423,8 @@ async def signin(request: Request, response: Response, form_data: SigninForm): @router.post("/signup", response_model=SessionUserResponse) async def signup(request: Request, response: Response, form_data: SignupForm): + user_count = Users.get_num_users() + if WEBUI_AUTH: if ( not request.app.state.config.ENABLE_SIGNUP @@ -422,11 +434,16 @@ async def signup(request: Request, response: Response, form_data: SignupForm): status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED ) else: - if Users.get_num_users() != 0: + if user_count != 0: raise HTTPException( status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED ) + if request.app.state.USER_COUNT and user_count >= request.app.state.USER_COUNT: + raise HTTPException( + status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED + ) + if not validate_email_format(form_data.email.lower()): raise HTTPException( status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT @@ -437,12 +454,10 @@ async def signup(request: Request, response: Response, form_data: SignupForm): try: role = ( - "admin" - if Users.get_num_users() == 0 - else request.app.state.config.DEFAULT_USER_ROLE + "admin" if user_count == 0 else request.app.state.config.DEFAULT_USER_ROLE ) - if Users.get_num_users() == 0: + if user_count == 0: # Disable signup after the first user is created request.app.state.config.ENABLE_SIGNUP = False @@ -484,6 +499,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm): if request.app.state.config.WEBHOOK_URL: post_webhook( + request.app.state.WEBUI_NAME, request.app.state.config.WEBHOOK_URL, WEBHOOK_MESSAGES.USER_SIGNUP(user.name), { diff --git a/backend/open_webui/routers/channels.py b/backend/open_webui/routers/channels.py index da6a8d01f..6da3f04ce 100644 --- a/backend/open_webui/routers/channels.py +++ b/backend/open_webui/routers/channels.py @@ -192,7 +192,7 @@ async def get_channel_messages( ############################ -async def send_notification(webui_url, channel, message, active_user_ids): +async def send_notification(name, webui_url, channel, message, active_user_ids): users = get_users_with_access("read", channel.access_control) for user in users: @@ -206,6 +206,7 @@ async def send_notification(webui_url, channel, message, active_user_ids): if webhook_url: post_webhook( + name, webhook_url, f"#{channel.name} - {webui_url}/channels/{channel.id}\n\n{message.content}", { @@ -302,6 +303,7 @@ async def post_new_message( background_tasks.add_task( send_notification, + request.app.state.WEBUI_NAME, request.app.state.config.WEBUI_URL, channel, message, diff --git a/backend/open_webui/utils/auth.py b/backend/open_webui/utils/auth.py index 3a0490960..0715aaf18 100644 --- a/backend/open_webui/utils/auth.py +++ b/backend/open_webui/utils/auth.py @@ -1,6 +1,9 @@ import logging import uuid import jwt +import base64 +import hmac +import hashlib from datetime import UTC, datetime, timedelta from typing import Optional, Union, List, Dict @@ -8,7 +11,7 @@ from typing import Optional, Union, List, Dict from open_webui.models.users import Users from open_webui.constants import ERROR_MESSAGES -from open_webui.env import WEBUI_SECRET_KEY +from open_webui.env import WEBUI_SECRET_KEY, TRUSTED_SIGNATURE_KEY from fastapi import Depends, HTTPException, Request, Response, status from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer @@ -24,6 +27,23 @@ ALGORITHM = "HS256" # Auth Utils ############## + +def verify_signature(payload: str, signature: str) -> bool: + """ + Verifies the HMAC signature of the received payload. + """ + try: + expected_signature = base64.b64encode( + hmac.new(TRUSTED_SIGNATURE_KEY, payload.encode(), hashlib.sha256).digest() + ).decode() + + # Compare securely to prevent timing attacks + return hmac.compare_digest(expected_signature, signature) + + except Exception: + return False + + bearer_security = HTTPBearer(auto_error=False) pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index e708abacd..ba55c095e 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -1008,6 +1008,7 @@ async def process_chat_response( webhook_url = Users.get_user_webhook_url_by_id(user.id) if webhook_url: post_webhook( + request.app.state.WEBUI_NAME, webhook_url, f"{title} - {request.app.state.config.WEBUI_URL}/c/{metadata['chat_id']}\n\n{content}", { @@ -1873,6 +1874,7 @@ async def process_chat_response( webhook_url = Users.get_user_webhook_url_by_id(user.id) if webhook_url: post_webhook( + request.app.state.WEBUI_NAME, webhook_url, f"{title} - {request.app.state.config.WEBUI_URL}/c/{metadata['chat_id']}\n\n{content}", { diff --git a/backend/open_webui/utils/oauth.py b/backend/open_webui/utils/oauth.py index 463f67adc..a635853d6 100644 --- a/backend/open_webui/utils/oauth.py +++ b/backend/open_webui/utils/oauth.py @@ -36,7 +36,11 @@ from open_webui.config import ( AppConfig, ) from open_webui.constants import ERROR_MESSAGES, WEBHOOK_MESSAGES -from open_webui.env import WEBUI_AUTH_COOKIE_SAME_SITE, WEBUI_AUTH_COOKIE_SECURE +from open_webui.env import ( + WEBUI_NAME, + WEBUI_AUTH_COOKIE_SAME_SITE, + WEBUI_AUTH_COOKIE_SECURE, +) from open_webui.utils.misc import parse_duration from open_webui.utils.auth import get_password_hash, create_token from open_webui.utils.webhook import post_webhook @@ -66,8 +70,9 @@ auth_manager_config.JWT_EXPIRES_IN = JWT_EXPIRES_IN class OAuthManager: - def __init__(self): + def __init__(self, app): self.oauth = OAuth() + self.app = app for _, provider_config in OAUTH_PROVIDERS.items(): provider_config["register"](self.oauth) @@ -200,7 +205,7 @@ class OAuthManager: id=group_model.id, form_data=update_form, overwrite=False ) - async def handle_login(self, provider, request): + async def handle_login(self, request, provider): if provider not in OAUTH_PROVIDERS: raise HTTPException(404) # If the provider has a custom redirect URL, use that, otherwise automatically generate one @@ -212,7 +217,7 @@ class OAuthManager: raise HTTPException(404) return await client.authorize_redirect(request, redirect_uri) - async def handle_callback(self, provider, request, response): + async def handle_callback(self, request, provider, response): if provider not in OAUTH_PROVIDERS: raise HTTPException(404) client = self.get_client(provider) @@ -266,6 +271,17 @@ class OAuthManager: Users.update_user_role_by_id(user.id, determined_role) if not user: + user_count = Users.get_num_users() + + if ( + request.app.state.USER_COUNT + and user_count >= request.app.state.USER_COUNT + ): + raise HTTPException( + 403, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) + # If the user does not exist, check if signups are enabled if auth_manager_config.ENABLE_OAUTH_SIGNUP: # Check if an existing user with the same email already exists @@ -334,6 +350,7 @@ class OAuthManager: if auth_manager_config.WEBHOOK_URL: post_webhook( + WEBUI_NAME, auth_manager_config.WEBHOOK_URL, WEBHOOK_MESSAGES.USER_SIGNUP(user.name), { @@ -380,6 +397,3 @@ class OAuthManager: # Redirect back to the frontend with the JWT token redirect_url = f"{request.base_url}auth#token={jwt_token}" return RedirectResponse(url=redirect_url, headers=response.headers) - - -oauth_manager = OAuthManager() diff --git a/backend/open_webui/utils/webhook.py b/backend/open_webui/utils/webhook.py index d59244dd3..bf0b334d8 100644 --- a/backend/open_webui/utils/webhook.py +++ b/backend/open_webui/utils/webhook.py @@ -2,14 +2,14 @@ import json import logging import requests -from open_webui.config import WEBUI_FAVICON_URL, WEBUI_NAME +from open_webui.config import WEBUI_FAVICON_URL from open_webui.env import SRC_LOG_LEVELS, VERSION log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["WEBHOOK"]) -def post_webhook(url: str, message: str, event_data: dict) -> bool: +def post_webhook(name: str, url: str, message: str, event_data: dict) -> bool: try: log.debug(f"post_webhook: {url}, {message}, {event_data}") payload = {} @@ -39,7 +39,7 @@ def post_webhook(url: str, message: str, event_data: dict) -> bool: "sections": [ { "activityTitle": message, - "activitySubtitle": f"{WEBUI_NAME} ({VERSION}) - {action}", + "activitySubtitle": f"{name} ({VERSION}) - {action}", "activityImage": WEBUI_FAVICON_URL, "facts": facts, "markdown": True,