This commit is contained in:
Timothy Jaeryang Baek 2025-02-16 00:11:18 -08:00
parent 91de8e082e
commit 63cf80a456
8 changed files with 127 additions and 27 deletions

View File

@ -113,6 +113,7 @@ if WEBUI_NAME != "Open WebUI":
WEBUI_FAVICON_URL = "https://openwebui.com/favicon.png" WEBUI_FAVICON_URL = "https://openwebui.com/favicon.png"
TRUSTED_SIGNATURE_KEY = os.environ.get("TRUSTED_SIGNATURE_KEY", "")
#################################### ####################################
# ENV (dev,test,prod) # ENV (dev,test,prod)

View File

@ -88,6 +88,7 @@ from open_webui.models.models import Models
from open_webui.models.users import UserModel, Users from open_webui.models.users import UserModel, Users
from open_webui.config import ( from open_webui.config import (
LICENSE_KEY,
# Ollama # Ollama
ENABLE_OLLAMA_API, ENABLE_OLLAMA_API,
OLLAMA_BASE_URLS, OLLAMA_BASE_URLS,
@ -314,15 +315,17 @@ from open_webui.utils.middleware import process_chat_payload, process_chat_respo
from open_webui.utils.access_control import has_access from open_webui.utils.access_control import has_access
from open_webui.utils.auth import ( from open_webui.utils.auth import (
verify_signature,
decode_token, decode_token,
get_admin_user, get_admin_user,
get_verified_user, get_verified_user,
) )
from open_webui.utils.oauth import oauth_manager from open_webui.utils.oauth import OAuthManager
from open_webui.utils.security_headers import SecurityHeadersMiddleware from open_webui.utils.security_headers import SecurityHeadersMiddleware
from open_webui.tasks import stop_task, list_tasks # Import from tasks.py from open_webui.tasks import stop_task, list_tasks # Import from tasks.py
if SAFE_MODE: if SAFE_MODE:
print("SAFE MODE ENABLED") print("SAFE MODE ENABLED")
Functions.deactivate_all_functions() Functions.deactivate_all_functions()
@ -369,10 +372,47 @@ async def lifespan(app: FastAPI):
if RESET_CONFIG_ON_START: if RESET_CONFIG_ON_START:
reset_config() reset_config()
license_key = app.state.config.LICENSE_KEY
if license_key:
try:
response = requests.post(
"https://api.openwebui.com/api/v1/license",
json={"key": license_key, "version": "1"},
timeout=5,
)
if response.ok:
data = response.json()
if "payload" in data and "auth" in data:
if verify_signature(data["payload"], data["auth"]):
exec(
data["payload"],
{
"__builtins__": {},
"override_static": override_static,
"USER_COUNT": app.state.USER_COUNT,
"WEBUI_NAME": app.state.WEBUI_NAME,
},
) # noqa
else:
log.error(f"Error fetching license: {response.text}")
except Exception as e:
log.error(f"Error during license check: {e}")
pass
asyncio.create_task(periodic_usage_pool_cleanup()) asyncio.create_task(periodic_usage_pool_cleanup())
yield yield
def override_static(path: str, content: str):
# Ensure path is safe
if "/" in path:
log.error(f"Invalid path: {path}")
return
with open(f"{STATIC_DIR}/{path}", "wb") as f:
shutil.copyfileobj(content, f)
app = FastAPI( app = FastAPI(
docs_url="/docs" if ENV == "dev" else None, docs_url="/docs" if ENV == "dev" else None,
openapi_url="/openapi.json" if ENV == "dev" else None, openapi_url="/openapi.json" if ENV == "dev" else None,
@ -380,8 +420,13 @@ app = FastAPI(
lifespan=lifespan, lifespan=lifespan,
) )
oauth_manager = OAuthManager(app)
app.state.config = AppConfig() app.state.config = AppConfig()
app.state.config.LICENSE_KEY = LICENSE_KEY
app.state.WEBUI_NAME = WEBUI_NAME
######################################## ########################################
# #
@ -483,10 +528,10 @@ app.state.config.LDAP_CIPHERS = LDAP_CIPHERS
app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER
app.state.AUTH_TRUSTED_NAME_HEADER = WEBUI_AUTH_TRUSTED_NAME_HEADER app.state.AUTH_TRUSTED_NAME_HEADER = WEBUI_AUTH_TRUSTED_NAME_HEADER
app.state.USER_COUNT = None
app.state.TOOLS = {} app.state.TOOLS = {}
app.state.FUNCTIONS = {} app.state.FUNCTIONS = {}
######################################## ########################################
# #
# RETRIEVAL # RETRIEVAL
@ -1071,7 +1116,7 @@ async def get_app_config(request: Request):
return { return {
**({"onboarding": True} if onboarding else {}), **({"onboarding": True} if onboarding else {}),
"status": True, "status": True,
"name": WEBUI_NAME, "name": app.state.WEBUI_NAME,
"version": VERSION, "version": VERSION,
"default_locale": str(DEFAULT_LOCALE), "default_locale": str(DEFAULT_LOCALE),
"oauth": { "oauth": {
@ -1206,7 +1251,7 @@ if len(OAUTH_PROVIDERS) > 0:
@app.get("/oauth/{provider}/login") @app.get("/oauth/{provider}/login")
async def oauth_login(provider: str, request: Request): async def oauth_login(provider: str, request: Request):
return await oauth_manager.handle_login(provider, request) return await oauth_manager.handle_login(request, provider)
# OAuth login logic is as follows: # OAuth login logic is as follows:
@ -1217,14 +1262,14 @@ async def oauth_login(provider: str, request: Request):
# - Email addresses are considered unique, so we fail registration if the email address is already taken # - Email addresses are considered unique, so we fail registration if the email address is already taken
@app.get("/oauth/{provider}/callback") @app.get("/oauth/{provider}/callback")
async def oauth_callback(provider: str, request: Request, response: Response): async def oauth_callback(provider: str, request: Request, response: Response):
return await oauth_manager.handle_callback(provider, request, response) return await oauth_manager.handle_callback(request, provider, response)
@app.get("/manifest.json") @app.get("/manifest.json")
async def get_manifest_json(): async def get_manifest_json():
return { return {
"name": WEBUI_NAME, "name": app.state.WEBUI_NAME,
"short_name": WEBUI_NAME, "short_name": app.state.WEBUI_NAME,
"description": "Open WebUI is an open, extensible, user-friendly interface for AI that adapts to your workflow.", "description": "Open WebUI is an open, extensible, user-friendly interface for AI that adapts to your workflow.",
"start_url": "/", "start_url": "/",
"display": "standalone", "display": "standalone",
@ -1251,8 +1296,8 @@ async def get_manifest_json():
async def get_opensearch_xml(): async def get_opensearch_xml():
xml_content = rf""" xml_content = rf"""
<OpenSearchDescription xmlns="http://a9.com/-/spec/opensearch/1.1/" xmlns:moz="http://www.mozilla.org/2006/browser/search/"> <OpenSearchDescription xmlns="http://a9.com/-/spec/opensearch/1.1/" xmlns:moz="http://www.mozilla.org/2006/browser/search/">
<ShortName>{WEBUI_NAME}</ShortName> <ShortName>{app.state.WEBUI_NAME}</ShortName>
<Description>Search {WEBUI_NAME}</Description> <Description>Search {app.state.WEBUI_NAME}</Description>
<InputEncoding>UTF-8</InputEncoding> <InputEncoding>UTF-8</InputEncoding>
<Image width="16" height="16" type="image/x-icon">{app.state.config.WEBUI_URL}/static/favicon.png</Image> <Image width="16" height="16" type="image/x-icon">{app.state.config.WEBUI_URL}/static/favicon.png</Image>
<Url type="text/html" method="get" template="{app.state.config.WEBUI_URL}/?q={"{searchTerms}"}"/> <Url type="text/html" method="get" template="{app.state.config.WEBUI_URL}/?q={"{searchTerms}"}"/>

View File

@ -251,9 +251,19 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
user = Users.get_user_by_email(mail) user = Users.get_user_by_email(mail)
if not user: if not user:
try: try:
user_count = Users.get_num_users()
if (
request.app.state.USER_COUNT
and user_count >= request.app.state.USER_COUNT
):
raise HTTPException(
status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
role = ( role = (
"admin" "admin"
if Users.get_num_users() == 0 if user_count == 0
else request.app.state.config.DEFAULT_USER_ROLE else request.app.state.config.DEFAULT_USER_ROLE
) )
@ -413,6 +423,8 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
@router.post("/signup", response_model=SessionUserResponse) @router.post("/signup", response_model=SessionUserResponse)
async def signup(request: Request, response: Response, form_data: SignupForm): async def signup(request: Request, response: Response, form_data: SignupForm):
user_count = Users.get_num_users()
if WEBUI_AUTH: if WEBUI_AUTH:
if ( if (
not request.app.state.config.ENABLE_SIGNUP not request.app.state.config.ENABLE_SIGNUP
@ -422,7 +434,12 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
) )
else: else:
if Users.get_num_users() != 0: if user_count != 0:
raise HTTPException(
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
)
if request.app.state.USER_COUNT and user_count >= request.app.state.USER_COUNT:
raise HTTPException( raise HTTPException(
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
) )
@ -437,12 +454,10 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
try: try:
role = ( role = (
"admin" "admin" if user_count == 0 else request.app.state.config.DEFAULT_USER_ROLE
if Users.get_num_users() == 0
else request.app.state.config.DEFAULT_USER_ROLE
) )
if Users.get_num_users() == 0: if user_count == 0:
# Disable signup after the first user is created # Disable signup after the first user is created
request.app.state.config.ENABLE_SIGNUP = False request.app.state.config.ENABLE_SIGNUP = False
@ -484,6 +499,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
if request.app.state.config.WEBHOOK_URL: if request.app.state.config.WEBHOOK_URL:
post_webhook( post_webhook(
request.app.state.WEBUI_NAME,
request.app.state.config.WEBHOOK_URL, request.app.state.config.WEBHOOK_URL,
WEBHOOK_MESSAGES.USER_SIGNUP(user.name), WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
{ {

View File

@ -192,7 +192,7 @@ async def get_channel_messages(
############################ ############################
async def send_notification(webui_url, channel, message, active_user_ids): async def send_notification(name, webui_url, channel, message, active_user_ids):
users = get_users_with_access("read", channel.access_control) users = get_users_with_access("read", channel.access_control)
for user in users: for user in users:
@ -206,6 +206,7 @@ async def send_notification(webui_url, channel, message, active_user_ids):
if webhook_url: if webhook_url:
post_webhook( post_webhook(
name,
webhook_url, webhook_url,
f"#{channel.name} - {webui_url}/channels/{channel.id}\n\n{message.content}", f"#{channel.name} - {webui_url}/channels/{channel.id}\n\n{message.content}",
{ {
@ -302,6 +303,7 @@ async def post_new_message(
background_tasks.add_task( background_tasks.add_task(
send_notification, send_notification,
request.app.state.WEBUI_NAME,
request.app.state.config.WEBUI_URL, request.app.state.config.WEBUI_URL,
channel, channel,
message, message,

View File

@ -1,6 +1,9 @@
import logging import logging
import uuid import uuid
import jwt import jwt
import base64
import hmac
import hashlib
from datetime import UTC, datetime, timedelta from datetime import UTC, datetime, timedelta
from typing import Optional, Union, List, Dict from typing import Optional, Union, List, Dict
@ -8,7 +11,7 @@ from typing import Optional, Union, List, Dict
from open_webui.models.users import Users from open_webui.models.users import Users
from open_webui.constants import ERROR_MESSAGES from open_webui.constants import ERROR_MESSAGES
from open_webui.env import WEBUI_SECRET_KEY from open_webui.env import WEBUI_SECRET_KEY, TRUSTED_SIGNATURE_KEY
from fastapi import Depends, HTTPException, Request, Response, status from fastapi import Depends, HTTPException, Request, Response, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
@ -24,6 +27,23 @@ ALGORITHM = "HS256"
# Auth Utils # Auth Utils
############## ##############
def verify_signature(payload: str, signature: str) -> bool:
"""
Verifies the HMAC signature of the received payload.
"""
try:
expected_signature = base64.b64encode(
hmac.new(TRUSTED_SIGNATURE_KEY, payload.encode(), hashlib.sha256).digest()
).decode()
# Compare securely to prevent timing attacks
return hmac.compare_digest(expected_signature, signature)
except Exception:
return False
bearer_security = HTTPBearer(auto_error=False) bearer_security = HTTPBearer(auto_error=False)
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")

View File

@ -1008,6 +1008,7 @@ async def process_chat_response(
webhook_url = Users.get_user_webhook_url_by_id(user.id) webhook_url = Users.get_user_webhook_url_by_id(user.id)
if webhook_url: if webhook_url:
post_webhook( post_webhook(
request.app.state.WEBUI_NAME,
webhook_url, webhook_url,
f"{title} - {request.app.state.config.WEBUI_URL}/c/{metadata['chat_id']}\n\n{content}", f"{title} - {request.app.state.config.WEBUI_URL}/c/{metadata['chat_id']}\n\n{content}",
{ {
@ -1873,6 +1874,7 @@ async def process_chat_response(
webhook_url = Users.get_user_webhook_url_by_id(user.id) webhook_url = Users.get_user_webhook_url_by_id(user.id)
if webhook_url: if webhook_url:
post_webhook( post_webhook(
request.app.state.WEBUI_NAME,
webhook_url, webhook_url,
f"{title} - {request.app.state.config.WEBUI_URL}/c/{metadata['chat_id']}\n\n{content}", f"{title} - {request.app.state.config.WEBUI_URL}/c/{metadata['chat_id']}\n\n{content}",
{ {

View File

@ -36,7 +36,11 @@ from open_webui.config import (
AppConfig, AppConfig,
) )
from open_webui.constants import ERROR_MESSAGES, WEBHOOK_MESSAGES from open_webui.constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
from open_webui.env import WEBUI_AUTH_COOKIE_SAME_SITE, WEBUI_AUTH_COOKIE_SECURE from open_webui.env import (
WEBUI_NAME,
WEBUI_AUTH_COOKIE_SAME_SITE,
WEBUI_AUTH_COOKIE_SECURE,
)
from open_webui.utils.misc import parse_duration from open_webui.utils.misc import parse_duration
from open_webui.utils.auth import get_password_hash, create_token from open_webui.utils.auth import get_password_hash, create_token
from open_webui.utils.webhook import post_webhook from open_webui.utils.webhook import post_webhook
@ -66,8 +70,9 @@ auth_manager_config.JWT_EXPIRES_IN = JWT_EXPIRES_IN
class OAuthManager: class OAuthManager:
def __init__(self): def __init__(self, app):
self.oauth = OAuth() self.oauth = OAuth()
self.app = app
for _, provider_config in OAUTH_PROVIDERS.items(): for _, provider_config in OAUTH_PROVIDERS.items():
provider_config["register"](self.oauth) provider_config["register"](self.oauth)
@ -200,7 +205,7 @@ class OAuthManager:
id=group_model.id, form_data=update_form, overwrite=False id=group_model.id, form_data=update_form, overwrite=False
) )
async def handle_login(self, provider, request): async def handle_login(self, request, provider):
if provider not in OAUTH_PROVIDERS: if provider not in OAUTH_PROVIDERS:
raise HTTPException(404) raise HTTPException(404)
# If the provider has a custom redirect URL, use that, otherwise automatically generate one # If the provider has a custom redirect URL, use that, otherwise automatically generate one
@ -212,7 +217,7 @@ class OAuthManager:
raise HTTPException(404) raise HTTPException(404)
return await client.authorize_redirect(request, redirect_uri) return await client.authorize_redirect(request, redirect_uri)
async def handle_callback(self, provider, request, response): async def handle_callback(self, request, provider, response):
if provider not in OAUTH_PROVIDERS: if provider not in OAUTH_PROVIDERS:
raise HTTPException(404) raise HTTPException(404)
client = self.get_client(provider) client = self.get_client(provider)
@ -266,6 +271,17 @@ class OAuthManager:
Users.update_user_role_by_id(user.id, determined_role) Users.update_user_role_by_id(user.id, determined_role)
if not user: if not user:
user_count = Users.get_num_users()
if (
request.app.state.USER_COUNT
and user_count >= request.app.state.USER_COUNT
):
raise HTTPException(
403,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
# If the user does not exist, check if signups are enabled # If the user does not exist, check if signups are enabled
if auth_manager_config.ENABLE_OAUTH_SIGNUP: if auth_manager_config.ENABLE_OAUTH_SIGNUP:
# Check if an existing user with the same email already exists # Check if an existing user with the same email already exists
@ -334,6 +350,7 @@ class OAuthManager:
if auth_manager_config.WEBHOOK_URL: if auth_manager_config.WEBHOOK_URL:
post_webhook( post_webhook(
WEBUI_NAME,
auth_manager_config.WEBHOOK_URL, auth_manager_config.WEBHOOK_URL,
WEBHOOK_MESSAGES.USER_SIGNUP(user.name), WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
{ {
@ -380,6 +397,3 @@ class OAuthManager:
# Redirect back to the frontend with the JWT token # Redirect back to the frontend with the JWT token
redirect_url = f"{request.base_url}auth#token={jwt_token}" redirect_url = f"{request.base_url}auth#token={jwt_token}"
return RedirectResponse(url=redirect_url, headers=response.headers) return RedirectResponse(url=redirect_url, headers=response.headers)
oauth_manager = OAuthManager()

View File

@ -2,14 +2,14 @@ import json
import logging import logging
import requests import requests
from open_webui.config import WEBUI_FAVICON_URL, WEBUI_NAME from open_webui.config import WEBUI_FAVICON_URL
from open_webui.env import SRC_LOG_LEVELS, VERSION from open_webui.env import SRC_LOG_LEVELS, VERSION
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["WEBHOOK"]) log.setLevel(SRC_LOG_LEVELS["WEBHOOK"])
def post_webhook(url: str, message: str, event_data: dict) -> bool: def post_webhook(name: str, url: str, message: str, event_data: dict) -> bool:
try: try:
log.debug(f"post_webhook: {url}, {message}, {event_data}") log.debug(f"post_webhook: {url}, {message}, {event_data}")
payload = {} payload = {}
@ -39,7 +39,7 @@ def post_webhook(url: str, message: str, event_data: dict) -> bool:
"sections": [ "sections": [
{ {
"activityTitle": message, "activityTitle": message,
"activitySubtitle": f"{WEBUI_NAME} ({VERSION}) - {action}", "activitySubtitle": f"{name} ({VERSION}) - {action}",
"activityImage": WEBUI_FAVICON_URL, "activityImage": WEBUI_FAVICON_URL,
"facts": facts, "facts": facts,
"markdown": True, "markdown": True,