diff --git a/CHANGELOG.md b/CHANGELOG.md index 98ba0c4c2..f7416361d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,9 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +### Added +- **🌐 Enhanced Translations**: Added Slovak language, improved Czech language. + ## [0.4.8] - 2024-12-07 ### Added diff --git a/backend/open_webui/apps/audio/main.py b/backend/open_webui/apps/audio/main.py deleted file mode 100644 index 5c24c2633..000000000 --- a/backend/open_webui/apps/audio/main.py +++ /dev/null @@ -1,703 +0,0 @@ -import hashlib -import json -import logging -import os -import uuid -from functools import lru_cache -from pathlib import Path -from pydub import AudioSegment -from pydub.silence import split_on_silence - -import aiohttp -import aiofiles -import requests -from open_webui.config import ( - AUDIO_STT_ENGINE, - AUDIO_STT_MODEL, - AUDIO_STT_OPENAI_API_BASE_URL, - AUDIO_STT_OPENAI_API_KEY, - AUDIO_TTS_API_KEY, - AUDIO_TTS_ENGINE, - AUDIO_TTS_MODEL, - AUDIO_TTS_OPENAI_API_BASE_URL, - AUDIO_TTS_OPENAI_API_KEY, - AUDIO_TTS_SPLIT_ON, - AUDIO_TTS_VOICE, - AUDIO_TTS_AZURE_SPEECH_REGION, - AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT, - CACHE_DIR, - CORS_ALLOW_ORIGIN, - WHISPER_MODEL, - WHISPER_MODEL_AUTO_UPDATE, - WHISPER_MODEL_DIR, - AppConfig, -) - -from open_webui.constants import ERROR_MESSAGES -from open_webui.env import ( - ENV, - SRC_LOG_LEVELS, - DEVICE_TYPE, - ENABLE_FORWARD_USER_INFO_HEADERS, -) - -from fastapi import Depends, FastAPI, File, HTTPException, Request, UploadFile, status -from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import FileResponse -from pydantic import BaseModel -from open_webui.utils.utils import get_admin_user, get_verified_user - -# Constants -MAX_FILE_SIZE_MB = 25 -MAX_FILE_SIZE = MAX_FILE_SIZE_MB * 1024 * 1024 # Convert MB to bytes - - -log = logging.getLogger(__name__) -log.setLevel(SRC_LOG_LEVELS["AUDIO"]) - -app = FastAPI( - docs_url="/docs" if ENV == "dev" else None, - openapi_url="/openapi.json" if ENV == "dev" else None, - redoc_url=None, -) - -app.add_middleware( - CORSMiddleware, - allow_origins=CORS_ALLOW_ORIGIN, - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) - -app.state.config = AppConfig() - -app.state.config.STT_OPENAI_API_BASE_URL = AUDIO_STT_OPENAI_API_BASE_URL -app.state.config.STT_OPENAI_API_KEY = AUDIO_STT_OPENAI_API_KEY -app.state.config.STT_ENGINE = AUDIO_STT_ENGINE -app.state.config.STT_MODEL = AUDIO_STT_MODEL - -app.state.config.WHISPER_MODEL = WHISPER_MODEL -app.state.faster_whisper_model = None - -app.state.config.TTS_OPENAI_API_BASE_URL = AUDIO_TTS_OPENAI_API_BASE_URL -app.state.config.TTS_OPENAI_API_KEY = AUDIO_TTS_OPENAI_API_KEY -app.state.config.TTS_ENGINE = AUDIO_TTS_ENGINE -app.state.config.TTS_MODEL = AUDIO_TTS_MODEL -app.state.config.TTS_VOICE = AUDIO_TTS_VOICE -app.state.config.TTS_API_KEY = AUDIO_TTS_API_KEY -app.state.config.TTS_SPLIT_ON = AUDIO_TTS_SPLIT_ON - - -app.state.speech_synthesiser = None -app.state.speech_speaker_embeddings_dataset = None - -app.state.config.TTS_AZURE_SPEECH_REGION = AUDIO_TTS_AZURE_SPEECH_REGION -app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT = AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT - -# setting device type for whisper model -whisper_device_type = DEVICE_TYPE if DEVICE_TYPE and DEVICE_TYPE == "cuda" else "cpu" -log.info(f"whisper_device_type: {whisper_device_type}") - -SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/") -SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True) - - -def set_faster_whisper_model(model: str, auto_update: bool = False): - if model and app.state.config.STT_ENGINE == "": - from faster_whisper import WhisperModel - - faster_whisper_kwargs = { - "model_size_or_path": model, - "device": whisper_device_type, - "compute_type": "int8", - "download_root": WHISPER_MODEL_DIR, - "local_files_only": not auto_update, - } - - try: - app.state.faster_whisper_model = WhisperModel(**faster_whisper_kwargs) - except Exception: - log.warning( - "WhisperModel initialization failed, attempting download with local_files_only=False" - ) - faster_whisper_kwargs["local_files_only"] = False - app.state.faster_whisper_model = WhisperModel(**faster_whisper_kwargs) - - else: - app.state.faster_whisper_model = None - - -class TTSConfigForm(BaseModel): - OPENAI_API_BASE_URL: str - OPENAI_API_KEY: str - API_KEY: str - ENGINE: str - MODEL: str - VOICE: str - SPLIT_ON: str - AZURE_SPEECH_REGION: str - AZURE_SPEECH_OUTPUT_FORMAT: str - - -class STTConfigForm(BaseModel): - OPENAI_API_BASE_URL: str - OPENAI_API_KEY: str - ENGINE: str - MODEL: str - WHISPER_MODEL: str - - -class AudioConfigUpdateForm(BaseModel): - tts: TTSConfigForm - stt: STTConfigForm - - -from pydub import AudioSegment -from pydub.utils import mediainfo - - -def is_mp4_audio(file_path): - """Check if the given file is an MP4 audio file.""" - if not os.path.isfile(file_path): - print(f"File not found: {file_path}") - return False - - info = mediainfo(file_path) - if ( - info.get("codec_name") == "aac" - and info.get("codec_type") == "audio" - and info.get("codec_tag_string") == "mp4a" - ): - return True - return False - - -def convert_mp4_to_wav(file_path, output_path): - """Convert MP4 audio file to WAV format.""" - audio = AudioSegment.from_file(file_path, format="mp4") - audio.export(output_path, format="wav") - print(f"Converted {file_path} to {output_path}") - - -@app.get("/config") -async def get_audio_config(user=Depends(get_admin_user)): - return { - "tts": { - "OPENAI_API_BASE_URL": app.state.config.TTS_OPENAI_API_BASE_URL, - "OPENAI_API_KEY": app.state.config.TTS_OPENAI_API_KEY, - "API_KEY": app.state.config.TTS_API_KEY, - "ENGINE": app.state.config.TTS_ENGINE, - "MODEL": app.state.config.TTS_MODEL, - "VOICE": app.state.config.TTS_VOICE, - "SPLIT_ON": app.state.config.TTS_SPLIT_ON, - "AZURE_SPEECH_REGION": app.state.config.TTS_AZURE_SPEECH_REGION, - "AZURE_SPEECH_OUTPUT_FORMAT": app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT, - }, - "stt": { - "OPENAI_API_BASE_URL": app.state.config.STT_OPENAI_API_BASE_URL, - "OPENAI_API_KEY": app.state.config.STT_OPENAI_API_KEY, - "ENGINE": app.state.config.STT_ENGINE, - "MODEL": app.state.config.STT_MODEL, - "WHISPER_MODEL": app.state.config.WHISPER_MODEL, - }, - } - - -@app.post("/config/update") -async def update_audio_config( - form_data: AudioConfigUpdateForm, user=Depends(get_admin_user) -): - app.state.config.TTS_OPENAI_API_BASE_URL = form_data.tts.OPENAI_API_BASE_URL - app.state.config.TTS_OPENAI_API_KEY = form_data.tts.OPENAI_API_KEY - app.state.config.TTS_API_KEY = form_data.tts.API_KEY - app.state.config.TTS_ENGINE = form_data.tts.ENGINE - app.state.config.TTS_MODEL = form_data.tts.MODEL - app.state.config.TTS_VOICE = form_data.tts.VOICE - app.state.config.TTS_SPLIT_ON = form_data.tts.SPLIT_ON - app.state.config.TTS_AZURE_SPEECH_REGION = form_data.tts.AZURE_SPEECH_REGION - app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT = ( - form_data.tts.AZURE_SPEECH_OUTPUT_FORMAT - ) - - app.state.config.STT_OPENAI_API_BASE_URL = form_data.stt.OPENAI_API_BASE_URL - app.state.config.STT_OPENAI_API_KEY = form_data.stt.OPENAI_API_KEY - app.state.config.STT_ENGINE = form_data.stt.ENGINE - app.state.config.STT_MODEL = form_data.stt.MODEL - app.state.config.WHISPER_MODEL = form_data.stt.WHISPER_MODEL - set_faster_whisper_model(form_data.stt.WHISPER_MODEL, WHISPER_MODEL_AUTO_UPDATE) - - return { - "tts": { - "OPENAI_API_BASE_URL": app.state.config.TTS_OPENAI_API_BASE_URL, - "OPENAI_API_KEY": app.state.config.TTS_OPENAI_API_KEY, - "API_KEY": app.state.config.TTS_API_KEY, - "ENGINE": app.state.config.TTS_ENGINE, - "MODEL": app.state.config.TTS_MODEL, - "VOICE": app.state.config.TTS_VOICE, - "SPLIT_ON": app.state.config.TTS_SPLIT_ON, - "AZURE_SPEECH_REGION": app.state.config.TTS_AZURE_SPEECH_REGION, - "AZURE_SPEECH_OUTPUT_FORMAT": app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT, - }, - "stt": { - "OPENAI_API_BASE_URL": app.state.config.STT_OPENAI_API_BASE_URL, - "OPENAI_API_KEY": app.state.config.STT_OPENAI_API_KEY, - "ENGINE": app.state.config.STT_ENGINE, - "MODEL": app.state.config.STT_MODEL, - "WHISPER_MODEL": app.state.config.WHISPER_MODEL, - }, - } - - -def load_speech_pipeline(): - from transformers import pipeline - from datasets import load_dataset - - if app.state.speech_synthesiser is None: - app.state.speech_synthesiser = pipeline( - "text-to-speech", "microsoft/speecht5_tts" - ) - - if app.state.speech_speaker_embeddings_dataset is None: - app.state.speech_speaker_embeddings_dataset = load_dataset( - "Matthijs/cmu-arctic-xvectors", split="validation" - ) - - -@app.post("/speech") -async def speech(request: Request, user=Depends(get_verified_user)): - body = await request.body() - name = hashlib.sha256(body).hexdigest() - - file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3") - file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json") - - # Check if the file already exists in the cache - if file_path.is_file(): - return FileResponse(file_path) - - if app.state.config.TTS_ENGINE == "openai": - headers = {} - headers["Authorization"] = f"Bearer {app.state.config.TTS_OPENAI_API_KEY}" - headers["Content-Type"] = "application/json" - - if ENABLE_FORWARD_USER_INFO_HEADERS: - headers["X-OpenWebUI-User-Name"] = user.name - headers["X-OpenWebUI-User-Id"] = user.id - headers["X-OpenWebUI-User-Email"] = user.email - headers["X-OpenWebUI-User-Role"] = user.role - - try: - body = body.decode("utf-8") - body = json.loads(body) - body["model"] = app.state.config.TTS_MODEL - body = json.dumps(body).encode("utf-8") - except Exception: - pass - - try: - async with aiohttp.ClientSession() as session: - async with session.post( - url=f"{app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech", - data=body, - headers=headers, - ) as r: - r.raise_for_status() - async with aiofiles.open(file_path, "wb") as f: - await f.write(await r.read()) - - async with aiofiles.open(file_body_path, "w") as f: - await f.write(json.dumps(json.loads(body.decode("utf-8")))) - - return FileResponse(file_path) - - except Exception as e: - log.exception(e) - error_detail = "Open WebUI: Server Connection Error" - try: - if r.status != 200: - res = await r.json() - if "error" in res: - error_detail = f"External: {res['error']['message']}" - except Exception: - error_detail = f"External: {e}" - - raise HTTPException( - status_code=getattr(r, "status", 500), - detail=error_detail, - ) - - elif app.state.config.TTS_ENGINE == "elevenlabs": - try: - payload = json.loads(body.decode("utf-8")) - except Exception as e: - log.exception(e) - raise HTTPException(status_code=400, detail="Invalid JSON payload") - - voice_id = payload.get("voice", "") - if voice_id not in get_available_voices(): - raise HTTPException( - status_code=400, - detail="Invalid voice id", - ) - - url = f"https://api.elevenlabs.io/v1/text-to-speech/{voice_id}" - headers = { - "Accept": "audio/mpeg", - "Content-Type": "application/json", - "xi-api-key": app.state.config.TTS_API_KEY, - } - data = { - "text": payload["input"], - "model_id": app.state.config.TTS_MODEL, - "voice_settings": {"stability": 0.5, "similarity_boost": 0.5}, - } - - try: - async with aiohttp.ClientSession() as session: - async with session.post(url, json=data, headers=headers) as r: - r.raise_for_status() - async with aiofiles.open(file_path, "wb") as f: - await f.write(await r.read()) - - async with aiofiles.open(file_body_path, "w") as f: - await f.write(json.dumps(json.loads(body.decode("utf-8")))) - - return FileResponse(file_path) - - except Exception as e: - log.exception(e) - error_detail = "Open WebUI: Server Connection Error" - try: - if r.status != 200: - res = await r.json() - if "error" in res: - error_detail = f"External: {res['error']['message']}" - except Exception: - error_detail = f"External: {e}" - - raise HTTPException( - status_code=getattr(r, "status", 500), - detail=error_detail, - ) - - elif app.state.config.TTS_ENGINE == "azure": - try: - payload = json.loads(body.decode("utf-8")) - except Exception as e: - log.exception(e) - raise HTTPException(status_code=400, detail="Invalid JSON payload") - - region = app.state.config.TTS_AZURE_SPEECH_REGION - language = app.state.config.TTS_VOICE - locale = "-".join(app.state.config.TTS_VOICE.split("-")[:1]) - output_format = app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT - url = f"https://{region}.tts.speech.microsoft.com/cognitiveservices/v1" - - headers = { - "Ocp-Apim-Subscription-Key": app.state.config.TTS_API_KEY, - "Content-Type": "application/ssml+xml", - "X-Microsoft-OutputFormat": output_format, - } - - data = f""" - {payload["input"]} - """ - - try: - async with aiohttp.ClientSession() as session: - async with session.post(url, headers=headers, data=data) as response: - if response.status == 200: - async with aiofiles.open(file_path, "wb") as f: - await f.write(await response.read()) - return FileResponse(file_path) - else: - error_msg = f"Error synthesizing speech - {response.reason}" - log.error(error_msg) - raise HTTPException(status_code=500, detail=error_msg) - except Exception as e: - log.exception(e) - raise HTTPException(status_code=500, detail=str(e)) - elif app.state.config.TTS_ENGINE == "transformers": - payload = None - try: - payload = json.loads(body.decode("utf-8")) - except Exception as e: - log.exception(e) - raise HTTPException(status_code=400, detail="Invalid JSON payload") - - import torch - import soundfile as sf - - load_speech_pipeline() - - embeddings_dataset = app.state.speech_speaker_embeddings_dataset - - speaker_index = 6799 - try: - speaker_index = embeddings_dataset["filename"].index( - app.state.config.TTS_MODEL - ) - except Exception: - pass - - speaker_embedding = torch.tensor( - embeddings_dataset[speaker_index]["xvector"] - ).unsqueeze(0) - - speech = app.state.speech_synthesiser( - payload["input"], - forward_params={"speaker_embeddings": speaker_embedding}, - ) - - sf.write(file_path, speech["audio"], samplerate=speech["sampling_rate"]) - with open(file_body_path, "w") as f: - json.dump(json.loads(body.decode("utf-8")), f) - - return FileResponse(file_path) - - -def transcribe(file_path): - print("transcribe", file_path) - filename = os.path.basename(file_path) - file_dir = os.path.dirname(file_path) - id = filename.split(".")[0] - - if app.state.config.STT_ENGINE == "": - if app.state.faster_whisper_model is None: - set_faster_whisper_model(app.state.config.WHISPER_MODEL) - - model = app.state.faster_whisper_model - segments, info = model.transcribe(file_path, beam_size=5) - log.info( - "Detected language '%s' with probability %f" - % (info.language, info.language_probability) - ) - - transcript = "".join([segment.text for segment in list(segments)]) - data = {"text": transcript.strip()} - - # save the transcript to a json file - transcript_file = f"{file_dir}/{id}.json" - with open(transcript_file, "w") as f: - json.dump(data, f) - - log.debug(data) - return data - elif app.state.config.STT_ENGINE == "openai": - if is_mp4_audio(file_path): - print("is_mp4_audio") - os.rename(file_path, file_path.replace(".wav", ".mp4")) - # Convert MP4 audio file to WAV format - convert_mp4_to_wav(file_path.replace(".wav", ".mp4"), file_path) - - headers = {"Authorization": f"Bearer {app.state.config.STT_OPENAI_API_KEY}"} - - files = {"file": (filename, open(file_path, "rb"))} - data = {"model": app.state.config.STT_MODEL} - - log.debug(files, data) - - r = None - try: - r = requests.post( - url=f"{app.state.config.STT_OPENAI_API_BASE_URL}/audio/transcriptions", - headers=headers, - files=files, - data=data, - ) - - r.raise_for_status() - - data = r.json() - - # save the transcript to a json file - transcript_file = f"{file_dir}/{id}.json" - with open(transcript_file, "w") as f: - json.dump(data, f) - - print(data) - return data - except Exception as e: - log.exception(e) - error_detail = "Open WebUI: Server Connection Error" - if r is not None: - try: - res = r.json() - if "error" in res: - error_detail = f"External: {res['error']['message']}" - except Exception: - error_detail = f"External: {e}" - - raise Exception(error_detail) - - -@app.post("/transcriptions") -def transcription( - file: UploadFile = File(...), - user=Depends(get_verified_user), -): - log.info(f"file.content_type: {file.content_type}") - - if file.content_type not in ["audio/mpeg", "audio/wav", "audio/ogg", "audio/x-m4a"]: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.FILE_NOT_SUPPORTED, - ) - - try: - ext = file.filename.split(".")[-1] - id = uuid.uuid4() - - filename = f"{id}.{ext}" - contents = file.file.read() - - file_dir = f"{CACHE_DIR}/audio/transcriptions" - os.makedirs(file_dir, exist_ok=True) - file_path = f"{file_dir}/{filename}" - - with open(file_path, "wb") as f: - f.write(contents) - - try: - if os.path.getsize(file_path) > MAX_FILE_SIZE: # file is bigger than 25MB - log.debug(f"File size is larger than {MAX_FILE_SIZE_MB}MB") - audio = AudioSegment.from_file(file_path) - audio = audio.set_frame_rate(16000).set_channels(1) # Compress audio - compressed_path = f"{file_dir}/{id}_compressed.opus" - audio.export(compressed_path, format="opus", bitrate="32k") - log.debug(f"Compressed audio to {compressed_path}") - file_path = compressed_path - - if ( - os.path.getsize(file_path) > MAX_FILE_SIZE - ): # Still larger than 25MB after compression - log.debug( - f"Compressed file size is still larger than {MAX_FILE_SIZE_MB}MB: {os.path.getsize(file_path)}" - ) - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.FILE_TOO_LARGE( - size=f"{MAX_FILE_SIZE_MB}MB" - ), - ) - - data = transcribe(file_path) - else: - data = transcribe(file_path) - - file_path = file_path.split("/")[-1] - return {**data, "filename": file_path} - except Exception as e: - log.exception(e) - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.DEFAULT(e), - ) - - except Exception as e: - log.exception(e) - - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.DEFAULT(e), - ) - - -def get_available_models() -> list[dict]: - if app.state.config.TTS_ENGINE == "openai": - return [{"id": "tts-1"}, {"id": "tts-1-hd"}] - elif app.state.config.TTS_ENGINE == "elevenlabs": - headers = { - "xi-api-key": app.state.config.TTS_API_KEY, - "Content-Type": "application/json", - } - - try: - response = requests.get( - "https://api.elevenlabs.io/v1/models", headers=headers, timeout=5 - ) - response.raise_for_status() - models = response.json() - return [ - {"name": model["name"], "id": model["model_id"]} for model in models - ] - except requests.RequestException as e: - log.error(f"Error fetching voices: {str(e)}") - return [] - - -@app.get("/models") -async def get_models(user=Depends(get_verified_user)): - return {"models": get_available_models()} - - -def get_available_voices() -> dict: - """Returns {voice_id: voice_name} dict""" - ret = {} - if app.state.config.TTS_ENGINE == "openai": - ret = { - "alloy": "alloy", - "echo": "echo", - "fable": "fable", - "onyx": "onyx", - "nova": "nova", - "shimmer": "shimmer", - } - elif app.state.config.TTS_ENGINE == "elevenlabs": - try: - ret = get_elevenlabs_voices() - except Exception: - # Avoided @lru_cache with exception - pass - elif app.state.config.TTS_ENGINE == "azure": - try: - region = app.state.config.TTS_AZURE_SPEECH_REGION - url = f"https://{region}.tts.speech.microsoft.com/cognitiveservices/voices/list" - headers = {"Ocp-Apim-Subscription-Key": app.state.config.TTS_API_KEY} - - response = requests.get(url, headers=headers) - response.raise_for_status() - voices = response.json() - for voice in voices: - ret[voice["ShortName"]] = ( - f"{voice['DisplayName']} ({voice['ShortName']})" - ) - except requests.RequestException as e: - log.error(f"Error fetching voices: {str(e)}") - - return ret - - -@lru_cache -def get_elevenlabs_voices() -> dict: - """ - Note, set the following in your .env file to use Elevenlabs: - AUDIO_TTS_ENGINE=elevenlabs - AUDIO_TTS_API_KEY=sk_... # Your Elevenlabs API key - AUDIO_TTS_VOICE=EXAVITQu4vr4xnSDxMaL # From https://api.elevenlabs.io/v1/voices - AUDIO_TTS_MODEL=eleven_multilingual_v2 - """ - headers = { - "xi-api-key": app.state.config.TTS_API_KEY, - "Content-Type": "application/json", - } - try: - # TODO: Add retries - response = requests.get("https://api.elevenlabs.io/v1/voices", headers=headers) - response.raise_for_status() - voices_data = response.json() - - voices = {} - for voice in voices_data.get("voices", []): - voices[voice["voice_id"]] = voice["name"] - except requests.RequestException as e: - # Avoid @lru_cache with exception - log.error(f"Error fetching voices: {str(e)}") - raise RuntimeError(f"Error fetching voices: {str(e)}") - - return voices - - -@app.get("/voices") -async def get_voices(user=Depends(get_verified_user)): - return {"voices": [{"id": k, "name": v} for k, v in get_available_voices().items()]} diff --git a/backend/open_webui/apps/retrieval/vector/connector.py b/backend/open_webui/apps/retrieval/vector/connector.py deleted file mode 100644 index 528835b56..000000000 --- a/backend/open_webui/apps/retrieval/vector/connector.py +++ /dev/null @@ -1,22 +0,0 @@ -from open_webui.config import VECTOR_DB - -if VECTOR_DB == "milvus": - from open_webui.apps.retrieval.vector.dbs.milvus import MilvusClient - - VECTOR_DB_CLIENT = MilvusClient() -elif VECTOR_DB == "qdrant": - from open_webui.apps.retrieval.vector.dbs.qdrant import QdrantClient - - VECTOR_DB_CLIENT = QdrantClient() -elif VECTOR_DB == "opensearch": - from open_webui.apps.retrieval.vector.dbs.opensearch import OpenSearchClient - - VECTOR_DB_CLIENT = OpenSearchClient() -elif VECTOR_DB == "pgvector": - from open_webui.apps.retrieval.vector.dbs.pgvector import PgvectorClient - - VECTOR_DB_CLIENT = PgvectorClient() -else: - from open_webui.apps.retrieval.vector.dbs.chroma import ChromaClient - - VECTOR_DB_CLIENT = ChromaClient() diff --git a/backend/open_webui/apps/webui/main.py b/backend/open_webui/apps/webui/main.py deleted file mode 100644 index 054c6280e..000000000 --- a/backend/open_webui/apps/webui/main.py +++ /dev/null @@ -1,506 +0,0 @@ -import inspect -import json -import logging -import time -from typing import AsyncGenerator, Generator, Iterator - -from open_webui.apps.socket.main import get_event_call, get_event_emitter -from open_webui.apps.webui.models.functions import Functions -from open_webui.apps.webui.models.models import Models -from open_webui.apps.webui.routers import ( - auths, - chats, - folders, - configs, - groups, - files, - functions, - memories, - models, - knowledge, - prompts, - evaluations, - tools, - users, - utils, -) -from open_webui.apps.webui.utils import load_function_module_by_id -from open_webui.config import ( - ADMIN_EMAIL, - CORS_ALLOW_ORIGIN, - DEFAULT_MODELS, - DEFAULT_PROMPT_SUGGESTIONS, - DEFAULT_USER_ROLE, - MODEL_ORDER_LIST, - ENABLE_COMMUNITY_SHARING, - ENABLE_LOGIN_FORM, - ENABLE_MESSAGE_RATING, - ENABLE_SIGNUP, - ENABLE_API_KEY, - ENABLE_EVALUATION_ARENA_MODELS, - EVALUATION_ARENA_MODELS, - DEFAULT_ARENA_MODEL, - JWT_EXPIRES_IN, - ENABLE_OAUTH_ROLE_MANAGEMENT, - OAUTH_ROLES_CLAIM, - OAUTH_EMAIL_CLAIM, - OAUTH_PICTURE_CLAIM, - OAUTH_USERNAME_CLAIM, - OAUTH_ALLOWED_ROLES, - OAUTH_ADMIN_ROLES, - SHOW_ADMIN_DETAILS, - USER_PERMISSIONS, - WEBHOOK_URL, - WEBUI_AUTH, - WEBUI_BANNERS, - ENABLE_LDAP, - LDAP_SERVER_LABEL, - LDAP_SERVER_HOST, - LDAP_SERVER_PORT, - LDAP_ATTRIBUTE_FOR_USERNAME, - LDAP_SEARCH_FILTERS, - LDAP_SEARCH_BASE, - LDAP_APP_DN, - LDAP_APP_PASSWORD, - LDAP_USE_TLS, - LDAP_CA_CERT_FILE, - LDAP_CIPHERS, - AppConfig, -) -from open_webui.env import ( - ENV, - SRC_LOG_LEVELS, - WEBUI_AUTH_TRUSTED_EMAIL_HEADER, - WEBUI_AUTH_TRUSTED_NAME_HEADER, -) -from fastapi import FastAPI -from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import StreamingResponse -from pydantic import BaseModel -from open_webui.utils.misc import ( - openai_chat_chunk_message_template, - openai_chat_completion_message_template, -) -from open_webui.utils.payload import ( - apply_model_params_to_body_openai, - apply_model_system_prompt_to_body, -) - - -from open_webui.utils.tools import get_tools - -app = FastAPI( - docs_url="/docs" if ENV == "dev" else None, - openapi_url="/openapi.json" if ENV == "dev" else None, - redoc_url=None, -) - -log = logging.getLogger(__name__) -log.setLevel(SRC_LOG_LEVELS["MAIN"]) - -app.state.config = AppConfig() - -app.state.config.ENABLE_SIGNUP = ENABLE_SIGNUP -app.state.config.ENABLE_LOGIN_FORM = ENABLE_LOGIN_FORM -app.state.config.ENABLE_API_KEY = ENABLE_API_KEY - -app.state.config.JWT_EXPIRES_IN = JWT_EXPIRES_IN -app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER -app.state.AUTH_TRUSTED_NAME_HEADER = WEBUI_AUTH_TRUSTED_NAME_HEADER - - -app.state.config.SHOW_ADMIN_DETAILS = SHOW_ADMIN_DETAILS -app.state.config.ADMIN_EMAIL = ADMIN_EMAIL - - -app.state.config.DEFAULT_MODELS = DEFAULT_MODELS -app.state.config.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS -app.state.config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE - - -app.state.config.USER_PERMISSIONS = USER_PERMISSIONS -app.state.config.WEBHOOK_URL = WEBHOOK_URL -app.state.config.BANNERS = WEBUI_BANNERS -app.state.config.MODEL_ORDER_LIST = MODEL_ORDER_LIST - -app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING -app.state.config.ENABLE_MESSAGE_RATING = ENABLE_MESSAGE_RATING - -app.state.config.ENABLE_EVALUATION_ARENA_MODELS = ENABLE_EVALUATION_ARENA_MODELS -app.state.config.EVALUATION_ARENA_MODELS = EVALUATION_ARENA_MODELS - -app.state.config.OAUTH_USERNAME_CLAIM = OAUTH_USERNAME_CLAIM -app.state.config.OAUTH_PICTURE_CLAIM = OAUTH_PICTURE_CLAIM -app.state.config.OAUTH_EMAIL_CLAIM = OAUTH_EMAIL_CLAIM - -app.state.config.ENABLE_OAUTH_ROLE_MANAGEMENT = ENABLE_OAUTH_ROLE_MANAGEMENT -app.state.config.OAUTH_ROLES_CLAIM = OAUTH_ROLES_CLAIM -app.state.config.OAUTH_ALLOWED_ROLES = OAUTH_ALLOWED_ROLES -app.state.config.OAUTH_ADMIN_ROLES = OAUTH_ADMIN_ROLES - -app.state.config.ENABLE_LDAP = ENABLE_LDAP -app.state.config.LDAP_SERVER_LABEL = LDAP_SERVER_LABEL -app.state.config.LDAP_SERVER_HOST = LDAP_SERVER_HOST -app.state.config.LDAP_SERVER_PORT = LDAP_SERVER_PORT -app.state.config.LDAP_ATTRIBUTE_FOR_USERNAME = LDAP_ATTRIBUTE_FOR_USERNAME -app.state.config.LDAP_APP_DN = LDAP_APP_DN -app.state.config.LDAP_APP_PASSWORD = LDAP_APP_PASSWORD -app.state.config.LDAP_SEARCH_BASE = LDAP_SEARCH_BASE -app.state.config.LDAP_SEARCH_FILTERS = LDAP_SEARCH_FILTERS -app.state.config.LDAP_USE_TLS = LDAP_USE_TLS -app.state.config.LDAP_CA_CERT_FILE = LDAP_CA_CERT_FILE -app.state.config.LDAP_CIPHERS = LDAP_CIPHERS - -app.state.TOOLS = {} -app.state.FUNCTIONS = {} - -app.add_middleware( - CORSMiddleware, - allow_origins=CORS_ALLOW_ORIGIN, - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) - - -app.include_router(configs.router, prefix="/configs", tags=["configs"]) - -app.include_router(auths.router, prefix="/auths", tags=["auths"]) -app.include_router(users.router, prefix="/users", tags=["users"]) - -app.include_router(chats.router, prefix="/chats", tags=["chats"]) - -app.include_router(models.router, prefix="/models", tags=["models"]) -app.include_router(knowledge.router, prefix="/knowledge", tags=["knowledge"]) -app.include_router(prompts.router, prefix="/prompts", tags=["prompts"]) -app.include_router(tools.router, prefix="/tools", tags=["tools"]) - -app.include_router(memories.router, prefix="/memories", tags=["memories"]) -app.include_router(folders.router, prefix="/folders", tags=["folders"]) - -app.include_router(groups.router, prefix="/groups", tags=["groups"]) -app.include_router(files.router, prefix="/files", tags=["files"]) -app.include_router(functions.router, prefix="/functions", tags=["functions"]) -app.include_router(evaluations.router, prefix="/evaluations", tags=["evaluations"]) - - -app.include_router(utils.router, prefix="/utils", tags=["utils"]) - - -@app.get("/") -async def get_status(): - return { - "status": True, - "auth": WEBUI_AUTH, - "default_models": app.state.config.DEFAULT_MODELS, - "default_prompt_suggestions": app.state.config.DEFAULT_PROMPT_SUGGESTIONS, - } - - -async def get_all_models(): - models = [] - pipe_models = await get_pipe_models() - models = models + pipe_models - - if app.state.config.ENABLE_EVALUATION_ARENA_MODELS: - arena_models = [] - if len(app.state.config.EVALUATION_ARENA_MODELS) > 0: - arena_models = [ - { - "id": model["id"], - "name": model["name"], - "info": { - "meta": model["meta"], - }, - "object": "model", - "created": int(time.time()), - "owned_by": "arena", - "arena": True, - } - for model in app.state.config.EVALUATION_ARENA_MODELS - ] - else: - # Add default arena model - arena_models = [ - { - "id": DEFAULT_ARENA_MODEL["id"], - "name": DEFAULT_ARENA_MODEL["name"], - "info": { - "meta": DEFAULT_ARENA_MODEL["meta"], - }, - "object": "model", - "created": int(time.time()), - "owned_by": "arena", - "arena": True, - } - ] - models = models + arena_models - return models - - -def get_function_module(pipe_id: str): - # Check if function is already loaded - if pipe_id not in app.state.FUNCTIONS: - function_module, _, _ = load_function_module_by_id(pipe_id) - app.state.FUNCTIONS[pipe_id] = function_module - else: - function_module = app.state.FUNCTIONS[pipe_id] - - if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): - valves = Functions.get_function_valves_by_id(pipe_id) - function_module.valves = function_module.Valves(**(valves if valves else {})) - return function_module - - -async def get_pipe_models(): - pipes = Functions.get_functions_by_type("pipe", active_only=True) - pipe_models = [] - - for pipe in pipes: - function_module = get_function_module(pipe.id) - - # Check if function is a manifold - if hasattr(function_module, "pipes"): - sub_pipes = [] - - # Check if pipes is a function or a list - - try: - if callable(function_module.pipes): - sub_pipes = function_module.pipes() - else: - sub_pipes = function_module.pipes - except Exception as e: - log.exception(e) - sub_pipes = [] - - log.debug( - f"get_pipe_models: function '{pipe.id}' is a manifold of {sub_pipes}" - ) - - for p in sub_pipes: - sub_pipe_id = f'{pipe.id}.{p["id"]}' - sub_pipe_name = p["name"] - - if hasattr(function_module, "name"): - sub_pipe_name = f"{function_module.name}{sub_pipe_name}" - - pipe_flag = {"type": pipe.type} - - pipe_models.append( - { - "id": sub_pipe_id, - "name": sub_pipe_name, - "object": "model", - "created": pipe.created_at, - "owned_by": "openai", - "pipe": pipe_flag, - } - ) - else: - pipe_flag = {"type": "pipe"} - - log.debug( - f"get_pipe_models: function '{pipe.id}' is a single pipe {{ 'id': {pipe.id}, 'name': {pipe.name} }}" - ) - - pipe_models.append( - { - "id": pipe.id, - "name": pipe.name, - "object": "model", - "created": pipe.created_at, - "owned_by": "openai", - "pipe": pipe_flag, - } - ) - - return pipe_models - - -async def execute_pipe(pipe, params): - if inspect.iscoroutinefunction(pipe): - return await pipe(**params) - else: - return pipe(**params) - - -async def get_message_content(res: str | Generator | AsyncGenerator) -> str: - if isinstance(res, str): - return res - if isinstance(res, Generator): - return "".join(map(str, res)) - if isinstance(res, AsyncGenerator): - return "".join([str(stream) async for stream in res]) - - -def process_line(form_data: dict, line): - if isinstance(line, BaseModel): - line = line.model_dump_json() - line = f"data: {line}" - if isinstance(line, dict): - line = f"data: {json.dumps(line)}" - - try: - line = line.decode("utf-8") - except Exception: - pass - - if line.startswith("data:"): - return f"{line}\n\n" - else: - line = openai_chat_chunk_message_template(form_data["model"], line) - return f"data: {json.dumps(line)}\n\n" - - -def get_pipe_id(form_data: dict) -> str: - pipe_id = form_data["model"] - if "." in pipe_id: - pipe_id, _ = pipe_id.split(".", 1) - - return pipe_id - - -def get_function_params(function_module, form_data, user, extra_params=None): - if extra_params is None: - extra_params = {} - - pipe_id = get_pipe_id(form_data) - - # Get the signature of the function - sig = inspect.signature(function_module.pipe) - params = {"body": form_data} | { - k: v for k, v in extra_params.items() if k in sig.parameters - } - - if "__user__" in params and hasattr(function_module, "UserValves"): - user_valves = Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id) - try: - params["__user__"]["valves"] = function_module.UserValves(**user_valves) - except Exception as e: - log.exception(e) - params["__user__"]["valves"] = function_module.UserValves() - - return params - - -async def generate_function_chat_completion(form_data, user, models: dict = {}): - model_id = form_data.get("model") - model_info = Models.get_model_by_id(model_id) - - metadata = form_data.pop("metadata", {}) - - files = metadata.get("files", []) - tool_ids = metadata.get("tool_ids", []) - # Check if tool_ids is None - if tool_ids is None: - tool_ids = [] - - __event_emitter__ = None - __event_call__ = None - __task__ = None - __task_body__ = None - - if metadata: - if all(k in metadata for k in ("session_id", "chat_id", "message_id")): - __event_emitter__ = get_event_emitter(metadata) - __event_call__ = get_event_call(metadata) - __task__ = metadata.get("task", None) - __task_body__ = metadata.get("task_body", None) - - extra_params = { - "__event_emitter__": __event_emitter__, - "__event_call__": __event_call__, - "__task__": __task__, - "__task_body__": __task_body__, - "__files__": files, - "__user__": { - "id": user.id, - "email": user.email, - "name": user.name, - "role": user.role, - }, - "__metadata__": metadata, - } - extra_params["__tools__"] = get_tools( - app, - tool_ids, - user, - { - **extra_params, - "__model__": models.get(form_data["model"], None), - "__messages__": form_data["messages"], - "__files__": files, - }, - ) - - if model_info: - if model_info.base_model_id: - form_data["model"] = model_info.base_model_id - - params = model_info.params.model_dump() - form_data = apply_model_params_to_body_openai(params, form_data) - form_data = apply_model_system_prompt_to_body(params, form_data, user) - - pipe_id = get_pipe_id(form_data) - function_module = get_function_module(pipe_id) - - pipe = function_module.pipe - params = get_function_params(function_module, form_data, user, extra_params) - - if form_data.get("stream", False): - - async def stream_content(): - try: - res = await execute_pipe(pipe, params) - - # Directly return if the response is a StreamingResponse - if isinstance(res, StreamingResponse): - async for data in res.body_iterator: - yield data - return - if isinstance(res, dict): - yield f"data: {json.dumps(res)}\n\n" - return - - except Exception as e: - log.error(f"Error: {e}") - yield f"data: {json.dumps({'error': {'detail':str(e)}})}\n\n" - return - - if isinstance(res, str): - message = openai_chat_chunk_message_template(form_data["model"], res) - yield f"data: {json.dumps(message)}\n\n" - - if isinstance(res, Iterator): - for line in res: - yield process_line(form_data, line) - - if isinstance(res, AsyncGenerator): - async for line in res: - yield process_line(form_data, line) - - if isinstance(res, str) or isinstance(res, Generator): - finish_message = openai_chat_chunk_message_template( - form_data["model"], "" - ) - finish_message["choices"][0]["finish_reason"] = "stop" - yield f"data: {json.dumps(finish_message)}\n\n" - yield "data: [DONE]" - - return StreamingResponse(stream_content(), media_type="text/event-stream") - else: - try: - res = await execute_pipe(pipe, params) - - except Exception as e: - log.error(f"Error: {e}") - return {"error": {"detail": str(e)}} - - if isinstance(res, StreamingResponse) or isinstance(res, dict): - return res - if isinstance(res, BaseModel): - return res.model_dump() - - message = await get_message_content(res) - return openai_chat_completion_message_template(form_data["model"], message) diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index c0a0f63b5..e49c251a1 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -10,7 +10,7 @@ from urllib.parse import urlparse import chromadb import requests import yaml -from open_webui.apps.webui.internal.db import Base, get_db +from open_webui.internal.db import Base, get_db from open_webui.env import ( OPEN_WEBUI_DIR, DATA_DIR, @@ -21,6 +21,7 @@ from open_webui.env import ( WEBUI_NAME, log, DATABASE_URL, + OFFLINE_MODE ) from pydantic import BaseModel from sqlalchemy import JSON, Column, DateTime, Integer, func @@ -429,6 +430,15 @@ OAUTH_ADMIN_ROLES = PersistentConfig( [role.strip() for role in os.environ.get("OAUTH_ADMIN_ROLES", "admin").split(",")], ) +OAUTH_ALLOWED_DOMAINS = PersistentConfig( + "OAUTH_ALLOWED_DOMAINS", + "oauth.allowed_domains", + [ + domain.strip() + for domain in os.environ.get("OAUTH_ALLOWED_DOMAINS", "*").split(",") + ], +) + def load_oauth_providers(): OAUTH_PROVIDERS.clear() @@ -948,12 +958,45 @@ TITLE_GENERATION_PROMPT_TEMPLATE = PersistentConfig( os.environ.get("TITLE_GENERATION_PROMPT_TEMPLATE", ""), ) +DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE = """Create a concise, 3-5 word title with an emoji as a title for the chat history, in the given language. Suitable Emojis for the summary can be used to enhance understanding but avoid quotation marks or special formatting. RESPOND ONLY WITH THE TITLE TEXT. + +Examples of titles: +📉 Stock Market Trends +🍪 Perfect Chocolate Chip Recipe +Evolution of Music Streaming +Remote Work Productivity Tips +Artificial Intelligence in Healthcare +🎮 Video Game Development Insights + + +{{MESSAGES:END:2}} +""" + + TAGS_GENERATION_PROMPT_TEMPLATE = PersistentConfig( "TAGS_GENERATION_PROMPT_TEMPLATE", "task.tags.prompt_template", os.environ.get("TAGS_GENERATION_PROMPT_TEMPLATE", ""), ) +DEFAULT_TAGS_GENERATION_PROMPT_TEMPLATE = """### Task: +Generate 1-3 broad tags categorizing the main themes of the chat history, along with 1-3 more specific subtopic tags. + +### Guidelines: +- Start with high-level domains (e.g. Science, Technology, Philosophy, Arts, Politics, Business, Health, Sports, Entertainment, Education) +- Consider including relevant subfields/subdomains if they are strongly represented throughout the conversation +- If content is too short (less than 3 messages) or too diverse, use only ["General"] +- Use the chat's primary language; default to English if multilingual +- Prioritize accuracy over specificity + +### Output: +JSON format: { "tags": ["tag1", "tag2", "tag3"] } + +### Chat History: + +{{MESSAGES:END:6}} +""" + ENABLE_TAGS_GENERATION = PersistentConfig( "ENABLE_TAGS_GENERATION", "task.tags.enable", @@ -1072,6 +1115,19 @@ TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = PersistentConfig( ) +DEFAULT_TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = """Available Tools: {{TOOLS}}\nReturn an empty string if no tools match the query. If a function tool matches, construct and return a JSON object in the format {\"name\": \"functionName\", \"parameters\": {\"requiredFunctionParamKey\": \"requiredFunctionParamValue\"}} using the appropriate tool and its parameters. Only return the object and limit the response to the JSON object without additional text.""" + + +DEFAULT_EMOJI_GENERATION_PROMPT_TEMPLATE = """Your task is to reflect the speaker's likely facial expression through a fitting emoji. Interpret emotions from the message and reflect their facial expression using fitting, diverse emojis (e.g., 😊, 😢, 😡, 😱). + +Message: ```{{prompt}}```""" + +DEFAULT_MOA_GENERATION_PROMPT_TEMPLATE = """You have been provided with a set of responses from various models to the latest user query: "{{prompt}}" + +Your task is to synthesize these responses into a single, high-quality response. It is crucial to critically evaluate the information provided in these responses, recognizing that some of it may be biased or incorrect. Your response should not simply replicate the given answers but should offer a refined, accurate, and comprehensive reply to the instruction. Ensure your response is well-structured, coherent, and adheres to the highest standards of accuracy and reliability. + +Responses from models: {{responses}}""" + #################################### # Vector Database #################################### @@ -1197,7 +1253,7 @@ RAG_EMBEDDING_MODEL = PersistentConfig( log.info(f"Embedding model set: {RAG_EMBEDDING_MODEL.value}") RAG_EMBEDDING_MODEL_AUTO_UPDATE = ( - os.environ.get("RAG_EMBEDDING_MODEL_AUTO_UPDATE", "True").lower() == "true" + not OFFLINE_MODE and os.environ.get("RAG_EMBEDDING_MODEL_AUTO_UPDATE", "True").lower() == "true" ) RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE = ( @@ -1222,7 +1278,7 @@ if RAG_RERANKING_MODEL.value != "": log.info(f"Reranking model set: {RAG_RERANKING_MODEL.value}") RAG_RERANKING_MODEL_AUTO_UPDATE = ( - os.environ.get("RAG_RERANKING_MODEL_AUTO_UPDATE", "True").lower() == "true" + not OFFLINE_MODE and os.environ.get("RAG_RERANKING_MODEL_AUTO_UPDATE", "True").lower() == "true" ) RAG_RERANKING_MODEL_TRUST_REMOTE_CODE = ( @@ -1380,6 +1436,12 @@ BRAVE_SEARCH_API_KEY = PersistentConfig( os.getenv("BRAVE_SEARCH_API_KEY", ""), ) +KAGI_SEARCH_API_KEY = PersistentConfig( + "KAGI_SEARCH_API_KEY", + "rag.web.search.kagi_search_api_key", + os.getenv("KAGI_SEARCH_API_KEY", ""), +) + MOJEEK_SEARCH_API_KEY = PersistentConfig( "MOJEEK_SEARCH_API_KEY", "rag.web.search.mojeek_search_api_key", @@ -1686,7 +1748,7 @@ WHISPER_MODEL = PersistentConfig( WHISPER_MODEL_DIR = os.getenv("WHISPER_MODEL_DIR", f"{CACHE_DIR}/whisper/models") WHISPER_MODEL_AUTO_UPDATE = ( - os.environ.get("WHISPER_MODEL_AUTO_UPDATE", "").lower() == "true" + not OFFLINE_MODE and os.environ.get("WHISPER_MODEL_AUTO_UPDATE", "").lower() == "true" ) diff --git a/backend/open_webui/env.py b/backend/open_webui/env.py index ffdc72d57..cd08cffed 100644 --- a/backend/open_webui/env.py +++ b/backend/open_webui/env.py @@ -378,7 +378,7 @@ else: AIOHTTP_CLIENT_TIMEOUT = 300 AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST = os.environ.get( - "AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST", "5" + "AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST", "" ) if AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST == "": diff --git a/backend/open_webui/functions.py b/backend/open_webui/functions.py new file mode 100644 index 000000000..16536a612 --- /dev/null +++ b/backend/open_webui/functions.py @@ -0,0 +1,316 @@ +import logging +import sys +import inspect +import json + +from pydantic import BaseModel +from typing import AsyncGenerator, Generator, Iterator +from fastapi import ( + Depends, + FastAPI, + File, + Form, + HTTPException, + Request, + UploadFile, + status, +) +from starlette.responses import Response, StreamingResponse + + +from open_webui.socket.main import ( + get_event_call, + get_event_emitter, +) + + +from open_webui.models.functions import Functions +from open_webui.models.models import Models + +from open_webui.utils.plugin import load_function_module_by_id +from open_webui.utils.tools import get_tools +from open_webui.utils.access_control import has_access + +from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL + +from open_webui.utils.misc import ( + add_or_update_system_message, + get_last_user_message, + prepend_to_first_user_message_content, + openai_chat_chunk_message_template, + openai_chat_completion_message_template, +) +from open_webui.utils.payload import ( + apply_model_params_to_body_openai, + apply_model_system_prompt_to_body, +) + + +logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL) +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["MAIN"]) + + +def get_function_module_by_id(request: Request, pipe_id: str): + # Check if function is already loaded + if pipe_id not in request.app.state.FUNCTIONS: + function_module, _, _ = load_function_module_by_id(pipe_id) + request.app.state.FUNCTIONS[pipe_id] = function_module + else: + function_module = request.app.state.FUNCTIONS[pipe_id] + + if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): + valves = Functions.get_function_valves_by_id(pipe_id) + function_module.valves = function_module.Valves(**(valves if valves else {})) + return function_module + + +async def get_function_models(request): + pipes = Functions.get_functions_by_type("pipe", active_only=True) + pipe_models = [] + + for pipe in pipes: + function_module = get_function_module_by_id(request, pipe.id) + + # Check if function is a manifold + if hasattr(function_module, "pipes"): + sub_pipes = [] + + # Check if pipes is a function or a list + + try: + if callable(function_module.pipes): + sub_pipes = function_module.pipes() + else: + sub_pipes = function_module.pipes + except Exception as e: + log.exception(e) + sub_pipes = [] + + log.debug( + f"get_function_models: function '{pipe.id}' is a manifold of {sub_pipes}" + ) + + for p in sub_pipes: + sub_pipe_id = f'{pipe.id}.{p["id"]}' + sub_pipe_name = p["name"] + + if hasattr(function_module, "name"): + sub_pipe_name = f"{function_module.name}{sub_pipe_name}" + + pipe_flag = {"type": pipe.type} + + pipe_models.append( + { + "id": sub_pipe_id, + "name": sub_pipe_name, + "object": "model", + "created": pipe.created_at, + "owned_by": "openai", + "pipe": pipe_flag, + } + ) + else: + pipe_flag = {"type": "pipe"} + + log.debug( + f"get_function_models: function '{pipe.id}' is a single pipe {{ 'id': {pipe.id}, 'name': {pipe.name} }}" + ) + + pipe_models.append( + { + "id": pipe.id, + "name": pipe.name, + "object": "model", + "created": pipe.created_at, + "owned_by": "openai", + "pipe": pipe_flag, + } + ) + + return pipe_models + + +async def generate_function_chat_completion( + request, form_data, user, models: dict = {} +): + async def execute_pipe(pipe, params): + if inspect.iscoroutinefunction(pipe): + return await pipe(**params) + else: + return pipe(**params) + + async def get_message_content(res: str | Generator | AsyncGenerator) -> str: + if isinstance(res, str): + return res + if isinstance(res, Generator): + return "".join(map(str, res)) + if isinstance(res, AsyncGenerator): + return "".join([str(stream) async for stream in res]) + + def process_line(form_data: dict, line): + if isinstance(line, BaseModel): + line = line.model_dump_json() + line = f"data: {line}" + if isinstance(line, dict): + line = f"data: {json.dumps(line)}" + + try: + line = line.decode("utf-8") + except Exception: + pass + + if line.startswith("data:"): + return f"{line}\n\n" + else: + line = openai_chat_chunk_message_template(form_data["model"], line) + return f"data: {json.dumps(line)}\n\n" + + def get_pipe_id(form_data: dict) -> str: + pipe_id = form_data["model"] + if "." in pipe_id: + pipe_id, _ = pipe_id.split(".", 1) + return pipe_id + + def get_function_params(function_module, form_data, user, extra_params=None): + if extra_params is None: + extra_params = {} + + pipe_id = get_pipe_id(form_data) + + # Get the signature of the function + sig = inspect.signature(function_module.pipe) + params = {"body": form_data} | { + k: v for k, v in extra_params.items() if k in sig.parameters + } + + if "__user__" in params and hasattr(function_module, "UserValves"): + user_valves = Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id) + try: + params["__user__"]["valves"] = function_module.UserValves(**user_valves) + except Exception as e: + log.exception(e) + params["__user__"]["valves"] = function_module.UserValves() + + return params + + model_id = form_data.get("model") + model_info = Models.get_model_by_id(model_id) + + metadata = form_data.pop("metadata", {}) + + files = metadata.get("files", []) + tool_ids = metadata.get("tool_ids", []) + # Check if tool_ids is None + if tool_ids is None: + tool_ids = [] + + __event_emitter__ = None + __event_call__ = None + __task__ = None + __task_body__ = None + + if metadata: + if all(k in metadata for k in ("session_id", "chat_id", "message_id")): + __event_emitter__ = get_event_emitter(metadata) + __event_call__ = get_event_call(metadata) + __task__ = metadata.get("task", None) + __task_body__ = metadata.get("task_body", None) + + extra_params = { + "__event_emitter__": __event_emitter__, + "__event_call__": __event_call__, + "__task__": __task__, + "__task_body__": __task_body__, + "__files__": files, + "__user__": { + "id": user.id, + "email": user.email, + "name": user.name, + "role": user.role, + }, + "__metadata__": metadata, + "__request__": request, + } + extra_params["__tools__"] = get_tools( + request, + tool_ids, + user, + { + **extra_params, + "__model__": models.get(form_data["model"], None), + "__messages__": form_data["messages"], + "__files__": files, + }, + ) + + if model_info: + if model_info.base_model_id: + form_data["model"] = model_info.base_model_id + + params = model_info.params.model_dump() + form_data = apply_model_params_to_body_openai(params, form_data) + form_data = apply_model_system_prompt_to_body(params, form_data, user) + + pipe_id = get_pipe_id(form_data) + function_module = get_function_module_by_id(request, pipe_id) + + pipe = function_module.pipe + params = get_function_params(function_module, form_data, user, extra_params) + + if form_data.get("stream", False): + + async def stream_content(): + try: + res = await execute_pipe(pipe, params) + + # Directly return if the response is a StreamingResponse + if isinstance(res, StreamingResponse): + async for data in res.body_iterator: + yield data + return + if isinstance(res, dict): + yield f"data: {json.dumps(res)}\n\n" + return + + except Exception as e: + log.error(f"Error: {e}") + yield f"data: {json.dumps({'error': {'detail':str(e)}})}\n\n" + return + + if isinstance(res, str): + message = openai_chat_chunk_message_template(form_data["model"], res) + yield f"data: {json.dumps(message)}\n\n" + + if isinstance(res, Iterator): + for line in res: + yield process_line(form_data, line) + + if isinstance(res, AsyncGenerator): + async for line in res: + yield process_line(form_data, line) + + if isinstance(res, str) or isinstance(res, Generator): + finish_message = openai_chat_chunk_message_template( + form_data["model"], "" + ) + finish_message["choices"][0]["finish_reason"] = "stop" + yield f"data: {json.dumps(finish_message)}\n\n" + yield "data: [DONE]" + + return StreamingResponse(stream_content(), media_type="text/event-stream") + else: + try: + res = await execute_pipe(pipe, params) + + except Exception as e: + log.error(f"Error: {e}") + return {"error": {"detail": str(e)}} + + if isinstance(res, StreamingResponse) or isinstance(res, dict): + return res + if isinstance(res, BaseModel): + return res.model_dump() + + message = await get_message_content(res) + return openai_chat_completion_message_template(form_data["model"], message) diff --git a/backend/open_webui/apps/webui/internal/db.py b/backend/open_webui/internal/db.py similarity index 97% rename from backend/open_webui/apps/webui/internal/db.py rename to backend/open_webui/internal/db.py index 72185ea1e..4452a6f23 100644 --- a/backend/open_webui/apps/webui/internal/db.py +++ b/backend/open_webui/internal/db.py @@ -3,7 +3,7 @@ import logging from contextlib import contextmanager from typing import Any, Optional -from open_webui.apps.webui.internal.wrappers import register_connection +from open_webui.internal.wrappers import register_connection from open_webui.env import ( OPEN_WEBUI_DIR, DATABASE_URL, diff --git a/backend/open_webui/apps/webui/internal/migrations/001_initial_schema.py b/backend/open_webui/internal/migrations/001_initial_schema.py similarity index 100% rename from backend/open_webui/apps/webui/internal/migrations/001_initial_schema.py rename to backend/open_webui/internal/migrations/001_initial_schema.py diff --git a/backend/open_webui/apps/webui/internal/migrations/002_add_local_sharing.py b/backend/open_webui/internal/migrations/002_add_local_sharing.py similarity index 100% rename from backend/open_webui/apps/webui/internal/migrations/002_add_local_sharing.py rename to backend/open_webui/internal/migrations/002_add_local_sharing.py diff --git a/backend/open_webui/apps/webui/internal/migrations/003_add_auth_api_key.py b/backend/open_webui/internal/migrations/003_add_auth_api_key.py similarity index 100% rename from backend/open_webui/apps/webui/internal/migrations/003_add_auth_api_key.py rename to backend/open_webui/internal/migrations/003_add_auth_api_key.py diff --git a/backend/open_webui/apps/webui/internal/migrations/004_add_archived.py b/backend/open_webui/internal/migrations/004_add_archived.py similarity index 100% rename from backend/open_webui/apps/webui/internal/migrations/004_add_archived.py rename to backend/open_webui/internal/migrations/004_add_archived.py diff --git a/backend/open_webui/apps/webui/internal/migrations/005_add_updated_at.py b/backend/open_webui/internal/migrations/005_add_updated_at.py similarity index 100% rename from backend/open_webui/apps/webui/internal/migrations/005_add_updated_at.py rename to backend/open_webui/internal/migrations/005_add_updated_at.py diff --git a/backend/open_webui/apps/webui/internal/migrations/006_migrate_timestamps_and_charfields.py b/backend/open_webui/internal/migrations/006_migrate_timestamps_and_charfields.py similarity index 100% rename from backend/open_webui/apps/webui/internal/migrations/006_migrate_timestamps_and_charfields.py rename to backend/open_webui/internal/migrations/006_migrate_timestamps_and_charfields.py diff --git a/backend/open_webui/apps/webui/internal/migrations/007_add_user_last_active_at.py b/backend/open_webui/internal/migrations/007_add_user_last_active_at.py similarity index 100% rename from backend/open_webui/apps/webui/internal/migrations/007_add_user_last_active_at.py rename to backend/open_webui/internal/migrations/007_add_user_last_active_at.py diff --git a/backend/open_webui/apps/webui/internal/migrations/008_add_memory.py b/backend/open_webui/internal/migrations/008_add_memory.py similarity index 100% rename from backend/open_webui/apps/webui/internal/migrations/008_add_memory.py rename to backend/open_webui/internal/migrations/008_add_memory.py diff --git a/backend/open_webui/apps/webui/internal/migrations/009_add_models.py b/backend/open_webui/internal/migrations/009_add_models.py similarity index 100% rename from backend/open_webui/apps/webui/internal/migrations/009_add_models.py rename to backend/open_webui/internal/migrations/009_add_models.py diff --git a/backend/open_webui/apps/webui/internal/migrations/010_migrate_modelfiles_to_models.py b/backend/open_webui/internal/migrations/010_migrate_modelfiles_to_models.py similarity index 100% rename from backend/open_webui/apps/webui/internal/migrations/010_migrate_modelfiles_to_models.py rename to backend/open_webui/internal/migrations/010_migrate_modelfiles_to_models.py diff --git a/backend/open_webui/apps/webui/internal/migrations/011_add_user_settings.py b/backend/open_webui/internal/migrations/011_add_user_settings.py similarity index 100% rename from backend/open_webui/apps/webui/internal/migrations/011_add_user_settings.py rename to backend/open_webui/internal/migrations/011_add_user_settings.py diff --git a/backend/open_webui/apps/webui/internal/migrations/012_add_tools.py b/backend/open_webui/internal/migrations/012_add_tools.py similarity index 100% rename from backend/open_webui/apps/webui/internal/migrations/012_add_tools.py rename to backend/open_webui/internal/migrations/012_add_tools.py diff --git a/backend/open_webui/apps/webui/internal/migrations/013_add_user_info.py b/backend/open_webui/internal/migrations/013_add_user_info.py similarity index 100% rename from backend/open_webui/apps/webui/internal/migrations/013_add_user_info.py rename to backend/open_webui/internal/migrations/013_add_user_info.py diff --git a/backend/open_webui/apps/webui/internal/migrations/014_add_files.py b/backend/open_webui/internal/migrations/014_add_files.py similarity index 100% rename from backend/open_webui/apps/webui/internal/migrations/014_add_files.py rename to backend/open_webui/internal/migrations/014_add_files.py diff --git a/backend/open_webui/apps/webui/internal/migrations/015_add_functions.py b/backend/open_webui/internal/migrations/015_add_functions.py similarity index 100% rename from backend/open_webui/apps/webui/internal/migrations/015_add_functions.py rename to backend/open_webui/internal/migrations/015_add_functions.py diff --git a/backend/open_webui/apps/webui/internal/migrations/016_add_valves_and_is_active.py b/backend/open_webui/internal/migrations/016_add_valves_and_is_active.py similarity index 100% rename from backend/open_webui/apps/webui/internal/migrations/016_add_valves_and_is_active.py rename to backend/open_webui/internal/migrations/016_add_valves_and_is_active.py diff --git a/backend/open_webui/apps/webui/internal/migrations/017_add_user_oauth_sub.py b/backend/open_webui/internal/migrations/017_add_user_oauth_sub.py similarity index 100% rename from backend/open_webui/apps/webui/internal/migrations/017_add_user_oauth_sub.py rename to backend/open_webui/internal/migrations/017_add_user_oauth_sub.py diff --git a/backend/open_webui/apps/webui/internal/migrations/018_add_function_is_global.py b/backend/open_webui/internal/migrations/018_add_function_is_global.py similarity index 100% rename from backend/open_webui/apps/webui/internal/migrations/018_add_function_is_global.py rename to backend/open_webui/internal/migrations/018_add_function_is_global.py diff --git a/backend/open_webui/apps/webui/internal/wrappers.py b/backend/open_webui/internal/wrappers.py similarity index 100% rename from backend/open_webui/apps/webui/internal/wrappers.py rename to backend/open_webui/internal/wrappers.py diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 1bf221beb..31604984f 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -8,9 +8,13 @@ import shutil import sys import time import random -from contextlib import asynccontextmanager -from typing import Optional +from contextlib import asynccontextmanager +from urllib.parse import urlencode, parse_qs, urlparse +from pydantic import BaseModel +from sqlalchemy import text + +from typing import Optional from aiocache import cached import aiohttp import requests @@ -27,126 +31,260 @@ from fastapi import ( from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, RedirectResponse from fastapi.staticfiles import StaticFiles -from pydantic import BaseModel -from sqlalchemy import text + from starlette.exceptions import HTTPException as StarletteHTTPException from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.sessions import SessionMiddleware from starlette.responses import Response, StreamingResponse -from open_webui.apps.audio.main import app as audio_app -from open_webui.apps.images.main import app as images_app -from open_webui.apps.ollama.main import ( - app as ollama_app, - get_all_models as get_ollama_models, - generate_chat_completion as generate_ollama_chat_completion, - GenerateChatCompletionForm, -) -from open_webui.apps.openai.main import ( - app as openai_app, - generate_chat_completion as generate_openai_chat_completion, - get_all_models as get_openai_models, - get_all_models_responses as get_openai_models_responses, -) -from open_webui.apps.retrieval.main import app as retrieval_app -from open_webui.apps.retrieval.utils import get_sources_from_files - -from open_webui.apps.socket.main import ( +from open_webui.socket.main import ( app as socket_app, periodic_usage_pool_cleanup, - get_event_call, - get_event_emitter, ) -from open_webui.apps.webui.internal.db import Session -from open_webui.apps.webui.main import ( - app as webui_app, - generate_function_chat_completion, - get_all_models as get_open_webui_models, +from open_webui.routers import ( + audio, + images, + ollama, + openai, + retrieval, + pipelines, + tasks, + auths, + chats, + folders, + configs, + groups, + files, + functions, + memories, + models, + knowledge, + prompts, + evaluations, + tools, + users, + utils, ) -from open_webui.apps.webui.models.functions import Functions -from open_webui.apps.webui.models.models import Models -from open_webui.apps.webui.models.users import UserModel, Users -from open_webui.apps.webui.utils import load_function_module_by_id + +from open_webui.routers.retrieval import ( + get_embedding_function, + get_ef, + get_rf, +) + +from open_webui.internal.db import Session + +from open_webui.models.functions import Functions +from open_webui.models.models import Models +from open_webui.models.users import UserModel, Users + from open_webui.config import ( - CACHE_DIR, - CORS_ALLOW_ORIGIN, - DEFAULT_LOCALE, - ENABLE_ADMIN_CHAT_ACCESS, - ENABLE_ADMIN_EXPORT, + # Ollama ENABLE_OLLAMA_API, + OLLAMA_BASE_URLS, + OLLAMA_API_CONFIGS, + # OpenAI ENABLE_OPENAI_API, - ENABLE_TAGS_GENERATION, - ENV, - FRONTEND_BUILD_DIR, - OAUTH_PROVIDERS, - STATIC_DIR, - TASK_MODEL, - TASK_MODEL_EXTERNAL, - ENABLE_SEARCH_QUERY_GENERATION, - ENABLE_RETRIEVAL_QUERY_GENERATION, - QUERY_GENERATION_PROMPT_TEMPLATE, - DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE, - TITLE_GENERATION_PROMPT_TEMPLATE, - TAGS_GENERATION_PROMPT_TEMPLATE, - ENABLE_AUTOCOMPLETE_GENERATION, - AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH, - AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE, - DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE, - TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, - WEBHOOK_URL, + OPENAI_API_BASE_URLS, + OPENAI_API_KEYS, + OPENAI_API_CONFIGS, + # Image + AUTOMATIC1111_API_AUTH, + AUTOMATIC1111_BASE_URL, + AUTOMATIC1111_CFG_SCALE, + AUTOMATIC1111_SAMPLER, + AUTOMATIC1111_SCHEDULER, + COMFYUI_BASE_URL, + COMFYUI_WORKFLOW, + COMFYUI_WORKFLOW_NODES, + ENABLE_IMAGE_GENERATION, + IMAGE_GENERATION_ENGINE, + IMAGE_GENERATION_MODEL, + IMAGE_SIZE, + IMAGE_STEPS, + IMAGES_OPENAI_API_BASE_URL, + IMAGES_OPENAI_API_KEY, + # Audio + AUDIO_STT_ENGINE, + AUDIO_STT_MODEL, + AUDIO_STT_OPENAI_API_BASE_URL, + AUDIO_STT_OPENAI_API_KEY, + AUDIO_TTS_API_KEY, + AUDIO_TTS_ENGINE, + AUDIO_TTS_MODEL, + AUDIO_TTS_OPENAI_API_BASE_URL, + AUDIO_TTS_OPENAI_API_KEY, + AUDIO_TTS_SPLIT_ON, + AUDIO_TTS_VOICE, + AUDIO_TTS_AZURE_SPEECH_REGION, + AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT, + WHISPER_MODEL, + WHISPER_MODEL_AUTO_UPDATE, + WHISPER_MODEL_DIR, + # Retrieval + RAG_TEMPLATE, + DEFAULT_RAG_TEMPLATE, + RAG_EMBEDDING_MODEL, + RAG_EMBEDDING_MODEL_AUTO_UPDATE, + RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE, + RAG_RERANKING_MODEL, + RAG_RERANKING_MODEL_AUTO_UPDATE, + RAG_RERANKING_MODEL_TRUST_REMOTE_CODE, + RAG_EMBEDDING_ENGINE, + RAG_EMBEDDING_BATCH_SIZE, + RAG_RELEVANCE_THRESHOLD, + RAG_FILE_MAX_COUNT, + RAG_FILE_MAX_SIZE, + RAG_OPENAI_API_BASE_URL, + RAG_OPENAI_API_KEY, + RAG_OLLAMA_BASE_URL, + RAG_OLLAMA_API_KEY, + CHUNK_OVERLAP, + CHUNK_SIZE, + CONTENT_EXTRACTION_ENGINE, + TIKA_SERVER_URL, + RAG_TOP_K, + RAG_TEXT_SPLITTER, + TIKTOKEN_ENCODING_NAME, + PDF_EXTRACT_IMAGES, + YOUTUBE_LOADER_LANGUAGE, + YOUTUBE_LOADER_PROXY_URL, + # Retrieval (Web Search) + RAG_WEB_SEARCH_ENGINE, + RAG_WEB_SEARCH_RESULT_COUNT, + RAG_WEB_SEARCH_CONCURRENT_REQUESTS, + RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, + JINA_API_KEY, + SEARCHAPI_API_KEY, + SEARCHAPI_ENGINE, + SEARXNG_QUERY_URL, + SERPER_API_KEY, + SERPLY_API_KEY, + SERPSTACK_API_KEY, + SERPSTACK_HTTPS, + TAVILY_API_KEY, + BING_SEARCH_V7_ENDPOINT, + BING_SEARCH_V7_SUBSCRIPTION_KEY, + BRAVE_SEARCH_API_KEY, + KAGI_SEARCH_API_KEY, + MOJEEK_SEARCH_API_KEY, + GOOGLE_PSE_API_KEY, + GOOGLE_PSE_ENGINE_ID, + ENABLE_RAG_HYBRID_SEARCH, + ENABLE_RAG_LOCAL_WEB_FETCH, + ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, + ENABLE_RAG_WEB_SEARCH, + UPLOAD_DIR, + # WebUI WEBUI_AUTH, WEBUI_NAME, + WEBUI_BANNERS, + WEBHOOK_URL, + ADMIN_EMAIL, + SHOW_ADMIN_DETAILS, + JWT_EXPIRES_IN, + ENABLE_SIGNUP, + ENABLE_LOGIN_FORM, + ENABLE_API_KEY, + ENABLE_COMMUNITY_SHARING, + ENABLE_MESSAGE_RATING, + ENABLE_EVALUATION_ARENA_MODELS, + USER_PERMISSIONS, + DEFAULT_USER_ROLE, + DEFAULT_PROMPT_SUGGESTIONS, + DEFAULT_MODELS, + DEFAULT_ARENA_MODEL, + MODEL_ORDER_LIST, + EVALUATION_ARENA_MODELS, + # WebUI (OAuth) + ENABLE_OAUTH_ROLE_MANAGEMENT, + OAUTH_ROLES_CLAIM, + OAUTH_EMAIL_CLAIM, + OAUTH_PICTURE_CLAIM, + OAUTH_USERNAME_CLAIM, + OAUTH_ALLOWED_ROLES, + OAUTH_ADMIN_ROLES, + # WebUI (LDAP) + ENABLE_LDAP, + LDAP_SERVER_LABEL, + LDAP_SERVER_HOST, + LDAP_SERVER_PORT, + LDAP_ATTRIBUTE_FOR_USERNAME, + LDAP_SEARCH_FILTERS, + LDAP_SEARCH_BASE, + LDAP_APP_DN, + LDAP_APP_PASSWORD, + LDAP_USE_TLS, + LDAP_CA_CERT_FILE, + LDAP_CIPHERS, + # Misc + ENV, + CACHE_DIR, + STATIC_DIR, + FRONTEND_BUILD_DIR, + CORS_ALLOW_ORIGIN, + DEFAULT_LOCALE, + OAUTH_PROVIDERS, + # Admin + ENABLE_ADMIN_CHAT_ACCESS, + ENABLE_ADMIN_EXPORT, + # Tasks + TASK_MODEL, + TASK_MODEL_EXTERNAL, + ENABLE_TAGS_GENERATION, + ENABLE_SEARCH_QUERY_GENERATION, + ENABLE_RETRIEVAL_QUERY_GENERATION, + ENABLE_AUTOCOMPLETE_GENERATION, + TITLE_GENERATION_PROMPT_TEMPLATE, + TAGS_GENERATION_PROMPT_TEMPLATE, + TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, + QUERY_GENERATION_PROMPT_TEMPLATE, + AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE, + AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH, AppConfig, reset_config, ) -from open_webui.constants import TASKS from open_webui.env import ( CHANGELOG, GLOBAL_LOG_LEVEL, SAFE_MODE, SRC_LOG_LEVELS, VERSION, + WEBUI_URL, WEBUI_BUILD_HASH, WEBUI_SECRET_KEY, WEBUI_SESSION_COOKIE_SAME_SITE, WEBUI_SESSION_COOKIE_SECURE, - WEBUI_URL, + WEBUI_AUTH_TRUSTED_EMAIL_HEADER, + WEBUI_AUTH_TRUSTED_NAME_HEADER, BYPASS_MODEL_ACCESS_CONTROL, RESET_CONFIG_ON_START, OFFLINE_MODE, ) -from open_webui.utils.misc import ( - add_or_update_system_message, - get_last_user_message, - prepend_to_first_user_message_content, + + +from open_webui.utils.models import ( + get_all_models, + get_all_base_models, + check_model_access, ) -from open_webui.utils.oauth import oauth_manager -from open_webui.utils.payload import convert_payload_openai_to_ollama -from open_webui.utils.response import ( - convert_response_ollama_to_openai, - convert_streaming_response_ollama_to_openai, +from open_webui.utils.chat import ( + generate_chat_completion as chat_completion_handler, + chat_completed as chat_completed_handler, + chat_action as chat_action_handler, ) -from open_webui.utils.security_headers import SecurityHeadersMiddleware -from open_webui.utils.task import ( - rag_template, - title_generation_template, - query_generation_template, - autocomplete_generation_template, - tags_generation_template, - emoji_generation_template, - moa_response_generation_template, - tools_function_calling_generation_template, -) -from open_webui.utils.tools import get_tools -from open_webui.utils.utils import ( +from open_webui.utils.middleware import process_chat_payload, process_chat_response +from open_webui.utils.access_control import has_access + +from open_webui.utils.auth import ( decode_token, get_admin_user, - get_current_user, - get_http_authorization_cred, get_verified_user, ) -from open_webui.utils.access_control import has_access +from open_webui.utils.oauth import oauth_manager +from open_webui.utils.security_headers import SecurityHeadersMiddleware + if SAFE_MODE: print("SAFE MODE ENABLED") @@ -203,757 +341,298 @@ app = FastAPI( app.state.config = AppConfig() -app.state.config.ENABLE_OPENAI_API = ENABLE_OPENAI_API -app.state.config.ENABLE_OLLAMA_API = ENABLE_OLLAMA_API +######################################## +# +# OLLAMA +# +######################################## + + +app.state.config.ENABLE_OLLAMA_API = ENABLE_OLLAMA_API +app.state.config.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS +app.state.config.OLLAMA_API_CONFIGS = OLLAMA_API_CONFIGS + +app.state.OLLAMA_MODELS = {} + +######################################## +# +# OPENAI +# +######################################## + +app.state.config.ENABLE_OPENAI_API = ENABLE_OPENAI_API +app.state.config.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS +app.state.config.OPENAI_API_KEYS = OPENAI_API_KEYS +app.state.config.OPENAI_API_CONFIGS = OPENAI_API_CONFIGS + +app.state.OPENAI_MODELS = {} + +######################################## +# +# WEBUI +# +######################################## + +app.state.config.ENABLE_SIGNUP = ENABLE_SIGNUP +app.state.config.ENABLE_LOGIN_FORM = ENABLE_LOGIN_FORM +app.state.config.ENABLE_API_KEY = ENABLE_API_KEY + +app.state.config.JWT_EXPIRES_IN = JWT_EXPIRES_IN + +app.state.config.SHOW_ADMIN_DETAILS = SHOW_ADMIN_DETAILS +app.state.config.ADMIN_EMAIL = ADMIN_EMAIL + + +app.state.config.DEFAULT_MODELS = DEFAULT_MODELS +app.state.config.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS +app.state.config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE + +app.state.config.USER_PERMISSIONS = USER_PERMISSIONS app.state.config.WEBHOOK_URL = WEBHOOK_URL +app.state.config.BANNERS = WEBUI_BANNERS +app.state.config.MODEL_ORDER_LIST = MODEL_ORDER_LIST + +app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING +app.state.config.ENABLE_MESSAGE_RATING = ENABLE_MESSAGE_RATING + +app.state.config.ENABLE_EVALUATION_ARENA_MODELS = ENABLE_EVALUATION_ARENA_MODELS +app.state.config.EVALUATION_ARENA_MODELS = EVALUATION_ARENA_MODELS + +app.state.config.OAUTH_USERNAME_CLAIM = OAUTH_USERNAME_CLAIM +app.state.config.OAUTH_PICTURE_CLAIM = OAUTH_PICTURE_CLAIM +app.state.config.OAUTH_EMAIL_CLAIM = OAUTH_EMAIL_CLAIM + +app.state.config.ENABLE_OAUTH_ROLE_MANAGEMENT = ENABLE_OAUTH_ROLE_MANAGEMENT +app.state.config.OAUTH_ROLES_CLAIM = OAUTH_ROLES_CLAIM +app.state.config.OAUTH_ALLOWED_ROLES = OAUTH_ALLOWED_ROLES +app.state.config.OAUTH_ADMIN_ROLES = OAUTH_ADMIN_ROLES + +app.state.config.ENABLE_LDAP = ENABLE_LDAP +app.state.config.LDAP_SERVER_LABEL = LDAP_SERVER_LABEL +app.state.config.LDAP_SERVER_HOST = LDAP_SERVER_HOST +app.state.config.LDAP_SERVER_PORT = LDAP_SERVER_PORT +app.state.config.LDAP_ATTRIBUTE_FOR_USERNAME = LDAP_ATTRIBUTE_FOR_USERNAME +app.state.config.LDAP_APP_DN = LDAP_APP_DN +app.state.config.LDAP_APP_PASSWORD = LDAP_APP_PASSWORD +app.state.config.LDAP_SEARCH_BASE = LDAP_SEARCH_BASE +app.state.config.LDAP_SEARCH_FILTERS = LDAP_SEARCH_FILTERS +app.state.config.LDAP_USE_TLS = LDAP_USE_TLS +app.state.config.LDAP_CA_CERT_FILE = LDAP_CA_CERT_FILE +app.state.config.LDAP_CIPHERS = LDAP_CIPHERS + + +app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER +app.state.AUTH_TRUSTED_NAME_HEADER = WEBUI_AUTH_TRUSTED_NAME_HEADER + +app.state.TOOLS = {} +app.state.FUNCTIONS = {} + + +######################################## +# +# RETRIEVAL +# +######################################## + + +app.state.config.TOP_K = RAG_TOP_K +app.state.config.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD +app.state.config.FILE_MAX_SIZE = RAG_FILE_MAX_SIZE +app.state.config.FILE_MAX_COUNT = RAG_FILE_MAX_COUNT + +app.state.config.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH +app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = ( + ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION +) + +app.state.config.CONTENT_EXTRACTION_ENGINE = CONTENT_EXTRACTION_ENGINE +app.state.config.TIKA_SERVER_URL = TIKA_SERVER_URL + +app.state.config.TEXT_SPLITTER = RAG_TEXT_SPLITTER +app.state.config.TIKTOKEN_ENCODING_NAME = TIKTOKEN_ENCODING_NAME + +app.state.config.CHUNK_SIZE = CHUNK_SIZE +app.state.config.CHUNK_OVERLAP = CHUNK_OVERLAP + +app.state.config.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE +app.state.config.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL +app.state.config.RAG_EMBEDDING_BATCH_SIZE = RAG_EMBEDDING_BATCH_SIZE +app.state.config.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL +app.state.config.RAG_TEMPLATE = RAG_TEMPLATE + +app.state.config.RAG_OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL +app.state.config.RAG_OPENAI_API_KEY = RAG_OPENAI_API_KEY + +app.state.config.RAG_OLLAMA_BASE_URL = RAG_OLLAMA_BASE_URL +app.state.config.RAG_OLLAMA_API_KEY = RAG_OLLAMA_API_KEY + +app.state.config.PDF_EXTRACT_IMAGES = PDF_EXTRACT_IMAGES + +app.state.config.YOUTUBE_LOADER_LANGUAGE = YOUTUBE_LOADER_LANGUAGE +app.state.config.YOUTUBE_LOADER_PROXY_URL = YOUTUBE_LOADER_PROXY_URL + + +app.state.config.ENABLE_RAG_WEB_SEARCH = ENABLE_RAG_WEB_SEARCH +app.state.config.RAG_WEB_SEARCH_ENGINE = RAG_WEB_SEARCH_ENGINE +app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = RAG_WEB_SEARCH_DOMAIN_FILTER_LIST + +app.state.config.SEARXNG_QUERY_URL = SEARXNG_QUERY_URL +app.state.config.GOOGLE_PSE_API_KEY = GOOGLE_PSE_API_KEY +app.state.config.GOOGLE_PSE_ENGINE_ID = GOOGLE_PSE_ENGINE_ID +app.state.config.BRAVE_SEARCH_API_KEY = BRAVE_SEARCH_API_KEY +app.state.config.KAGI_SEARCH_API_KEY = KAGI_SEARCH_API_KEY +app.state.config.MOJEEK_SEARCH_API_KEY = MOJEEK_SEARCH_API_KEY +app.state.config.SERPSTACK_API_KEY = SERPSTACK_API_KEY +app.state.config.SERPSTACK_HTTPS = SERPSTACK_HTTPS +app.state.config.SERPER_API_KEY = SERPER_API_KEY +app.state.config.SERPLY_API_KEY = SERPLY_API_KEY +app.state.config.TAVILY_API_KEY = TAVILY_API_KEY +app.state.config.SEARCHAPI_API_KEY = SEARCHAPI_API_KEY +app.state.config.SEARCHAPI_ENGINE = SEARCHAPI_ENGINE +app.state.config.JINA_API_KEY = JINA_API_KEY +app.state.config.BING_SEARCH_V7_ENDPOINT = BING_SEARCH_V7_ENDPOINT +app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY = BING_SEARCH_V7_SUBSCRIPTION_KEY + +app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = RAG_WEB_SEARCH_RESULT_COUNT +app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS + +app.state.EMBEDDING_FUNCTION = None +app.state.ef = None +app.state.rf = None + +app.state.YOUTUBE_LOADER_TRANSLATION = None + + +app.state.EMBEDDING_FUNCTION = get_embedding_function( + app.state.config.RAG_EMBEDDING_ENGINE, + app.state.config.RAG_EMBEDDING_MODEL, + app.state.ef, + ( + app.state.config.RAG_OPENAI_API_BASE_URL + if app.state.config.RAG_EMBEDDING_ENGINE == "openai" + else app.state.config.RAG_OLLAMA_BASE_URL + ), + ( + app.state.config.RAG_OPENAI_API_KEY + if app.state.config.RAG_EMBEDDING_ENGINE == "openai" + else app.state.config.RAG_OLLAMA_API_KEY + ), + app.state.config.RAG_EMBEDDING_BATCH_SIZE, +) + +try: + app.state.ef = get_ef( + app.state.config.RAG_EMBEDDING_ENGINE, + app.state.config.RAG_EMBEDDING_MODEL, + RAG_EMBEDDING_MODEL_AUTO_UPDATE, + ) + + app.state.rf = get_rf( + app.state.config.RAG_RERANKING_MODEL, + RAG_RERANKING_MODEL_AUTO_UPDATE, + ) +except Exception as e: + log.error(f"Error updating models: {e}") + pass + + +######################################## +# +# IMAGES +# +######################################## + +app.state.config.IMAGE_GENERATION_ENGINE = IMAGE_GENERATION_ENGINE +app.state.config.ENABLE_IMAGE_GENERATION = ENABLE_IMAGE_GENERATION + +app.state.config.IMAGES_OPENAI_API_BASE_URL = IMAGES_OPENAI_API_BASE_URL +app.state.config.IMAGES_OPENAI_API_KEY = IMAGES_OPENAI_API_KEY + +app.state.config.IMAGE_GENERATION_MODEL = IMAGE_GENERATION_MODEL + +app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL +app.state.config.AUTOMATIC1111_API_AUTH = AUTOMATIC1111_API_AUTH +app.state.config.AUTOMATIC1111_CFG_SCALE = AUTOMATIC1111_CFG_SCALE +app.state.config.AUTOMATIC1111_SAMPLER = AUTOMATIC1111_SAMPLER +app.state.config.AUTOMATIC1111_SCHEDULER = AUTOMATIC1111_SCHEDULER +app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL +app.state.config.COMFYUI_WORKFLOW = COMFYUI_WORKFLOW +app.state.config.COMFYUI_WORKFLOW_NODES = COMFYUI_WORKFLOW_NODES + +app.state.config.IMAGE_SIZE = IMAGE_SIZE +app.state.config.IMAGE_STEPS = IMAGE_STEPS + + +######################################## +# +# AUDIO +# +######################################## + +app.state.config.STT_OPENAI_API_BASE_URL = AUDIO_STT_OPENAI_API_BASE_URL +app.state.config.STT_OPENAI_API_KEY = AUDIO_STT_OPENAI_API_KEY +app.state.config.STT_ENGINE = AUDIO_STT_ENGINE +app.state.config.STT_MODEL = AUDIO_STT_MODEL + +app.state.config.WHISPER_MODEL = WHISPER_MODEL + +app.state.config.TTS_OPENAI_API_BASE_URL = AUDIO_TTS_OPENAI_API_BASE_URL +app.state.config.TTS_OPENAI_API_KEY = AUDIO_TTS_OPENAI_API_KEY +app.state.config.TTS_ENGINE = AUDIO_TTS_ENGINE +app.state.config.TTS_MODEL = AUDIO_TTS_MODEL +app.state.config.TTS_VOICE = AUDIO_TTS_VOICE +app.state.config.TTS_API_KEY = AUDIO_TTS_API_KEY +app.state.config.TTS_SPLIT_ON = AUDIO_TTS_SPLIT_ON + + +app.state.config.TTS_AZURE_SPEECH_REGION = AUDIO_TTS_AZURE_SPEECH_REGION +app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT = AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT + + +app.state.faster_whisper_model = None +app.state.speech_synthesiser = None +app.state.speech_speaker_embeddings_dataset = None + + +######################################## +# +# TASKS +# +######################################## + app.state.config.TASK_MODEL = TASK_MODEL app.state.config.TASK_MODEL_EXTERNAL = TASK_MODEL_EXTERNAL -app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMPLATE +app.state.config.ENABLE_SEARCH_QUERY_GENERATION = ENABLE_SEARCH_QUERY_GENERATION +app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION = ENABLE_RETRIEVAL_QUERY_GENERATION app.state.config.ENABLE_AUTOCOMPLETE_GENERATION = ENABLE_AUTOCOMPLETE_GENERATION +app.state.config.ENABLE_TAGS_GENERATION = ENABLE_TAGS_GENERATION + + +app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMPLATE +app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = TAGS_GENERATION_PROMPT_TEMPLATE +app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = ( + TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE +) +app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE = QUERY_GENERATION_PROMPT_TEMPLATE +app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE = ( + AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE +) app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH = ( AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH ) -app.state.config.ENABLE_TAGS_GENERATION = ENABLE_TAGS_GENERATION -app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = TAGS_GENERATION_PROMPT_TEMPLATE - -app.state.config.ENABLE_SEARCH_QUERY_GENERATION = ENABLE_SEARCH_QUERY_GENERATION -app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION = ENABLE_RETRIEVAL_QUERY_GENERATION -app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE = QUERY_GENERATION_PROMPT_TEMPLATE - -app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE = ( - AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE -) - -app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = ( - TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE -) - -################################## +######################################## # -# ChatCompletion Middleware +# WEBUI # -################################## +######################################## - -def get_filter_function_ids(model): - def get_priority(function_id): - function = Functions.get_function_by_id(function_id) - if function is not None and hasattr(function, "valves"): - # TODO: Fix FunctionModel - return (function.valves if function.valves else {}).get("priority", 0) - return 0 - - filter_ids = [function.id for function in Functions.get_global_filter_functions()] - if "info" in model and "meta" in model["info"]: - filter_ids.extend(model["info"]["meta"].get("filterIds", [])) - filter_ids = list(set(filter_ids)) - - enabled_filter_ids = [ - function.id - for function in Functions.get_functions_by_type("filter", active_only=True) - ] - - filter_ids = [ - filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids - ] - - filter_ids.sort(key=get_priority) - return filter_ids - - -async def chat_completion_filter_functions_handler(body, model, extra_params): - skip_files = None - - filter_ids = get_filter_function_ids(model) - for filter_id in filter_ids: - filter = Functions.get_function_by_id(filter_id) - if not filter: - continue - - if filter_id in webui_app.state.FUNCTIONS: - function_module = webui_app.state.FUNCTIONS[filter_id] - else: - function_module, _, _ = load_function_module_by_id(filter_id) - webui_app.state.FUNCTIONS[filter_id] = function_module - - # Check if the function has a file_handler variable - if hasattr(function_module, "file_handler"): - skip_files = function_module.file_handler - - if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): - valves = Functions.get_function_valves_by_id(filter_id) - function_module.valves = function_module.Valves( - **(valves if valves else {}) - ) - - if not hasattr(function_module, "inlet"): - continue - - try: - inlet = function_module.inlet - - # Get the signature of the function - sig = inspect.signature(inlet) - params = {"body": body} | { - k: v - for k, v in { - **extra_params, - "__model__": model, - "__id__": filter_id, - }.items() - if k in sig.parameters - } - - if "__user__" in params and hasattr(function_module, "UserValves"): - try: - params["__user__"]["valves"] = function_module.UserValves( - **Functions.get_user_valves_by_id_and_user_id( - filter_id, params["__user__"]["id"] - ) - ) - except Exception as e: - print(e) - - if inspect.iscoroutinefunction(inlet): - body = await inlet(**params) - else: - body = inlet(**params) - - except Exception as e: - print(f"Error: {e}") - raise e - - if skip_files and "files" in body.get("metadata", {}): - del body["metadata"]["files"] - - return body, {} - - -def get_tools_function_calling_payload(messages, task_model_id, content): - user_message = get_last_user_message(messages) - history = "\n".join( - f"{message['role'].upper()}: \"\"\"{message['content']}\"\"\"" - for message in messages[::-1][:4] - ) - - prompt = f"History:\n{history}\nQuery: {user_message}" - - return { - "model": task_model_id, - "messages": [ - {"role": "system", "content": content}, - {"role": "user", "content": f"Query: {prompt}"}, - ], - "stream": False, - "metadata": {"task": str(TASKS.FUNCTION_CALLING)}, - } - - -async def get_content_from_response(response) -> Optional[str]: - content = None - if hasattr(response, "body_iterator"): - async for chunk in response.body_iterator: - data = json.loads(chunk.decode("utf-8")) - content = data["choices"][0]["message"]["content"] - - # Cleanup any remaining background tasks if necessary - if response.background is not None: - await response.background() - else: - content = response["choices"][0]["message"]["content"] - return content - - -def get_task_model_id( - default_model_id: str, task_model: str, task_model_external: str, models -) -> str: - # Set the task model - task_model_id = default_model_id - # Check if the user has a custom task model and use that model - if models[task_model_id]["owned_by"] == "ollama": - if task_model and task_model in models: - task_model_id = task_model - else: - if task_model_external and task_model_external in models: - task_model_id = task_model_external - - return task_model_id - - -async def chat_completion_tools_handler( - body: dict, user: UserModel, models, extra_params: dict -) -> tuple[dict, dict]: - # If tool_ids field is present, call the functions - metadata = body.get("metadata", {}) - - tool_ids = metadata.get("tool_ids", None) - log.debug(f"{tool_ids=}") - if not tool_ids: - return body, {} - - skip_files = False - sources = [] - - task_model_id = get_task_model_id( - body["model"], - app.state.config.TASK_MODEL, - app.state.config.TASK_MODEL_EXTERNAL, - models, - ) - tools = get_tools( - webui_app, - tool_ids, - user, - { - **extra_params, - "__model__": models[task_model_id], - "__messages__": body["messages"], - "__files__": metadata.get("files", []), - }, - ) - log.info(f"{tools=}") - - specs = [tool["spec"] for tool in tools.values()] - tools_specs = json.dumps(specs) - - if app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE != "": - template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE - else: - template = """Available Tools: {{TOOLS}}\nReturn an empty string if no tools match the query. If a function tool matches, construct and return a JSON object in the format {\"name\": \"functionName\", \"parameters\": {\"requiredFunctionParamKey\": \"requiredFunctionParamValue\"}} using the appropriate tool and its parameters. Only return the object and limit the response to the JSON object without additional text.""" - - tools_function_calling_prompt = tools_function_calling_generation_template( - template, tools_specs - ) - log.info(f"{tools_function_calling_prompt=}") - payload = get_tools_function_calling_payload( - body["messages"], task_model_id, tools_function_calling_prompt - ) - - try: - payload = filter_pipeline(payload, user, models) - except Exception as e: - raise e - - try: - response = await generate_chat_completions(form_data=payload, user=user) - log.debug(f"{response=}") - content = await get_content_from_response(response) - log.debug(f"{content=}") - - if not content: - return body, {} - - try: - content = content[content.find("{") : content.rfind("}") + 1] - if not content: - raise Exception("No JSON object found in the response") - - result = json.loads(content) - - tool_function_name = result.get("name", None) - if tool_function_name not in tools: - return body, {} - - tool_function_params = result.get("parameters", {}) - - try: - required_params = ( - tools[tool_function_name] - .get("spec", {}) - .get("parameters", {}) - .get("required", []) - ) - tool_function = tools[tool_function_name]["callable"] - tool_function_params = { - k: v - for k, v in tool_function_params.items() - if k in required_params - } - tool_output = await tool_function(**tool_function_params) - - except Exception as e: - tool_output = str(e) - - print(tools[tool_function_name]["citation"]) - - if isinstance(tool_output, str): - if tools[tool_function_name]["citation"]: - sources.append( - { - "source": { - "name": f"TOOL:{tools[tool_function_name]['toolkit_id']}/{tool_function_name}" - }, - "document": [tool_output], - "metadata": [ - { - "source": f"TOOL:{tools[tool_function_name]['toolkit_id']}/{tool_function_name}" - } - ], - } - ) - else: - sources.append( - { - "source": {}, - "document": [tool_output], - "metadata": [ - { - "source": f"TOOL:{tools[tool_function_name]['toolkit_id']}/{tool_function_name}" - } - ], - } - ) - - if tools[tool_function_name]["file_handler"]: - skip_files = True - - except Exception as e: - log.exception(f"Error: {e}") - content = None - except Exception as e: - log.exception(f"Error: {e}") - content = None - - log.debug(f"tool_contexts: {sources}") - - if skip_files and "files" in body.get("metadata", {}): - del body["metadata"]["files"] - - return body, {"sources": sources} - - -async def chat_completion_files_handler( - body: dict, user: UserModel -) -> tuple[dict, dict[str, list]]: - sources = [] - - if files := body.get("metadata", {}).get("files", None): - try: - queries_response = await generate_queries( - { - "model": body["model"], - "messages": body["messages"], - "type": "retrieval", - }, - user, - ) - queries_response = queries_response["choices"][0]["message"]["content"] - - try: - bracket_start = queries_response.find("{") - bracket_end = queries_response.rfind("}") + 1 - - if bracket_start == -1 or bracket_end == -1: - raise Exception("No JSON object found in the response") - - queries_response = queries_response[bracket_start:bracket_end] - queries_response = json.loads(queries_response) - except Exception as e: - queries_response = {"queries": [queries_response]} - - queries = queries_response.get("queries", []) - except Exception as e: - queries = [] - - if len(queries) == 0: - queries = [get_last_user_message(body["messages"])] - - sources = get_sources_from_files( - files=files, - queries=queries, - embedding_function=retrieval_app.state.EMBEDDING_FUNCTION, - k=retrieval_app.state.config.TOP_K, - reranking_function=retrieval_app.state.sentence_transformer_rf, - r=retrieval_app.state.config.RELEVANCE_THRESHOLD, - hybrid_search=retrieval_app.state.config.ENABLE_RAG_HYBRID_SEARCH, - ) - - log.debug(f"rag_contexts:sources: {sources}") - return body, {"sources": sources} - - -def is_chat_completion_request(request): - return request.method == "POST" and any( - endpoint in request.url.path - for endpoint in ["/ollama/api/chat", "/chat/completions"] - ) - - -async def get_body_and_model_and_user(request, models): - # Read the original request body - body = await request.body() - body_str = body.decode("utf-8") - body = json.loads(body_str) if body_str else {} - - model_id = body["model"] - if model_id not in models: - raise Exception("Model not found") - model = models[model_id] - - user = get_current_user( - request, - get_http_authorization_cred(request.headers.get("Authorization")), - ) - - return body, model, user - - -class ChatCompletionMiddleware(BaseHTTPMiddleware): - async def dispatch(self, request: Request, call_next): - if not is_chat_completion_request(request): - return await call_next(request) - log.debug(f"request.url.path: {request.url.path}") - - model_list = await get_all_models() - models = {model["id"]: model for model in model_list} - - try: - body, model, user = await get_body_and_model_and_user(request, models) - except Exception as e: - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content={"detail": str(e)}, - ) - - model_info = Models.get_model_by_id(model["id"]) - if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL: - if model.get("arena"): - if not has_access( - user.id, - type="read", - access_control=model.get("info", {}) - .get("meta", {}) - .get("access_control", {}), - ): - raise HTTPException( - status_code=403, - detail="Model not found", - ) - else: - if not model_info: - return JSONResponse( - status_code=status.HTTP_404_NOT_FOUND, - content={"detail": "Model not found"}, - ) - elif not ( - user.id == model_info.user_id - or has_access( - user.id, type="read", access_control=model_info.access_control - ) - ): - return JSONResponse( - status_code=status.HTTP_403_FORBIDDEN, - content={"detail": "User does not have access to the model"}, - ) - - metadata = { - "chat_id": body.pop("chat_id", None), - "message_id": body.pop("id", None), - "session_id": body.pop("session_id", None), - "tool_ids": body.get("tool_ids", None), - "files": body.get("files", None), - } - body["metadata"] = metadata - - extra_params = { - "__event_emitter__": get_event_emitter(metadata), - "__event_call__": get_event_call(metadata), - "__user__": { - "id": user.id, - "email": user.email, - "name": user.name, - "role": user.role, - }, - "__metadata__": metadata, - } - - # Initialize data_items to store additional data to be sent to the client - # Initialize contexts and citation - data_items = [] - sources = [] - - try: - body, flags = await chat_completion_filter_functions_handler( - body, model, extra_params - ) - except Exception as e: - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content={"detail": str(e)}, - ) - - tool_ids = body.pop("tool_ids", None) - files = body.pop("files", None) - - metadata = { - **metadata, - "tool_ids": tool_ids, - "files": files, - } - body["metadata"] = metadata - - try: - body, flags = await chat_completion_tools_handler( - body, user, models, extra_params - ) - sources.extend(flags.get("sources", [])) - except Exception as e: - log.exception(e) - - try: - body, flags = await chat_completion_files_handler(body, user) - sources.extend(flags.get("sources", [])) - except Exception as e: - log.exception(e) - - # If context is not empty, insert it into the messages - if len(sources) > 0: - context_string = "" - for source_idx, source in enumerate(sources): - source_id = source.get("source", {}).get("name", "") - - if "document" in source: - for doc_idx, doc_context in enumerate(source["document"]): - metadata = source.get("metadata") - doc_source_id = None - - if metadata: - doc_source_id = metadata[doc_idx].get("source", source_id) - - if source_id: - context_string += f"{doc_source_id if doc_source_id is not None else source_id}{doc_context}\n" - else: - # If there is no source_id, then do not include the source_id tag - context_string += f"{doc_context}\n" - - context_string = context_string.strip() - prompt = get_last_user_message(body["messages"]) - - if prompt is None: - raise Exception("No user message found") - if ( - retrieval_app.state.config.RELEVANCE_THRESHOLD == 0 - and context_string.strip() == "" - ): - log.debug( - f"With a 0 relevancy threshold for RAG, the context cannot be empty" - ) - - # Workaround for Ollama 2.0+ system prompt issue - # TODO: replace with add_or_update_system_message - if model["owned_by"] == "ollama": - body["messages"] = prepend_to_first_user_message_content( - rag_template( - retrieval_app.state.config.RAG_TEMPLATE, context_string, prompt - ), - body["messages"], - ) - else: - body["messages"] = add_or_update_system_message( - rag_template( - retrieval_app.state.config.RAG_TEMPLATE, context_string, prompt - ), - body["messages"], - ) - - # If there are citations, add them to the data_items - sources = [ - source for source in sources if source.get("source", {}).get("name", "") - ] - if len(sources) > 0: - data_items.append({"sources": sources}) - - modified_body_bytes = json.dumps(body).encode("utf-8") - # Replace the request body with the modified one - request._body = modified_body_bytes - # Set custom header to ensure content-length matches new body length - request.headers.__dict__["_list"] = [ - (b"content-length", str(len(modified_body_bytes)).encode("utf-8")), - *[(k, v) for k, v in request.headers.raw if k.lower() != b"content-length"], - ] - - response = await call_next(request) - if not isinstance(response, StreamingResponse): - return response - - content_type = response.headers["Content-Type"] - is_openai = "text/event-stream" in content_type - is_ollama = "application/x-ndjson" in content_type - if not is_openai and not is_ollama: - return response - - def wrap_item(item): - return f"data: {item}\n\n" if is_openai else f"{item}\n" - - async def stream_wrapper(original_generator, data_items): - for item in data_items: - yield wrap_item(json.dumps(item)) - - async for data in original_generator: - yield data - - return StreamingResponse( - stream_wrapper(response.body_iterator, data_items), - headers=dict(response.headers), - ) - - async def _receive(self, body: bytes): - return {"type": "http.request", "body": body, "more_body": False} - - -app.add_middleware(ChatCompletionMiddleware) - - -################################## -# -# Pipeline Middleware -# -################################## - - -def get_sorted_filters(model_id, models): - filters = [ - model - for model in models.values() - if "pipeline" in model - and "type" in model["pipeline"] - and model["pipeline"]["type"] == "filter" - and ( - model["pipeline"]["pipelines"] == ["*"] - or any( - model_id == target_model_id - for target_model_id in model["pipeline"]["pipelines"] - ) - ) - ] - sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"]) - return sorted_filters - - -def filter_pipeline(payload, user, models): - user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role} - model_id = payload["model"] - - sorted_filters = get_sorted_filters(model_id, models) - model = models[model_id] - - if "pipeline" in model: - sorted_filters.append(model) - - for filter in sorted_filters: - r = None - try: - urlIdx = filter["urlIdx"] - - url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] - key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] - - if key == "": - continue - - headers = {"Authorization": f"Bearer {key}"} - r = requests.post( - f"{url}/{filter['id']}/filter/inlet", - headers=headers, - json={ - "user": user, - "body": payload, - }, - ) - - r.raise_for_status() - payload = r.json() - except Exception as e: - # Handle connection error here - print(f"Connection error: {e}") - - if r is not None: - res = r.json() - if "detail" in res: - raise Exception(r.status_code, res["detail"]) - - return payload - - -class PipelineMiddleware(BaseHTTPMiddleware): - async def dispatch(self, request: Request, call_next): - if not is_chat_completion_request(request): - return await call_next(request) - - log.debug(f"request.url.path: {request.url.path}") - - # Read the original request body - body = await request.body() - # Decode body to string - body_str = body.decode("utf-8") - # Parse string to JSON - data = json.loads(body_str) if body_str else {} - - try: - user = get_current_user( - request, - get_http_authorization_cred(request.headers["Authorization"]), - ) - except KeyError as e: - if len(e.args) > 1: - return JSONResponse( - status_code=e.args[0], - content={"detail": e.args[1]}, - ) - else: - return JSONResponse( - status_code=status.HTTP_401_UNAUTHORIZED, - content={"detail": "Not authenticated"}, - ) - except HTTPException as e: - return JSONResponse( - status_code=e.status_code, - content={"detail": e.detail}, - ) - - model_list = await get_all_models() - models = {model["id"]: model for model in model_list} - - try: - data = filter_pipeline(data, user, models) - except Exception as e: - if len(e.args) > 1: - return JSONResponse( - status_code=e.args[0], - content={"detail": e.args[1]}, - ) - else: - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content={"detail": str(e)}, - ) - - modified_body_bytes = json.dumps(data).encode("utf-8") - # Replace the request body with the modified one - request._body = modified_body_bytes - # Set custom header to ensure content-length matches new body length - request.headers.__dict__["_list"] = [ - (b"content-length", str(len(modified_body_bytes)).encode("utf-8")), - *[(k, v) for k, v in request.headers.raw if k.lower() != b"content-length"], - ] - - response = await call_next(request) - return response - - async def _receive(self, body: bytes): - return {"type": "http.request", "body": body, "more_body": False} - - -app.add_middleware(PipelineMiddleware) - - -from urllib.parse import urlencode, parse_qs, urlparse +app.state.MODELS = {} class RedirectMiddleware(BaseHTTPMiddleware): @@ -977,16 +656,6 @@ class RedirectMiddleware(BaseHTTPMiddleware): # Add the middleware to the app app.add_middleware(RedirectMiddleware) - - -app.add_middleware( - CORSMiddleware, - allow_origins=CORS_ALLOW_ORIGIN, - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) - app.add_middleware(SecurityHeadersMiddleware) @@ -1001,21 +670,13 @@ async def commit_session_after_request(request: Request, call_next): @app.middleware("http") async def check_url(request: Request, call_next): start_time = int(time.time()) - request.state.enable_api_key = webui_app.state.config.ENABLE_API_KEY + request.state.enable_api_key = app.state.config.ENABLE_API_KEY response = await call_next(request) process_time = int(time.time()) - start_time response.headers["X-Process-Time"] = str(process_time) return response -@app.middleware("http") -async def update_embedding_function(request: Request, call_next): - response = await call_next(request) - if "/embedding/update" in request.url.path: - webui_app.state.EMBEDDING_FUNCTION = retrieval_app.state.EMBEDDING_FUNCTION - return response - - @app.middleware("http") async def inspect_websocket(request: Request, call_next): if ( @@ -1034,198 +695,61 @@ async def inspect_websocket(request: Request, call_next): return await call_next(request) +app.add_middleware( + CORSMiddleware, + allow_origins=CORS_ALLOW_ORIGIN, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + + app.mount("/ws", socket_app) -app.mount("/ollama", ollama_app) -app.mount("/openai", openai_app) - -app.mount("/images/api/v1", images_app) -app.mount("/audio/api/v1", audio_app) -app.mount("/retrieval/api/v1", retrieval_app) - -app.mount("/api/v1", webui_app) - -webui_app.state.EMBEDDING_FUNCTION = retrieval_app.state.EMBEDDING_FUNCTION -async def get_all_base_models(): - open_webui_models = [] - openai_models = [] - ollama_models = [] - - if app.state.config.ENABLE_OPENAI_API: - openai_models = await get_openai_models() - openai_models = openai_models["data"] - - if app.state.config.ENABLE_OLLAMA_API: - ollama_models = await get_ollama_models() - ollama_models = [ - { - "id": model["model"], - "name": model["name"], - "object": "model", - "created": int(time.time()), - "owned_by": "ollama", - "ollama": model, - } - for model in ollama_models["models"] - ] - - open_webui_models = await get_open_webui_models() - - models = open_webui_models + openai_models + ollama_models - return models +app.include_router(ollama.router, prefix="/ollama", tags=["ollama"]) +app.include_router(openai.router, prefix="/openai", tags=["openai"]) -@cached(ttl=3) -async def get_all_models(): - models = await get_all_base_models() +app.include_router(pipelines.router, prefix="/api/v1/pipelines", tags=["pipelines"]) +app.include_router(tasks.router, prefix="/api/v1/tasks", tags=["tasks"]) +app.include_router(images.router, prefix="/api/v1/images", tags=["images"]) +app.include_router(audio.router, prefix="/api/v1/audio", tags=["audio"]) +app.include_router(retrieval.router, prefix="/api/v1/retrieval", tags=["retrieval"]) - # If there are no models, return an empty list - if len([model for model in models if not model.get("arena", False)]) == 0: - return [] +app.include_router(configs.router, prefix="/api/v1/configs", tags=["configs"]) - global_action_ids = [ - function.id for function in Functions.get_global_action_functions() - ] - enabled_action_ids = [ - function.id - for function in Functions.get_functions_by_type("action", active_only=True) - ] +app.include_router(auths.router, prefix="/api/v1/auths", tags=["auths"]) +app.include_router(users.router, prefix="/api/v1/users", tags=["users"]) - custom_models = Models.get_all_models() - for custom_model in custom_models: - if custom_model.base_model_id is None: - for model in models: - if ( - custom_model.id == model["id"] - or custom_model.id == model["id"].split(":")[0] - ): - if custom_model.is_active: - model["name"] = custom_model.name - model["info"] = custom_model.model_dump() +app.include_router(chats.router, prefix="/api/v1/chats", tags=["chats"]) - action_ids = [] - if "info" in model and "meta" in model["info"]: - action_ids.extend( - model["info"]["meta"].get("actionIds", []) - ) +app.include_router(models.router, prefix="/api/v1/models", tags=["models"]) +app.include_router(knowledge.router, prefix="/api/v1/knowledge", tags=["knowledge"]) +app.include_router(prompts.router, prefix="/api/v1/prompts", tags=["prompts"]) +app.include_router(tools.router, prefix="/api/v1/tools", tags=["tools"]) - model["action_ids"] = action_ids - else: - models.remove(model) +app.include_router(memories.router, prefix="/api/v1/memories", tags=["memories"]) +app.include_router(folders.router, prefix="/api/v1/folders", tags=["folders"]) +app.include_router(groups.router, prefix="/api/v1/groups", tags=["groups"]) +app.include_router(files.router, prefix="/api/v1/files", tags=["files"]) +app.include_router(functions.router, prefix="/api/v1/functions", tags=["functions"]) +app.include_router( + evaluations.router, prefix="/api/v1/evaluations", tags=["evaluations"] +) +app.include_router(utils.router, prefix="/api/v1/utils", tags=["utils"]) - elif custom_model.is_active and ( - custom_model.id not in [model["id"] for model in models] - ): - owned_by = "openai" - pipe = None - action_ids = [] - for model in models: - if ( - custom_model.base_model_id == model["id"] - or custom_model.base_model_id == model["id"].split(":")[0] - ): - owned_by = model["owned_by"] - if "pipe" in model: - pipe = model["pipe"] - break - - if custom_model.meta: - meta = custom_model.meta.model_dump() - if "actionIds" in meta: - action_ids.extend(meta["actionIds"]) - - models.append( - { - "id": f"{custom_model.id}", - "name": custom_model.name, - "object": "model", - "created": custom_model.created_at, - "owned_by": owned_by, - "info": custom_model.model_dump(), - "preset": True, - **({"pipe": pipe} if pipe is not None else {}), - "action_ids": action_ids, - } - ) - - # Process action_ids to get the actions - def get_action_items_from_module(function, module): - actions = [] - if hasattr(module, "actions"): - actions = module.actions - return [ - { - "id": f"{function.id}.{action['id']}", - "name": action.get("name", f"{function.name} ({action['id']})"), - "description": function.meta.description, - "icon_url": action.get( - "icon_url", function.meta.manifest.get("icon_url", None) - ), - } - for action in actions - ] - else: - return [ - { - "id": function.id, - "name": function.name, - "description": function.meta.description, - "icon_url": function.meta.manifest.get("icon_url", None), - } - ] - - def get_function_module_by_id(function_id): - if function_id in webui_app.state.FUNCTIONS: - function_module = webui_app.state.FUNCTIONS[function_id] - else: - function_module, _, _ = load_function_module_by_id(function_id) - webui_app.state.FUNCTIONS[function_id] = function_module - - for model in models: - action_ids = [ - action_id - for action_id in list(set(model.pop("action_ids", []) + global_action_ids)) - if action_id in enabled_action_ids - ] - - model["actions"] = [] - for action_id in action_ids: - action_function = Functions.get_function_by_id(action_id) - if action_function is None: - raise Exception(f"Action not found: {action_id}") - - function_module = get_function_module_by_id(action_id) - model["actions"].extend( - get_action_items_from_module(action_function, function_module) - ) - log.debug(f"get_all_models() returned {len(models)} models") - - return models +################################## +# +# Chat Endpoints +# +################################## @app.get("/api/models") -async def get_models(user=Depends(get_verified_user)): - models = await get_all_models() - - # Filter out filter pipelines - models = [ - model - for model in models - if "pipeline" not in model or model["pipeline"].get("type", None) != "filter" - ] - - model_order_list = webui_app.state.config.MODEL_ORDER_LIST - if model_order_list: - model_order_dict = {model_id: i for i, model_id in enumerate(model_order_list)} - # Sort models by order list priority, with fallback for those not in the list - models.sort( - key=lambda x: (model_order_dict.get(x["id"], float("inf")), x["name"]) - ) - - # Filter out models that the user does not have access to - if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL: +async def get_models(request: Request, user=Depends(get_verified_user)): + def get_filtered_models(models, user): filtered_models = [] for model in models: if model.get("arena"): @@ -1245,1319 +769,112 @@ async def get_models(user=Depends(get_verified_user)): user.id, type="read", access_control=model_info.access_control ): filtered_models.append(model) - models = filtered_models + + return filtered_models + + models = await get_all_models(request) + + # Filter out filter pipelines + models = [ + model + for model in models + if "pipeline" not in model or model["pipeline"].get("type", None) != "filter" + ] + + model_order_list = request.app.state.config.MODEL_ORDER_LIST + if model_order_list: + model_order_dict = {model_id: i for i, model_id in enumerate(model_order_list)} + # Sort models by order list priority, with fallback for those not in the list + models.sort( + key=lambda x: (model_order_dict.get(x["id"], float("inf")), x["name"]) + ) + + # Filter out models that the user does not have access to + if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL: + models = get_filtered_models(models, user) log.debug( f"/api/models returned filtered models accessible to the user: {json.dumps([model['id'] for model in models])}" ) - return {"data": models} @app.get("/api/models/base") -async def get_base_models(user=Depends(get_admin_user)): - models = await get_all_base_models() - - # Filter out arena models - models = [model for model in models if not model.get("arena", False)] +async def get_base_models(request: Request, user=Depends(get_admin_user)): + models = await get_all_base_models(request) return {"data": models} @app.post("/api/chat/completions") -async def generate_chat_completions( - form_data: dict, user=Depends(get_verified_user), bypass_filter: bool = False +async def chat_completion( + request: Request, + form_data: dict, + user=Depends(get_verified_user), + bypass_filter: bool = False, ): - model_list = await get_all_models() - models = {model["id"]: model for model in model_list} + if not request.app.state.MODELS: + await get_all_models(request) - model_id = form_data["model"] - if model_id not in models: + try: + model_id = form_data.get("model", None) + if model_id not in request.app.state.MODELS: + raise Exception("Model not found") + model = request.app.state.MODELS[model_id] + + # Check if user has access to the model + if not bypass_filter and user.role == "user": + try: + check_model_access(user, model) + except Exception as e: + raise e + + form_data, events = await process_chat_payload(request, form_data, user, model) + except Exception as e: raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Model not found", + status_code=status.HTTP_400_BAD_REQUEST, + detail=str(e), ) - model = models[model_id] - - # Check if user has access to the model - if not bypass_filter and user.role == "user": - if model.get("arena"): - if not has_access( - user.id, - type="read", - access_control=model.get("info", {}) - .get("meta", {}) - .get("access_control", {}), - ): - raise HTTPException( - status_code=403, - detail="Model not found", - ) - else: - model_info = Models.get_model_by_id(model_id) - if not model_info: - raise HTTPException( - status_code=404, - detail="Model not found", - ) - elif not ( - user.id == model_info.user_id - or has_access( - user.id, type="read", access_control=model_info.access_control - ) - ): - raise HTTPException( - status_code=403, - detail="Model not found", - ) - - if model["owned_by"] == "arena": - model_ids = model.get("info", {}).get("meta", {}).get("model_ids") - filter_mode = model.get("info", {}).get("meta", {}).get("filter_mode") - if model_ids and filter_mode == "exclude": - model_ids = [ - model["id"] - for model in await get_all_models() - if model.get("owned_by") != "arena" and model["id"] not in model_ids - ] - - selected_model_id = None - if isinstance(model_ids, list) and model_ids: - selected_model_id = random.choice(model_ids) - else: - model_ids = [ - model["id"] - for model in await get_all_models() - if model.get("owned_by") != "arena" - ] - selected_model_id = random.choice(model_ids) - - form_data["model"] = selected_model_id - - if form_data.get("stream") == True: - - async def stream_wrapper(stream): - yield f"data: {json.dumps({'selected_model_id': selected_model_id})}\n\n" - async for chunk in stream: - yield chunk - - response = await generate_chat_completions( - form_data, user, bypass_filter=True - ) - return StreamingResponse( - stream_wrapper(response.body_iterator), media_type="text/event-stream" - ) - else: - return { - **( - await generate_chat_completions(form_data, user, bypass_filter=True) - ), - "selected_model_id": selected_model_id, - } - - if model.get("pipe"): - # Below does not require bypass_filter because this is the only route the uses this function and it is already bypassing the filter - return await generate_function_chat_completion( - form_data, user=user, models=models + try: + response = await chat_completion_handler( + request, form_data, user, bypass_filter ) - if model["owned_by"] == "ollama": - # Using /ollama/api/chat endpoint - form_data = convert_payload_openai_to_ollama(form_data) - form_data = GenerateChatCompletionForm(**form_data) - response = await generate_ollama_chat_completion( - form_data=form_data, user=user, bypass_filter=bypass_filter - ) - if form_data.stream: - response.headers["content-type"] = "text/event-stream" - return StreamingResponse( - convert_streaming_response_ollama_to_openai(response), - headers=dict(response.headers), - ) - else: - return convert_response_ollama_to_openai(response) - else: - return await generate_openai_chat_completion( - form_data, user=user, bypass_filter=bypass_filter + return await process_chat_response(response, events) + except Exception as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=str(e), ) +# Alias for chat_completion (Legacy) +generate_chat_completions = chat_completion +generate_chat_completion = chat_completion + + @app.post("/api/chat/completed") -async def chat_completed(form_data: dict, user=Depends(get_verified_user)): - model_list = await get_all_models() - models = {model["id"]: model for model in model_list} - - data = form_data - model_id = data["model"] - if model_id not in models: +async def chat_completed( + request: Request, form_data: dict, user=Depends(get_verified_user) +): + try: + return await chat_completed_handler(request, form_data, user) + except Exception as e: raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Model not found", + status_code=status.HTTP_400_BAD_REQUEST, + detail=str(e), ) - model = models[model_id] - sorted_filters = get_sorted_filters(model_id, models) - if "pipeline" in model: - sorted_filters = [model] + sorted_filters - - for filter in sorted_filters: - r = None - try: - urlIdx = filter["urlIdx"] - - url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] - key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] - - if key != "": - headers = {"Authorization": f"Bearer {key}"} - r = requests.post( - f"{url}/{filter['id']}/filter/outlet", - headers=headers, - json={ - "user": { - "id": user.id, - "name": user.name, - "email": user.email, - "role": user.role, - }, - "body": data, - }, - ) - - r.raise_for_status() - data = r.json() - except Exception as e: - # Handle connection error here - print(f"Connection error: {e}") - - if r is not None: - try: - res = r.json() - if "detail" in res: - return JSONResponse( - status_code=r.status_code, - content=res, - ) - except Exception: - pass - - else: - pass - - __event_emitter__ = get_event_emitter( - { - "chat_id": data["chat_id"], - "message_id": data["id"], - "session_id": data["session_id"], - } - ) - - __event_call__ = get_event_call( - { - "chat_id": data["chat_id"], - "message_id": data["id"], - "session_id": data["session_id"], - } - ) - - def get_priority(function_id): - function = Functions.get_function_by_id(function_id) - if function is not None and hasattr(function, "valves"): - # TODO: Fix FunctionModel to include vavles - return (function.valves if function.valves else {}).get("priority", 0) - return 0 - - filter_ids = [function.id for function in Functions.get_global_filter_functions()] - if "info" in model and "meta" in model["info"]: - filter_ids.extend(model["info"]["meta"].get("filterIds", [])) - filter_ids = list(set(filter_ids)) - - enabled_filter_ids = [ - function.id - for function in Functions.get_functions_by_type("filter", active_only=True) - ] - filter_ids = [ - filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids - ] - - # Sort filter_ids by priority, using the get_priority function - filter_ids.sort(key=get_priority) - - for filter_id in filter_ids: - filter = Functions.get_function_by_id(filter_id) - if not filter: - continue - - if filter_id in webui_app.state.FUNCTIONS: - function_module = webui_app.state.FUNCTIONS[filter_id] - else: - function_module, _, _ = load_function_module_by_id(filter_id) - webui_app.state.FUNCTIONS[filter_id] = function_module - - if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): - valves = Functions.get_function_valves_by_id(filter_id) - function_module.valves = function_module.Valves( - **(valves if valves else {}) - ) - - if not hasattr(function_module, "outlet"): - continue - try: - outlet = function_module.outlet - - # Get the signature of the function - sig = inspect.signature(outlet) - params = {"body": data} - - # Extra parameters to be passed to the function - extra_params = { - "__model__": model, - "__id__": filter_id, - "__event_emitter__": __event_emitter__, - "__event_call__": __event_call__, - } - - # Add extra params in contained in function signature - for key, value in extra_params.items(): - if key in sig.parameters: - params[key] = value - - if "__user__" in sig.parameters: - __user__ = { - "id": user.id, - "email": user.email, - "name": user.name, - "role": user.role, - } - - try: - if hasattr(function_module, "UserValves"): - __user__["valves"] = function_module.UserValves( - **Functions.get_user_valves_by_id_and_user_id( - filter_id, user.id - ) - ) - except Exception as e: - print(e) - - params = {**params, "__user__": __user__} - - if inspect.iscoroutinefunction(outlet): - data = await outlet(**params) - else: - data = outlet(**params) - - except Exception as e: - print(f"Error: {e}") - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content={"detail": str(e)}, - ) - - return data - @app.post("/api/chat/actions/{action_id}") -async def chat_action(action_id: str, form_data: dict, user=Depends(get_verified_user)): - if "." in action_id: - action_id, sub_action_id = action_id.split(".") - else: - sub_action_id = None - - action = Functions.get_function_by_id(action_id) - if not action: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Action not found", - ) - - model_list = await get_all_models() - models = {model["id"]: model for model in model_list} - - data = form_data - model_id = data["model"] - - if model_id not in models: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Model not found", - ) - model = models[model_id] - - __event_emitter__ = get_event_emitter( - { - "chat_id": data["chat_id"], - "message_id": data["id"], - "session_id": data["session_id"], - } - ) - __event_call__ = get_event_call( - { - "chat_id": data["chat_id"], - "message_id": data["id"], - "session_id": data["session_id"], - } - ) - - if action_id in webui_app.state.FUNCTIONS: - function_module = webui_app.state.FUNCTIONS[action_id] - else: - function_module, _, _ = load_function_module_by_id(action_id) - webui_app.state.FUNCTIONS[action_id] = function_module - - if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): - valves = Functions.get_function_valves_by_id(action_id) - function_module.valves = function_module.Valves(**(valves if valves else {})) - - if hasattr(function_module, "action"): - try: - action = function_module.action - - # Get the signature of the function - sig = inspect.signature(action) - params = {"body": data} - - # Extra parameters to be passed to the function - extra_params = { - "__model__": model, - "__id__": sub_action_id if sub_action_id is not None else action_id, - "__event_emitter__": __event_emitter__, - "__event_call__": __event_call__, - } - - # Add extra params in contained in function signature - for key, value in extra_params.items(): - if key in sig.parameters: - params[key] = value - - if "__user__" in sig.parameters: - __user__ = { - "id": user.id, - "email": user.email, - "name": user.name, - "role": user.role, - } - - try: - if hasattr(function_module, "UserValves"): - __user__["valves"] = function_module.UserValves( - **Functions.get_user_valves_by_id_and_user_id( - action_id, user.id - ) - ) - except Exception as e: - print(e) - - params = {**params, "__user__": __user__} - - if inspect.iscoroutinefunction(action): - data = await action(**params) - else: - data = action(**params) - - except Exception as e: - print(f"Error: {e}") - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content={"detail": str(e)}, - ) - - return data - - -################################## -# -# Task Endpoints -# -################################## - - -# TODO: Refactor task API endpoints below into a separate file - - -@app.get("/api/task/config") -async def get_task_config(user=Depends(get_verified_user)): - return { - "TASK_MODEL": app.state.config.TASK_MODEL, - "TASK_MODEL_EXTERNAL": app.state.config.TASK_MODEL_EXTERNAL, - "TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE, - "ENABLE_AUTOCOMPLETE_GENERATION": app.state.config.ENABLE_AUTOCOMPLETE_GENERATION, - "AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH, - "TAGS_GENERATION_PROMPT_TEMPLATE": app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE, - "ENABLE_TAGS_GENERATION": app.state.config.ENABLE_TAGS_GENERATION, - "ENABLE_SEARCH_QUERY_GENERATION": app.state.config.ENABLE_SEARCH_QUERY_GENERATION, - "ENABLE_RETRIEVAL_QUERY_GENERATION": app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION, - "QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE, - "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, - } - - -class TaskConfigForm(BaseModel): - TASK_MODEL: Optional[str] - TASK_MODEL_EXTERNAL: Optional[str] - TITLE_GENERATION_PROMPT_TEMPLATE: str - ENABLE_AUTOCOMPLETE_GENERATION: bool - AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH: int - TAGS_GENERATION_PROMPT_TEMPLATE: str - ENABLE_TAGS_GENERATION: bool - ENABLE_SEARCH_QUERY_GENERATION: bool - ENABLE_RETRIEVAL_QUERY_GENERATION: bool - QUERY_GENERATION_PROMPT_TEMPLATE: str - TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE: str - - -@app.post("/api/task/config/update") -async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_user)): - app.state.config.TASK_MODEL = form_data.TASK_MODEL - app.state.config.TASK_MODEL_EXTERNAL = form_data.TASK_MODEL_EXTERNAL - app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = ( - form_data.TITLE_GENERATION_PROMPT_TEMPLATE - ) - - app.state.config.ENABLE_AUTOCOMPLETE_GENERATION = ( - form_data.ENABLE_AUTOCOMPLETE_GENERATION - ) - app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH = ( - form_data.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH - ) - - app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = ( - form_data.TAGS_GENERATION_PROMPT_TEMPLATE - ) - app.state.config.ENABLE_TAGS_GENERATION = form_data.ENABLE_TAGS_GENERATION - app.state.config.ENABLE_SEARCH_QUERY_GENERATION = ( - form_data.ENABLE_SEARCH_QUERY_GENERATION - ) - app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION = ( - form_data.ENABLE_RETRIEVAL_QUERY_GENERATION - ) - - app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE = ( - form_data.QUERY_GENERATION_PROMPT_TEMPLATE - ) - app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = ( - form_data.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE - ) - - return { - "TASK_MODEL": app.state.config.TASK_MODEL, - "TASK_MODEL_EXTERNAL": app.state.config.TASK_MODEL_EXTERNAL, - "TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE, - "ENABLE_AUTOCOMPLETE_GENERATION": app.state.config.ENABLE_AUTOCOMPLETE_GENERATION, - "AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH, - "TAGS_GENERATION_PROMPT_TEMPLATE": app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE, - "ENABLE_TAGS_GENERATION": app.state.config.ENABLE_TAGS_GENERATION, - "ENABLE_SEARCH_QUERY_GENERATION": app.state.config.ENABLE_SEARCH_QUERY_GENERATION, - "ENABLE_RETRIEVAL_QUERY_GENERATION": app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION, - "QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE, - "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, - } - - -@app.post("/api/task/title/completions") -async def generate_title(form_data: dict, user=Depends(get_verified_user)): - - model_list = await get_all_models() - models = {model["id"]: model for model in model_list} - - model_id = form_data["model"] - if model_id not in models: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Model not found", - ) - - # Check if the user has a custom task model - # If the user has a custom task model, use that model - task_model_id = get_task_model_id( - model_id, - app.state.config.TASK_MODEL, - app.state.config.TASK_MODEL_EXTERNAL, - models, - ) - - log.debug( - f"generating chat title using model {task_model_id} for user {user.email} " - ) - - if app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE != "": - template = app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE - else: - template = """Create a concise, 3-5 word title with an emoji as a title for the chat history, in the given language. Suitable Emojis for the summary can be used to enhance understanding but avoid quotation marks or special formatting. RESPOND ONLY WITH THE TITLE TEXT. - -Examples of titles: -📉 Stock Market Trends -🍪 Perfect Chocolate Chip Recipe -Evolution of Music Streaming -Remote Work Productivity Tips -Artificial Intelligence in Healthcare -🎮 Video Game Development Insights - - -{{MESSAGES:END:2}} -""" - - content = title_generation_template( - template, - form_data["messages"], - { - "name": user.name, - "location": user.info.get("location") if user.info else None, - }, - ) - - payload = { - "model": task_model_id, - "messages": [{"role": "user", "content": content}], - "stream": False, - **( - {"max_tokens": 50} - if models[task_model_id]["owned_by"] == "ollama" - else { - "max_completion_tokens": 50, - } - ), - "metadata": { - "task": str(TASKS.TITLE_GENERATION), - "task_body": form_data, - "chat_id": form_data.get("chat_id", None), - }, - } - - # Handle pipeline filters +async def chat_action( + request: Request, action_id: str, form_data: dict, user=Depends(get_verified_user) +): try: - payload = filter_pipeline(payload, user, models) + return await chat_action_handler(request, action_id, form_data, user) except Exception as e: - if len(e.args) > 1: - return JSONResponse( - status_code=e.args[0], - content={"detail": e.args[1]}, - ) - else: - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content={"detail": str(e)}, - ) - if "chat_id" in payload: - del payload["chat_id"] - - return await generate_chat_completions(form_data=payload, user=user) - - -@app.post("/api/task/tags/completions") -async def generate_chat_tags(form_data: dict, user=Depends(get_verified_user)): - - if not app.state.config.ENABLE_TAGS_GENERATION: - return JSONResponse( - status_code=status.HTTP_200_OK, - content={"detail": "Tags generation is disabled"}, - ) - - model_list = await get_all_models() - models = {model["id"]: model for model in model_list} - - model_id = form_data["model"] - if model_id not in models: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Model not found", - ) - - # Check if the user has a custom task model - # If the user has a custom task model, use that model - task_model_id = get_task_model_id( - model_id, - app.state.config.TASK_MODEL, - app.state.config.TASK_MODEL_EXTERNAL, - models, - ) - - log.debug( - f"generating chat tags using model {task_model_id} for user {user.email} " - ) - - if app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE != "": - template = app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE - else: - template = """### Task: -Generate 1-3 broad tags categorizing the main themes of the chat history, along with 1-3 more specific subtopic tags. - -### Guidelines: -- Start with high-level domains (e.g. Science, Technology, Philosophy, Arts, Politics, Business, Health, Sports, Entertainment, Education) -- Consider including relevant subfields/subdomains if they are strongly represented throughout the conversation -- If content is too short (less than 3 messages) or too diverse, use only ["General"] -- Use the chat's primary language; default to English if multilingual -- Prioritize accuracy over specificity - -### Output: -JSON format: { "tags": ["tag1", "tag2", "tag3"] } - -### Chat History: - -{{MESSAGES:END:6}} -""" - - content = tags_generation_template( - template, form_data["messages"], {"name": user.name} - ) - - payload = { - "model": task_model_id, - "messages": [{"role": "user", "content": content}], - "stream": False, - "metadata": { - "task": str(TASKS.TAGS_GENERATION), - "task_body": form_data, - "chat_id": form_data.get("chat_id", None), - }, - } - - # Handle pipeline filters - try: - payload = filter_pipeline(payload, user, models) - except Exception as e: - if len(e.args) > 1: - return JSONResponse( - status_code=e.args[0], - content={"detail": e.args[1]}, - ) - else: - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content={"detail": str(e)}, - ) - if "chat_id" in payload: - del payload["chat_id"] - - return await generate_chat_completions(form_data=payload, user=user) - - -@app.post("/api/task/queries/completions") -async def generate_queries(form_data: dict, user=Depends(get_verified_user)): - - type = form_data.get("type") - if type == "web_search": - if not app.state.config.ENABLE_SEARCH_QUERY_GENERATION: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Search query generation is disabled", - ) - elif type == "retrieval": - if not app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Query generation is disabled", - ) - - model_list = await get_all_models() - models = {model["id"]: model for model in model_list} - - model_id = form_data["model"] - if model_id not in models: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Model not found", - ) - - # Check if the user has a custom task model - # If the user has a custom task model, use that model - task_model_id = get_task_model_id( - model_id, - app.state.config.TASK_MODEL, - app.state.config.TASK_MODEL_EXTERNAL, - models, - ) - - log.debug( - f"generating {type} queries using model {task_model_id} for user {user.email}" - ) - - if (app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE).strip() != "": - template = app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE - else: - template = DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE - - content = query_generation_template( - template, form_data["messages"], {"name": user.name} - ) - - payload = { - "model": task_model_id, - "messages": [{"role": "user", "content": content}], - "stream": False, - "metadata": { - "task": str(TASKS.QUERY_GENERATION), - "task_body": form_data, - "chat_id": form_data.get("chat_id", None), - }, - } - - # Handle pipeline filters - try: - payload = filter_pipeline(payload, user, models) - except Exception as e: - if len(e.args) > 1: - return JSONResponse( - status_code=e.args[0], - content={"detail": e.args[1]}, - ) - else: - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content={"detail": str(e)}, - ) - if "chat_id" in payload: - del payload["chat_id"] - - return await generate_chat_completions(form_data=payload, user=user) - - -@app.post("/api/task/auto/completions") -async def generate_autocompletion(form_data: dict, user=Depends(get_verified_user)): - if not app.state.config.ENABLE_AUTOCOMPLETE_GENERATION: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Autocompletion generation is disabled", - ) - - type = form_data.get("type") - prompt = form_data.get("prompt") - messages = form_data.get("messages") - - if app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH > 0: - if len(prompt) > app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Input prompt exceeds maximum length of {app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH}", - ) - - model_list = await get_all_models() - models = {model["id"]: model for model in model_list} - - model_id = form_data["model"] - if model_id not in models: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Model not found", - ) - - # Check if the user has a custom task model - # If the user has a custom task model, use that model - task_model_id = get_task_model_id( - model_id, - app.state.config.TASK_MODEL, - app.state.config.TASK_MODEL_EXTERNAL, - models, - ) - - log.debug( - f"generating autocompletion using model {task_model_id} for user {user.email}" - ) - - if (app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE).strip() != "": - template = app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE - else: - template = DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE - - content = autocomplete_generation_template( - template, prompt, messages, type, {"name": user.name} - ) - - payload = { - "model": task_model_id, - "messages": [{"role": "user", "content": content}], - "stream": False, - "metadata": { - "task": str(TASKS.AUTOCOMPLETE_GENERATION), - "task_body": form_data, - "chat_id": form_data.get("chat_id", None), - }, - } - - print(payload) - - # Handle pipeline filters - try: - payload = filter_pipeline(payload, user, models) - except Exception as e: - if len(e.args) > 1: - return JSONResponse( - status_code=e.args[0], - content={"detail": e.args[1]}, - ) - else: - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content={"detail": str(e)}, - ) - if "chat_id" in payload: - del payload["chat_id"] - - return await generate_chat_completions(form_data=payload, user=user) - - -@app.post("/api/task/emoji/completions") -async def generate_emoji(form_data: dict, user=Depends(get_verified_user)): - - model_list = await get_all_models() - models = {model["id"]: model for model in model_list} - - model_id = form_data["model"] - if model_id not in models: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Model not found", - ) - - # Check if the user has a custom task model - # If the user has a custom task model, use that model - task_model_id = get_task_model_id( - model_id, - app.state.config.TASK_MODEL, - app.state.config.TASK_MODEL_EXTERNAL, - models, - ) - - log.debug(f"generating emoji using model {task_model_id} for user {user.email} ") - - template = ''' -Your task is to reflect the speaker's likely facial expression through a fitting emoji. Interpret emotions from the message and reflect their facial expression using fitting, diverse emojis (e.g., 😊, 😢, 😡, 😱). - -Message: """{{prompt}}""" -''' - content = emoji_generation_template( - template, - form_data["prompt"], - { - "name": user.name, - "location": user.info.get("location") if user.info else None, - }, - ) - - payload = { - "model": task_model_id, - "messages": [{"role": "user", "content": content}], - "stream": False, - **( - {"max_tokens": 4} - if models[task_model_id]["owned_by"] == "ollama" - else { - "max_completion_tokens": 4, - } - ), - "chat_id": form_data.get("chat_id", None), - "metadata": {"task": str(TASKS.EMOJI_GENERATION), "task_body": form_data}, - } - - # Handle pipeline filters - try: - payload = filter_pipeline(payload, user, models) - except Exception as e: - if len(e.args) > 1: - return JSONResponse( - status_code=e.args[0], - content={"detail": e.args[1]}, - ) - else: - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content={"detail": str(e)}, - ) - if "chat_id" in payload: - del payload["chat_id"] - - return await generate_chat_completions(form_data=payload, user=user) - - -@app.post("/api/task/moa/completions") -async def generate_moa_response(form_data: dict, user=Depends(get_verified_user)): - - model_list = await get_all_models() - models = {model["id"]: model for model in model_list} - - model_id = form_data["model"] - if model_id not in models: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Model not found", - ) - - # Check if the user has a custom task model - # If the user has a custom task model, use that model - task_model_id = get_task_model_id( - model_id, - app.state.config.TASK_MODEL, - app.state.config.TASK_MODEL_EXTERNAL, - models, - ) - - log.debug(f"generating MOA model {task_model_id} for user {user.email} ") - - template = """You have been provided with a set of responses from various models to the latest user query: "{{prompt}}" - -Your task is to synthesize these responses into a single, high-quality response. It is crucial to critically evaluate the information provided in these responses, recognizing that some of it may be biased or incorrect. Your response should not simply replicate the given answers but should offer a refined, accurate, and comprehensive reply to the instruction. Ensure your response is well-structured, coherent, and adheres to the highest standards of accuracy and reliability. - -Responses from models: {{responses}}""" - - content = moa_response_generation_template( - template, - form_data["prompt"], - form_data["responses"], - ) - - payload = { - "model": task_model_id, - "messages": [{"role": "user", "content": content}], - "stream": form_data.get("stream", False), - "chat_id": form_data.get("chat_id", None), - "metadata": { - "task": str(TASKS.MOA_RESPONSE_GENERATION), - "task_body": form_data, - }, - } - - try: - payload = filter_pipeline(payload, user, models) - except Exception as e: - if len(e.args) > 1: - return JSONResponse( - status_code=e.args[0], - content={"detail": e.args[1]}, - ) - else: - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content={"detail": str(e)}, - ) - if "chat_id" in payload: - del payload["chat_id"] - - return await generate_chat_completions(form_data=payload, user=user) - - -################################## -# -# Pipelines Endpoints -# -################################## - - -# TODO: Refactor pipelines API endpoints below into a separate file - - -@app.get("/api/pipelines/list") -async def get_pipelines_list(user=Depends(get_admin_user)): - responses = await get_openai_models_responses() - - log.debug(f"get_pipelines_list: get_openai_models_responses returned {responses}") - urlIdxs = [ - idx - for idx, response in enumerate(responses) - if response is not None and "pipelines" in response - ] - - return { - "data": [ - { - "url": openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx], - "idx": urlIdx, - } - for urlIdx in urlIdxs - ] - } - - -@app.post("/api/pipelines/upload") -async def upload_pipeline( - urlIdx: int = Form(...), file: UploadFile = File(...), user=Depends(get_admin_user) -): - print("upload_pipeline", urlIdx, file.filename) - # Check if the uploaded file is a python file - if not (file.filename and file.filename.endswith(".py")): - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Only Python (.py) files are allowed.", - ) - - upload_folder = f"{CACHE_DIR}/pipelines" - os.makedirs(upload_folder, exist_ok=True) - file_path = os.path.join(upload_folder, file.filename) - - r = None - try: - # Save the uploaded file - with open(file_path, "wb") as buffer: - shutil.copyfileobj(file.file, buffer) - - url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] - key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] - - headers = {"Authorization": f"Bearer {key}"} - - with open(file_path, "rb") as f: - files = {"file": f} - r = requests.post(f"{url}/pipelines/upload", headers=headers, files=files) - - r.raise_for_status() - data = r.json() - - return {**data} - except Exception as e: - # Handle connection error here - print(f"Connection error: {e}") - - detail = "Pipeline not found" - status_code = status.HTTP_404_NOT_FOUND - if r is not None: - status_code = r.status_code - try: - res = r.json() - if "detail" in res: - detail = res["detail"] - except Exception: - pass - - raise HTTPException( - status_code=status_code, - detail=detail, - ) - finally: - # Ensure the file is deleted after the upload is completed or on failure - if os.path.exists(file_path): - os.remove(file_path) - - -class AddPipelineForm(BaseModel): - url: str - urlIdx: int - - -@app.post("/api/pipelines/add") -async def add_pipeline(form_data: AddPipelineForm, user=Depends(get_admin_user)): - r = None - try: - urlIdx = form_data.urlIdx - - url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] - key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] - - headers = {"Authorization": f"Bearer {key}"} - r = requests.post( - f"{url}/pipelines/add", headers=headers, json={"url": form_data.url} - ) - - r.raise_for_status() - data = r.json() - - return {**data} - except Exception as e: - # Handle connection error here - print(f"Connection error: {e}") - - detail = "Pipeline not found" - if r is not None: - try: - res = r.json() - if "detail" in res: - detail = res["detail"] - except Exception: - pass - - raise HTTPException( - status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND), - detail=detail, - ) - - -class DeletePipelineForm(BaseModel): - id: str - urlIdx: int - - -@app.delete("/api/pipelines/delete") -async def delete_pipeline(form_data: DeletePipelineForm, user=Depends(get_admin_user)): - r = None - try: - urlIdx = form_data.urlIdx - - url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] - key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] - - headers = {"Authorization": f"Bearer {key}"} - r = requests.delete( - f"{url}/pipelines/delete", headers=headers, json={"id": form_data.id} - ) - - r.raise_for_status() - data = r.json() - - return {**data} - except Exception as e: - # Handle connection error here - print(f"Connection error: {e}") - - detail = "Pipeline not found" - if r is not None: - try: - res = r.json() - if "detail" in res: - detail = res["detail"] - except Exception: - pass - - raise HTTPException( - status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND), - detail=detail, - ) - - -@app.get("/api/pipelines") -async def get_pipelines(urlIdx: Optional[int] = None, user=Depends(get_admin_user)): - r = None - try: - url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] - key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] - - headers = {"Authorization": f"Bearer {key}"} - r = requests.get(f"{url}/pipelines", headers=headers) - - r.raise_for_status() - data = r.json() - - return {**data} - except Exception as e: - # Handle connection error here - print(f"Connection error: {e}") - - detail = "Pipeline not found" - if r is not None: - try: - res = r.json() - if "detail" in res: - detail = res["detail"] - except Exception: - pass - - raise HTTPException( - status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND), - detail=detail, - ) - - -@app.get("/api/pipelines/{pipeline_id}/valves") -async def get_pipeline_valves( - urlIdx: Optional[int], - pipeline_id: str, - user=Depends(get_admin_user), -): - r = None - try: - url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] - key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] - - headers = {"Authorization": f"Bearer {key}"} - r = requests.get(f"{url}/{pipeline_id}/valves", headers=headers) - - r.raise_for_status() - data = r.json() - - return {**data} - except Exception as e: - # Handle connection error here - print(f"Connection error: {e}") - - detail = "Pipeline not found" - - if r is not None: - try: - res = r.json() - if "detail" in res: - detail = res["detail"] - except Exception: - pass - - raise HTTPException( - status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND), - detail=detail, - ) - - -@app.get("/api/pipelines/{pipeline_id}/valves/spec") -async def get_pipeline_valves_spec( - urlIdx: Optional[int], - pipeline_id: str, - user=Depends(get_admin_user), -): - r = None - try: - url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] - key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] - - headers = {"Authorization": f"Bearer {key}"} - r = requests.get(f"{url}/{pipeline_id}/valves/spec", headers=headers) - - r.raise_for_status() - data = r.json() - - return {**data} - except Exception as e: - # Handle connection error here - print(f"Connection error: {e}") - - detail = "Pipeline not found" - if r is not None: - try: - res = r.json() - if "detail" in res: - detail = res["detail"] - except Exception: - pass - - raise HTTPException( - status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND), - detail=detail, - ) - - -@app.post("/api/pipelines/{pipeline_id}/valves/update") -async def update_pipeline_valves( - urlIdx: Optional[int], - pipeline_id: str, - form_data: dict, - user=Depends(get_admin_user), -): - r = None - try: - url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] - key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] - - headers = {"Authorization": f"Bearer {key}"} - r = requests.post( - f"{url}/{pipeline_id}/valves/update", - headers=headers, - json={**form_data}, - ) - - r.raise_for_status() - data = r.json() - - return {**data} - except Exception as e: - # Handle connection error here - print(f"Connection error: {e}") - - detail = "Pipeline not found" - - if r is not None: - try: - res = r.json() - if "detail" in res: - detail = res["detail"] - except Exception: - pass - - raise HTTPException( - status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND), - detail=detail, + detail=str(e), ) @@ -2603,17 +920,17 @@ async def get_app_config(request: Request): }, "features": { "auth": WEBUI_AUTH, - "auth_trusted_header": bool(webui_app.state.AUTH_TRUSTED_EMAIL_HEADER), - "enable_ldap": webui_app.state.config.ENABLE_LDAP, - "enable_api_key": webui_app.state.config.ENABLE_API_KEY, - "enable_signup": webui_app.state.config.ENABLE_SIGNUP, - "enable_login_form": webui_app.state.config.ENABLE_LOGIN_FORM, + "auth_trusted_header": bool(app.state.AUTH_TRUSTED_EMAIL_HEADER), + "enable_ldap": app.state.config.ENABLE_LDAP, + "enable_api_key": app.state.config.ENABLE_API_KEY, + "enable_signup": app.state.config.ENABLE_SIGNUP, + "enable_login_form": app.state.config.ENABLE_LOGIN_FORM, **( { - "enable_web_search": retrieval_app.state.config.ENABLE_RAG_WEB_SEARCH, - "enable_image_generation": images_app.state.config.ENABLED, - "enable_community_sharing": webui_app.state.config.ENABLE_COMMUNITY_SHARING, - "enable_message_rating": webui_app.state.config.ENABLE_MESSAGE_RATING, + "enable_web_search": app.state.config.ENABLE_RAG_WEB_SEARCH, + "enable_image_generation": app.state.config.ENABLE_IMAGE_GENERATION, + "enable_community_sharing": app.state.config.ENABLE_COMMUNITY_SHARING, + "enable_message_rating": app.state.config.ENABLE_MESSAGE_RATING, "enable_admin_export": ENABLE_ADMIN_EXPORT, "enable_admin_chat_access": ENABLE_ADMIN_CHAT_ACCESS, } @@ -2623,23 +940,23 @@ async def get_app_config(request: Request): }, **( { - "default_models": webui_app.state.config.DEFAULT_MODELS, - "default_prompt_suggestions": webui_app.state.config.DEFAULT_PROMPT_SUGGESTIONS, + "default_models": app.state.config.DEFAULT_MODELS, + "default_prompt_suggestions": app.state.config.DEFAULT_PROMPT_SUGGESTIONS, "audio": { "tts": { - "engine": audio_app.state.config.TTS_ENGINE, - "voice": audio_app.state.config.TTS_VOICE, - "split_on": audio_app.state.config.TTS_SPLIT_ON, + "engine": app.state.config.TTS_ENGINE, + "voice": app.state.config.TTS_VOICE, + "split_on": app.state.config.TTS_SPLIT_ON, }, "stt": { - "engine": audio_app.state.config.STT_ENGINE, + "engine": app.state.config.STT_ENGINE, }, }, "file": { - "max_size": retrieval_app.state.config.FILE_MAX_SIZE, - "max_count": retrieval_app.state.config.FILE_MAX_COUNT, + "max_size": app.state.config.FILE_MAX_SIZE, + "max_count": app.state.config.FILE_MAX_COUNT, }, - "permissions": {**webui_app.state.config.USER_PERMISSIONS}, + "permissions": {**app.state.config.USER_PERMISSIONS}, } if user is not None else {} @@ -2647,7 +964,8 @@ async def get_app_config(request: Request): } -# TODO: webhook endpoint should be under config endpoints +class UrlForm(BaseModel): + url: str @app.get("/api/webhook") @@ -2657,14 +975,10 @@ async def get_webhook_url(user=Depends(get_admin_user)): } -class UrlForm(BaseModel): - url: str - - @app.post("/api/webhook") async def update_webhook_url(form_data: UrlForm, user=Depends(get_admin_user)): app.state.config.WEBHOOK_URL = form_data.url - webui_app.state.WEBHOOK_URL = app.state.config.WEBHOOK_URL + app.state.WEBHOOK_URL = app.state.config.WEBHOOK_URL return {"url": app.state.config.WEBHOOK_URL} @@ -2675,11 +989,6 @@ async def get_app_version(): } -@app.get("/api/changelog") -async def get_app_changelog(): - return {key: CHANGELOG[key] for idx, key in enumerate(CHANGELOG) if idx < 5} - - @app.get("/api/version/updates") async def get_app_latest_release_version(): if OFFLINE_MODE: @@ -2703,6 +1012,11 @@ async def get_app_latest_release_version(): return {"current": VERSION, "latest": VERSION} +@app.get("/api/changelog") +async def get_app_changelog(): + return {key: CHANGELOG[key] for idx, key in enumerate(CHANGELOG) if idx < 5} + + ############################ # OAuth Login & Callback ############################ @@ -2790,7 +1104,6 @@ async def healthcheck_with_db(): app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static") app.mount("/cache", StaticFiles(directory=CACHE_DIR), name="cache") - if os.path.exists(FRONTEND_BUILD_DIR): mimetypes.add_type("text/javascript", ".js") app.mount( diff --git a/backend/open_webui/migrations/env.py b/backend/open_webui/migrations/env.py index 5e860c8a0..128881647 100644 --- a/backend/open_webui/migrations/env.py +++ b/backend/open_webui/migrations/env.py @@ -1,7 +1,7 @@ from logging.config import fileConfig from alembic import context -from open_webui.apps.webui.models.auths import Auth +from open_webui.models.auths import Auth from open_webui.env import DATABASE_URL from sqlalchemy import engine_from_config, pool diff --git a/backend/open_webui/migrations/script.py.mako b/backend/open_webui/migrations/script.py.mako index 01e730e77..bcf5567fd 100644 --- a/backend/open_webui/migrations/script.py.mako +++ b/backend/open_webui/migrations/script.py.mako @@ -9,7 +9,7 @@ from typing import Sequence, Union from alembic import op import sqlalchemy as sa -import open_webui.apps.webui.internal.db +import open_webui.internal.db ${imports if imports else ""} # revision identifiers, used by Alembic. diff --git a/backend/open_webui/migrations/versions/7e5b5dc7342b_init.py b/backend/open_webui/migrations/versions/7e5b5dc7342b_init.py index 607a7b2c9..9e56282ef 100644 --- a/backend/open_webui/migrations/versions/7e5b5dc7342b_init.py +++ b/backend/open_webui/migrations/versions/7e5b5dc7342b_init.py @@ -11,8 +11,8 @@ from typing import Sequence, Union import sqlalchemy as sa from alembic import op -import open_webui.apps.webui.internal.db -from open_webui.apps.webui.internal.db import JSONField +import open_webui.internal.db +from open_webui.internal.db import JSONField from open_webui.migrations.util import get_existing_tables # revision identifiers, used by Alembic. diff --git a/backend/open_webui/apps/webui/models/auths.py b/backend/open_webui/models/auths.py similarity index 96% rename from backend/open_webui/apps/webui/models/auths.py rename to backend/open_webui/models/auths.py index ead897347..f07c36c73 100644 --- a/backend/open_webui/apps/webui/models/auths.py +++ b/backend/open_webui/models/auths.py @@ -2,12 +2,12 @@ import logging import uuid from typing import Optional -from open_webui.apps.webui.internal.db import Base, get_db -from open_webui.apps.webui.models.users import UserModel, Users +from open_webui.internal.db import Base, get_db +from open_webui.models.users import UserModel, Users from open_webui.env import SRC_LOG_LEVELS from pydantic import BaseModel from sqlalchemy import Boolean, Column, String, Text -from open_webui.utils.utils import verify_password +from open_webui.utils.auth import verify_password log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MODELS"]) diff --git a/backend/open_webui/apps/webui/models/chats.py b/backend/open_webui/models/chats.py similarity index 99% rename from backend/open_webui/apps/webui/models/chats.py rename to backend/open_webui/models/chats.py index 21250add8..3e621a150 100644 --- a/backend/open_webui/apps/webui/models/chats.py +++ b/backend/open_webui/models/chats.py @@ -3,8 +3,8 @@ import time import uuid from typing import Optional -from open_webui.apps.webui.internal.db import Base, get_db -from open_webui.apps.webui.models.tags import TagModel, Tag, Tags +from open_webui.internal.db import Base, get_db +from open_webui.models.tags import TagModel, Tag, Tags from pydantic import BaseModel, ConfigDict diff --git a/backend/open_webui/apps/webui/models/feedbacks.py b/backend/open_webui/models/feedbacks.py similarity index 98% rename from backend/open_webui/apps/webui/models/feedbacks.py rename to backend/open_webui/models/feedbacks.py index c2356dfd8..7ff5c4540 100644 --- a/backend/open_webui/apps/webui/models/feedbacks.py +++ b/backend/open_webui/models/feedbacks.py @@ -3,8 +3,8 @@ import time import uuid from typing import Optional -from open_webui.apps.webui.internal.db import Base, get_db -from open_webui.apps.webui.models.chats import Chats +from open_webui.internal.db import Base, get_db +from open_webui.models.chats import Chats from open_webui.env import SRC_LOG_LEVELS from pydantic import BaseModel, ConfigDict diff --git a/backend/open_webui/apps/webui/models/files.py b/backend/open_webui/models/files.py similarity index 98% rename from backend/open_webui/apps/webui/models/files.py rename to backend/open_webui/models/files.py index 31c9164b6..4050b0140 100644 --- a/backend/open_webui/apps/webui/models/files.py +++ b/backend/open_webui/models/files.py @@ -2,7 +2,7 @@ import logging import time from typing import Optional -from open_webui.apps.webui.internal.db import Base, JSONField, get_db +from open_webui.internal.db import Base, JSONField, get_db from open_webui.env import SRC_LOG_LEVELS from pydantic import BaseModel, ConfigDict from sqlalchemy import BigInteger, Column, String, Text, JSON diff --git a/backend/open_webui/apps/webui/models/folders.py b/backend/open_webui/models/folders.py similarity index 98% rename from backend/open_webui/apps/webui/models/folders.py rename to backend/open_webui/models/folders.py index 90e8880aa..040774196 100644 --- a/backend/open_webui/apps/webui/models/folders.py +++ b/backend/open_webui/models/folders.py @@ -3,8 +3,8 @@ import time import uuid from typing import Optional -from open_webui.apps.webui.internal.db import Base, get_db -from open_webui.apps.webui.models.chats import Chats +from open_webui.internal.db import Base, get_db +from open_webui.models.chats import Chats from open_webui.env import SRC_LOG_LEVELS from pydantic import BaseModel, ConfigDict diff --git a/backend/open_webui/apps/webui/models/functions.py b/backend/open_webui/models/functions.py similarity index 98% rename from backend/open_webui/apps/webui/models/functions.py rename to backend/open_webui/models/functions.py index fda155075..6c6aed862 100644 --- a/backend/open_webui/apps/webui/models/functions.py +++ b/backend/open_webui/models/functions.py @@ -2,8 +2,8 @@ import logging import time from typing import Optional -from open_webui.apps.webui.internal.db import Base, JSONField, get_db -from open_webui.apps.webui.models.users import Users +from open_webui.internal.db import Base, JSONField, get_db +from open_webui.models.users import Users from open_webui.env import SRC_LOG_LEVELS from pydantic import BaseModel, ConfigDict from sqlalchemy import BigInteger, Boolean, Column, String, Text diff --git a/backend/open_webui/apps/webui/models/groups.py b/backend/open_webui/models/groups.py similarity index 97% rename from backend/open_webui/apps/webui/models/groups.py rename to backend/open_webui/models/groups.py index e692198cd..8f0728411 100644 --- a/backend/open_webui/apps/webui/models/groups.py +++ b/backend/open_webui/models/groups.py @@ -4,10 +4,10 @@ import time from typing import Optional import uuid -from open_webui.apps.webui.internal.db import Base, get_db +from open_webui.internal.db import Base, get_db from open_webui.env import SRC_LOG_LEVELS -from open_webui.apps.webui.models.files import FileMetadataResponse +from open_webui.models.files import FileMetadataResponse from pydantic import BaseModel, ConfigDict diff --git a/backend/open_webui/apps/webui/models/knowledge.py b/backend/open_webui/models/knowledge.py similarity index 97% rename from backend/open_webui/apps/webui/models/knowledge.py rename to backend/open_webui/models/knowledge.py index e1a13b3fd..bed3d5542 100644 --- a/backend/open_webui/apps/webui/models/knowledge.py +++ b/backend/open_webui/models/knowledge.py @@ -4,11 +4,11 @@ import time from typing import Optional import uuid -from open_webui.apps.webui.internal.db import Base, get_db +from open_webui.internal.db import Base, get_db from open_webui.env import SRC_LOG_LEVELS -from open_webui.apps.webui.models.files import FileMetadataResponse -from open_webui.apps.webui.models.users import Users, UserResponse +from open_webui.models.files import FileMetadataResponse +from open_webui.models.users import Users, UserResponse from pydantic import BaseModel, ConfigDict diff --git a/backend/open_webui/apps/webui/models/memories.py b/backend/open_webui/models/memories.py similarity index 98% rename from backend/open_webui/apps/webui/models/memories.py rename to backend/open_webui/models/memories.py index 6686058d3..c8dae9726 100644 --- a/backend/open_webui/apps/webui/models/memories.py +++ b/backend/open_webui/models/memories.py @@ -2,7 +2,7 @@ import time import uuid from typing import Optional -from open_webui.apps.webui.internal.db import Base, get_db +from open_webui.internal.db import Base, get_db from pydantic import BaseModel, ConfigDict from sqlalchemy import BigInteger, Column, String, Text diff --git a/backend/open_webui/apps/webui/models/models.py b/backend/open_webui/models/models.py similarity index 98% rename from backend/open_webui/apps/webui/models/models.py rename to backend/open_webui/models/models.py index 50581bc73..f2f59d7c4 100644 --- a/backend/open_webui/apps/webui/models/models.py +++ b/backend/open_webui/models/models.py @@ -2,10 +2,10 @@ import logging import time from typing import Optional -from open_webui.apps.webui.internal.db import Base, JSONField, get_db +from open_webui.internal.db import Base, JSONField, get_db from open_webui.env import SRC_LOG_LEVELS -from open_webui.apps.webui.models.users import Users, UserResponse +from open_webui.models.users import Users, UserResponse from pydantic import BaseModel, ConfigDict diff --git a/backend/open_webui/apps/webui/models/prompts.py b/backend/open_webui/models/prompts.py similarity index 97% rename from backend/open_webui/apps/webui/models/prompts.py rename to backend/open_webui/models/prompts.py index fe9999195..8ef4cd2be 100644 --- a/backend/open_webui/apps/webui/models/prompts.py +++ b/backend/open_webui/models/prompts.py @@ -1,8 +1,8 @@ import time from typing import Optional -from open_webui.apps.webui.internal.db import Base, get_db -from open_webui.apps.webui.models.users import Users, UserResponse +from open_webui.internal.db import Base, get_db +from open_webui.models.users import Users, UserResponse from pydantic import BaseModel, ConfigDict from sqlalchemy import BigInteger, Column, String, Text, JSON diff --git a/backend/open_webui/apps/webui/models/tags.py b/backend/open_webui/models/tags.py similarity index 98% rename from backend/open_webui/apps/webui/models/tags.py rename to backend/open_webui/models/tags.py index 7424a2660..3e812db95 100644 --- a/backend/open_webui/apps/webui/models/tags.py +++ b/backend/open_webui/models/tags.py @@ -3,7 +3,7 @@ import time import uuid from typing import Optional -from open_webui.apps.webui.internal.db import Base, get_db +from open_webui.internal.db import Base, get_db from open_webui.env import SRC_LOG_LEVELS diff --git a/backend/open_webui/apps/webui/models/tools.py b/backend/open_webui/models/tools.py similarity index 98% rename from backend/open_webui/apps/webui/models/tools.py rename to backend/open_webui/models/tools.py index 8f798c317..a5f13ebb7 100644 --- a/backend/open_webui/apps/webui/models/tools.py +++ b/backend/open_webui/models/tools.py @@ -2,8 +2,8 @@ import logging import time from typing import Optional -from open_webui.apps.webui.internal.db import Base, JSONField, get_db -from open_webui.apps.webui.models.users import Users, UserResponse +from open_webui.internal.db import Base, JSONField, get_db +from open_webui.models.users import Users, UserResponse from open_webui.env import SRC_LOG_LEVELS from pydantic import BaseModel, ConfigDict from sqlalchemy import BigInteger, Column, String, Text, JSON diff --git a/backend/open_webui/apps/webui/models/users.py b/backend/open_webui/models/users.py similarity index 98% rename from backend/open_webui/apps/webui/models/users.py rename to backend/open_webui/models/users.py index 5bbcc3099..5b6c27214 100644 --- a/backend/open_webui/apps/webui/models/users.py +++ b/backend/open_webui/models/users.py @@ -1,8 +1,8 @@ import time from typing import Optional -from open_webui.apps.webui.internal.db import Base, JSONField, get_db -from open_webui.apps.webui.models.chats import Chats +from open_webui.internal.db import Base, JSONField, get_db +from open_webui.models.chats import Chats from pydantic import BaseModel, ConfigDict from sqlalchemy import BigInteger, Column, String, Text diff --git a/backend/open_webui/apps/retrieval/loaders/main.py b/backend/open_webui/retrieval/loaders/main.py similarity index 96% rename from backend/open_webui/apps/retrieval/loaders/main.py rename to backend/open_webui/retrieval/loaders/main.py index 36f03cbb2..a9372f65a 100644 --- a/backend/open_webui/apps/retrieval/loaders/main.py +++ b/backend/open_webui/retrieval/loaders/main.py @@ -1,6 +1,7 @@ import requests import logging import ftfy +import sys from langchain_community.document_loaders import ( BSHTMLLoader, @@ -18,8 +19,9 @@ from langchain_community.document_loaders import ( YoutubeLoader, ) from langchain_core.documents import Document -from open_webui.env import SRC_LOG_LEVELS +from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL +logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL) log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["RAG"]) @@ -106,7 +108,7 @@ class TikaLoader: if "Content-Type" in raw_metadata: headers["Content-Type"] = raw_metadata["Content-Type"] - log.info("Tika extracted text: %s", text) + log.debug("Tika extracted text: %s", text) return [Document(page_content=text, metadata=headers)] else: diff --git a/backend/open_webui/apps/retrieval/loaders/youtube.py b/backend/open_webui/retrieval/loaders/youtube.py similarity index 100% rename from backend/open_webui/apps/retrieval/loaders/youtube.py rename to backend/open_webui/retrieval/loaders/youtube.py diff --git a/backend/open_webui/apps/retrieval/models/colbert.py b/backend/open_webui/retrieval/models/colbert.py similarity index 100% rename from backend/open_webui/apps/retrieval/models/colbert.py rename to backend/open_webui/retrieval/models/colbert.py diff --git a/backend/open_webui/apps/retrieval/utils.py b/backend/open_webui/retrieval/utils.py similarity index 99% rename from backend/open_webui/apps/retrieval/utils.py rename to backend/open_webui/retrieval/utils.py index bf939ecf1..9444ade95 100644 --- a/backend/open_webui/apps/retrieval/utils.py +++ b/backend/open_webui/retrieval/utils.py @@ -11,7 +11,7 @@ from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriev from langchain_community.retrievers import BM25Retriever from langchain_core.documents import Document -from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT +from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT from open_webui.utils.misc import get_last_user_message from open_webui.env import SRC_LOG_LEVELS diff --git a/backend/open_webui/retrieval/vector/connector.py b/backend/open_webui/retrieval/vector/connector.py new file mode 100644 index 000000000..bf97bc7b1 --- /dev/null +++ b/backend/open_webui/retrieval/vector/connector.py @@ -0,0 +1,22 @@ +from open_webui.config import VECTOR_DB + +if VECTOR_DB == "milvus": + from open_webui.retrieval.vector.dbs.milvus import MilvusClient + + VECTOR_DB_CLIENT = MilvusClient() +elif VECTOR_DB == "qdrant": + from open_webui.retrieval.vector.dbs.qdrant import QdrantClient + + VECTOR_DB_CLIENT = QdrantClient() +elif VECTOR_DB == "opensearch": + from open_webui.retrieval.vector.dbs.opensearch import OpenSearchClient + + VECTOR_DB_CLIENT = OpenSearchClient() +elif VECTOR_DB == "pgvector": + from open_webui.retrieval.vector.dbs.pgvector import PgvectorClient + + VECTOR_DB_CLIENT = PgvectorClient() +else: + from open_webui.retrieval.vector.dbs.chroma import ChromaClient + + VECTOR_DB_CLIENT = ChromaClient() diff --git a/backend/open_webui/apps/retrieval/vector/dbs/chroma.py b/backend/open_webui/retrieval/vector/dbs/chroma.py similarity index 98% rename from backend/open_webui/apps/retrieval/vector/dbs/chroma.py rename to backend/open_webui/retrieval/vector/dbs/chroma.py index b2fcdd16a..00d73a889 100644 --- a/backend/open_webui/apps/retrieval/vector/dbs/chroma.py +++ b/backend/open_webui/retrieval/vector/dbs/chroma.py @@ -4,7 +4,7 @@ from chromadb.utils.batch_utils import create_batches from typing import Optional -from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult +from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult from open_webui.config import ( CHROMA_DATA_PATH, CHROMA_HTTP_HOST, diff --git a/backend/open_webui/apps/retrieval/vector/dbs/milvus.py b/backend/open_webui/retrieval/vector/dbs/milvus.py similarity index 99% rename from backend/open_webui/apps/retrieval/vector/dbs/milvus.py rename to backend/open_webui/retrieval/vector/dbs/milvus.py index 5351f860e..31d890664 100644 --- a/backend/open_webui/apps/retrieval/vector/dbs/milvus.py +++ b/backend/open_webui/retrieval/vector/dbs/milvus.py @@ -4,7 +4,7 @@ import json from typing import Optional -from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult +from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult from open_webui.config import ( MILVUS_URI, ) diff --git a/backend/open_webui/apps/retrieval/vector/dbs/opensearch.py b/backend/open_webui/retrieval/vector/dbs/opensearch.py similarity index 98% rename from backend/open_webui/apps/retrieval/vector/dbs/opensearch.py rename to backend/open_webui/retrieval/vector/dbs/opensearch.py index 6234b2837..b3d8b5eb8 100644 --- a/backend/open_webui/apps/retrieval/vector/dbs/opensearch.py +++ b/backend/open_webui/retrieval/vector/dbs/opensearch.py @@ -1,7 +1,7 @@ from opensearchpy import OpenSearch from typing import Optional -from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult +from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult from open_webui.config import ( OPENSEARCH_URI, OPENSEARCH_SSL, diff --git a/backend/open_webui/apps/retrieval/vector/dbs/pgvector.py b/backend/open_webui/retrieval/vector/dbs/pgvector.py similarity index 98% rename from backend/open_webui/apps/retrieval/vector/dbs/pgvector.py rename to backend/open_webui/retrieval/vector/dbs/pgvector.py index d537943a1..cb8c545e9 100644 --- a/backend/open_webui/apps/retrieval/vector/dbs/pgvector.py +++ b/backend/open_webui/retrieval/vector/dbs/pgvector.py @@ -18,7 +18,7 @@ from sqlalchemy.dialects.postgresql import JSONB, array from pgvector.sqlalchemy import Vector from sqlalchemy.ext.mutable import MutableDict -from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult +from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult from open_webui.config import PGVECTOR_DB_URL VECTOR_LENGTH = 1536 @@ -40,7 +40,7 @@ class PgvectorClient: # if no pgvector uri, use the existing database connection if not PGVECTOR_DB_URL: - from open_webui.apps.webui.internal.db import Session + from open_webui.internal.db import Session self.session = Session else: diff --git a/backend/open_webui/apps/retrieval/vector/dbs/qdrant.py b/backend/open_webui/retrieval/vector/dbs/qdrant.py similarity index 98% rename from backend/open_webui/apps/retrieval/vector/dbs/qdrant.py rename to backend/open_webui/retrieval/vector/dbs/qdrant.py index 60c1c3d4d..f077ae45a 100644 --- a/backend/open_webui/apps/retrieval/vector/dbs/qdrant.py +++ b/backend/open_webui/retrieval/vector/dbs/qdrant.py @@ -4,7 +4,7 @@ from qdrant_client import QdrantClient as Qclient from qdrant_client.http.models import PointStruct from qdrant_client.models import models -from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult +from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult from open_webui.config import QDRANT_URI, QDRANT_API_KEY NO_LIMIT = 999999999 diff --git a/backend/open_webui/apps/retrieval/vector/main.py b/backend/open_webui/retrieval/vector/main.py similarity index 100% rename from backend/open_webui/apps/retrieval/vector/main.py rename to backend/open_webui/retrieval/vector/main.py diff --git a/backend/open_webui/apps/retrieval/web/bing.py b/backend/open_webui/retrieval/web/bing.py similarity index 96% rename from backend/open_webui/apps/retrieval/web/bing.py rename to backend/open_webui/retrieval/web/bing.py index b5f889c54..09beb3460 100644 --- a/backend/open_webui/apps/retrieval/web/bing.py +++ b/backend/open_webui/retrieval/web/bing.py @@ -3,7 +3,7 @@ import os from pprint import pprint from typing import Optional import requests -from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results +from open_webui.retrieval.web.main import SearchResult, get_filtered_results from open_webui.env import SRC_LOG_LEVELS import argparse diff --git a/backend/open_webui/apps/retrieval/web/brave.py b/backend/open_webui/retrieval/web/brave.py similarity index 93% rename from backend/open_webui/apps/retrieval/web/brave.py rename to backend/open_webui/retrieval/web/brave.py index f988b3b08..3075db990 100644 --- a/backend/open_webui/apps/retrieval/web/brave.py +++ b/backend/open_webui/retrieval/web/brave.py @@ -2,7 +2,7 @@ import logging from typing import Optional import requests -from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results +from open_webui.retrieval.web.main import SearchResult, get_filtered_results from open_webui.env import SRC_LOG_LEVELS log = logging.getLogger(__name__) diff --git a/backend/open_webui/apps/retrieval/web/duckduckgo.py b/backend/open_webui/retrieval/web/duckduckgo.py similarity index 95% rename from backend/open_webui/apps/retrieval/web/duckduckgo.py rename to backend/open_webui/retrieval/web/duckduckgo.py index 11e512296..7c0c3f1c2 100644 --- a/backend/open_webui/apps/retrieval/web/duckduckgo.py +++ b/backend/open_webui/retrieval/web/duckduckgo.py @@ -1,7 +1,7 @@ import logging from typing import Optional -from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results +from open_webui.retrieval.web.main import SearchResult, get_filtered_results from duckduckgo_search import DDGS from open_webui.env import SRC_LOG_LEVELS diff --git a/backend/open_webui/apps/retrieval/web/google_pse.py b/backend/open_webui/retrieval/web/google_pse.py similarity index 94% rename from backend/open_webui/apps/retrieval/web/google_pse.py rename to backend/open_webui/retrieval/web/google_pse.py index 61b919583..2c51dd3c9 100644 --- a/backend/open_webui/apps/retrieval/web/google_pse.py +++ b/backend/open_webui/retrieval/web/google_pse.py @@ -2,7 +2,7 @@ import logging from typing import Optional import requests -from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results +from open_webui.retrieval.web.main import SearchResult, get_filtered_results from open_webui.env import SRC_LOG_LEVELS log = logging.getLogger(__name__) diff --git a/backend/open_webui/apps/retrieval/web/jina_search.py b/backend/open_webui/retrieval/web/jina_search.py similarity index 94% rename from backend/open_webui/apps/retrieval/web/jina_search.py rename to backend/open_webui/retrieval/web/jina_search.py index f5e2febbe..3de6c1807 100644 --- a/backend/open_webui/apps/retrieval/web/jina_search.py +++ b/backend/open_webui/retrieval/web/jina_search.py @@ -1,7 +1,7 @@ import logging import requests -from open_webui.apps.retrieval.web.main import SearchResult +from open_webui.retrieval.web.main import SearchResult from open_webui.env import SRC_LOG_LEVELS from yarl import URL diff --git a/backend/open_webui/retrieval/web/kagi.py b/backend/open_webui/retrieval/web/kagi.py new file mode 100644 index 000000000..0b69da8bc --- /dev/null +++ b/backend/open_webui/retrieval/web/kagi.py @@ -0,0 +1,48 @@ +import logging +from typing import Optional + +import requests +from open_webui.retrieval.web.main import SearchResult, get_filtered_results +from open_webui.env import SRC_LOG_LEVELS + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["RAG"]) + + +def search_kagi( + api_key: str, query: str, count: int, filter_list: Optional[list[str]] = None +) -> list[SearchResult]: + """Search using Kagi's Search API and return the results as a list of SearchResult objects. + + The Search API will inherit the settings in your account, including results personalization and snippet length. + + Args: + api_key (str): A Kagi Search API key + query (str): The query to search for + count (int): The number of results to return + """ + url = "https://kagi.com/api/v0/search" + headers = { + "Authorization": f"Bot {api_key}", + } + params = {"q": query, "limit": count} + + response = requests.get(url, headers=headers, params=params) + response.raise_for_status() + json_response = response.json() + search_results = json_response.get("data", []) + + results = [ + SearchResult( + link=result["url"], title=result["title"], snippet=result.get("snippet") + ) + for result in search_results + if result["t"] == 0 + ] + + print(results) + + if filter_list: + results = get_filtered_results(results, filter_list) + + return results diff --git a/backend/open_webui/apps/retrieval/web/main.py b/backend/open_webui/retrieval/web/main.py similarity index 100% rename from backend/open_webui/apps/retrieval/web/main.py rename to backend/open_webui/retrieval/web/main.py diff --git a/backend/open_webui/apps/retrieval/web/mojeek.py b/backend/open_webui/retrieval/web/mojeek.py similarity index 93% rename from backend/open_webui/apps/retrieval/web/mojeek.py rename to backend/open_webui/retrieval/web/mojeek.py index f257c92aa..d298b0ee5 100644 --- a/backend/open_webui/apps/retrieval/web/mojeek.py +++ b/backend/open_webui/retrieval/web/mojeek.py @@ -2,7 +2,7 @@ import logging from typing import Optional import requests -from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results +from open_webui.retrieval.web.main import SearchResult, get_filtered_results from open_webui.env import SRC_LOG_LEVELS log = logging.getLogger(__name__) diff --git a/backend/open_webui/apps/retrieval/web/searchapi.py b/backend/open_webui/retrieval/web/searchapi.py similarity index 93% rename from backend/open_webui/apps/retrieval/web/searchapi.py rename to backend/open_webui/retrieval/web/searchapi.py index 412dc6b69..38bc0b574 100644 --- a/backend/open_webui/apps/retrieval/web/searchapi.py +++ b/backend/open_webui/retrieval/web/searchapi.py @@ -3,7 +3,7 @@ from typing import Optional from urllib.parse import urlencode import requests -from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results +from open_webui.retrieval.web.main import SearchResult, get_filtered_results from open_webui.env import SRC_LOG_LEVELS log = logging.getLogger(__name__) diff --git a/backend/open_webui/apps/retrieval/web/searxng.py b/backend/open_webui/retrieval/web/searxng.py similarity index 97% rename from backend/open_webui/apps/retrieval/web/searxng.py rename to backend/open_webui/retrieval/web/searxng.py index cb1eaf91d..15e3c098a 100644 --- a/backend/open_webui/apps/retrieval/web/searxng.py +++ b/backend/open_webui/retrieval/web/searxng.py @@ -2,7 +2,7 @@ import logging from typing import Optional import requests -from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results +from open_webui.retrieval.web.main import SearchResult, get_filtered_results from open_webui.env import SRC_LOG_LEVELS log = logging.getLogger(__name__) diff --git a/backend/open_webui/apps/retrieval/web/serper.py b/backend/open_webui/retrieval/web/serper.py similarity index 93% rename from backend/open_webui/apps/retrieval/web/serper.py rename to backend/open_webui/retrieval/web/serper.py index 436fa167e..685e34375 100644 --- a/backend/open_webui/apps/retrieval/web/serper.py +++ b/backend/open_webui/retrieval/web/serper.py @@ -3,7 +3,7 @@ import logging from typing import Optional import requests -from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results +from open_webui.retrieval.web.main import SearchResult, get_filtered_results from open_webui.env import SRC_LOG_LEVELS log = logging.getLogger(__name__) diff --git a/backend/open_webui/apps/retrieval/web/serply.py b/backend/open_webui/retrieval/web/serply.py similarity index 95% rename from backend/open_webui/apps/retrieval/web/serply.py rename to backend/open_webui/retrieval/web/serply.py index 1c2521c47..a9b473eb0 100644 --- a/backend/open_webui/apps/retrieval/web/serply.py +++ b/backend/open_webui/retrieval/web/serply.py @@ -3,7 +3,7 @@ from typing import Optional from urllib.parse import urlencode import requests -from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results +from open_webui.retrieval.web.main import SearchResult, get_filtered_results from open_webui.env import SRC_LOG_LEVELS log = logging.getLogger(__name__) diff --git a/backend/open_webui/apps/retrieval/web/serpstack.py b/backend/open_webui/retrieval/web/serpstack.py similarity index 94% rename from backend/open_webui/apps/retrieval/web/serpstack.py rename to backend/open_webui/retrieval/web/serpstack.py index b655934de..d4dbda57c 100644 --- a/backend/open_webui/apps/retrieval/web/serpstack.py +++ b/backend/open_webui/retrieval/web/serpstack.py @@ -2,7 +2,7 @@ import logging from typing import Optional import requests -from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results +from open_webui.retrieval.web.main import SearchResult, get_filtered_results from open_webui.env import SRC_LOG_LEVELS log = logging.getLogger(__name__) diff --git a/backend/open_webui/apps/retrieval/web/tavily.py b/backend/open_webui/retrieval/web/tavily.py similarity index 94% rename from backend/open_webui/apps/retrieval/web/tavily.py rename to backend/open_webui/retrieval/web/tavily.py index 03b0be75a..cc468725d 100644 --- a/backend/open_webui/apps/retrieval/web/tavily.py +++ b/backend/open_webui/retrieval/web/tavily.py @@ -1,7 +1,7 @@ import logging import requests -from open_webui.apps.retrieval.web.main import SearchResult +from open_webui.retrieval.web.main import SearchResult from open_webui.env import SRC_LOG_LEVELS log = logging.getLogger(__name__) diff --git a/backend/open_webui/apps/retrieval/web/testdata/bing.json b/backend/open_webui/retrieval/web/testdata/bing.json similarity index 100% rename from backend/open_webui/apps/retrieval/web/testdata/bing.json rename to backend/open_webui/retrieval/web/testdata/bing.json diff --git a/backend/open_webui/apps/retrieval/web/testdata/brave.json b/backend/open_webui/retrieval/web/testdata/brave.json similarity index 100% rename from backend/open_webui/apps/retrieval/web/testdata/brave.json rename to backend/open_webui/retrieval/web/testdata/brave.json diff --git a/backend/open_webui/apps/retrieval/web/testdata/google_pse.json b/backend/open_webui/retrieval/web/testdata/google_pse.json similarity index 100% rename from backend/open_webui/apps/retrieval/web/testdata/google_pse.json rename to backend/open_webui/retrieval/web/testdata/google_pse.json diff --git a/backend/open_webui/apps/retrieval/web/testdata/searchapi.json b/backend/open_webui/retrieval/web/testdata/searchapi.json similarity index 100% rename from backend/open_webui/apps/retrieval/web/testdata/searchapi.json rename to backend/open_webui/retrieval/web/testdata/searchapi.json diff --git a/backend/open_webui/apps/retrieval/web/testdata/searxng.json b/backend/open_webui/retrieval/web/testdata/searxng.json similarity index 100% rename from backend/open_webui/apps/retrieval/web/testdata/searxng.json rename to backend/open_webui/retrieval/web/testdata/searxng.json diff --git a/backend/open_webui/apps/retrieval/web/testdata/serper.json b/backend/open_webui/retrieval/web/testdata/serper.json similarity index 100% rename from backend/open_webui/apps/retrieval/web/testdata/serper.json rename to backend/open_webui/retrieval/web/testdata/serper.json diff --git a/backend/open_webui/apps/retrieval/web/testdata/serply.json b/backend/open_webui/retrieval/web/testdata/serply.json similarity index 100% rename from backend/open_webui/apps/retrieval/web/testdata/serply.json rename to backend/open_webui/retrieval/web/testdata/serply.json diff --git a/backend/open_webui/apps/retrieval/web/testdata/serpstack.json b/backend/open_webui/retrieval/web/testdata/serpstack.json similarity index 100% rename from backend/open_webui/apps/retrieval/web/testdata/serpstack.json rename to backend/open_webui/retrieval/web/testdata/serpstack.json diff --git a/backend/open_webui/apps/retrieval/web/utils.py b/backend/open_webui/retrieval/web/utils.py similarity index 100% rename from backend/open_webui/apps/retrieval/web/utils.py rename to backend/open_webui/retrieval/web/utils.py diff --git a/backend/open_webui/routers/audio.py b/backend/open_webui/routers/audio.py new file mode 100644 index 000000000..a26355945 --- /dev/null +++ b/backend/open_webui/routers/audio.py @@ -0,0 +1,703 @@ +import hashlib +import json +import logging +import os +import uuid +from functools import lru_cache +from pathlib import Path +from pydub import AudioSegment +from pydub.silence import split_on_silence + +import aiohttp +import aiofiles +import requests + +from fastapi import ( + Depends, + FastAPI, + File, + HTTPException, + Request, + UploadFile, + status, + APIRouter, +) +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import FileResponse +from pydantic import BaseModel + + +from open_webui.utils.auth import get_admin_user, get_verified_user +from open_webui.config import ( + WHISPER_MODEL_AUTO_UPDATE, + WHISPER_MODEL_DIR, + CACHE_DIR, +) + +from open_webui.constants import ERROR_MESSAGES +from open_webui.env import ( + ENV, + SRC_LOG_LEVELS, + DEVICE_TYPE, + ENABLE_FORWARD_USER_INFO_HEADERS, +) + + +router = APIRouter() + +# Constants +MAX_FILE_SIZE_MB = 25 +MAX_FILE_SIZE = MAX_FILE_SIZE_MB * 1024 * 1024 # Convert MB to bytes + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["AUDIO"]) + +SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/") +SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True) + + +########################################## +# +# Utility functions +# +########################################## + +from pydub import AudioSegment +from pydub.utils import mediainfo + + +def is_mp4_audio(file_path): + """Check if the given file is an MP4 audio file.""" + if not os.path.isfile(file_path): + print(f"File not found: {file_path}") + return False + + info = mediainfo(file_path) + if ( + info.get("codec_name") == "aac" + and info.get("codec_type") == "audio" + and info.get("codec_tag_string") == "mp4a" + ): + return True + return False + + +def convert_mp4_to_wav(file_path, output_path): + """Convert MP4 audio file to WAV format.""" + audio = AudioSegment.from_file(file_path, format="mp4") + audio.export(output_path, format="wav") + print(f"Converted {file_path} to {output_path}") + + +def set_faster_whisper_model(model: str, auto_update: bool = False): + whisper_model = None + if model: + from faster_whisper import WhisperModel + + faster_whisper_kwargs = { + "model_size_or_path": model, + "device": DEVICE_TYPE if DEVICE_TYPE and DEVICE_TYPE == "cuda" else "cpu", + "compute_type": "int8", + "download_root": WHISPER_MODEL_DIR, + "local_files_only": not auto_update, + } + + try: + whisper_model = WhisperModel(**faster_whisper_kwargs) + except Exception: + log.warning( + "WhisperModel initialization failed, attempting download with local_files_only=False" + ) + faster_whisper_kwargs["local_files_only"] = False + whisper_model = WhisperModel(**faster_whisper_kwargs) + return whisper_model + + +########################################## +# +# Audio API +# +########################################## + + +class TTSConfigForm(BaseModel): + OPENAI_API_BASE_URL: str + OPENAI_API_KEY: str + API_KEY: str + ENGINE: str + MODEL: str + VOICE: str + SPLIT_ON: str + AZURE_SPEECH_REGION: str + AZURE_SPEECH_OUTPUT_FORMAT: str + + +class STTConfigForm(BaseModel): + OPENAI_API_BASE_URL: str + OPENAI_API_KEY: str + ENGINE: str + MODEL: str + WHISPER_MODEL: str + + +class AudioConfigUpdateForm(BaseModel): + tts: TTSConfigForm + stt: STTConfigForm + + +@router.get("/config") +async def get_audio_config(request: Request, user=Depends(get_admin_user)): + return { + "tts": { + "OPENAI_API_BASE_URL": request.app.state.config.TTS_OPENAI_API_BASE_URL, + "OPENAI_API_KEY": request.app.state.config.TTS_OPENAI_API_KEY, + "API_KEY": request.app.state.config.TTS_API_KEY, + "ENGINE": request.app.state.config.TTS_ENGINE, + "MODEL": request.app.state.config.TTS_MODEL, + "VOICE": request.app.state.config.TTS_VOICE, + "SPLIT_ON": request.app.state.config.TTS_SPLIT_ON, + "AZURE_SPEECH_REGION": request.app.state.config.TTS_AZURE_SPEECH_REGION, + "AZURE_SPEECH_OUTPUT_FORMAT": request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT, + }, + "stt": { + "OPENAI_API_BASE_URL": request.app.state.config.STT_OPENAI_API_BASE_URL, + "OPENAI_API_KEY": request.app.state.config.STT_OPENAI_API_KEY, + "ENGINE": request.app.state.config.STT_ENGINE, + "MODEL": request.app.state.config.STT_MODEL, + "WHISPER_MODEL": request.app.state.config.WHISPER_MODEL, + }, + } + + +@router.post("/config/update") +async def update_audio_config( + request: Request, form_data: AudioConfigUpdateForm, user=Depends(get_admin_user) +): + request.app.state.config.TTS_OPENAI_API_BASE_URL = form_data.tts.OPENAI_API_BASE_URL + request.app.state.config.TTS_OPENAI_API_KEY = form_data.tts.OPENAI_API_KEY + request.app.state.config.TTS_API_KEY = form_data.tts.API_KEY + request.app.state.config.TTS_ENGINE = form_data.tts.ENGINE + request.app.state.config.TTS_MODEL = form_data.tts.MODEL + request.app.state.config.TTS_VOICE = form_data.tts.VOICE + request.app.state.config.TTS_SPLIT_ON = form_data.tts.SPLIT_ON + request.app.state.config.TTS_AZURE_SPEECH_REGION = form_data.tts.AZURE_SPEECH_REGION + request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT = ( + form_data.tts.AZURE_SPEECH_OUTPUT_FORMAT + ) + + request.app.state.config.STT_OPENAI_API_BASE_URL = form_data.stt.OPENAI_API_BASE_URL + request.app.state.config.STT_OPENAI_API_KEY = form_data.stt.OPENAI_API_KEY + request.app.state.config.STT_ENGINE = form_data.stt.ENGINE + request.app.state.config.STT_MODEL = form_data.stt.MODEL + request.app.state.config.WHISPER_MODEL = form_data.stt.WHISPER_MODEL + + if request.app.state.config.STT_ENGINE == "": + request.app.state.faster_whisper_model = set_faster_whisper_model( + form_data.stt.WHISPER_MODEL, WHISPER_MODEL_AUTO_UPDATE + ) + + return { + "tts": { + "OPENAI_API_BASE_URL": request.app.state.config.TTS_OPENAI_API_BASE_URL, + "OPENAI_API_KEY": request.app.state.config.TTS_OPENAI_API_KEY, + "API_KEY": request.app.state.config.TTS_API_KEY, + "ENGINE": request.app.state.config.TTS_ENGINE, + "MODEL": request.app.state.config.TTS_MODEL, + "VOICE": request.app.state.config.TTS_VOICE, + "SPLIT_ON": request.app.state.config.TTS_SPLIT_ON, + "AZURE_SPEECH_REGION": request.app.state.config.TTS_AZURE_SPEECH_REGION, + "AZURE_SPEECH_OUTPUT_FORMAT": request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT, + }, + "stt": { + "OPENAI_API_BASE_URL": request.app.state.config.STT_OPENAI_API_BASE_URL, + "OPENAI_API_KEY": request.app.state.config.STT_OPENAI_API_KEY, + "ENGINE": request.app.state.config.STT_ENGINE, + "MODEL": request.app.state.config.STT_MODEL, + "WHISPER_MODEL": request.app.state.config.WHISPER_MODEL, + }, + } + + +def load_speech_pipeline(): + from transformers import pipeline + from datasets import load_dataset + + if request.app.state.speech_synthesiser is None: + request.app.state.speech_synthesiser = pipeline( + "text-to-speech", "microsoft/speecht5_tts" + ) + + if request.app.state.speech_speaker_embeddings_dataset is None: + request.app.state.speech_speaker_embeddings_dataset = load_dataset( + "Matthijs/cmu-arctic-xvectors", split="validation" + ) + + +@router.post("/speech") +async def speech(request: Request, user=Depends(get_verified_user)): + body = await request.body() + name = hashlib.sha256(body).hexdigest() + + file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3") + file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json") + + # Check if the file already exists in the cache + if file_path.is_file(): + return FileResponse(file_path) + + payload = None + try: + payload = json.loads(body.decode("utf-8")) + except Exception as e: + log.exception(e) + raise HTTPException(status_code=400, detail="Invalid JSON payload") + + if request.app.state.config.TTS_ENGINE == "openai": + payload["model"] = request.app.state.config.TTS_MODEL + + try: + async with aiohttp.ClientSession() as session: + async with session.post( + url=f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech", + data=payload, + headers={ + "Content-Type": "application/json", + "Authorization": f"Bearer {request.app.state.config.TTS_OPENAI_API_KEY}", + **( + { + "X-OpenWebUI-User-Name": user.name, + "X-OpenWebUI-User-Id": user.id, + "X-OpenWebUI-User-Email": user.email, + "X-OpenWebUI-User-Role": user.role, + } + if ENABLE_FORWARD_USER_INFO_HEADERS + else {} + ), + }, + ) as r: + r.raise_for_status() + + async with aiofiles.open(file_path, "wb") as f: + await f.write(await r.read()) + + async with aiofiles.open(file_body_path, "w") as f: + await f.write(json.dumps(json.loads(body.decode("utf-8")))) + + return FileResponse(file_path) + + except Exception as e: + log.exception(e) + detail = None + + try: + if r.status != 200: + res = await r.json() + if "error" in res: + detail = f"External: {res['error'].get('message', '')}" + except Exception: + detail = f"External: {e}" + + raise HTTPException( + status_code=getattr(r, "status", 500), + detail=detail if detail else "Open WebUI: Server Connection Error", + ) + + elif request.app.state.config.TTS_ENGINE == "elevenlabs": + voice_id = payload.get("voice", "") + + if voice_id not in get_available_voices(): + raise HTTPException( + status_code=400, + detail="Invalid voice id", + ) + + try: + async with aiohttp.ClientSession() as session: + async with session.post( + f"https://api.elevenlabs.io/v1/text-to-speech/{voice_id}", + json={ + "text": payload["input"], + "model_id": request.app.state.config.TTS_MODEL, + "voice_settings": {"stability": 0.5, "similarity_boost": 0.5}, + }, + headers={ + "Accept": "audio/mpeg", + "Content-Type": "application/json", + "xi-api-key": request.app.state.config.TTS_API_KEY, + }, + ) as r: + r.raise_for_status() + + async with aiofiles.open(file_path, "wb") as f: + await f.write(await r.read()) + + async with aiofiles.open(file_body_path, "w") as f: + await f.write(json.dumps(json.loads(body.decode("utf-8")))) + + return FileResponse(file_path) + + except Exception as e: + log.exception(e) + detail = None + + try: + if r.status != 200: + res = await r.json() + if "error" in res: + detail = f"External: {res['error'].get('message', '')}" + except Exception: + detail = f"External: {e}" + + raise HTTPException( + status_code=getattr(r, "status", 500), + detail=detail if detail else "Open WebUI: Server Connection Error", + ) + + elif request.app.state.config.TTS_ENGINE == "azure": + try: + payload = json.loads(body.decode("utf-8")) + except Exception as e: + log.exception(e) + raise HTTPException(status_code=400, detail="Invalid JSON payload") + + region = request.app.state.config.TTS_AZURE_SPEECH_REGION + language = request.app.state.config.TTS_VOICE + locale = "-".join(request.app.state.config.TTS_VOICE.split("-")[:1]) + output_format = request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT + + try: + data = f""" + {payload["input"]} + """ + async with aiohttp.ClientSession() as session: + async with session.post( + f"https://{region}.tts.speech.microsoft.com/cognitiveservices/v1", + headers={ + "Ocp-Apim-Subscription-Key": request.app.state.config.TTS_API_KEY, + "Content-Type": "application/ssml+xml", + "X-Microsoft-OutputFormat": output_format, + }, + data=data, + ) as r: + r.raise_for_status() + + async with aiofiles.open(file_path, "wb") as f: + await f.write(await r.read()) + + return FileResponse(file_path) + + except Exception as e: + log.exception(e) + detail = None + + try: + if r.status != 200: + res = await r.json() + if "error" in res: + detail = f"External: {res['error'].get('message', '')}" + except Exception: + detail = f"External: {e}" + + raise HTTPException( + status_code=getattr(r, "status", 500), + detail=detail if detail else "Open WebUI: Server Connection Error", + ) + + elif request.app.state.config.TTS_ENGINE == "transformers": + payload = None + try: + payload = json.loads(body.decode("utf-8")) + except Exception as e: + log.exception(e) + raise HTTPException(status_code=400, detail="Invalid JSON payload") + + import torch + import soundfile as sf + + load_speech_pipeline() + + embeddings_dataset = request.app.state.speech_speaker_embeddings_dataset + + speaker_index = 6799 + try: + speaker_index = embeddings_dataset["filename"].index( + request.app.state.config.TTS_MODEL + ) + except Exception: + pass + + speaker_embedding = torch.tensor( + embeddings_dataset[speaker_index]["xvector"] + ).unsqueeze(0) + + speech = request.app.state.speech_synthesiser( + payload["input"], + forward_params={"speaker_embeddings": speaker_embedding}, + ) + + sf.write(file_path, speech["audio"], samplerate=speech["sampling_rate"]) + with open(file_body_path, "w") as f: + json.dump(json.loads(body.decode("utf-8")), f) + + return FileResponse(file_path) + + +def transcribe(request: Request, file_path): + print("transcribe", file_path) + filename = os.path.basename(file_path) + file_dir = os.path.dirname(file_path) + id = filename.split(".")[0] + + if request.app.state.config.STT_ENGINE == "": + if request.app.state.faster_whisper_model is None: + request.app.state.faster_whisper_model = set_faster_whisper_model( + request.app.state.config.WHISPER_MODEL + ) + + model = request.app.state.faster_whisper_model + segments, info = model.transcribe(file_path, beam_size=5) + log.info( + "Detected language '%s' with probability %f" + % (info.language, info.language_probability) + ) + + transcript = "".join([segment.text for segment in list(segments)]) + data = {"text": transcript.strip()} + + # save the transcript to a json file + transcript_file = f"{file_dir}/{id}.json" + with open(transcript_file, "w") as f: + json.dump(data, f) + + log.debug(data) + return data + elif request.app.state.config.STT_ENGINE == "openai": + if is_mp4_audio(file_path): + os.rename(file_path, file_path.replace(".wav", ".mp4")) + # Convert MP4 audio file to WAV format + convert_mp4_to_wav(file_path.replace(".wav", ".mp4"), file_path) + + r = None + try: + r = requests.post( + url=f"{request.app.state.config.STT_OPENAI_API_BASE_URL}/audio/transcriptions", + headers={ + "Authorization": f"Bearer {request.app.state.config.STT_OPENAI_API_KEY}" + }, + files={"file": (filename, open(file_path, "rb"))}, + data={"model": request.app.state.config.STT_MODEL}, + ) + + r.raise_for_status() + data = r.json() + + # save the transcript to a json file + transcript_file = f"{file_dir}/{id}.json" + with open(transcript_file, "w") as f: + json.dump(data, f) + + return data + except Exception as e: + log.exception(e) + + detail = None + if r is not None: + try: + res = r.json() + if "error" in res: + detail = f"External: {res['error'].get('message', '')}" + except Exception: + detail = f"External: {e}" + + raise Exception(detail if detail else "Open WebUI: Server Connection Error") + + +def compress_audio(file_path): + if os.path.getsize(file_path) > MAX_FILE_SIZE: + file_dir = os.path.dirname(file_path) + audio = AudioSegment.from_file(file_path) + audio = audio.set_frame_rate(16000).set_channels(1) # Compress audio + compressed_path = f"{file_dir}/{id}_compressed.opus" + audio.export(compressed_path, format="opus", bitrate="32k") + log.debug(f"Compressed audio to {compressed_path}") + + if ( + os.path.getsize(compressed_path) > MAX_FILE_SIZE + ): # Still larger than MAX_FILE_SIZE after compression + raise Exception(ERROR_MESSAGES.FILE_TOO_LARGE(size=f"{MAX_FILE_SIZE_MB}MB")) + return compressed_path + else: + return file_path + + +@router.post("/transcriptions") +def transcription( + request: Request, + file: UploadFile = File(...), + user=Depends(get_verified_user), +): + log.info(f"file.content_type: {file.content_type}") + + if file.content_type not in ["audio/mpeg", "audio/wav", "audio/ogg", "audio/x-m4a"]: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.FILE_NOT_SUPPORTED, + ) + + try: + ext = file.filename.split(".")[-1] + id = uuid.uuid4() + + filename = f"{id}.{ext}" + contents = file.file.read() + + file_dir = f"{CACHE_DIR}/audio/transcriptions" + os.makedirs(file_dir, exist_ok=True) + file_path = f"{file_dir}/{filename}" + + with open(file_path, "wb") as f: + f.write(contents) + + try: + try: + file_path = compress_audio(file_path) + except Exception as e: + log.exception(e) + + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT(e), + ) + + data = transcribe(request, file_path) + file_path = file_path.split("/")[-1] + return {**data, "filename": file_path} + except Exception as e: + log.exception(e) + + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT(e), + ) + + except Exception as e: + log.exception(e) + + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT(e), + ) + + +def get_available_models(request: Request) -> list[dict]: + available_models = [] + if request.app.state.config.TTS_ENGINE == "openai": + available_models = [{"id": "tts-1"}, {"id": "tts-1-hd"}] + elif request.app.state.config.TTS_ENGINE == "elevenlabs": + try: + response = requests.get( + "https://api.elevenlabs.io/v1/models", + headers={ + "xi-api-key": request.app.state.config.TTS_API_KEY, + "Content-Type": "application/json", + }, + timeout=5, + ) + response.raise_for_status() + models = response.json() + + available_models = [ + {"name": model["name"], "id": model["model_id"]} for model in models + ] + except requests.RequestException as e: + log.error(f"Error fetching voices: {str(e)}") + return available_models + + +@router.get("/models") +async def get_models(request: Request, user=Depends(get_verified_user)): + return {"models": get_available_models(request)} + + +def get_available_voices(request) -> dict: + """Returns {voice_id: voice_name} dict""" + available_voices = {} + if request.app.state.config.TTS_ENGINE == "openai": + available_voices = { + "alloy": "alloy", + "echo": "echo", + "fable": "fable", + "onyx": "onyx", + "nova": "nova", + "shimmer": "shimmer", + } + elif request.app.state.config.TTS_ENGINE == "elevenlabs": + try: + available_voices = get_elevenlabs_voices( + api_key=request.app.state.config.TTS_API_KEY + ) + except Exception: + # Avoided @lru_cache with exception + pass + elif request.app.state.config.TTS_ENGINE == "azure": + try: + region = request.app.state.config.TTS_AZURE_SPEECH_REGION + url = f"https://{region}.tts.speech.microsoft.com/cognitiveservices/voices/list" + headers = { + "Ocp-Apim-Subscription-Key": request.app.state.config.TTS_API_KEY + } + + response = requests.get(url, headers=headers) + response.raise_for_status() + voices = response.json() + + for voice in voices: + available_voices[voice["ShortName"]] = ( + f"{voice['DisplayName']} ({voice['ShortName']})" + ) + except requests.RequestException as e: + log.error(f"Error fetching voices: {str(e)}") + + return available_voices + + +@lru_cache +def get_elevenlabs_voices(api_key: str) -> dict: + """ + Note, set the following in your .env file to use Elevenlabs: + AUDIO_TTS_ENGINE=elevenlabs + AUDIO_TTS_API_KEY=sk_... # Your Elevenlabs API key + AUDIO_TTS_VOICE=EXAVITQu4vr4xnSDxMaL # From https://api.elevenlabs.io/v1/voices + AUDIO_TTS_MODEL=eleven_multilingual_v2 + """ + + try: + # TODO: Add retries + response = requests.get( + "https://api.elevenlabs.io/v1/voices", + headers={ + "xi-api-key": api_key, + "Content-Type": "application/json", + }, + ) + response.raise_for_status() + voices_data = response.json() + + voices = {} + for voice in voices_data.get("voices", []): + voices[voice["voice_id"]] = voice["name"] + except requests.RequestException as e: + # Avoid @lru_cache with exception + log.error(f"Error fetching voices: {str(e)}") + raise RuntimeError(f"Error fetching voices: {str(e)}") + + return voices + + +@router.get("/voices") +async def get_voices(request: Request, user=Depends(get_verified_user)): + return { + "voices": [ + {"id": k, "name": v} for k, v in get_available_voices(request).items() + ] + } diff --git a/backend/open_webui/apps/webui/routers/auths.py b/backend/open_webui/routers/auths.py similarity index 94% rename from backend/open_webui/apps/webui/routers/auths.py rename to backend/open_webui/routers/auths.py index 8f175f366..0b1f42edf 100644 --- a/backend/open_webui/apps/webui/routers/auths.py +++ b/backend/open_webui/routers/auths.py @@ -3,8 +3,9 @@ import uuid import time import datetime import logging +from aiohttp import ClientSession -from open_webui.apps.webui.models.auths import ( +from open_webui.models.auths import ( AddUserForm, ApiKey, Auths, @@ -17,7 +18,7 @@ from open_webui.apps.webui.models.auths import ( UpdateProfileForm, UserResponse, ) -from open_webui.apps.webui.models.users import Users +from open_webui.models.users import Users from open_webui.constants import ERROR_MESSAGES, WEBHOOK_MESSAGES from open_webui.env import ( @@ -29,10 +30,14 @@ from open_webui.env import ( SRC_LOG_LEVELS, ) from fastapi import APIRouter, Depends, HTTPException, Request, status -from fastapi.responses import Response +from fastapi.responses import RedirectResponse, Response +from open_webui.config import ( + OPENID_PROVIDER_URL, + ENABLE_OAUTH_SIGNUP, +) from pydantic import BaseModel from open_webui.utils.misc import parse_duration, validate_email_format -from open_webui.utils.utils import ( +from open_webui.utils.auth import ( create_api_key, create_token, get_admin_user, @@ -498,8 +503,31 @@ async def signup(request: Request, response: Response, form_data: SignupForm): @router.get("/signout") -async def signout(response: Response): +async def signout(request: Request, response: Response): response.delete_cookie("token") + + if ENABLE_OAUTH_SIGNUP.value: + oauth_id_token = request.cookies.get("oauth_id_token") + if oauth_id_token: + try: + async with ClientSession() as session: + async with session.get(OPENID_PROVIDER_URL.value) as resp: + if resp.status == 200: + openid_data = await resp.json() + logout_url = openid_data.get("end_session_endpoint") + if logout_url: + response.delete_cookie("oauth_id_token") + return RedirectResponse( + url=f"{logout_url}?id_token_hint={oauth_id_token}" + ) + else: + raise HTTPException( + status_code=resp.status, + detail="Failed to fetch OpenID configuration", + ) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + return {"status": True} diff --git a/backend/open_webui/apps/webui/routers/chats.py b/backend/open_webui/routers/chats.py similarity index 98% rename from backend/open_webui/apps/webui/routers/chats.py rename to backend/open_webui/routers/chats.py index db95337d5..5e0e75e24 100644 --- a/backend/open_webui/apps/webui/routers/chats.py +++ b/backend/open_webui/routers/chats.py @@ -2,15 +2,15 @@ import json import logging from typing import Optional -from open_webui.apps.webui.models.chats import ( +from open_webui.models.chats import ( ChatForm, ChatImportForm, ChatResponse, Chats, ChatTitleIdResponse, ) -from open_webui.apps.webui.models.tags import TagModel, Tags -from open_webui.apps.webui.models.folders import Folders +from open_webui.models.tags import TagModel, Tags +from open_webui.models.folders import Folders from open_webui.config import ENABLE_ADMIN_CHAT_ACCESS, ENABLE_ADMIN_EXPORT from open_webui.constants import ERROR_MESSAGES @@ -19,7 +19,7 @@ from fastapi import APIRouter, Depends, HTTPException, Request, status from pydantic import BaseModel -from open_webui.utils.utils import get_admin_user, get_verified_user +from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.utils.access_control import has_permission log = logging.getLogger(__name__) @@ -607,7 +607,6 @@ async def add_tag_by_id_and_tag_name( detail=ERROR_MESSAGES.DEFAULT("Tag name cannot be 'None'"), ) - print(tags, tag_id) if tag_id not in tags: Chats.add_chat_tag_by_id_and_user_id_and_tag_name( id, user.id, form_data.name diff --git a/backend/open_webui/apps/webui/routers/configs.py b/backend/open_webui/routers/configs.py similarity index 97% rename from backend/open_webui/apps/webui/routers/configs.py rename to backend/open_webui/routers/configs.py index 7466e6fda..ef6c4d8c1 100644 --- a/backend/open_webui/apps/webui/routers/configs.py +++ b/backend/open_webui/routers/configs.py @@ -3,7 +3,7 @@ from pydantic import BaseModel from typing import Optional -from open_webui.utils.utils import get_admin_user, get_verified_user +from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.config import get_config, save_config from open_webui.config import BannerModel diff --git a/backend/open_webui/apps/webui/routers/evaluations.py b/backend/open_webui/routers/evaluations.py similarity index 96% rename from backend/open_webui/apps/webui/routers/evaluations.py rename to backend/open_webui/routers/evaluations.py index b9e3bff29..f0c4a6b06 100644 --- a/backend/open_webui/apps/webui/routers/evaluations.py +++ b/backend/open_webui/routers/evaluations.py @@ -2,8 +2,8 @@ from typing import Optional from fastapi import APIRouter, Depends, HTTPException, status, Request from pydantic import BaseModel -from open_webui.apps.webui.models.users import Users, UserModel -from open_webui.apps.webui.models.feedbacks import ( +from open_webui.models.users import Users, UserModel +from open_webui.models.feedbacks import ( FeedbackModel, FeedbackResponse, FeedbackForm, @@ -11,7 +11,7 @@ from open_webui.apps.webui.models.feedbacks import ( ) from open_webui.constants import ERROR_MESSAGES -from open_webui.utils.utils import get_admin_user, get_verified_user +from open_webui.utils.auth import get_admin_user, get_verified_user router = APIRouter() diff --git a/backend/open_webui/apps/webui/routers/files.py b/backend/open_webui/routers/files.py similarity index 89% rename from backend/open_webui/apps/webui/routers/files.py rename to backend/open_webui/routers/files.py index e7459a15f..fa36a03ea 100644 --- a/backend/open_webui/apps/webui/routers/files.py +++ b/backend/open_webui/routers/files.py @@ -5,27 +5,28 @@ from pathlib import Path from typing import Optional from pydantic import BaseModel import mimetypes +from urllib.parse import quote from open_webui.storage.provider import Storage -from open_webui.apps.webui.models.files import ( +from open_webui.models.files import ( FileForm, FileModel, FileModelResponse, Files, ) -from open_webui.apps.retrieval.main import process_file, ProcessFileForm +from open_webui.routers.retrieval import process_file, ProcessFileForm from open_webui.config import UPLOAD_DIR from open_webui.env import SRC_LOG_LEVELS from open_webui.constants import ERROR_MESSAGES -from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status +from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status, Request from fastapi.responses import FileResponse, StreamingResponse -from open_webui.utils.utils import get_admin_user, get_verified_user +from open_webui.utils.auth import get_admin_user, get_verified_user log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MODELS"]) @@ -39,7 +40,9 @@ router = APIRouter() @router.post("/", response_model=FileModelResponse) -def upload_file(file: UploadFile = File(...), user=Depends(get_verified_user)): +def upload_file( + request: Request, file: UploadFile = File(...), user=Depends(get_verified_user) +): log.info(f"file.content_type: {file.content_type}") try: unsanitized_filename = file.filename @@ -68,7 +71,7 @@ def upload_file(file: UploadFile = File(...), user=Depends(get_verified_user)): ) try: - process_file(ProcessFileForm(file_id=id)) + process_file(request, ProcessFileForm(file_id=id)) file_item = Files.get_file_by_id(id=id) except Exception as e: log.exception(e) @@ -183,13 +186,15 @@ class ContentForm(BaseModel): @router.post("/{id}/data/content/update") async def update_file_data_content_by_id( - id: str, form_data: ContentForm, user=Depends(get_verified_user) + request: Request, id: str, form_data: ContentForm, user=Depends(get_verified_user) ): file = Files.get_file_by_id(id) if file and (file.user_id == user.id or user.role == "admin"): try: - process_file(ProcessFileForm(file_id=id, content=form_data.content)) + process_file( + request, ProcessFileForm(file_id=id, content=form_data.content) + ) file = Files.get_file_by_id(id=id) except Exception as e: log.exception(e) @@ -218,11 +223,15 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)): # Check if the file already exists in the cache if file_path.is_file(): - print(f"file_path: {file_path}") + # Handle Unicode filenames + filename = file.meta.get("name", file.filename) + encoded_filename = quote(filename) # RFC5987 encoding headers = { - "Content-Disposition": f'attachment; filename="{file.meta.get("name", file.filename)}"' + "Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}" } + return FileResponse(file_path, headers=headers) + else: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, @@ -279,16 +288,20 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)): if file and (file.user_id == user.id or user.role == "admin"): file_path = file.path + + # Handle Unicode filenames + filename = file.meta.get("name", file.filename) + encoded_filename = quote(filename) # RFC5987 encoding + headers = { + "Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}" + } + if file_path: file_path = Storage.get_file(file_path) file_path = Path(file_path) # Check if the file already exists in the cache if file_path.is_file(): - print(f"file_path: {file_path}") - headers = { - "Content-Disposition": f'attachment; filename="{file.meta.get("name", file.filename)}"' - } return FileResponse(file_path, headers=headers) else: raise HTTPException( @@ -307,7 +320,7 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)): return StreamingResponse( generator(), media_type="text/plain", - headers={"Content-Disposition": f"attachment; filename={file_name}"}, + headers=headers, ) else: raise HTTPException( diff --git a/backend/open_webui/apps/webui/routers/folders.py b/backend/open_webui/routers/folders.py similarity index 97% rename from backend/open_webui/apps/webui/routers/folders.py rename to backend/open_webui/routers/folders.py index 36075c357..ca2fbd213 100644 --- a/backend/open_webui/apps/webui/routers/folders.py +++ b/backend/open_webui/routers/folders.py @@ -8,12 +8,12 @@ from pydantic import BaseModel import mimetypes -from open_webui.apps.webui.models.folders import ( +from open_webui.models.folders import ( FolderForm, FolderModel, Folders, ) -from open_webui.apps.webui.models.chats import Chats +from open_webui.models.chats import Chats from open_webui.config import UPLOAD_DIR from open_webui.env import SRC_LOG_LEVELS @@ -24,7 +24,7 @@ from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status from fastapi.responses import FileResponse, StreamingResponse -from open_webui.utils.utils import get_admin_user, get_verified_user +from open_webui.utils.auth import get_admin_user, get_verified_user log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MODELS"]) diff --git a/backend/open_webui/apps/webui/routers/functions.py b/backend/open_webui/routers/functions.py similarity index 98% rename from backend/open_webui/apps/webui/routers/functions.py rename to backend/open_webui/routers/functions.py index aeaceecfb..7f3305f25 100644 --- a/backend/open_webui/apps/webui/routers/functions.py +++ b/backend/open_webui/routers/functions.py @@ -2,17 +2,17 @@ import os from pathlib import Path from typing import Optional -from open_webui.apps.webui.models.functions import ( +from open_webui.models.functions import ( FunctionForm, FunctionModel, FunctionResponse, Functions, ) -from open_webui.apps.webui.utils import load_function_module_by_id, replace_imports +from open_webui.utils.plugin import load_function_module_by_id, replace_imports from open_webui.config import CACHE_DIR from open_webui.constants import ERROR_MESSAGES from fastapi import APIRouter, Depends, HTTPException, Request, status -from open_webui.utils.utils import get_admin_user, get_verified_user +from open_webui.utils.auth import get_admin_user, get_verified_user router = APIRouter() diff --git a/backend/open_webui/apps/webui/routers/groups.py b/backend/open_webui/routers/groups.py similarity index 96% rename from backend/open_webui/apps/webui/routers/groups.py rename to backend/open_webui/routers/groups.py index 59d7d0052..e8f8994a4 100644 --- a/backend/open_webui/apps/webui/routers/groups.py +++ b/backend/open_webui/routers/groups.py @@ -2,7 +2,7 @@ import os from pathlib import Path from typing import Optional -from open_webui.apps.webui.models.groups import ( +from open_webui.models.groups import ( Groups, GroupForm, GroupUpdateForm, @@ -12,7 +12,7 @@ from open_webui.apps.webui.models.groups import ( from open_webui.config import CACHE_DIR from open_webui.constants import ERROR_MESSAGES from fastapi import APIRouter, Depends, HTTPException, Request, status -from open_webui.utils.utils import get_admin_user, get_verified_user +from open_webui.utils.auth import get_admin_user, get_verified_user router = APIRouter() diff --git a/backend/open_webui/apps/images/main.py b/backend/open_webui/routers/images.py similarity index 57% rename from backend/open_webui/apps/images/main.py rename to backend/open_webui/routers/images.py index 62c76425d..3f51fbdb4 100644 --- a/backend/open_webui/apps/images/main.py +++ b/backend/open_webui/routers/images.py @@ -9,38 +9,24 @@ from pathlib import Path from typing import Optional import requests -from open_webui.apps.images.utils.comfyui import ( + + +from fastapi import Depends, FastAPI, HTTPException, Request, APIRouter +from fastapi.middleware.cors import CORSMiddleware +from pydantic import BaseModel + + +from open_webui.config import CACHE_DIR +from open_webui.constants import ERROR_MESSAGES +from open_webui.env import ENV, SRC_LOG_LEVELS, ENABLE_FORWARD_USER_INFO_HEADERS + +from open_webui.utils.auth import get_admin_user, get_verified_user +from open_webui.utils.images.comfyui import ( ComfyUIGenerateImageForm, ComfyUIWorkflow, comfyui_generate_image, ) -from open_webui.config import ( - AUTOMATIC1111_API_AUTH, - AUTOMATIC1111_BASE_URL, - AUTOMATIC1111_CFG_SCALE, - AUTOMATIC1111_SAMPLER, - AUTOMATIC1111_SCHEDULER, - CACHE_DIR, - COMFYUI_BASE_URL, - COMFYUI_WORKFLOW, - COMFYUI_WORKFLOW_NODES, - CORS_ALLOW_ORIGIN, - ENABLE_IMAGE_GENERATION, - IMAGE_GENERATION_ENGINE, - IMAGE_GENERATION_MODEL, - IMAGE_SIZE, - IMAGE_STEPS, - IMAGES_OPENAI_API_BASE_URL, - IMAGES_OPENAI_API_KEY, - AppConfig, -) -from open_webui.constants import ERROR_MESSAGES -from open_webui.env import ENV, SRC_LOG_LEVELS, ENABLE_FORWARD_USER_INFO_HEADERS -from fastapi import Depends, FastAPI, HTTPException, Request -from fastapi.middleware.cors import CORSMiddleware -from pydantic import BaseModel -from open_webui.utils.utils import get_admin_user, get_verified_user log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["IMAGES"]) @@ -48,63 +34,30 @@ log.setLevel(SRC_LOG_LEVELS["IMAGES"]) IMAGE_CACHE_DIR = Path(CACHE_DIR).joinpath("./image/generations/") IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True) -app = FastAPI( - docs_url="/docs" if ENV == "dev" else None, - openapi_url="/openapi.json" if ENV == "dev" else None, - redoc_url=None, -) -app.add_middleware( - CORSMiddleware, - allow_origins=CORS_ALLOW_ORIGIN, - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) - -app.state.config = AppConfig() - -app.state.config.ENGINE = IMAGE_GENERATION_ENGINE -app.state.config.ENABLED = ENABLE_IMAGE_GENERATION - -app.state.config.OPENAI_API_BASE_URL = IMAGES_OPENAI_API_BASE_URL -app.state.config.OPENAI_API_KEY = IMAGES_OPENAI_API_KEY - -app.state.config.MODEL = IMAGE_GENERATION_MODEL - -app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL -app.state.config.AUTOMATIC1111_API_AUTH = AUTOMATIC1111_API_AUTH -app.state.config.AUTOMATIC1111_CFG_SCALE = AUTOMATIC1111_CFG_SCALE -app.state.config.AUTOMATIC1111_SAMPLER = AUTOMATIC1111_SAMPLER -app.state.config.AUTOMATIC1111_SCHEDULER = AUTOMATIC1111_SCHEDULER -app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL -app.state.config.COMFYUI_WORKFLOW = COMFYUI_WORKFLOW -app.state.config.COMFYUI_WORKFLOW_NODES = COMFYUI_WORKFLOW_NODES - -app.state.config.IMAGE_SIZE = IMAGE_SIZE -app.state.config.IMAGE_STEPS = IMAGE_STEPS +router = APIRouter() -@app.get("/config") +@router.get("/config") async def get_config(request: Request, user=Depends(get_admin_user)): return { - "enabled": app.state.config.ENABLED, - "engine": app.state.config.ENGINE, + "enabled": request.app.state.config.ENABLE_IMAGE_GENERATION, + "engine": request.app.state.config.IMAGE_GENERATION_ENGINE, "openai": { - "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL, - "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY, + "OPENAI_API_BASE_URL": request.app.state.config.IMAGES_OPENAI_API_BASE_URL, + "OPENAI_API_KEY": request.app.state.config.IMAGES_OPENAI_API_KEY, }, "automatic1111": { - "AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL, - "AUTOMATIC1111_API_AUTH": app.state.config.AUTOMATIC1111_API_AUTH, - "AUTOMATIC1111_CFG_SCALE": app.state.config.AUTOMATIC1111_CFG_SCALE, - "AUTOMATIC1111_SAMPLER": app.state.config.AUTOMATIC1111_SAMPLER, - "AUTOMATIC1111_SCHEDULER": app.state.config.AUTOMATIC1111_SCHEDULER, + "AUTOMATIC1111_BASE_URL": request.app.state.config.AUTOMATIC1111_BASE_URL, + "AUTOMATIC1111_API_AUTH": request.app.state.config.AUTOMATIC1111_API_AUTH, + "AUTOMATIC1111_CFG_SCALE": request.app.state.config.AUTOMATIC1111_CFG_SCALE, + "AUTOMATIC1111_SAMPLER": request.app.state.config.AUTOMATIC1111_SAMPLER, + "AUTOMATIC1111_SCHEDULER": request.app.state.config.AUTOMATIC1111_SCHEDULER, }, "comfyui": { - "COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL, - "COMFYUI_WORKFLOW": app.state.config.COMFYUI_WORKFLOW, - "COMFYUI_WORKFLOW_NODES": app.state.config.COMFYUI_WORKFLOW_NODES, + "COMFYUI_BASE_URL": request.app.state.config.COMFYUI_BASE_URL, + "COMFYUI_WORKFLOW": request.app.state.config.COMFYUI_WORKFLOW, + "COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES, }, } @@ -117,7 +70,7 @@ class OpenAIConfigForm(BaseModel): class Automatic1111ConfigForm(BaseModel): AUTOMATIC1111_BASE_URL: str AUTOMATIC1111_API_AUTH: str - AUTOMATIC1111_CFG_SCALE: Optional[str] + AUTOMATIC1111_CFG_SCALE: Optional[str | float | int] AUTOMATIC1111_SAMPLER: Optional[str] AUTOMATIC1111_SCHEDULER: Optional[str] @@ -136,133 +89,156 @@ class ConfigForm(BaseModel): comfyui: ComfyUIConfigForm -@app.post("/config/update") -async def update_config(form_data: ConfigForm, user=Depends(get_admin_user)): - app.state.config.ENGINE = form_data.engine - app.state.config.ENABLED = form_data.enabled +@router.post("/config/update") +async def update_config( + request: Request, form_data: ConfigForm, user=Depends(get_admin_user) +): + request.app.state.config.IMAGE_GENERATION_ENGINE = form_data.engine + request.app.state.config.ENABLE_IMAGE_GENERATION = form_data.enabled - app.state.config.OPENAI_API_BASE_URL = form_data.openai.OPENAI_API_BASE_URL - app.state.config.OPENAI_API_KEY = form_data.openai.OPENAI_API_KEY + request.app.state.config.IMAGES_OPENAI_API_BASE_URL = ( + form_data.openai.OPENAI_API_BASE_URL + ) + request.app.state.config.IMAGES_OPENAI_API_KEY = form_data.openai.OPENAI_API_KEY - app.state.config.AUTOMATIC1111_BASE_URL = ( + request.app.state.config.AUTOMATIC1111_BASE_URL = ( form_data.automatic1111.AUTOMATIC1111_BASE_URL ) - app.state.config.AUTOMATIC1111_API_AUTH = ( + request.app.state.config.AUTOMATIC1111_API_AUTH = ( form_data.automatic1111.AUTOMATIC1111_API_AUTH ) - app.state.config.AUTOMATIC1111_CFG_SCALE = ( + request.app.state.config.AUTOMATIC1111_CFG_SCALE = ( float(form_data.automatic1111.AUTOMATIC1111_CFG_SCALE) if form_data.automatic1111.AUTOMATIC1111_CFG_SCALE else None ) - app.state.config.AUTOMATIC1111_SAMPLER = ( + request.app.state.config.AUTOMATIC1111_SAMPLER = ( form_data.automatic1111.AUTOMATIC1111_SAMPLER if form_data.automatic1111.AUTOMATIC1111_SAMPLER else None ) - app.state.config.AUTOMATIC1111_SCHEDULER = ( + request.app.state.config.AUTOMATIC1111_SCHEDULER = ( form_data.automatic1111.AUTOMATIC1111_SCHEDULER if form_data.automatic1111.AUTOMATIC1111_SCHEDULER else None ) - app.state.config.COMFYUI_BASE_URL = form_data.comfyui.COMFYUI_BASE_URL.strip("/") - app.state.config.COMFYUI_WORKFLOW = form_data.comfyui.COMFYUI_WORKFLOW - app.state.config.COMFYUI_WORKFLOW_NODES = form_data.comfyui.COMFYUI_WORKFLOW_NODES + request.app.state.config.COMFYUI_BASE_URL = ( + form_data.comfyui.COMFYUI_BASE_URL.strip("/") + ) + request.app.state.config.COMFYUI_WORKFLOW = form_data.comfyui.COMFYUI_WORKFLOW + request.app.state.config.COMFYUI_WORKFLOW_NODES = ( + form_data.comfyui.COMFYUI_WORKFLOW_NODES + ) return { - "enabled": app.state.config.ENABLED, - "engine": app.state.config.ENGINE, + "enabled": request.app.state.config.ENABLE_IMAGE_GENERATION, + "engine": request.app.state.config.IMAGE_GENERATION_ENGINE, "openai": { - "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL, - "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY, + "OPENAI_API_BASE_URL": request.app.state.config.IMAGES_OPENAI_API_BASE_URL, + "OPENAI_API_KEY": request.app.state.config.IMAGES_OPENAI_API_KEY, }, "automatic1111": { - "AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL, - "AUTOMATIC1111_API_AUTH": app.state.config.AUTOMATIC1111_API_AUTH, - "AUTOMATIC1111_CFG_SCALE": app.state.config.AUTOMATIC1111_CFG_SCALE, - "AUTOMATIC1111_SAMPLER": app.state.config.AUTOMATIC1111_SAMPLER, - "AUTOMATIC1111_SCHEDULER": app.state.config.AUTOMATIC1111_SCHEDULER, + "AUTOMATIC1111_BASE_URL": request.app.state.config.AUTOMATIC1111_BASE_URL, + "AUTOMATIC1111_API_AUTH": request.app.state.config.AUTOMATIC1111_API_AUTH, + "AUTOMATIC1111_CFG_SCALE": request.app.state.config.AUTOMATIC1111_CFG_SCALE, + "AUTOMATIC1111_SAMPLER": request.app.state.config.AUTOMATIC1111_SAMPLER, + "AUTOMATIC1111_SCHEDULER": request.app.state.config.AUTOMATIC1111_SCHEDULER, }, "comfyui": { - "COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL, - "COMFYUI_WORKFLOW": app.state.config.COMFYUI_WORKFLOW, - "COMFYUI_WORKFLOW_NODES": app.state.config.COMFYUI_WORKFLOW_NODES, + "COMFYUI_BASE_URL": request.app.state.config.COMFYUI_BASE_URL, + "COMFYUI_WORKFLOW": request.app.state.config.COMFYUI_WORKFLOW, + "COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES, }, } -def get_automatic1111_api_auth(): - if app.state.config.AUTOMATIC1111_API_AUTH is None: +def get_automatic1111_api_auth(request: Request): + if request.app.state.config.AUTOMATIC1111_API_AUTH is None: return "" else: - auth1111_byte_string = app.state.config.AUTOMATIC1111_API_AUTH.encode("utf-8") + auth1111_byte_string = request.app.state.config.AUTOMATIC1111_API_AUTH.encode( + "utf-8" + ) auth1111_base64_encoded_bytes = base64.b64encode(auth1111_byte_string) auth1111_base64_encoded_string = auth1111_base64_encoded_bytes.decode("utf-8") return f"Basic {auth1111_base64_encoded_string}" -@app.get("/config/url/verify") -async def verify_url(user=Depends(get_admin_user)): - if app.state.config.ENGINE == "automatic1111": +@router.get("/config/url/verify") +async def verify_url(request: Request, user=Depends(get_admin_user)): + if request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111": try: r = requests.get( - url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", - headers={"authorization": get_automatic1111_api_auth()}, + url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", + headers={"authorization": get_automatic1111_api_auth(request)}, ) r.raise_for_status() return True except Exception: - app.state.config.ENABLED = False + request.app.state.config.ENABLE_IMAGE_GENERATION = False raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL) - elif app.state.config.ENGINE == "comfyui": + elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui": try: - r = requests.get(url=f"{app.state.config.COMFYUI_BASE_URL}/object_info") + r = requests.get( + url=f"{request.app.state.config.COMFYUI_BASE_URL}/object_info" + ) r.raise_for_status() return True except Exception: - app.state.config.ENABLED = False + request.app.state.config.ENABLE_IMAGE_GENERATION = False raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL) else: return True -def set_image_model(model: str): +def set_image_model(request: Request, model: str): log.info(f"Setting image model to {model}") - app.state.config.MODEL = model - if app.state.config.ENGINE in ["", "automatic1111"]: + request.app.state.config.IMAGE_GENERATION_MODEL = model + if request.app.state.config.IMAGE_GENERATION_ENGINE in ["", "automatic1111"]: api_auth = get_automatic1111_api_auth() r = requests.get( - url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", + url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", headers={"authorization": api_auth}, ) options = r.json() if model != options["sd_model_checkpoint"]: options["sd_model_checkpoint"] = model r = requests.post( - url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", + url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", json=options, headers={"authorization": api_auth}, ) - return app.state.config.MODEL + return request.app.state.config.IMAGE_GENERATION_MODEL -def get_image_model(): - if app.state.config.ENGINE == "openai": - return app.state.config.MODEL if app.state.config.MODEL else "dall-e-2" - elif app.state.config.ENGINE == "comfyui": - return app.state.config.MODEL if app.state.config.MODEL else "" - elif app.state.config.ENGINE == "automatic1111" or app.state.config.ENGINE == "": +def get_image_model(request): + if request.app.state.config.IMAGE_GENERATION_ENGINE == "openai": + return ( + request.app.state.config.IMAGE_GENERATION_MODEL + if request.app.state.config.IMAGE_GENERATION_MODEL + else "dall-e-2" + ) + elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui": + return ( + request.app.state.config.IMAGE_GENERATION_MODEL + if request.app.state.config.IMAGE_GENERATION_MODEL + else "" + ) + elif ( + request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111" + or request.app.state.config.IMAGE_GENERATION_ENGINE == "" + ): try: r = requests.get( - url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", + url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", headers={"authorization": get_automatic1111_api_auth()}, ) options = r.json() return options["sd_model_checkpoint"] except Exception as e: - app.state.config.ENABLED = False + request.app.state.config.ENABLE_IMAGE_GENERATION = False raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) @@ -272,23 +248,25 @@ class ImageConfigForm(BaseModel): IMAGE_STEPS: int -@app.get("/image/config") -async def get_image_config(user=Depends(get_admin_user)): +@router.get("/image/config") +async def get_image_config(request: Request, user=Depends(get_admin_user)): return { - "MODEL": app.state.config.MODEL, - "IMAGE_SIZE": app.state.config.IMAGE_SIZE, - "IMAGE_STEPS": app.state.config.IMAGE_STEPS, + "MODEL": request.app.state.config.IMAGE_GENERATION_MODEL, + "IMAGE_SIZE": request.app.state.config.IMAGE_SIZE, + "IMAGE_STEPS": request.app.state.config.IMAGE_STEPS, } -@app.post("/image/config/update") -async def update_image_config(form_data: ImageConfigForm, user=Depends(get_admin_user)): +@router.post("/image/config/update") +async def update_image_config( + request: Request, form_data: ImageConfigForm, user=Depends(get_admin_user) +): - set_image_model(form_data.MODEL) + set_image_model(request, form_data.MODEL) pattern = r"^\d+x\d+$" if re.match(pattern, form_data.IMAGE_SIZE): - app.state.config.IMAGE_SIZE = form_data.IMAGE_SIZE + request.app.state.config.IMAGE_SIZE = form_data.IMAGE_SIZE else: raise HTTPException( status_code=400, @@ -296,7 +274,7 @@ async def update_image_config(form_data: ImageConfigForm, user=Depends(get_admin ) if form_data.IMAGE_STEPS >= 0: - app.state.config.IMAGE_STEPS = form_data.IMAGE_STEPS + request.app.state.config.IMAGE_STEPS = form_data.IMAGE_STEPS else: raise HTTPException( status_code=400, @@ -304,29 +282,31 @@ async def update_image_config(form_data: ImageConfigForm, user=Depends(get_admin ) return { - "MODEL": app.state.config.MODEL, - "IMAGE_SIZE": app.state.config.IMAGE_SIZE, - "IMAGE_STEPS": app.state.config.IMAGE_STEPS, + "MODEL": request.app.state.config.IMAGE_GENERATION_MODEL, + "IMAGE_SIZE": request.app.state.config.IMAGE_SIZE, + "IMAGE_STEPS": request.app.state.config.IMAGE_STEPS, } -@app.get("/models") -def get_models(user=Depends(get_verified_user)): +@router.get("/models") +def get_models(request: Request, user=Depends(get_verified_user)): try: - if app.state.config.ENGINE == "openai": + if request.app.state.config.IMAGE_GENERATION_ENGINE == "openai": return [ {"id": "dall-e-2", "name": "DALL·E 2"}, {"id": "dall-e-3", "name": "DALL·E 3"}, ] - elif app.state.config.ENGINE == "comfyui": + elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui": # TODO - get models from comfyui - r = requests.get(url=f"{app.state.config.COMFYUI_BASE_URL}/object_info") + r = requests.get( + url=f"{request.app.state.config.COMFYUI_BASE_URL}/object_info" + ) info = r.json() - workflow = json.loads(app.state.config.COMFYUI_WORKFLOW) + workflow = json.loads(request.app.state.config.COMFYUI_WORKFLOW) model_node_id = None - for node in app.state.config.COMFYUI_WORKFLOW_NODES: + for node in request.app.state.config.COMFYUI_WORKFLOW_NODES: if node["type"] == "model": if node["node_ids"]: model_node_id = node["node_ids"][0] @@ -362,10 +342,11 @@ def get_models(user=Depends(get_verified_user)): ) ) elif ( - app.state.config.ENGINE == "automatic1111" or app.state.config.ENGINE == "" + request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111" + or request.app.state.config.IMAGE_GENERATION_ENGINE == "" ): r = requests.get( - url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models", + url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models", headers={"authorization": get_automatic1111_api_auth()}, ) models = r.json() @@ -376,7 +357,7 @@ def get_models(user=Depends(get_verified_user)): ) ) except Exception as e: - app.state.config.ENABLED = False + request.app.state.config.ENABLE_IMAGE_GENERATION = False raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) @@ -448,18 +429,21 @@ def save_url_image(url): return None -@app.post("/generations") +@router.post("/generations") async def image_generations( + request: Request, form_data: GenerateImageForm, user=Depends(get_verified_user), ): - width, height = tuple(map(int, app.state.config.IMAGE_SIZE.split("x"))) + width, height = tuple(map(int, request.app.state.config.IMAGE_SIZE.split("x"))) r = None try: - if app.state.config.ENGINE == "openai": + if request.app.state.config.IMAGE_GENERATION_ENGINE == "openai": headers = {} - headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEY}" + headers["Authorization"] = ( + f"Bearer {request.app.state.config.IMAGES_OPENAI_API_KEY}" + ) headers["Content-Type"] = "application/json" if ENABLE_FORWARD_USER_INFO_HEADERS: @@ -470,14 +454,16 @@ async def image_generations( data = { "model": ( - app.state.config.MODEL - if app.state.config.MODEL != "" + request.app.state.config.IMAGE_GENERATION_MODEL + if request.app.state.config.IMAGE_GENERATION_MODEL != "" else "dall-e-2" ), "prompt": form_data.prompt, "n": form_data.n, "size": ( - form_data.size if form_data.size else app.state.config.IMAGE_SIZE + form_data.size + if form_data.size + else request.app.state.config.IMAGE_SIZE ), "response_format": "b64_json", } @@ -485,7 +471,7 @@ async def image_generations( # Use asyncio.to_thread for the requests.post call r = await asyncio.to_thread( requests.post, - url=f"{app.state.config.OPENAI_API_BASE_URL}/images/generations", + url=f"{request.app.state.config.IMAGES_OPENAI_API_BASE_URL}/images/generations", json=data, headers=headers, ) @@ -505,7 +491,7 @@ async def image_generations( return images - elif app.state.config.ENGINE == "comfyui": + elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui": data = { "prompt": form_data.prompt, "width": width, @@ -513,8 +499,8 @@ async def image_generations( "n": form_data.n, } - if app.state.config.IMAGE_STEPS is not None: - data["steps"] = app.state.config.IMAGE_STEPS + if request.app.state.config.IMAGE_STEPS is not None: + data["steps"] = request.app.state.config.IMAGE_STEPS if form_data.negative_prompt is not None: data["negative_prompt"] = form_data.negative_prompt @@ -523,18 +509,18 @@ async def image_generations( **{ "workflow": ComfyUIWorkflow( **{ - "workflow": app.state.config.COMFYUI_WORKFLOW, - "nodes": app.state.config.COMFYUI_WORKFLOW_NODES, + "workflow": request.app.state.config.COMFYUI_WORKFLOW, + "nodes": request.app.state.config.COMFYUI_WORKFLOW_NODES, } ), **data, } ) res = await comfyui_generate_image( - app.state.config.MODEL, + request.app.state.config.IMAGE_GENERATION_MODEL, form_data, user.id, - app.state.config.COMFYUI_BASE_URL, + request.app.state.config.COMFYUI_BASE_URL, ) log.debug(f"res: {res}") @@ -551,7 +537,8 @@ async def image_generations( log.debug(f"images: {images}") return images elif ( - app.state.config.ENGINE == "automatic1111" or app.state.config.ENGINE == "" + request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111" + or request.app.state.config.IMAGE_GENERATION_ENGINE == "" ): if form_data.model: set_image_model(form_data.model) @@ -563,25 +550,25 @@ async def image_generations( "height": height, } - if app.state.config.IMAGE_STEPS is not None: - data["steps"] = app.state.config.IMAGE_STEPS + if request.app.state.config.IMAGE_STEPS is not None: + data["steps"] = request.app.state.config.IMAGE_STEPS if form_data.negative_prompt is not None: data["negative_prompt"] = form_data.negative_prompt - if app.state.config.AUTOMATIC1111_CFG_SCALE: - data["cfg_scale"] = app.state.config.AUTOMATIC1111_CFG_SCALE + if request.app.state.config.AUTOMATIC1111_CFG_SCALE: + data["cfg_scale"] = request.app.state.config.AUTOMATIC1111_CFG_SCALE - if app.state.config.AUTOMATIC1111_SAMPLER: - data["sampler_name"] = app.state.config.AUTOMATIC1111_SAMPLER + if request.app.state.config.AUTOMATIC1111_SAMPLER: + data["sampler_name"] = request.app.state.config.AUTOMATIC1111_SAMPLER - if app.state.config.AUTOMATIC1111_SCHEDULER: - data["scheduler"] = app.state.config.AUTOMATIC1111_SCHEDULER + if request.app.state.config.AUTOMATIC1111_SCHEDULER: + data["scheduler"] = request.app.state.config.AUTOMATIC1111_SCHEDULER # Use asyncio.to_thread for the requests.post call r = await asyncio.to_thread( requests.post, - url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img", + url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img", json=data, headers={"authorization": get_automatic1111_api_auth()}, ) diff --git a/backend/open_webui/apps/webui/routers/knowledge.py b/backend/open_webui/routers/knowledge.py similarity index 96% rename from backend/open_webui/apps/webui/routers/knowledge.py rename to backend/open_webui/routers/knowledge.py index 1b063cda2..7f9947d7a 100644 --- a/backend/open_webui/apps/webui/routers/knowledge.py +++ b/backend/open_webui/routers/knowledge.py @@ -4,19 +4,19 @@ from pydantic import BaseModel from fastapi import APIRouter, Depends, HTTPException, status, Request import logging -from open_webui.apps.webui.models.knowledge import ( +from open_webui.models.knowledge import ( Knowledges, KnowledgeForm, KnowledgeResponse, KnowledgeUserResponse, ) -from open_webui.apps.webui.models.files import Files, FileModel -from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT -from open_webui.apps.retrieval.main import process_file, ProcessFileForm +from open_webui.models.files import Files, FileModel +from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT +from open_webui.routers.retrieval import process_file, ProcessFileForm from open_webui.constants import ERROR_MESSAGES -from open_webui.utils.utils import get_admin_user, get_verified_user +from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.utils.access_control import has_access, has_permission @@ -242,6 +242,7 @@ class KnowledgeFileIdForm(BaseModel): @router.post("/{id}/file/add", response_model=Optional[KnowledgeFilesResponse]) def add_file_to_knowledge_by_id( + request: Request, id: str, form_data: KnowledgeFileIdForm, user=Depends(get_verified_user), @@ -274,7 +275,9 @@ def add_file_to_knowledge_by_id( # Add content to the vector database try: - process_file(ProcessFileForm(file_id=form_data.file_id, collection_name=id)) + process_file( + request, ProcessFileForm(file_id=form_data.file_id, collection_name=id) + ) except Exception as e: log.debug(e) raise HTTPException( @@ -318,6 +321,7 @@ def add_file_to_knowledge_by_id( @router.post("/{id}/file/update", response_model=Optional[KnowledgeFilesResponse]) def update_file_from_knowledge_by_id( + request: Request, id: str, form_data: KnowledgeFileIdForm, user=Depends(get_verified_user), @@ -349,7 +353,9 @@ def update_file_from_knowledge_by_id( # Add content to the vector database try: - process_file(ProcessFileForm(file_id=form_data.file_id, collection_name=id)) + process_file( + request, ProcessFileForm(file_id=form_data.file_id, collection_name=id) + ) except Exception as e: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, diff --git a/backend/open_webui/apps/webui/routers/memories.py b/backend/open_webui/routers/memories.py similarity index 96% rename from backend/open_webui/apps/webui/routers/memories.py rename to backend/open_webui/routers/memories.py index ccf84a9d4..e72cf1445 100644 --- a/backend/open_webui/apps/webui/routers/memories.py +++ b/backend/open_webui/routers/memories.py @@ -3,9 +3,9 @@ from pydantic import BaseModel import logging from typing import Optional -from open_webui.apps.webui.models.memories import Memories, MemoryModel -from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT -from open_webui.utils.utils import get_verified_user +from open_webui.models.memories import Memories, MemoryModel +from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT +from open_webui.utils.auth import get_verified_user from open_webui.env import SRC_LOG_LEVELS diff --git a/backend/open_webui/apps/webui/routers/models.py b/backend/open_webui/routers/models.py similarity index 97% rename from backend/open_webui/apps/webui/routers/models.py rename to backend/open_webui/routers/models.py index 6a8085385..db981a913 100644 --- a/backend/open_webui/apps/webui/routers/models.py +++ b/backend/open_webui/routers/models.py @@ -1,6 +1,6 @@ from typing import Optional -from open_webui.apps.webui.models.models import ( +from open_webui.models.models import ( ModelForm, ModelModel, ModelResponse, @@ -11,7 +11,7 @@ from open_webui.constants import ERROR_MESSAGES from fastapi import APIRouter, Depends, HTTPException, Request, status -from open_webui.utils.utils import get_admin_user, get_verified_user +from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.utils.access_control import has_access, has_permission diff --git a/backend/open_webui/apps/ollama/main.py b/backend/open_webui/routers/ollama.py similarity index 58% rename from backend/open_webui/apps/ollama/main.py rename to backend/open_webui/routers/ollama.py index 82a37a752..233e30ce5 100644 --- a/backend/open_webui/apps/ollama/main.py +++ b/backend/open_webui/routers/ollama.py @@ -1,3 +1,7 @@ +# TODO: Implement a more intelligent load balancing mechanism for distributing requests among multiple backend instances. +# Current implementation uses a simple round-robin approach (random.choice). Consider incorporating algorithms like weighted round-robin, +# least connections, or least response time for better resource utilization and performance optimization. + import asyncio import json import logging @@ -12,31 +16,23 @@ import aiohttp from aiocache import cached import requests -from open_webui.apps.webui.models.models import Models -from open_webui.config import ( - CORS_ALLOW_ORIGIN, - ENABLE_OLLAMA_API, - OLLAMA_BASE_URLS, - OLLAMA_API_CONFIGS, - UPLOAD_DIR, - AppConfig, -) -from open_webui.env import ( - AIOHTTP_CLIENT_TIMEOUT, - AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST, - BYPASS_MODEL_ACCESS_CONTROL, -) - -from open_webui.constants import ERROR_MESSAGES -from open_webui.env import ENV, SRC_LOG_LEVELS -from fastapi import Depends, FastAPI, File, HTTPException, Request, UploadFile +from fastapi import ( + Depends, + FastAPI, + File, + HTTPException, + Request, + UploadFile, + APIRouter, +) from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse from pydantic import BaseModel, ConfigDict from starlette.background import BackgroundTask +from open_webui.models.models import Models from open_webui.utils.misc import ( calculate_sha256, ) @@ -45,131 +41,40 @@ from open_webui.utils.payload import ( apply_model_params_to_body_openai, apply_model_system_prompt_to_body, ) -from open_webui.utils.utils import get_admin_user, get_verified_user +from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.utils.access_control import has_access + +from open_webui.config import ( + UPLOAD_DIR, +) +from open_webui.env import ( + ENV, + SRC_LOG_LEVELS, + AIOHTTP_CLIENT_TIMEOUT, + AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST, + BYPASS_MODEL_ACCESS_CONTROL, +) +from open_webui.constants import ERROR_MESSAGES + log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["OLLAMA"]) -app = FastAPI( - docs_url="/docs" if ENV == "dev" else None, - openapi_url="/openapi.json" if ENV == "dev" else None, - redoc_url=None, -) - -app.add_middleware( - CORSMiddleware, - allow_origins=CORS_ALLOW_ORIGIN, - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) - -app.state.config = AppConfig() - -app.state.config.ENABLE_OLLAMA_API = ENABLE_OLLAMA_API -app.state.config.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS -app.state.config.OLLAMA_API_CONFIGS = OLLAMA_API_CONFIGS +########################################## +# +# Utility functions +# +########################################## -# TODO: Implement a more intelligent load balancing mechanism for distributing requests among multiple backend instances. -# Current implementation uses a simple round-robin approach (random.choice). Consider incorporating algorithms like weighted round-robin, -# least connections, or least response time for better resource utilization and performance optimization. - - -@app.head("/") -@app.get("/") -async def get_status(): - return {"status": True} - - -class ConnectionVerificationForm(BaseModel): - url: str - key: Optional[str] = None - - -@app.post("/verify") -async def verify_connection( - form_data: ConnectionVerificationForm, user=Depends(get_admin_user) -): - url = form_data.url - key = form_data.key - - headers = {} - if key: - headers["Authorization"] = f"Bearer {key}" - - timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST) - async with aiohttp.ClientSession(timeout=timeout) as session: - try: - async with session.get(f"{url}/api/version", headers=headers) as r: - if r.status != 200: - # Extract response error details if available - error_detail = f"HTTP Error: {r.status}" - res = await r.json() - if "error" in res: - error_detail = f"External Error: {res['error']}" - raise Exception(error_detail) - - response_data = await r.json() - return response_data - - except aiohttp.ClientError as e: - # ClientError covers all aiohttp requests issues - log.exception(f"Client error: {str(e)}") - # Handle aiohttp-specific connection issues, timeout etc. - raise HTTPException( - status_code=500, detail="Open WebUI: Server Connection Error" - ) - except Exception as e: - log.exception(f"Unexpected error: {e}") - # Generic error handler in case parsing JSON or other steps fail - error_detail = f"Unexpected error: {str(e)}" - raise HTTPException(status_code=500, detail=error_detail) - - -@app.get("/config") -async def get_config(user=Depends(get_admin_user)): - return { - "ENABLE_OLLAMA_API": app.state.config.ENABLE_OLLAMA_API, - "OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS, - "OLLAMA_API_CONFIGS": app.state.config.OLLAMA_API_CONFIGS, - } - - -class OllamaConfigForm(BaseModel): - ENABLE_OLLAMA_API: Optional[bool] = None - OLLAMA_BASE_URLS: list[str] - OLLAMA_API_CONFIGS: dict - - -@app.post("/config/update") -async def update_config(form_data: OllamaConfigForm, user=Depends(get_admin_user)): - app.state.config.ENABLE_OLLAMA_API = form_data.ENABLE_OLLAMA_API - app.state.config.OLLAMA_BASE_URLS = form_data.OLLAMA_BASE_URLS - - app.state.config.OLLAMA_API_CONFIGS = form_data.OLLAMA_API_CONFIGS - - # Remove any extra configs - config_urls = app.state.config.OLLAMA_API_CONFIGS.keys() - for url in list(app.state.config.OLLAMA_BASE_URLS): - if url not in config_urls: - app.state.config.OLLAMA_API_CONFIGS.pop(url, None) - - return { - "ENABLE_OLLAMA_API": app.state.config.ENABLE_OLLAMA_API, - "OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS, - "OLLAMA_API_CONFIGS": app.state.config.OLLAMA_API_CONFIGS, - } - - -async def aiohttp_get(url, key=None): +async def send_get_request(url, key=None): timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST) try: - headers = {"Authorization": f"Bearer {key}"} if key else {} async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: - async with session.get(url, headers=headers) as response: + async with session.get( + url, headers={**({"Authorization": f"Bearer {key}"} if key else {})} + ) as response: return await response.json() except Exception as e: # Handle connection error here @@ -177,46 +82,44 @@ async def aiohttp_get(url, key=None): return None -async def cleanup_response( - response: Optional[aiohttp.ClientResponse], - session: Optional[aiohttp.ClientSession], +async def send_post_request( + url: str, + payload: Union[str, bytes], + stream: bool = True, + key: Optional[str] = None, + content_type: Optional[str] = None, ): - if response: - response.close() - if session: - await session.close() + async def cleanup_response( + response: Optional[aiohttp.ClientResponse], + session: Optional[aiohttp.ClientSession], + ): + if response: + response.close() + if session: + await session.close() - -async def post_streaming_url( - url: str, payload: Union[str, bytes], stream: bool = True, content_type=None -): r = None try: session = aiohttp.ClientSession( trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT) ) - parsed_url = urlparse(url) - base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" - - api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {}) - key = api_config.get("key", None) - - headers = {"Content-Type": "application/json"} - if key: - headers["Authorization"] = f"Bearer {key}" - r = await session.post( url, data=payload, - headers=headers, + headers={ + "Content-Type": "application/json", + **({"Authorization": f"Bearer {key}"} if key else {}), + }, ) r.raise_for_status() if stream: response_headers = dict(r.headers) + if content_type: response_headers["Content-Type"] = content_type + return StreamingResponse( r.content, status_code=r.status, @@ -231,61 +134,146 @@ async def post_streaming_url( return res except Exception as e: - error_detail = "Open WebUI: Server Connection Error" + detail = None + if r is not None: try: res = await r.json() if "error" in res: - error_detail = f"Ollama: {res['error']}" + detail = f"Ollama: {res.get('error', 'Unknown error')}" except Exception: - error_detail = f"Ollama: {e}" + detail = f"Ollama: {e}" raise HTTPException( status_code=r.status if r else 500, - detail=error_detail, + detail=detail if detail else "Open WebUI: Server Connection Error", ) -def merge_models_lists(model_lists): - merged_models = {} +def get_api_key(url, configs): + parsed_url = urlparse(url) + base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" + return configs.get(base_url, {}).get("key", None) - for idx, model_list in enumerate(model_lists): - if model_list is not None: - for model in model_list: - id = model["model"] - if id not in merged_models: - model["urls"] = [idx] - merged_models[id] = model - else: - merged_models[id]["urls"].append(idx) - return list(merged_models.values()) +########################################## +# +# API routes +# +########################################## + +router = APIRouter() + + +@router.head("/") +@router.get("/") +async def get_status(): + return {"status": True} + + +class ConnectionVerificationForm(BaseModel): + url: str + key: Optional[str] = None + + +@router.post("/verify") +async def verify_connection( + form_data: ConnectionVerificationForm, user=Depends(get_admin_user) +): + url = form_data.url + key = form_data.key + + async with aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST) + ) as session: + try: + async with session.get( + f"{url}/api/version", + headers={**({"Authorization": f"Bearer {key}"} if key else {})}, + ) as r: + if r.status != 200: + detail = f"HTTP Error: {r.status}" + res = await r.json() + + if "error" in res: + detail = f"External Error: {res['error']}" + raise Exception(detail) + + data = await r.json() + return data + except aiohttp.ClientError as e: + log.exception(f"Client error: {str(e)}") + raise HTTPException( + status_code=500, detail="Open WebUI: Server Connection Error" + ) + except Exception as e: + log.exception(f"Unexpected error: {e}") + error_detail = f"Unexpected error: {str(e)}" + raise HTTPException(status_code=500, detail=error_detail) + + +@router.get("/config") +async def get_config(request: Request, user=Depends(get_admin_user)): + return { + "ENABLE_OLLAMA_API": request.app.state.config.ENABLE_OLLAMA_API, + "OLLAMA_BASE_URLS": request.app.state.config.OLLAMA_BASE_URLS, + "OLLAMA_API_CONFIGS": request.app.state.config.OLLAMA_API_CONFIGS, + } + + +class OllamaConfigForm(BaseModel): + ENABLE_OLLAMA_API: Optional[bool] = None + OLLAMA_BASE_URLS: list[str] + OLLAMA_API_CONFIGS: dict + + +@router.post("/config/update") +async def update_config( + request: Request, form_data: OllamaConfigForm, user=Depends(get_admin_user) +): + request.app.state.config.ENABLE_OLLAMA_API = form_data.ENABLE_OLLAMA_API + + request.app.state.config.OLLAMA_BASE_URLS = form_data.OLLAMA_BASE_URLS + request.app.state.config.OLLAMA_API_CONFIGS = form_data.OLLAMA_API_CONFIGS + + # Remove any extra configs + config_urls = request.app.state.config.OLLAMA_API_CONFIGS.keys() + for url in list(request.app.state.config.OLLAMA_BASE_URLS): + if url not in config_urls: + request.app.state.config.OLLAMA_API_CONFIGS.pop(url, None) + + return { + "ENABLE_OLLAMA_API": request.app.state.config.ENABLE_OLLAMA_API, + "OLLAMA_BASE_URLS": request.app.state.config.OLLAMA_BASE_URLS, + "OLLAMA_API_CONFIGS": request.app.state.config.OLLAMA_API_CONFIGS, + } @cached(ttl=3) -async def get_all_models(): +async def get_all_models(request: Request): log.info("get_all_models()") - if app.state.config.ENABLE_OLLAMA_API: - tasks = [] - for idx, url in enumerate(app.state.config.OLLAMA_BASE_URLS): - if url not in app.state.config.OLLAMA_API_CONFIGS: - tasks.append(aiohttp_get(f"{url}/api/tags")) + if request.app.state.config.ENABLE_OLLAMA_API: + request_tasks = [] + + for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS): + if url not in request.app.state.config.OLLAMA_API_CONFIGS: + request_tasks.append(send_get_request(f"{url}/api/tags")) else: - api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) + api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}) enable = api_config.get("enable", True) key = api_config.get("key", None) if enable: - tasks.append(aiohttp_get(f"{url}/api/tags", key)) + request_tasks.append(send_get_request(f"{url}/api/tags", key)) else: - tasks.append(asyncio.ensure_future(asyncio.sleep(0, None))) + request_tasks.append(asyncio.ensure_future(asyncio.sleep(0, None))) - responses = await asyncio.gather(*tasks) + responses = await asyncio.gather(*request_tasks) for idx, response in enumerate(responses): if response: - url = app.state.config.OLLAMA_BASE_URLS[idx] - api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) + url = request.app.state.config.OLLAMA_BASE_URLS[idx] + api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}) prefix_id = api_config.get("prefix_id", None) model_ids = api_config.get("model_ids", []) @@ -302,6 +290,21 @@ async def get_all_models(): for model in response.get("models", []): model["model"] = f"{prefix_id}.{model['model']}" + def merge_models_lists(model_lists): + merged_models = {} + + for idx, model_list in enumerate(model_lists): + if model_list is not None: + for model in model_list: + id = model["model"] + if id not in merged_models: + model["urls"] = [idx] + merged_models[id] = model + else: + merged_models[id]["urls"].append(idx) + + return list(merged_models.values()) + models = { "models": merge_models_lists( map( @@ -314,81 +317,87 @@ async def get_all_models(): else: models = {"models": []} + request.app.state.OLLAMA_MODELS = { + model["model"]: model for model in models["models"] + } return models -@app.get("/api/tags") -@app.get("/api/tags/{url_idx}") +async def get_filtered_models(models, user): + # Filter models based on user access control + filtered_models = [] + for model in models.get("models", []): + model_info = Models.get_model_by_id(model["model"]) + if model_info: + if user.id == model_info.user_id or has_access( + user.id, type="read", access_control=model_info.access_control + ): + filtered_models.append(model) + return filtered_models + + +@router.get("/api/tags") +@router.get("/api/tags/{url_idx}") async def get_ollama_tags( - url_idx: Optional[int] = None, user=Depends(get_verified_user) + request: Request, url_idx: Optional[int] = None, user=Depends(get_verified_user) ): models = [] + if url_idx is None: - models = await get_all_models() + models = await get_all_models(request) else: - url = app.state.config.OLLAMA_BASE_URLS[url_idx] - - parsed_url = urlparse(url) - base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" - - api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {}) - key = api_config.get("key", None) - - headers = {} - if key: - headers["Authorization"] = f"Bearer {key}" + url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] + key = get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS) r = None try: - r = requests.request(method="GET", url=f"{url}/api/tags", headers=headers) + r = requests.request( + method="GET", + url=f"{url}/api/tags", + headers={**({"Authorization": f"Bearer {key}"} if key else {})}, + ) r.raise_for_status() models = r.json() except Exception as e: log.exception(e) - error_detail = "Open WebUI: Server Connection Error" + + detail = None if r is not None: try: res = r.json() if "error" in res: - error_detail = f"Ollama: {res['error']}" + detail = f"Ollama: {res['error']}" except Exception: - error_detail = f"Ollama: {e}" + detail = f"Ollama: {e}" raise HTTPException( status_code=r.status_code if r else 500, - detail=error_detail, + detail=detail if detail else "Open WebUI: Server Connection Error", ) if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL: - # Filter models based on user access control - filtered_models = [] - for model in models.get("models", []): - model_info = Models.get_model_by_id(model["model"]) - if model_info: - if user.id == model_info.user_id or has_access( - user.id, type="read", access_control=model_info.access_control - ): - filtered_models.append(model) - models["models"] = filtered_models + models["models"] = get_filtered_models(models, user) return models -@app.get("/api/version") -@app.get("/api/version/{url_idx}") -async def get_ollama_versions(url_idx: Optional[int] = None): - if app.state.config.ENABLE_OLLAMA_API: +@router.get("/api/version") +@router.get("/api/version/{url_idx}") +async def get_ollama_versions(request: Request, url_idx: Optional[int] = None): + if request.app.state.config.ENABLE_OLLAMA_API: if url_idx is None: # returns lowest version - tasks = [ - aiohttp_get( + request_tasks = [ + send_get_request( f"{url}/api/version", - app.state.config.OLLAMA_API_CONFIGS.get(url, {}).get("key", None), + request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}).get( + "key", None + ), ) - for url in app.state.config.OLLAMA_BASE_URLS + for url in request.app.state.config.OLLAMA_BASE_URLS ] - responses = await asyncio.gather(*tasks) + responses = await asyncio.gather(*request_tasks) responses = list(filter(lambda x: x is not None, responses)) if len(responses) > 0: @@ -406,7 +415,7 @@ async def get_ollama_versions(url_idx: Optional[int] = None): detail=ERROR_MESSAGES.OLLAMA_NOT_FOUND, ) else: - url = app.state.config.OLLAMA_BASE_URLS[url_idx] + url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] r = None try: @@ -416,39 +425,69 @@ async def get_ollama_versions(url_idx: Optional[int] = None): return r.json() except Exception as e: log.exception(e) - error_detail = "Open WebUI: Server Connection Error" + + detail = None if r is not None: try: res = r.json() if "error" in res: - error_detail = f"Ollama: {res['error']}" + detail = f"Ollama: {res['error']}" except Exception: - error_detail = f"Ollama: {e}" + detail = f"Ollama: {e}" raise HTTPException( status_code=r.status_code if r else 500, - detail=error_detail, + detail=detail if detail else "Open WebUI: Server Connection Error", ) else: return {"version": False} +@router.get("/api/ps") +async def get_ollama_loaded_models(request: Request, user=Depends(get_verified_user)): + """ + List models that are currently loaded into Ollama memory, and which node they are loaded on. + """ + if request.app.state.config.ENABLE_OLLAMA_API: + request_tasks = [ + send_get_request( + f"{url}/api/ps", + request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}).get( + "key", None + ), + ) + for url in request.app.state.config.OLLAMA_BASE_URLS + ] + responses = await asyncio.gather(*request_tasks) + + return dict(zip(request.app.state.config.OLLAMA_BASE_URLS, responses)) + else: + return {} + + class ModelNameForm(BaseModel): name: str -@app.post("/api/pull") -@app.post("/api/pull/{url_idx}") +@router.post("/api/pull") +@router.post("/api/pull/{url_idx}") async def pull_model( - form_data: ModelNameForm, url_idx: int = 0, user=Depends(get_admin_user) + request: Request, + form_data: ModelNameForm, + url_idx: int = 0, + user=Depends(get_admin_user), ): - url = app.state.config.OLLAMA_BASE_URLS[url_idx] + url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") # Admin should be able to pull models from any source payload = {**form_data.model_dump(exclude_none=True), "insecure": True} - return await post_streaming_url(f"{url}/api/pull", json.dumps(payload)) + return await send_post_request( + url=f"{url}/api/pull", + payload=json.dumps(payload), + key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS), + ) class PushModelForm(BaseModel): @@ -457,16 +496,17 @@ class PushModelForm(BaseModel): stream: Optional[bool] = None -@app.delete("/api/push") -@app.delete("/api/push/{url_idx}") +@router.delete("/api/push") +@router.delete("/api/push/{url_idx}") async def push_model( + request: Request, form_data: PushModelForm, url_idx: Optional[int] = None, user=Depends(get_admin_user), ): if url_idx is None: - model_list = await get_all_models() - models = {model["model"]: model for model in model_list["models"]} + await get_all_models(request) + models = request.app.state.OLLAMA_MODELS if form_data.name in models: url_idx = models[form_data.name]["urls"][0] @@ -476,11 +516,13 @@ async def push_model( detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name), ) - url = app.state.config.OLLAMA_BASE_URLS[url_idx] + url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] log.debug(f"url: {url}") - return await post_streaming_url( - f"{url}/api/push", form_data.model_dump_json(exclude_none=True).encode() + return await send_post_request( + url=f"{url}/api/push", + payload=form_data.model_dump_json(exclude_none=True).encode(), + key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS), ) @@ -491,17 +533,21 @@ class CreateModelForm(BaseModel): path: Optional[str] = None -@app.post("/api/create") -@app.post("/api/create/{url_idx}") +@router.post("/api/create") +@router.post("/api/create/{url_idx}") async def create_model( - form_data: CreateModelForm, url_idx: int = 0, user=Depends(get_admin_user) + request: Request, + form_data: CreateModelForm, + url_idx: int = 0, + user=Depends(get_admin_user), ): log.debug(f"form_data: {form_data}") - url = app.state.config.OLLAMA_BASE_URLS[url_idx] - log.info(f"url: {url}") + url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] - return await post_streaming_url( - f"{url}/api/create", form_data.model_dump_json(exclude_none=True).encode() + return await send_post_request( + url=f"{url}/api/create", + payload=form_data.model_dump_json(exclude_none=True).encode(), + key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS), ) @@ -510,16 +556,17 @@ class CopyModelForm(BaseModel): destination: str -@app.post("/api/copy") -@app.post("/api/copy/{url_idx}") +@router.post("/api/copy") +@router.post("/api/copy/{url_idx}") async def copy_model( + request: Request, form_data: CopyModelForm, url_idx: Optional[int] = None, user=Depends(get_admin_user), ): if url_idx is None: - model_list = await get_all_models() - models = {model["model"]: model for model in model_list["models"]} + await get_all_models(request) + models = request.app.state.OLLAMA_MODELS if form_data.source in models: url_idx = models[form_data.source]["urls"][0] @@ -529,59 +576,52 @@ async def copy_model( detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.source), ) - url = app.state.config.OLLAMA_BASE_URLS[url_idx] - log.info(f"url: {url}") - - parsed_url = urlparse(url) - base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" - - api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {}) - key = api_config.get("key", None) - - headers = {"Content-Type": "application/json"} - if key: - headers["Authorization"] = f"Bearer {key}" - - r = requests.request( - method="POST", - url=f"{url}/api/copy", - headers=headers, - data=form_data.model_dump_json(exclude_none=True).encode(), - ) + url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] + key = get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS) try: + r = requests.request( + method="POST", + url=f"{url}/api/copy", + headers={ + "Content-Type": "application/json", + **({"Authorization": f"Bearer {key}"} if key else {}), + }, + data=form_data.model_dump_json(exclude_none=True).encode(), + ) r.raise_for_status() log.debug(f"r.text: {r.text}") - return True except Exception as e: log.exception(e) - error_detail = "Open WebUI: Server Connection Error" + + detail = None if r is not None: try: res = r.json() if "error" in res: - error_detail = f"Ollama: {res['error']}" + detail = f"Ollama: {res['error']}" except Exception: - error_detail = f"Ollama: {e}" + detail = f"Ollama: {e}" raise HTTPException( status_code=r.status_code if r else 500, - detail=error_detail, + detail=detail if detail else "Open WebUI: Server Connection Error", ) -@app.delete("/api/delete") -@app.delete("/api/delete/{url_idx}") +@router.delete("/api/delete") +@router.delete("/api/delete/{url_idx}") async def delete_model( + request: Request, form_data: ModelNameForm, url_idx: Optional[int] = None, user=Depends(get_admin_user), ): if url_idx is None: - model_list = await get_all_models() - models = {model["model"]: model for model in model_list["models"]} + await get_all_models(request) + models = request.app.state.OLLAMA_MODELS if form_data.name in models: url_idx = models[form_data.name]["urls"][0] @@ -591,52 +631,47 @@ async def delete_model( detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name), ) - url = app.state.config.OLLAMA_BASE_URLS[url_idx] - log.info(f"url: {url}") + url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] + key = get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS) - parsed_url = urlparse(url) - base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" - - api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {}) - key = api_config.get("key", None) - - headers = {"Content-Type": "application/json"} - if key: - headers["Authorization"] = f"Bearer {key}" - - r = requests.request( - method="DELETE", - url=f"{url}/api/delete", - data=form_data.model_dump_json(exclude_none=True).encode(), - headers=headers, - ) try: + r = requests.request( + method="DELETE", + url=f"{url}/api/delete", + data=form_data.model_dump_json(exclude_none=True).encode(), + headers={ + "Content-Type": "application/json", + **({"Authorization": f"Bearer {key}"} if key else {}), + }, + ) r.raise_for_status() log.debug(f"r.text: {r.text}") - return True except Exception as e: log.exception(e) - error_detail = "Open WebUI: Server Connection Error" + + detail = None if r is not None: try: res = r.json() if "error" in res: - error_detail = f"Ollama: {res['error']}" + detail = f"Ollama: {res['error']}" except Exception: - error_detail = f"Ollama: {e}" + detail = f"Ollama: {e}" raise HTTPException( status_code=r.status_code if r else 500, - detail=error_detail, + detail=detail if detail else "Open WebUI: Server Connection Error", ) -@app.post("/api/show") -async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_user)): - model_list = await get_all_models() - models = {model["model"]: model for model in model_list["models"]} +@router.post("/api/show") +async def show_model_info( + request: Request, form_data: ModelNameForm, user=Depends(get_verified_user) +): + await get_all_models(request) + models = request.app.state.OLLAMA_MODELS if form_data.name not in models: raise HTTPException( @@ -645,53 +680,41 @@ async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_us ) url_idx = random.choice(models[form_data.name]["urls"]) - url = app.state.config.OLLAMA_BASE_URLS[url_idx] - log.info(f"url: {url}") - parsed_url = urlparse(url) - base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" + url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] + key = get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS) - api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {}) - key = api_config.get("key", None) - - headers = {"Content-Type": "application/json"} - if key: - headers["Authorization"] = f"Bearer {key}" - - r = requests.request( - method="POST", - url=f"{url}/api/show", - headers=headers, - data=form_data.model_dump_json(exclude_none=True).encode(), - ) try: + r = requests.request( + method="POST", + url=f"{url}/api/show", + headers={ + "Content-Type": "application/json", + **({"Authorization": f"Bearer {key}"} if key else {}), + }, + data=form_data.model_dump_json(exclude_none=True).encode(), + ) r.raise_for_status() return r.json() except Exception as e: log.exception(e) - error_detail = "Open WebUI: Server Connection Error" + + detail = None if r is not None: try: res = r.json() if "error" in res: - error_detail = f"Ollama: {res['error']}" + detail = f"Ollama: {res['error']}" except Exception: - error_detail = f"Ollama: {e}" + detail = f"Ollama: {e}" raise HTTPException( status_code=r.status_code if r else 500, - detail=error_detail, + detail=detail if detail else "Open WebUI: Server Connection Error", ) -class GenerateEmbeddingsForm(BaseModel): - model: str - prompt: str - options: Optional[dict] = None - keep_alive: Optional[Union[int, str]] = None - - class GenerateEmbedForm(BaseModel): model: str input: list[str] | str @@ -700,105 +723,19 @@ class GenerateEmbedForm(BaseModel): keep_alive: Optional[Union[int, str]] = None -@app.post("/api/embed") -@app.post("/api/embed/{url_idx}") -async def generate_embeddings( +@router.post("/api/embed") +@router.post("/api/embed/{url_idx}") +async def embed( + request: Request, form_data: GenerateEmbedForm, url_idx: Optional[int] = None, user=Depends(get_verified_user), -): - return await generate_ollama_batch_embeddings(form_data, url_idx) - - -@app.post("/api/embeddings") -@app.post("/api/embeddings/{url_idx}") -async def generate_embeddings( - form_data: GenerateEmbeddingsForm, - url_idx: Optional[int] = None, - user=Depends(get_verified_user), -): - return await generate_ollama_embeddings(form_data=form_data, url_idx=url_idx) - - -async def generate_ollama_embeddings( - form_data: GenerateEmbeddingsForm, - url_idx: Optional[int] = None, -): - log.info(f"generate_ollama_embeddings {form_data}") - - if url_idx is None: - model_list = await get_all_models() - models = {model["model"]: model for model in model_list["models"]} - - model = form_data.model - - if ":" not in model: - model = f"{model}:latest" - - if model in models: - url_idx = random.choice(models[model]["urls"]) - else: - raise HTTPException( - status_code=400, - detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), - ) - - url = app.state.config.OLLAMA_BASE_URLS[url_idx] - log.info(f"url: {url}") - - parsed_url = urlparse(url) - base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" - - api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {}) - key = api_config.get("key", None) - - headers = {"Content-Type": "application/json"} - if key: - headers["Authorization"] = f"Bearer {key}" - - r = requests.request( - method="POST", - url=f"{url}/api/embeddings", - headers=headers, - data=form_data.model_dump_json(exclude_none=True).encode(), - ) - try: - r.raise_for_status() - - data = r.json() - - log.info(f"generate_ollama_embeddings {data}") - - if "embedding" in data: - return data - else: - raise Exception("Something went wrong :/") - except Exception as e: - log.exception(e) - error_detail = "Open WebUI: Server Connection Error" - if r is not None: - try: - res = r.json() - if "error" in res: - error_detail = f"Ollama: {res['error']}" - except Exception: - error_detail = f"Ollama: {e}" - - raise HTTPException( - status_code=r.status_code if r else 500, - detail=error_detail, - ) - - -async def generate_ollama_batch_embeddings( - form_data: GenerateEmbedForm, - url_idx: Optional[int] = None, ): log.info(f"generate_ollama_batch_embeddings {form_data}") if url_idx is None: - model_list = await get_all_models() - models = {model["model"]: model for model in model_list["models"]} + await get_all_models(request) + models = request.app.state.OLLAMA_MODELS model = form_data.model @@ -813,48 +750,108 @@ async def generate_ollama_batch_embeddings( detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), ) - url = app.state.config.OLLAMA_BASE_URLS[url_idx] - log.info(f"url: {url}") + url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] + key = get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS) - parsed_url = urlparse(url) - base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" - - api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {}) - key = api_config.get("key", None) - - headers = {"Content-Type": "application/json"} - if key: - headers["Authorization"] = f"Bearer {key}" - - r = requests.request( - method="POST", - url=f"{url}/api/embed", - headers=headers, - data=form_data.model_dump_json(exclude_none=True).encode(), - ) try: + r = requests.request( + method="POST", + url=f"{url}/api/embed", + headers={ + "Content-Type": "application/json", + **({"Authorization": f"Bearer {key}"} if key else {}), + }, + data=form_data.model_dump_json(exclude_none=True).encode(), + ) r.raise_for_status() data = r.json() - - log.info(f"generate_ollama_batch_embeddings {data}") - - if "embeddings" in data: - return data - else: - raise Exception("Something went wrong :/") + return data except Exception as e: log.exception(e) - error_detail = "Open WebUI: Server Connection Error" + + detail = None if r is not None: try: res = r.json() if "error" in res: - error_detail = f"Ollama: {res['error']}" + detail = f"Ollama: {res['error']}" except Exception: - error_detail = f"Ollama: {e}" + detail = f"Ollama: {e}" - raise Exception(error_detail) + raise HTTPException( + status_code=r.status_code if r else 500, + detail=detail if detail else "Open WebUI: Server Connection Error", + ) + + +class GenerateEmbeddingsForm(BaseModel): + model: str + prompt: str + options: Optional[dict] = None + keep_alive: Optional[Union[int, str]] = None + + +@router.post("/api/embeddings") +@router.post("/api/embeddings/{url_idx}") +async def embeddings( + request: Request, + form_data: GenerateEmbeddingsForm, + url_idx: Optional[int] = None, + user=Depends(get_verified_user), +): + log.info(f"generate_ollama_embeddings {form_data}") + + if url_idx is None: + await get_all_models(request) + models = request.app.state.OLLAMA_MODELS + + model = form_data.model + + if ":" not in model: + model = f"{model}:latest" + + if model in models: + url_idx = random.choice(models[model]["urls"]) + else: + raise HTTPException( + status_code=400, + detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), + ) + + url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] + key = get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS) + + try: + r = requests.request( + method="POST", + url=f"{url}/api/embeddings", + headers={ + "Content-Type": "application/json", + **({"Authorization": f"Bearer {key}"} if key else {}), + }, + data=form_data.model_dump_json(exclude_none=True).encode(), + ) + r.raise_for_status() + + data = r.json() + return data + except Exception as e: + log.exception(e) + + detail = None + if r is not None: + try: + res = r.json() + if "error" in res: + detail = f"Ollama: {res['error']}" + except Exception: + detail = f"Ollama: {e}" + + raise HTTPException( + status_code=r.status_code if r else 500, + detail=detail if detail else "Open WebUI: Server Connection Error", + ) class GenerateCompletionForm(BaseModel): @@ -872,16 +869,17 @@ class GenerateCompletionForm(BaseModel): keep_alive: Optional[Union[int, str]] = None -@app.post("/api/generate") -@app.post("/api/generate/{url_idx}") +@router.post("/api/generate") +@router.post("/api/generate/{url_idx}") async def generate_completion( + request: Request, form_data: GenerateCompletionForm, url_idx: Optional[int] = None, user=Depends(get_verified_user), ): if url_idx is None: - model_list = await get_all_models() - models = {model["model"]: model for model in model_list["models"]} + await get_all_models(request) + models = request.app.state.OLLAMA_MODELS model = form_data.model @@ -896,15 +894,17 @@ async def generate_completion( detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), ) - url = app.state.config.OLLAMA_BASE_URLS[url_idx] - api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) + url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] + api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}) + prefix_id = api_config.get("prefix_id", None) if prefix_id: form_data.model = form_data.model.replace(f"{prefix_id}.", "") - log.info(f"url: {url}") - return await post_streaming_url( - f"{url}/api/generate", form_data.model_dump_json(exclude_none=True).encode() + return await send_post_request( + url=f"{url}/api/generate", + payload=form_data.model_dump_json(exclude_none=True).encode(), + key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS), ) @@ -924,31 +924,41 @@ class GenerateChatCompletionForm(BaseModel): keep_alive: Optional[Union[int, str]] = None -async def get_ollama_url(url_idx: Optional[int], model: str): +async def get_ollama_url(request: Request, model: str, url_idx: Optional[int] = None): if url_idx is None: - model_list = await get_all_models() - models = {model["model"]: model for model in model_list["models"]} - + models = request.app.state.OLLAMA_MODELS if model not in models: raise HTTPException( status_code=400, detail=ERROR_MESSAGES.MODEL_NOT_FOUND(model), ) - url_idx = random.choice(models[model]["urls"]) - url = app.state.config.OLLAMA_BASE_URLS[url_idx] + url_idx = random.choice(models[model].get("urls", [])) + url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] return url -@app.post("/api/chat") -@app.post("/api/chat/{url_idx}") +@router.post("/api/chat") +@router.post("/api/chat/{url_idx}") async def generate_chat_completion( - form_data: GenerateChatCompletionForm, + request: Request, + form_data: dict, url_idx: Optional[int] = None, user=Depends(get_verified_user), bypass_filter: Optional[bool] = False, ): + if BYPASS_MODEL_ACCESS_CONTROL: + bypass_filter = True + + try: + form_data = GenerateChatCompletionForm(**form_data) + except Exception as e: + log.exception(e) + raise HTTPException( + status_code=400, + detail=str(e), + ) + payload = {**form_data.model_dump(exclude_none=True)} - log.debug(f"generate_chat_completion() - 1.payload = {payload}") if "metadata" in payload: del payload["metadata"] @@ -992,22 +1002,18 @@ async def generate_chat_completion( if ":" not in payload["model"]: payload["model"] = f"{payload['model']}:latest" - url = await get_ollama_url(url_idx, payload["model"]) - log.info(f"url: {url}") - log.debug(f"generate_chat_completion() - 2.payload = {payload}") + url = await get_ollama_url(request, payload["model"], url_idx) + api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}) - parsed_url = urlparse(url) - base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" - - api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {}) prefix_id = api_config.get("prefix_id", None) if prefix_id: payload["model"] = payload["model"].replace(f"{prefix_id}.", "") - return await post_streaming_url( - f"{url}/api/chat", - json.dumps(payload), + return await send_post_request( + url=f"{url}/api/chat", + payload=json.dumps(payload), stream=form_data.stream, + key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS), content_type="application/x-ndjson", ) @@ -1032,9 +1038,89 @@ class OpenAIChatCompletionForm(BaseModel): model_config = ConfigDict(extra="allow") -@app.post("/v1/chat/completions") -@app.post("/v1/chat/completions/{url_idx}") +class OpenAICompletionForm(BaseModel): + model: str + prompt: str + + model_config = ConfigDict(extra="allow") + + +@router.post("/v1/completions") +@router.post("/v1/completions/{url_idx}") +async def generate_openai_completion( + request: Request, + form_data: dict, + url_idx: Optional[int] = None, + user=Depends(get_verified_user), +): + try: + form_data = OpenAICompletionForm(**form_data) + except Exception as e: + log.exception(e) + raise HTTPException( + status_code=400, + detail=str(e), + ) + + payload = {**form_data.model_dump(exclude_none=True, exclude=["metadata"])} + if "metadata" in payload: + del payload["metadata"] + + model_id = form_data.model + if ":" not in model_id: + model_id = f"{model_id}:latest" + + model_info = Models.get_model_by_id(model_id) + if model_info: + if model_info.base_model_id: + payload["model"] = model_info.base_model_id + params = model_info.params.model_dump() + + if params: + payload = apply_model_params_to_body_openai(params, payload) + + # Check if user has access to the model + if user.role == "user": + if not ( + user.id == model_info.user_id + or has_access( + user.id, type="read", access_control=model_info.access_control + ) + ): + raise HTTPException( + status_code=403, + detail="Model not found", + ) + else: + if user.role != "admin": + raise HTTPException( + status_code=403, + detail="Model not found", + ) + + if ":" not in payload["model"]: + payload["model"] = f"{payload['model']}:latest" + + url = await get_ollama_url(request, payload["model"], url_idx) + api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}) + + prefix_id = api_config.get("prefix_id", None) + + if prefix_id: + payload["model"] = payload["model"].replace(f"{prefix_id}.", "") + + return await send_post_request( + url=f"{url}/v1/completions", + payload=json.dumps(payload), + stream=payload.get("stream", False), + key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS), + ) + + +@router.post("/v1/chat/completions") +@router.post("/v1/chat/completions/{url_idx}") async def generate_openai_chat_completion( + request: Request, form_data: dict, url_idx: Optional[int] = None, user=Depends(get_verified_user), @@ -1068,7 +1154,7 @@ async def generate_openai_chat_completion( payload = apply_model_system_prompt_to_body(params, payload, user) # Check if user has access to the model - if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL: + if user.role == "user": if not ( user.id == model_info.user_id or has_access( @@ -1089,31 +1175,32 @@ async def generate_openai_chat_completion( if ":" not in payload["model"]: payload["model"] = f"{payload['model']}:latest" - url = await get_ollama_url(url_idx, payload["model"]) - log.info(f"url: {url}") + url = await get_ollama_url(request, payload["model"], url_idx) + api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}) - api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) prefix_id = api_config.get("prefix_id", None) if prefix_id: payload["model"] = payload["model"].replace(f"{prefix_id}.", "") - return await post_streaming_url( - f"{url}/v1/chat/completions", - json.dumps(payload), + return await send_post_request( + url=f"{url}/v1/chat/completions", + payload=json.dumps(payload), stream=payload.get("stream", False), + key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS), ) -@app.get("/v1/models") -@app.get("/v1/models/{url_idx}") +@router.get("/v1/models") +@router.get("/v1/models/{url_idx}") async def get_openai_models( + request: Request, url_idx: Optional[int] = None, user=Depends(get_verified_user), ): models = [] if url_idx is None: - model_list = await get_all_models() + model_list = await get_all_models(request) models = [ { "id": model["model"], @@ -1125,7 +1212,7 @@ async def get_openai_models( ] else: - url = app.state.config.OLLAMA_BASE_URLS[url_idx] + url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] try: r = requests.request(method="GET", url=f"{url}/api/tags") r.raise_for_status() @@ -1249,9 +1336,10 @@ async def download_file_stream( # url = "https://huggingface.co/TheBloke/stablelm-zephyr-3b-GGUF/resolve/main/stablelm-zephyr-3b.Q2_K.gguf" -@app.post("/models/download") -@app.post("/models/download/{url_idx}") +@router.post("/models/download") +@router.post("/models/download/{url_idx}") async def download_model( + request: Request, form_data: UrlForm, url_idx: Optional[int] = None, user=Depends(get_admin_user), @@ -1266,7 +1354,7 @@ async def download_model( if url_idx is None: url_idx = 0 - url = app.state.config.OLLAMA_BASE_URLS[url_idx] + url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] file_name = parse_huggingface_url(form_data.url) @@ -1280,16 +1368,17 @@ async def download_model( return None -@app.post("/models/upload") -@app.post("/models/upload/{url_idx}") +@router.post("/models/upload") +@router.post("/models/upload/{url_idx}") def upload_model( + request: Request, file: UploadFile = File(...), url_idx: Optional[int] = None, user=Depends(get_admin_user), ): if url_idx is None: url_idx = 0 - ollama_url = app.state.config.OLLAMA_BASE_URLS[url_idx] + ollama_url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] file_path = f"{UPLOAD_DIR}/{file.filename}" diff --git a/backend/open_webui/apps/openai/main.py b/backend/open_webui/routers/openai.py similarity index 52% rename from backend/open_webui/apps/openai/main.py rename to backend/open_webui/routers/openai.py index 9193c2be6..f7f78be85 100644 --- a/backend/open_webui/apps/openai/main.py +++ b/backend/open_webui/routers/openai.py @@ -10,15 +10,15 @@ from aiocache import cached import requests -from open_webui.apps.webui.models.models import Models +from fastapi import Depends, FastAPI, HTTPException, Request, APIRouter +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import FileResponse, StreamingResponse +from pydantic import BaseModel +from starlette.background import BackgroundTask + +from open_webui.models.models import Models from open_webui.config import ( CACHE_DIR, - CORS_ALLOW_ORIGIN, - ENABLE_OPENAI_API, - OPENAI_API_BASE_URLS, - OPENAI_API_KEYS, - OPENAI_API_CONFIGS, - AppConfig, ) from open_webui.env import ( AIOHTTP_CLIENT_TIMEOUT, @@ -29,18 +29,14 @@ from open_webui.env import ( from open_webui.constants import ERROR_MESSAGES from open_webui.env import ENV, SRC_LOG_LEVELS -from fastapi import Depends, FastAPI, HTTPException, Request -from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import FileResponse, StreamingResponse -from pydantic import BaseModel -from starlette.background import BackgroundTask + from open_webui.utils.payload import ( apply_model_params_to_body_openai, apply_model_system_prompt_to_body, ) -from open_webui.utils.utils import get_admin_user, get_verified_user +from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.utils.access_control import has_access @@ -48,36 +44,69 @@ log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["OPENAI"]) -app = FastAPI( - docs_url="/docs" if ENV == "dev" else None, - openapi_url="/openapi.json" if ENV == "dev" else None, - redoc_url=None, -) +########################################## +# +# Utility functions +# +########################################## -app.add_middleware( - CORSMiddleware, - allow_origins=CORS_ALLOW_ORIGIN, - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) - -app.state.config = AppConfig() - -app.state.config.ENABLE_OPENAI_API = ENABLE_OPENAI_API -app.state.config.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS -app.state.config.OPENAI_API_KEYS = OPENAI_API_KEYS -app.state.config.OPENAI_API_CONFIGS = OPENAI_API_CONFIGS +async def send_get_request(url, key=None): + timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST) + try: + async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: + async with session.get( + url, headers={**({"Authorization": f"Bearer {key}"} if key else {})} + ) as response: + return await response.json() + except Exception as e: + # Handle connection error here + log.error(f"Connection error: {e}") + return None -@app.get("/config") -async def get_config(user=Depends(get_admin_user)): +async def cleanup_response( + response: Optional[aiohttp.ClientResponse], + session: Optional[aiohttp.ClientSession], +): + if response: + response.close() + if session: + await session.close() + + +def openai_o1_handler(payload): + """ + Handle O1 specific parameters + """ + if "max_tokens" in payload: + # Remove "max_tokens" from the payload + payload["max_completion_tokens"] = payload["max_tokens"] + del payload["max_tokens"] + + # Fix: O1 does not support the "system" parameter, Modify "system" to "user" + if payload["messages"][0]["role"] == "system": + payload["messages"][0]["role"] = "user" + + return payload + + +########################################## +# +# API routes +# +########################################## + +router = APIRouter() + + +@router.get("/config") +async def get_config(request: Request, user=Depends(get_admin_user)): return { - "ENABLE_OPENAI_API": app.state.config.ENABLE_OPENAI_API, - "OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS, - "OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS, - "OPENAI_API_CONFIGS": app.state.config.OPENAI_API_CONFIGS, + "ENABLE_OPENAI_API": request.app.state.config.ENABLE_OPENAI_API, + "OPENAI_API_BASE_URLS": request.app.state.config.OPENAI_API_BASE_URLS, + "OPENAI_API_KEYS": request.app.state.config.OPENAI_API_KEYS, + "OPENAI_API_CONFIGS": request.app.state.config.OPENAI_API_CONFIGS, } @@ -88,50 +117,56 @@ class OpenAIConfigForm(BaseModel): OPENAI_API_CONFIGS: dict -@app.post("/config/update") -async def update_config(form_data: OpenAIConfigForm, user=Depends(get_admin_user)): - app.state.config.ENABLE_OPENAI_API = form_data.ENABLE_OPENAI_API - - app.state.config.OPENAI_API_BASE_URLS = form_data.OPENAI_API_BASE_URLS - app.state.config.OPENAI_API_KEYS = form_data.OPENAI_API_KEYS +@router.post("/config/update") +async def update_config( + request: Request, form_data: OpenAIConfigForm, user=Depends(get_admin_user) +): + request.app.state.config.ENABLE_OPENAI_API = form_data.ENABLE_OPENAI_API + request.app.state.config.OPENAI_API_BASE_URLS = form_data.OPENAI_API_BASE_URLS + request.app.state.config.OPENAI_API_KEYS = form_data.OPENAI_API_KEYS # Check if API KEYS length is same than API URLS length - if len(app.state.config.OPENAI_API_KEYS) != len( - app.state.config.OPENAI_API_BASE_URLS + if len(request.app.state.config.OPENAI_API_KEYS) != len( + request.app.state.config.OPENAI_API_BASE_URLS ): - if len(app.state.config.OPENAI_API_KEYS) > len( - app.state.config.OPENAI_API_BASE_URLS + if len(request.app.state.config.OPENAI_API_KEYS) > len( + request.app.state.config.OPENAI_API_BASE_URLS ): - app.state.config.OPENAI_API_KEYS = app.state.config.OPENAI_API_KEYS[ - : len(app.state.config.OPENAI_API_BASE_URLS) - ] + request.app.state.config.OPENAI_API_KEYS = ( + request.app.state.config.OPENAI_API_KEYS[ + : len(request.app.state.config.OPENAI_API_BASE_URLS) + ] + ) else: - app.state.config.OPENAI_API_KEYS += [""] * ( - len(app.state.config.OPENAI_API_BASE_URLS) - - len(app.state.config.OPENAI_API_KEYS) + request.app.state.config.OPENAI_API_KEYS += [""] * ( + len(request.app.state.config.OPENAI_API_BASE_URLS) + - len(request.app.state.config.OPENAI_API_KEYS) ) - app.state.config.OPENAI_API_CONFIGS = form_data.OPENAI_API_CONFIGS + request.app.state.config.OPENAI_API_CONFIGS = form_data.OPENAI_API_CONFIGS # Remove any extra configs - config_urls = app.state.config.OPENAI_API_CONFIGS.keys() - for idx, url in enumerate(app.state.config.OPENAI_API_BASE_URLS): + config_urls = request.app.state.config.OPENAI_API_CONFIGS.keys() + for idx, url in enumerate(request.app.state.config.OPENAI_API_BASE_URLS): if url not in config_urls: - app.state.config.OPENAI_API_CONFIGS.pop(url, None) + request.app.state.config.OPENAI_API_CONFIGS.pop(url, None) return { - "ENABLE_OPENAI_API": app.state.config.ENABLE_OPENAI_API, - "OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS, - "OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS, - "OPENAI_API_CONFIGS": app.state.config.OPENAI_API_CONFIGS, + "ENABLE_OPENAI_API": request.app.state.config.ENABLE_OPENAI_API, + "OPENAI_API_BASE_URLS": request.app.state.config.OPENAI_API_BASE_URLS, + "OPENAI_API_KEYS": request.app.state.config.OPENAI_API_KEYS, + "OPENAI_API_CONFIGS": request.app.state.config.OPENAI_API_CONFIGS, } -@app.post("/audio/speech") +@router.post("/audio/speech") async def speech(request: Request, user=Depends(get_verified_user)): idx = None try: - idx = app.state.config.OPENAI_API_BASE_URLS.index("https://api.openai.com/v1") + idx = request.app.state.config.OPENAI_API_BASE_URLS.index( + "https://api.openai.com/v1" + ) + body = await request.body() name = hashlib.sha256(body).hexdigest() @@ -144,23 +179,35 @@ async def speech(request: Request, user=Depends(get_verified_user)): if file_path.is_file(): return FileResponse(file_path) - headers = {} - headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEYS[idx]}" - headers["Content-Type"] = "application/json" - if "openrouter.ai" in app.state.config.OPENAI_API_BASE_URLS[idx]: - headers["HTTP-Referer"] = "https://openwebui.com/" - headers["X-Title"] = "Open WebUI" - if ENABLE_FORWARD_USER_INFO_HEADERS: - headers["X-OpenWebUI-User-Name"] = user.name - headers["X-OpenWebUI-User-Id"] = user.id - headers["X-OpenWebUI-User-Email"] = user.email - headers["X-OpenWebUI-User-Role"] = user.role + url = request.app.state.config.OPENAI_API_BASE_URLS[idx] + r = None try: r = requests.post( - url=f"{app.state.config.OPENAI_API_BASE_URLS[idx]}/audio/speech", + url=f"{url}/audio/speech", data=body, - headers=headers, + headers={ + "Content-Type": "application/json", + "Authorization": f"Bearer {request.app.state.config.OPENAI_API_KEYS[idx]}", + **( + { + "HTTP-Referer": "https://openwebui.com/", + "X-Title": "Open WebUI", + } + if "openrouter.ai" in url + else {} + ), + **( + { + "X-OpenWebUI-User-Name": user.name, + "X-OpenWebUI-User-Id": user.id, + "X-OpenWebUI-User-Email": user.email, + "X-OpenWebUI-User-Role": user.role, + } + if ENABLE_FORWARD_USER_INFO_HEADERS + else {} + ), + }, stream=True, ) @@ -179,115 +226,62 @@ async def speech(request: Request, user=Depends(get_verified_user)): except Exception as e: log.exception(e) - error_detail = "Open WebUI: Server Connection Error" + + detail = None if r is not None: try: res = r.json() if "error" in res: - error_detail = f"External: {res['error']}" + detail = f"External: {res['error']}" except Exception: - error_detail = f"External: {e}" + detail = f"External: {e}" raise HTTPException( - status_code=r.status_code if r else 500, detail=error_detail + status_code=r.status_code if r else 500, + detail=detail if detail else "Open WebUI: Server Connection Error", ) except ValueError: raise HTTPException(status_code=401, detail=ERROR_MESSAGES.OPENAI_NOT_FOUND) -async def aiohttp_get(url, key=None): - timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST) - try: - headers = {"Authorization": f"Bearer {key}"} if key else {} - async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: - async with session.get(url, headers=headers) as response: - return await response.json() - except Exception as e: - # Handle connection error here - log.error(f"Connection error: {e}") - return None - - -async def cleanup_response( - response: Optional[aiohttp.ClientResponse], - session: Optional[aiohttp.ClientSession], -): - if response: - response.close() - if session: - await session.close() - - -def merge_models_lists(model_lists): - log.debug(f"merge_models_lists {model_lists}") - merged_list = [] - - for idx, models in enumerate(model_lists): - if models is not None and "error" not in models: - merged_list.extend( - [ - { - **model, - "name": model.get("name", model["id"]), - "owned_by": "openai", - "openai": model, - "urlIdx": idx, - } - for model in models - if "api.openai.com" - not in app.state.config.OPENAI_API_BASE_URLS[idx] - or not any( - name in model["id"] - for name in [ - "babbage", - "dall-e", - "davinci", - "embedding", - "tts", - "whisper", - ] - ) - ] - ) - - return merged_list - - -async def get_all_models_responses() -> list: - if not app.state.config.ENABLE_OPENAI_API: +async def get_all_models_responses(request: Request) -> list: + if not request.app.state.config.ENABLE_OPENAI_API: return [] # Check if API KEYS length is same than API URLS length - num_urls = len(app.state.config.OPENAI_API_BASE_URLS) - num_keys = len(app.state.config.OPENAI_API_KEYS) + num_urls = len(request.app.state.config.OPENAI_API_BASE_URLS) + num_keys = len(request.app.state.config.OPENAI_API_KEYS) if num_keys != num_urls: # if there are more keys than urls, remove the extra keys if num_keys > num_urls: - new_keys = app.state.config.OPENAI_API_KEYS[:num_urls] - app.state.config.OPENAI_API_KEYS = new_keys + new_keys = request.app.state.config.OPENAI_API_KEYS[:num_urls] + request.app.state.config.OPENAI_API_KEYS = new_keys # if there are more urls than keys, add empty keys else: - app.state.config.OPENAI_API_KEYS += [""] * (num_urls - num_keys) + request.app.state.config.OPENAI_API_KEYS += [""] * (num_urls - num_keys) - tasks = [] - for idx, url in enumerate(app.state.config.OPENAI_API_BASE_URLS): - if url not in app.state.config.OPENAI_API_CONFIGS: - tasks.append( - aiohttp_get(f"{url}/models", app.state.config.OPENAI_API_KEYS[idx]) + request_tasks = [] + for idx, url in enumerate(request.app.state.config.OPENAI_API_BASE_URLS): + if url not in request.app.state.config.OPENAI_API_CONFIGS: + request_tasks.append( + send_get_request( + f"{url}/models", request.app.state.config.OPENAI_API_KEYS[idx] + ) ) else: - api_config = app.state.config.OPENAI_API_CONFIGS.get(url, {}) + api_config = request.app.state.config.OPENAI_API_CONFIGS.get(url, {}) enable = api_config.get("enable", True) model_ids = api_config.get("model_ids", []) if enable: if len(model_ids) == 0: - tasks.append( - aiohttp_get( - f"{url}/models", app.state.config.OPENAI_API_KEYS[idx] + request_tasks.append( + send_get_request( + f"{url}/models", + request.app.state.config.OPENAI_API_KEYS[idx], ) ) else: @@ -305,16 +299,18 @@ async def get_all_models_responses() -> list: ], } - tasks.append(asyncio.ensure_future(asyncio.sleep(0, model_list))) + request_tasks.append( + asyncio.ensure_future(asyncio.sleep(0, model_list)) + ) else: - tasks.append(asyncio.ensure_future(asyncio.sleep(0, None))) + request_tasks.append(asyncio.ensure_future(asyncio.sleep(0, None))) - responses = await asyncio.gather(*tasks) + responses = await asyncio.gather(*request_tasks) for idx, response in enumerate(responses): if response: - url = app.state.config.OPENAI_API_BASE_URLS[idx] - api_config = app.state.config.OPENAI_API_CONFIGS.get(url, {}) + url = request.app.state.config.OPENAI_API_BASE_URLS[idx] + api_config = request.app.state.config.OPENAI_API_CONFIGS.get(url, {}) prefix_id = api_config.get("prefix_id", None) @@ -325,18 +321,30 @@ async def get_all_models_responses() -> list: model["id"] = f"{prefix_id}.{model['id']}" log.debug(f"get_all_models:responses() {responses}") - return responses +async def get_filtered_models(models, user): + # Filter models based on user access control + filtered_models = [] + for model in models.get("data", []): + model_info = Models.get_model_by_id(model["id"]) + if model_info: + if user.id == model_info.user_id or has_access( + user.id, type="read", access_control=model_info.access_control + ): + filtered_models.append(model) + return filtered_models + + @cached(ttl=3) -async def get_all_models() -> dict[str, list]: +async def get_all_models(request: Request) -> dict[str, list]: log.info("get_all_models()") - if not app.state.config.ENABLE_OPENAI_API: + if not request.app.state.config.ENABLE_OPENAI_API: return {"data": []} - responses = await get_all_models_responses() + responses = await get_all_models_responses(request) def extract_data(response): if response and "data" in response: @@ -345,41 +353,86 @@ async def get_all_models() -> dict[str, list]: return response return None + def merge_models_lists(model_lists): + log.debug(f"merge_models_lists {model_lists}") + merged_list = [] + + for idx, models in enumerate(model_lists): + if models is not None and "error" not in models: + merged_list.extend( + [ + { + **model, + "name": model.get("name", model["id"]), + "owned_by": "openai", + "openai": model, + "urlIdx": idx, + } + for model in models + if "api.openai.com" + not in request.app.state.config.OPENAI_API_BASE_URLS[idx] + or not any( + name in model["id"] + for name in [ + "babbage", + "dall-e", + "davinci", + "embedding", + "tts", + "whisper", + ] + ) + ] + ) + + return merged_list + models = {"data": merge_models_lists(map(extract_data, responses))} log.debug(f"models: {models}") + request.app.state.OPENAI_MODELS = {model["id"]: model for model in models["data"]} return models -@app.get("/models") -@app.get("/models/{url_idx}") -async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_user)): +@router.get("/models") +@router.get("/models/{url_idx}") +async def get_models( + request: Request, url_idx: Optional[int] = None, user=Depends(get_verified_user) +): models = { "data": [], } if url_idx is None: - models = await get_all_models() + models = await get_all_models(request) else: - url = app.state.config.OPENAI_API_BASE_URLS[url_idx] - key = app.state.config.OPENAI_API_KEYS[url_idx] - - headers = {} - headers["Authorization"] = f"Bearer {key}" - headers["Content-Type"] = "application/json" - - if ENABLE_FORWARD_USER_INFO_HEADERS: - headers["X-OpenWebUI-User-Name"] = user.name - headers["X-OpenWebUI-User-Id"] = user.id - headers["X-OpenWebUI-User-Email"] = user.email - headers["X-OpenWebUI-User-Role"] = user.role + url = request.app.state.config.OPENAI_API_BASE_URLS[url_idx] + key = request.app.state.config.OPENAI_API_KEYS[url_idx] r = None - - timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST) - async with aiohttp.ClientSession(timeout=timeout) as session: + async with aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout( + total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST + ) + ) as session: try: - async with session.get(f"{url}/models", headers=headers) as r: + async with session.get( + f"{url}/models", + headers={ + "Authorization": f"Bearer {key}", + "Content-Type": "application/json", + **( + { + "X-OpenWebUI-User-Name": user.name, + "X-OpenWebUI-User-Id": user.id, + "X-OpenWebUI-User-Email": user.email, + "X-OpenWebUI-User-Role": user.role, + } + if ENABLE_FORWARD_USER_INFO_HEADERS + else {} + ), + }, + ) as r: if r.status != 200: # Extract response error details if available error_detail = f"HTTP Error: {r.status}" @@ -413,27 +466,16 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_us except aiohttp.ClientError as e: # ClientError covers all aiohttp requests issues log.exception(f"Client error: {str(e)}") - # Handle aiohttp-specific connection issues, timeout etc. raise HTTPException( status_code=500, detail="Open WebUI: Server Connection Error" ) except Exception as e: log.exception(f"Unexpected error: {e}") - # Generic error handler in case parsing JSON or other steps fail error_detail = f"Unexpected error: {str(e)}" raise HTTPException(status_code=500, detail=error_detail) if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL: - # Filter models based on user access control - filtered_models = [] - for model in models.get("data", []): - model_info = Models.get_model_by_id(model["id"]) - if model_info: - if user.id == model_info.user_id or has_access( - user.id, type="read", access_control=model_info.access_control - ): - filtered_models.append(model) - models["data"] = filtered_models + models["data"] = get_filtered_models(models, user) return models @@ -443,21 +485,24 @@ class ConnectionVerificationForm(BaseModel): key: str -@app.post("/verify") +@router.post("/verify") async def verify_connection( form_data: ConnectionVerificationForm, user=Depends(get_admin_user) ): url = form_data.url key = form_data.key - headers = {} - headers["Authorization"] = f"Bearer {key}" - headers["Content-Type"] = "application/json" - - timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST) - async with aiohttp.ClientSession(timeout=timeout) as session: + async with aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST) + ) as session: try: - async with session.get(f"{url}/models", headers=headers) as r: + async with session.get( + f"{url}/models", + headers={ + "Authorization": f"Bearer {key}", + "Content-Type": "application/json", + }, + ) as r: if r.status != 200: # Extract response error details if available error_detail = f"HTTP Error: {r.status}" @@ -472,26 +517,24 @@ async def verify_connection( except aiohttp.ClientError as e: # ClientError covers all aiohttp requests issues log.exception(f"Client error: {str(e)}") - # Handle aiohttp-specific connection issues, timeout etc. raise HTTPException( status_code=500, detail="Open WebUI: Server Connection Error" ) except Exception as e: log.exception(f"Unexpected error: {e}") - # Generic error handler in case parsing JSON or other steps fail error_detail = f"Unexpected error: {str(e)}" raise HTTPException(status_code=500, detail=error_detail) -@app.post("/chat/completions") +@router.post("/chat/completions") async def generate_chat_completion( + request: Request, form_data: dict, user=Depends(get_verified_user), bypass_filter: Optional[bool] = False, ): idx = 0 payload = {**form_data} - if "metadata" in payload: del payload["metadata"] @@ -526,15 +569,7 @@ async def generate_chat_completion( detail="Model not found", ) - # Attemp to get urlIdx from the model - models = await get_all_models() - - # Find the model from the list - model = next( - (model for model in models["data"] if model["id"] == payload.get("model")), - None, - ) - + model = request.app.state.OPENAI_MODELS.get(model_id) if model: idx = model["urlIdx"] else: @@ -544,11 +579,11 @@ async def generate_chat_completion( ) # Get the API config for the model - api_config = app.state.config.OPENAI_API_CONFIGS.get( - app.state.config.OPENAI_API_BASE_URLS[idx], {} + api_config = request.app.state.config.OPENAI_API_CONFIGS.get( + request.app.state.config.OPENAI_API_BASE_URLS[idx], {} ) - prefix_id = api_config.get("prefix_id", None) + prefix_id = api_config.get("prefix_id", None) if prefix_id: payload["model"] = payload["model"].replace(f"{prefix_id}.", "") @@ -561,43 +596,26 @@ async def generate_chat_completion( "role": user.role, } - url = app.state.config.OPENAI_API_BASE_URLS[idx] - key = app.state.config.OPENAI_API_KEYS[idx] + url = request.app.state.config.OPENAI_API_BASE_URLS[idx] + key = request.app.state.config.OPENAI_API_KEYS[idx] # Fix: O1 does not support the "max_tokens" parameter, Modify "max_tokens" to "max_completion_tokens" is_o1 = payload["model"].lower().startswith("o1-") - # Change max_completion_tokens to max_tokens (Backward compatible) - if "api.openai.com" not in url and not is_o1: - if "max_completion_tokens" in payload: - # Remove "max_completion_tokens" from the payload - payload["max_tokens"] = payload["max_completion_tokens"] - del payload["max_completion_tokens"] - else: - if is_o1 and "max_tokens" in payload: + if is_o1: + payload = openai_o1_handler(payload) + elif "api.openai.com" not in url: + # Remove "max_tokens" from the payload for backward compatibility + if "max_tokens" in payload: payload["max_completion_tokens"] = payload["max_tokens"] del payload["max_tokens"] - if "max_tokens" in payload and "max_completion_tokens" in payload: - del payload["max_tokens"] - # Fix: O1 does not support the "system" parameter, Modify "system" to "user" - if is_o1 and payload["messages"][0]["role"] == "system": - payload["messages"][0]["role"] = "user" + # TODO: check if below is needed + # if "max_tokens" in payload and "max_completion_tokens" in payload: + # del payload["max_tokens"] # Convert the modified body back to JSON payload = json.dumps(payload) - headers = {} - headers["Authorization"] = f"Bearer {key}" - headers["Content-Type"] = "application/json" - if "openrouter.ai" in app.state.config.OPENAI_API_BASE_URLS[idx]: - headers["HTTP-Referer"] = "https://openwebui.com/" - headers["X-Title"] = "Open WebUI" - if ENABLE_FORWARD_USER_INFO_HEADERS: - headers["X-OpenWebUI-User-Name"] = user.name - headers["X-OpenWebUI-User-Id"] = user.id - headers["X-OpenWebUI-User-Email"] = user.email - headers["X-OpenWebUI-User-Role"] = user.role - r = None session = None streaming = False @@ -607,11 +625,33 @@ async def generate_chat_completion( session = aiohttp.ClientSession( trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT) ) + r = await session.request( method="POST", url=f"{url}/chat/completions", data=payload, - headers=headers, + headers={ + "Authorization": f"Bearer {key}", + "Content-Type": "application/json", + **( + { + "HTTP-Referer": "https://openwebui.com/", + "X-Title": "Open WebUI", + } + if "openrouter.ai" in url + else {} + ), + **( + { + "X-OpenWebUI-User-Name": user.name, + "X-OpenWebUI-User-Id": user.id, + "X-OpenWebUI-User-Email": user.email, + "X-OpenWebUI-User-Role": user.role, + } + if ENABLE_FORWARD_USER_INFO_HEADERS + else {} + ), + }, ) # Check if response is SSE @@ -636,14 +676,18 @@ async def generate_chat_completion( return response except Exception as e: log.exception(e) - error_detail = "Open WebUI: Server Connection Error" + + detail = None if isinstance(response, dict): if "error" in response: - error_detail = f"{response['error']['message'] if 'message' in response['error'] else response['error']}" + detail = f"{response['error']['message'] if 'message' in response['error'] else response['error']}" elif isinstance(response, str): - error_detail = response + detail = response - raise HTTPException(status_code=r.status if r else 500, detail=error_detail) + raise HTTPException( + status_code=r.status if r else 500, + detail=detail if detail else "Open WebUI: Server Connection Error", + ) finally: if not streaming and session: if r: @@ -651,25 +695,17 @@ async def generate_chat_completion( await session.close() -@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) +@router.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) async def proxy(path: str, request: Request, user=Depends(get_verified_user)): - idx = 0 + """ + Deprecated: proxy all requests to OpenAI API + """ body = await request.body() - url = app.state.config.OPENAI_API_BASE_URLS[idx] - key = app.state.config.OPENAI_API_KEYS[idx] - - target_url = f"{url}/{path}" - - headers = {} - headers["Authorization"] = f"Bearer {key}" - headers["Content-Type"] = "application/json" - if ENABLE_FORWARD_USER_INFO_HEADERS: - headers["X-OpenWebUI-User-Name"] = user.name - headers["X-OpenWebUI-User-Id"] = user.id - headers["X-OpenWebUI-User-Email"] = user.email - headers["X-OpenWebUI-User-Role"] = user.role + idx = 0 + url = request.app.state.config.OPENAI_API_BASE_URLS[idx] + key = request.app.state.config.OPENAI_API_KEYS[idx] r = None session = None @@ -679,11 +715,23 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)): session = aiohttp.ClientSession(trust_env=True) r = await session.request( method=request.method, - url=target_url, + url=f"{url}/{path}", data=body, - headers=headers, + headers={ + "Authorization": f"Bearer {key}", + "Content-Type": "application/json", + **( + { + "X-OpenWebUI-User-Name": user.name, + "X-OpenWebUI-User-Id": user.id, + "X-OpenWebUI-User-Email": user.email, + "X-OpenWebUI-User-Role": user.role, + } + if ENABLE_FORWARD_USER_INFO_HEADERS + else {} + ), + }, ) - r.raise_for_status() # Check if response is SSE @@ -700,18 +748,23 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)): else: response_data = await r.json() return response_data + except Exception as e: log.exception(e) - error_detail = "Open WebUI: Server Connection Error" + + detail = None if r is not None: try: res = await r.json() print(res) if "error" in res: - error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}" + detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}" except Exception: - error_detail = f"External: {e}" - raise HTTPException(status_code=r.status if r else 500, detail=error_detail) + detail = f"External: {e}" + raise HTTPException( + status_code=r.status if r else 500, + detail=detail if detail else "Open WebUI: Server Connection Error", + ) finally: if not streaming and session: if r: diff --git a/backend/open_webui/routers/pipelines.py b/backend/open_webui/routers/pipelines.py new file mode 100644 index 000000000..258c10ee6 --- /dev/null +++ b/backend/open_webui/routers/pipelines.py @@ -0,0 +1,496 @@ +from fastapi import ( + Depends, + FastAPI, + File, + Form, + HTTPException, + Request, + UploadFile, + status, + APIRouter, +) +import os +import logging +import shutil +import requests +from pydantic import BaseModel +from starlette.responses import FileResponse +from typing import Optional + +from open_webui.env import SRC_LOG_LEVELS +from open_webui.config import CACHE_DIR +from open_webui.constants import ERROR_MESSAGES + + +from open_webui.routers.openai import get_all_models_responses + +from open_webui.utils.auth import get_admin_user + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["MAIN"]) + + +################################## +# +# Pipeline Middleware +# +################################## + + +def get_sorted_filters(model_id, models): + filters = [ + model + for model in models.values() + if "pipeline" in model + and "type" in model["pipeline"] + and model["pipeline"]["type"] == "filter" + and ( + model["pipeline"]["pipelines"] == ["*"] + or any( + model_id == target_model_id + for target_model_id in model["pipeline"]["pipelines"] + ) + ) + ] + sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"]) + return sorted_filters + + +def process_pipeline_inlet_filter(request, payload, user, models): + user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role} + model_id = payload["model"] + + sorted_filters = get_sorted_filters(model_id, models) + model = models[model_id] + + if "pipeline" in model: + sorted_filters.append(model) + + for filter in sorted_filters: + r = None + try: + urlIdx = filter["urlIdx"] + + url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx] + key = request.app.state.config.OPENAI_API_KEYS[urlIdx] + + if key == "": + continue + + headers = {"Authorization": f"Bearer {key}"} + r = requests.post( + f"{url}/{filter['id']}/filter/inlet", + headers=headers, + json={ + "user": user, + "body": payload, + }, + ) + + r.raise_for_status() + payload = r.json() + except Exception as e: + # Handle connection error here + print(f"Connection error: {e}") + + if r is not None: + res = r.json() + if "detail" in res: + raise Exception(r.status_code, res["detail"]) + + return payload + + +def process_pipeline_outlet_filter(request, payload, user, models): + user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role} + model_id = payload["model"] + + sorted_filters = get_sorted_filters(model_id, models) + model = models[model_id] + + if "pipeline" in model: + sorted_filters = [model] + sorted_filters + + for filter in sorted_filters: + r = None + try: + urlIdx = filter["urlIdx"] + + url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx] + key = request.app.state.config.OPENAI_API_KEYS[urlIdx] + + if key != "": + r = requests.post( + f"{url}/{filter['id']}/filter/outlet", + headers={"Authorization": f"Bearer {key}"}, + json={ + "user": { + "id": user.id, + "name": user.name, + "email": user.email, + "role": user.role, + }, + "body": data, + }, + ) + + r.raise_for_status() + data = r.json() + except Exception as e: + # Handle connection error here + print(f"Connection error: {e}") + + if r is not None: + try: + res = r.json() + if "detail" in res: + return Exception(r.status_code, res) + except Exception: + pass + + else: + pass + + return payload + + +################################## +# +# Pipelines Endpoints +# +################################## + +router = APIRouter() + + +@router.get("/list") +async def get_pipelines_list(request: Request, user=Depends(get_admin_user)): + responses = await get_all_models_responses(request) + log.debug(f"get_pipelines_list: get_openai_models_responses returned {responses}") + + urlIdxs = [ + idx + for idx, response in enumerate(responses) + if response is not None and "pipelines" in response + ] + + return { + "data": [ + { + "url": request.app.state.config.OPENAI_API_BASE_URLS[urlIdx], + "idx": urlIdx, + } + for urlIdx in urlIdxs + ] + } + + +@router.post("/upload") +async def upload_pipeline( + request: Request, + urlIdx: int = Form(...), + file: UploadFile = File(...), + user=Depends(get_admin_user), +): + print("upload_pipeline", urlIdx, file.filename) + # Check if the uploaded file is a python file + if not (file.filename and file.filename.endswith(".py")): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Only Python (.py) files are allowed.", + ) + + upload_folder = f"{CACHE_DIR}/pipelines" + os.makedirs(upload_folder, exist_ok=True) + file_path = os.path.join(upload_folder, file.filename) + + r = None + try: + # Save the uploaded file + with open(file_path, "wb") as buffer: + shutil.copyfileobj(file.file, buffer) + + url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx] + key = request.app.state.config.OPENAI_API_KEYS[urlIdx] + + with open(file_path, "rb") as f: + files = {"file": f} + r = requests.post( + f"{url}/pipelines/upload", + headers={"Authorization": f"Bearer {key}"}, + files=files, + ) + + r.raise_for_status() + data = r.json() + + return {**data} + except Exception as e: + # Handle connection error here + print(f"Connection error: {e}") + + detail = None + status_code = status.HTTP_404_NOT_FOUND + if r is not None: + status_code = r.status_code + try: + res = r.json() + if "detail" in res: + detail = res["detail"] + except Exception: + pass + + raise HTTPException( + status_code=status_code, + detail=detail if detail else "Pipeline not found", + ) + finally: + # Ensure the file is deleted after the upload is completed or on failure + if os.path.exists(file_path): + os.remove(file_path) + + +class AddPipelineForm(BaseModel): + url: str + urlIdx: int + + +@router.post("/add") +async def add_pipeline( + request: Request, form_data: AddPipelineForm, user=Depends(get_admin_user) +): + r = None + try: + urlIdx = form_data.urlIdx + + url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx] + key = request.app.state.config.OPENAI_API_KEYS[urlIdx] + + r = requests.post( + f"{url}/pipelines/add", + headers={"Authorization": f"Bearer {key}"}, + json={"url": form_data.url}, + ) + + r.raise_for_status() + data = r.json() + + return {**data} + except Exception as e: + # Handle connection error here + print(f"Connection error: {e}") + + detail = None + if r is not None: + try: + res = r.json() + if "detail" in res: + detail = res["detail"] + except Exception: + pass + + raise HTTPException( + status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND), + detail=detail if detail else "Pipeline not found", + ) + + +class DeletePipelineForm(BaseModel): + id: str + urlIdx: int + + +@router.delete("/delete") +async def delete_pipeline( + request: Request, form_data: DeletePipelineForm, user=Depends(get_admin_user) +): + r = None + try: + urlIdx = form_data.urlIdx + + url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx] + key = request.app.state.config.OPENAI_API_KEYS[urlIdx] + + r = requests.delete( + f"{url}/pipelines/delete", + headers={"Authorization": f"Bearer {key}"}, + json={"id": form_data.id}, + ) + + r.raise_for_status() + data = r.json() + + return {**data} + except Exception as e: + # Handle connection error here + print(f"Connection error: {e}") + + detail = None + if r is not None: + try: + res = r.json() + if "detail" in res: + detail = res["detail"] + except Exception: + pass + + raise HTTPException( + status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND), + detail=detail if detail else "Pipeline not found", + ) + + +@router.get("/") +async def get_pipelines( + request: Request, urlIdx: Optional[int] = None, user=Depends(get_admin_user) +): + r = None + try: + url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx] + key = request.app.state.config.OPENAI_API_KEYS[urlIdx] + + r = requests.get(f"{url}/pipelines", headers={"Authorization": f"Bearer {key}"}) + + r.raise_for_status() + data = r.json() + + return {**data} + except Exception as e: + # Handle connection error here + print(f"Connection error: {e}") + + detail = None + if r is not None: + try: + res = r.json() + if "detail" in res: + detail = res["detail"] + except Exception: + pass + + raise HTTPException( + status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND), + detail=detail if detail else "Pipeline not found", + ) + + +@router.get("/{pipeline_id}/valves") +async def get_pipeline_valves( + request: Request, + urlIdx: Optional[int], + pipeline_id: str, + user=Depends(get_admin_user), +): + r = None + try: + url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx] + key = request.app.state.config.OPENAI_API_KEYS[urlIdx] + + r = requests.get( + f"{url}/{pipeline_id}/valves", headers={"Authorization": f"Bearer {key}"} + ) + + r.raise_for_status() + data = r.json() + + return {**data} + except Exception as e: + # Handle connection error here + print(f"Connection error: {e}") + + detail = None + if r is not None: + try: + res = r.json() + if "detail" in res: + detail = res["detail"] + except Exception: + pass + + raise HTTPException( + status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND), + detail=detail if detail else "Pipeline not found", + ) + + +@router.get("/{pipeline_id}/valves/spec") +async def get_pipeline_valves_spec( + request: Request, + urlIdx: Optional[int], + pipeline_id: str, + user=Depends(get_admin_user), +): + r = None + try: + url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx] + key = request.app.state.config.OPENAI_API_KEYS[urlIdx] + + r = requests.get( + f"{url}/{pipeline_id}/valves/spec", + headers={"Authorization": f"Bearer {key}"}, + ) + + r.raise_for_status() + data = r.json() + + return {**data} + except Exception as e: + # Handle connection error here + print(f"Connection error: {e}") + + detail = None + if r is not None: + try: + res = r.json() + if "detail" in res: + detail = res["detail"] + except Exception: + pass + + raise HTTPException( + status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND), + detail=detail if detail else "Pipeline not found", + ) + + +@router.post("/{pipeline_id}/valves/update") +async def update_pipeline_valves( + request: Request, + urlIdx: Optional[int], + pipeline_id: str, + form_data: dict, + user=Depends(get_admin_user), +): + r = None + try: + url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx] + key = request.app.state.config.OPENAI_API_KEYS[urlIdx] + + r = requests.post( + f"{url}/{pipeline_id}/valves/update", + headers={"Authorization": f"Bearer {key}"}, + json={**form_data}, + ) + + r.raise_for_status() + data = r.json() + + return {**data} + except Exception as e: + # Handle connection error here + print(f"Connection error: {e}") + + detail = None + + if r is not None: + try: + res = r.json() + if "detail" in res: + detail = res["detail"] + except Exception: + pass + + raise HTTPException( + status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND), + detail=detail if detail else "Pipeline not found", + ) diff --git a/backend/open_webui/apps/webui/routers/prompts.py b/backend/open_webui/routers/prompts.py similarity index 97% rename from backend/open_webui/apps/webui/routers/prompts.py rename to backend/open_webui/routers/prompts.py index 7cacde606..4f1c48482 100644 --- a/backend/open_webui/apps/webui/routers/prompts.py +++ b/backend/open_webui/routers/prompts.py @@ -1,6 +1,6 @@ from typing import Optional -from open_webui.apps.webui.models.prompts import ( +from open_webui.models.prompts import ( PromptForm, PromptUserResponse, PromptModel, @@ -8,7 +8,7 @@ from open_webui.apps.webui.models.prompts import ( ) from open_webui.constants import ERROR_MESSAGES from fastapi import APIRouter, Depends, HTTPException, status, Request -from open_webui.utils.utils import get_admin_user, get_verified_user +from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.utils.access_control import has_access, has_permission router = APIRouter() diff --git a/backend/open_webui/apps/retrieval/main.py b/backend/open_webui/routers/retrieval.py similarity index 50% rename from backend/open_webui/apps/retrieval/main.py rename to backend/open_webui/routers/retrieval.py index 341f4f500..e577f70f1 100644 --- a/backend/open_webui/apps/retrieval/main.py +++ b/backend/open_webui/routers/retrieval.py @@ -1,5 +1,3 @@ -# TODO: Merge this with the webui_app and make it a single app - import json import logging import mimetypes @@ -11,38 +9,55 @@ from datetime import datetime from pathlib import Path from typing import Iterator, Optional, Sequence, Union -from fastapi import Depends, FastAPI, File, Form, HTTPException, UploadFile, status +from fastapi import ( + Depends, + FastAPI, + File, + Form, + HTTPException, + UploadFile, + Request, + status, + APIRouter, +) from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel import tiktoken +from langchain.text_splitter import RecursiveCharacterTextSplitter, TokenTextSplitter +from langchain_core.documents import Document + +from open_webui.models.files import Files +from open_webui.models.knowledge import Knowledges from open_webui.storage.provider import Storage -from open_webui.apps.webui.models.knowledge import Knowledges -from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT + + +from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT # Document loaders -from open_webui.apps.retrieval.loaders.main import Loader -from open_webui.apps.retrieval.loaders.youtube import YoutubeLoader +from open_webui.retrieval.loaders.main import Loader +from open_webui.retrieval.loaders.youtube import YoutubeLoader # Web search engines -from open_webui.apps.retrieval.web.main import SearchResult -from open_webui.apps.retrieval.web.utils import get_web_loader -from open_webui.apps.retrieval.web.brave import search_brave -from open_webui.apps.retrieval.web.mojeek import search_mojeek -from open_webui.apps.retrieval.web.duckduckgo import search_duckduckgo -from open_webui.apps.retrieval.web.google_pse import search_google_pse -from open_webui.apps.retrieval.web.jina_search import search_jina -from open_webui.apps.retrieval.web.searchapi import search_searchapi -from open_webui.apps.retrieval.web.searxng import search_searxng -from open_webui.apps.retrieval.web.serper import search_serper -from open_webui.apps.retrieval.web.serply import search_serply -from open_webui.apps.retrieval.web.serpstack import search_serpstack -from open_webui.apps.retrieval.web.tavily import search_tavily -from open_webui.apps.retrieval.web.bing import search_bing +from open_webui.retrieval.web.main import SearchResult +from open_webui.retrieval.web.utils import get_web_loader +from open_webui.retrieval.web.brave import search_brave +from open_webui.retrieval.web.kagi import search_kagi +from open_webui.retrieval.web.mojeek import search_mojeek +from open_webui.retrieval.web.duckduckgo import search_duckduckgo +from open_webui.retrieval.web.google_pse import search_google_pse +from open_webui.retrieval.web.jina_search import search_jina +from open_webui.retrieval.web.searchapi import search_searchapi +from open_webui.retrieval.web.searxng import search_searxng +from open_webui.retrieval.web.serper import search_serper +from open_webui.retrieval.web.serply import search_serply +from open_webui.retrieval.web.serpstack import search_serpstack +from open_webui.retrieval.web.tavily import search_tavily +from open_webui.retrieval.web.bing import search_bing -from open_webui.apps.retrieval.utils import ( +from open_webui.retrieval.utils import ( get_embedding_function, get_model_path, query_collection, @@ -50,245 +65,100 @@ from open_webui.apps.retrieval.utils import ( query_doc, query_doc_with_hybrid_search, ) +from open_webui.utils.misc import ( + calculate_sha256_string, +) +from open_webui.utils.auth import get_admin_user, get_verified_user + -from open_webui.apps.webui.models.files import Files from open_webui.config import ( - BRAVE_SEARCH_API_KEY, - MOJEEK_SEARCH_API_KEY, - TIKTOKEN_ENCODING_NAME, - RAG_TEXT_SPLITTER, - CHUNK_OVERLAP, - CHUNK_SIZE, - CONTENT_EXTRACTION_ENGINE, - CORS_ALLOW_ORIGIN, - ENABLE_RAG_HYBRID_SEARCH, - ENABLE_RAG_LOCAL_WEB_FETCH, - ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, - ENABLE_RAG_WEB_SEARCH, ENV, - GOOGLE_PSE_API_KEY, - GOOGLE_PSE_ENGINE_ID, - PDF_EXTRACT_IMAGES, - RAG_EMBEDDING_ENGINE, - RAG_EMBEDDING_MODEL, RAG_EMBEDDING_MODEL_AUTO_UPDATE, RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE, - RAG_EMBEDDING_BATCH_SIZE, - RAG_FILE_MAX_COUNT, - RAG_FILE_MAX_SIZE, - RAG_OPENAI_API_BASE_URL, - RAG_OPENAI_API_KEY, - RAG_OLLAMA_BASE_URL, - RAG_OLLAMA_API_KEY, - RAG_RELEVANCE_THRESHOLD, - RAG_RERANKING_MODEL, RAG_RERANKING_MODEL_AUTO_UPDATE, RAG_RERANKING_MODEL_TRUST_REMOTE_CODE, - DEFAULT_RAG_TEMPLATE, - RAG_TEMPLATE, - RAG_TOP_K, - RAG_WEB_SEARCH_CONCURRENT_REQUESTS, - RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, - RAG_WEB_SEARCH_ENGINE, - RAG_WEB_SEARCH_RESULT_COUNT, - JINA_API_KEY, - SEARCHAPI_API_KEY, - SEARCHAPI_ENGINE, - SEARXNG_QUERY_URL, - SERPER_API_KEY, - SERPLY_API_KEY, - SERPSTACK_API_KEY, - SERPSTACK_HTTPS, - TAVILY_API_KEY, - BING_SEARCH_V7_ENDPOINT, - BING_SEARCH_V7_SUBSCRIPTION_KEY, - TIKA_SERVER_URL, UPLOAD_DIR, - YOUTUBE_LOADER_LANGUAGE, - YOUTUBE_LOADER_PROXY_URL, DEFAULT_LOCALE, - AppConfig, ) -from open_webui.constants import ERROR_MESSAGES from open_webui.env import ( SRC_LOG_LEVELS, DEVICE_TYPE, DOCKER, ) -from open_webui.utils.misc import ( - calculate_sha256, - calculate_sha256_string, - extract_folders_after_data_docs, - sanitize_filename, -) -from open_webui.utils.utils import get_admin_user, get_verified_user - -from langchain.text_splitter import RecursiveCharacterTextSplitter, TokenTextSplitter -from langchain_core.documents import Document - +from open_webui.constants import ERROR_MESSAGES log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["RAG"]) -app = FastAPI( - docs_url="/docs" if ENV == "dev" else None, - openapi_url="/openapi.json" if ENV == "dev" else None, - redoc_url=None, -) - -app.state.config = AppConfig() - -app.state.config.TOP_K = RAG_TOP_K -app.state.config.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD -app.state.config.FILE_MAX_SIZE = RAG_FILE_MAX_SIZE -app.state.config.FILE_MAX_COUNT = RAG_FILE_MAX_COUNT - -app.state.config.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH -app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = ( - ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION -) - -app.state.config.CONTENT_EXTRACTION_ENGINE = CONTENT_EXTRACTION_ENGINE -app.state.config.TIKA_SERVER_URL = TIKA_SERVER_URL - -app.state.config.TEXT_SPLITTER = RAG_TEXT_SPLITTER -app.state.config.TIKTOKEN_ENCODING_NAME = TIKTOKEN_ENCODING_NAME - -app.state.config.CHUNK_SIZE = CHUNK_SIZE -app.state.config.CHUNK_OVERLAP = CHUNK_OVERLAP - -app.state.config.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE -app.state.config.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL -app.state.config.RAG_EMBEDDING_BATCH_SIZE = RAG_EMBEDDING_BATCH_SIZE -app.state.config.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL -app.state.config.RAG_TEMPLATE = RAG_TEMPLATE - -app.state.config.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL -app.state.config.OPENAI_API_KEY = RAG_OPENAI_API_KEY - -app.state.config.OLLAMA_BASE_URL = RAG_OLLAMA_BASE_URL -app.state.config.OLLAMA_API_KEY = RAG_OLLAMA_API_KEY - -app.state.config.PDF_EXTRACT_IMAGES = PDF_EXTRACT_IMAGES - -app.state.config.YOUTUBE_LOADER_LANGUAGE = YOUTUBE_LOADER_LANGUAGE -app.state.config.YOUTUBE_LOADER_PROXY_URL = YOUTUBE_LOADER_PROXY_URL -app.state.YOUTUBE_LOADER_TRANSLATION = None +########################################## +# +# Utility functions +# +########################################## -app.state.config.ENABLE_RAG_WEB_SEARCH = ENABLE_RAG_WEB_SEARCH -app.state.config.RAG_WEB_SEARCH_ENGINE = RAG_WEB_SEARCH_ENGINE -app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = RAG_WEB_SEARCH_DOMAIN_FILTER_LIST - -app.state.config.SEARXNG_QUERY_URL = SEARXNG_QUERY_URL -app.state.config.GOOGLE_PSE_API_KEY = GOOGLE_PSE_API_KEY -app.state.config.GOOGLE_PSE_ENGINE_ID = GOOGLE_PSE_ENGINE_ID -app.state.config.BRAVE_SEARCH_API_KEY = BRAVE_SEARCH_API_KEY -app.state.config.MOJEEK_SEARCH_API_KEY = MOJEEK_SEARCH_API_KEY -app.state.config.SERPSTACK_API_KEY = SERPSTACK_API_KEY -app.state.config.SERPSTACK_HTTPS = SERPSTACK_HTTPS -app.state.config.SERPER_API_KEY = SERPER_API_KEY -app.state.config.SERPLY_API_KEY = SERPLY_API_KEY -app.state.config.TAVILY_API_KEY = TAVILY_API_KEY -app.state.config.SEARCHAPI_API_KEY = SEARCHAPI_API_KEY -app.state.config.SEARCHAPI_ENGINE = SEARCHAPI_ENGINE -app.state.config.JINA_API_KEY = JINA_API_KEY -app.state.config.BING_SEARCH_V7_ENDPOINT = BING_SEARCH_V7_ENDPOINT -app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY = BING_SEARCH_V7_SUBSCRIPTION_KEY - -app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = RAG_WEB_SEARCH_RESULT_COUNT -app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS - - -def update_embedding_model( +def get_ef( + engine: str, embedding_model: str, auto_update: bool = False, ): - if embedding_model and app.state.config.RAG_EMBEDDING_ENGINE == "": + ef = None + if embedding_model and engine == "": from sentence_transformers import SentenceTransformer try: - app.state.sentence_transformer_ef = SentenceTransformer( + ef = SentenceTransformer( get_model_path(embedding_model, auto_update), device=DEVICE_TYPE, trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE, ) except Exception as e: log.debug(f"Error loading SentenceTransformer: {e}") - app.state.sentence_transformer_ef = None - else: - app.state.sentence_transformer_ef = None + + return ef -def update_reranking_model( +def get_rf( reranking_model: str, auto_update: bool = False, ): + rf = None if reranking_model: if any(model in reranking_model for model in ["jinaai/jina-colbert-v2"]): try: - from open_webui.apps.retrieval.models.colbert import ColBERT + from open_webui.retrieval.models.colbert import ColBERT - app.state.sentence_transformer_rf = ColBERT( + rf = ColBERT( get_model_path(reranking_model, auto_update), env="docker" if DOCKER else None, ) + except Exception as e: log.error(f"ColBERT: {e}") - app.state.sentence_transformer_rf = None - app.state.config.ENABLE_RAG_HYBRID_SEARCH = False + raise Exception(ERROR_MESSAGES.DEFAULT(e)) else: import sentence_transformers try: - app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder( + rf = sentence_transformers.CrossEncoder( get_model_path(reranking_model, auto_update), device=DEVICE_TYPE, trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE, ) except: log.error("CrossEncoder error") - app.state.sentence_transformer_rf = None - app.state.config.ENABLE_RAG_HYBRID_SEARCH = False - else: - app.state.sentence_transformer_rf = None + raise Exception(ERROR_MESSAGES.DEFAULT("CrossEncoder error")) + return rf -update_embedding_model( - app.state.config.RAG_EMBEDDING_MODEL, - RAG_EMBEDDING_MODEL_AUTO_UPDATE, -) - -update_reranking_model( - app.state.config.RAG_RERANKING_MODEL, - RAG_RERANKING_MODEL_AUTO_UPDATE, -) +########################################## +# +# API routes +# +########################################## -app.state.EMBEDDING_FUNCTION = get_embedding_function( - app.state.config.RAG_EMBEDDING_ENGINE, - app.state.config.RAG_EMBEDDING_MODEL, - app.state.sentence_transformer_ef, - ( - app.state.config.OPENAI_API_BASE_URL - if app.state.config.RAG_EMBEDDING_ENGINE == "openai" - else app.state.config.OLLAMA_BASE_URL - ), - ( - app.state.config.OPENAI_API_KEY - if app.state.config.RAG_EMBEDDING_ENGINE == "openai" - else app.state.config.OLLAMA_API_KEY - ), - app.state.config.RAG_EMBEDDING_BATCH_SIZE, -) - -app.add_middleware( - CORSMiddleware, - allow_origins=CORS_ALLOW_ORIGIN, - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) +router = APIRouter() class CollectionNameForm(BaseModel): @@ -303,43 +173,43 @@ class SearchForm(CollectionNameForm): query: str -@app.get("/") -async def get_status(): +@router.get("/") +async def get_status(request: Request): return { "status": True, - "chunk_size": app.state.config.CHUNK_SIZE, - "chunk_overlap": app.state.config.CHUNK_OVERLAP, - "template": app.state.config.RAG_TEMPLATE, - "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE, - "embedding_model": app.state.config.RAG_EMBEDDING_MODEL, - "reranking_model": app.state.config.RAG_RERANKING_MODEL, - "embedding_batch_size": app.state.config.RAG_EMBEDDING_BATCH_SIZE, + "chunk_size": request.app.state.config.CHUNK_SIZE, + "chunk_overlap": request.app.state.config.CHUNK_OVERLAP, + "template": request.app.state.config.RAG_TEMPLATE, + "embedding_engine": request.app.state.config.RAG_EMBEDDING_ENGINE, + "embedding_model": request.app.state.config.RAG_EMBEDDING_MODEL, + "reranking_model": request.app.state.config.RAG_RERANKING_MODEL, + "embedding_batch_size": request.app.state.config.RAG_EMBEDDING_BATCH_SIZE, } -@app.get("/embedding") -async def get_embedding_config(user=Depends(get_admin_user)): +@router.get("/embedding") +async def get_embedding_config(request: Request, user=Depends(get_admin_user)): return { "status": True, - "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE, - "embedding_model": app.state.config.RAG_EMBEDDING_MODEL, - "embedding_batch_size": app.state.config.RAG_EMBEDDING_BATCH_SIZE, + "embedding_engine": request.app.state.config.RAG_EMBEDDING_ENGINE, + "embedding_model": request.app.state.config.RAG_EMBEDDING_MODEL, + "embedding_batch_size": request.app.state.config.RAG_EMBEDDING_BATCH_SIZE, "openai_config": { - "url": app.state.config.OPENAI_API_BASE_URL, - "key": app.state.config.OPENAI_API_KEY, + "url": request.app.state.config.RAG_OPENAI_API_BASE_URL, + "key": request.app.state.config.RAG_OPENAI_API_KEY, }, "ollama_config": { - "url": app.state.config.OLLAMA_BASE_URL, - "key": app.state.config.OLLAMA_API_KEY, + "url": request.app.state.config.RAG_OLLAMA_BASE_URL, + "key": request.app.state.config.RAG_OLLAMA_API_KEY, }, } -@app.get("/reranking") -async def get_reraanking_config(user=Depends(get_admin_user)): +@router.get("/reranking") +async def get_reraanking_config(request: Request, user=Depends(get_admin_user)): return { "status": True, - "reranking_model": app.state.config.RAG_RERANKING_MODEL, + "reranking_model": request.app.state.config.RAG_RERANKING_MODEL, } @@ -361,59 +231,72 @@ class EmbeddingModelUpdateForm(BaseModel): embedding_batch_size: Optional[int] = 1 -@app.post("/embedding/update") +@router.post("/embedding/update") async def update_embedding_config( - form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user) + request: Request, form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user) ): log.info( - f"Updating embedding model: {app.state.config.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}" + f"Updating embedding model: {request.app.state.config.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}" ) try: - app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine - app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model + request.app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine + request.app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model - if app.state.config.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]: + if request.app.state.config.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]: if form_data.openai_config is not None: - app.state.config.OPENAI_API_BASE_URL = form_data.openai_config.url - app.state.config.OPENAI_API_KEY = form_data.openai_config.key + request.app.state.config.RAG_OPENAI_API_BASE_URL = ( + form_data.openai_config.url + ) + request.app.state.config.RAG_OPENAI_API_KEY = ( + form_data.openai_config.key + ) if form_data.ollama_config is not None: - app.state.config.OLLAMA_BASE_URL = form_data.ollama_config.url - app.state.config.OLLAMA_API_KEY = form_data.ollama_config.key + request.app.state.config.RAG_OLLAMA_BASE_URL = ( + form_data.ollama_config.url + ) + request.app.state.config.RAG_OLLAMA_API_KEY = ( + form_data.ollama_config.key + ) - app.state.config.RAG_EMBEDDING_BATCH_SIZE = form_data.embedding_batch_size + request.app.state.config.RAG_EMBEDDING_BATCH_SIZE = ( + form_data.embedding_batch_size + ) - update_embedding_model(app.state.config.RAG_EMBEDDING_MODEL) + request.app.state.ef = get_ef( + request.app.state.config.RAG_EMBEDDING_ENGINE, + request.app.state.config.RAG_EMBEDDING_MODEL, + ) - app.state.EMBEDDING_FUNCTION = get_embedding_function( - app.state.config.RAG_EMBEDDING_ENGINE, - app.state.config.RAG_EMBEDDING_MODEL, - app.state.sentence_transformer_ef, + request.app.state.EMBEDDING_FUNCTION = get_embedding_function( + request.app.state.config.RAG_EMBEDDING_ENGINE, + request.app.state.config.RAG_EMBEDDING_MODEL, + request.app.state.ef, ( - app.state.config.OPENAI_API_BASE_URL - if app.state.config.RAG_EMBEDDING_ENGINE == "openai" - else app.state.config.OLLAMA_BASE_URL + request.app.state.config.RAG_OPENAI_API_BASE_URL + if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai" + else request.app.state.config.RAG_OLLAMA_BASE_URL ), ( - app.state.config.OPENAI_API_KEY - if app.state.config.RAG_EMBEDDING_ENGINE == "openai" - else app.state.config.OLLAMA_API_KEY + request.app.state.config.RAG_OPENAI_API_KEY + if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai" + else request.app.state.config.RAG_OLLAMA_API_KEY ), - app.state.config.RAG_EMBEDDING_BATCH_SIZE, + request.app.state.config.RAG_EMBEDDING_BATCH_SIZE, ) return { "status": True, - "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE, - "embedding_model": app.state.config.RAG_EMBEDDING_MODEL, - "embedding_batch_size": app.state.config.RAG_EMBEDDING_BATCH_SIZE, + "embedding_engine": request.app.state.config.RAG_EMBEDDING_ENGINE, + "embedding_model": request.app.state.config.RAG_EMBEDDING_MODEL, + "embedding_batch_size": request.app.state.config.RAG_EMBEDDING_BATCH_SIZE, "openai_config": { - "url": app.state.config.OPENAI_API_BASE_URL, - "key": app.state.config.OPENAI_API_KEY, + "url": request.app.state.config.RAG_OPENAI_API_BASE_URL, + "key": request.app.state.config.RAG_OPENAI_API_KEY, }, "ollama_config": { - "url": app.state.config.OLLAMA_BASE_URL, - "key": app.state.config.OLLAMA_API_KEY, + "url": request.app.state.config.RAG_OLLAMA_BASE_URL, + "key": request.app.state.config.RAG_OLLAMA_API_KEY, }, } except Exception as e: @@ -428,21 +311,28 @@ class RerankingModelUpdateForm(BaseModel): reranking_model: str -@app.post("/reranking/update") +@router.post("/reranking/update") async def update_reranking_config( - form_data: RerankingModelUpdateForm, user=Depends(get_admin_user) + request: Request, form_data: RerankingModelUpdateForm, user=Depends(get_admin_user) ): log.info( - f"Updating reranking model: {app.state.config.RAG_RERANKING_MODEL} to {form_data.reranking_model}" + f"Updating reranking model: {request.app.state.config.RAG_RERANKING_MODEL} to {form_data.reranking_model}" ) try: - app.state.config.RAG_RERANKING_MODEL = form_data.reranking_model + request.app.state.config.RAG_RERANKING_MODEL = form_data.reranking_model - update_reranking_model(app.state.config.RAG_RERANKING_MODEL, True) + try: + request.app.state.rf = get_rf( + request.app.state.config.RAG_RERANKING_MODEL, + True, + ) + except Exception as e: + log.error(f"Error loading reranking model: {e}") + request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = False return { "status": True, - "reranking_model": app.state.config.RAG_RERANKING_MODEL, + "reranking_model": request.app.state.config.RAG_RERANKING_MODEL, } except Exception as e: log.exception(f"Problem updating reranking model: {e}") @@ -452,51 +342,52 @@ async def update_reranking_config( ) -@app.get("/config") -async def get_rag_config(user=Depends(get_admin_user)): +@router.get("/config") +async def get_rag_config(request: Request, user=Depends(get_admin_user)): return { "status": True, - "pdf_extract_images": app.state.config.PDF_EXTRACT_IMAGES, + "pdf_extract_images": request.app.state.config.PDF_EXTRACT_IMAGES, "content_extraction": { - "engine": app.state.config.CONTENT_EXTRACTION_ENGINE, - "tika_server_url": app.state.config.TIKA_SERVER_URL, + "engine": request.app.state.config.CONTENT_EXTRACTION_ENGINE, + "tika_server_url": request.app.state.config.TIKA_SERVER_URL, }, "chunk": { - "text_splitter": app.state.config.TEXT_SPLITTER, - "chunk_size": app.state.config.CHUNK_SIZE, - "chunk_overlap": app.state.config.CHUNK_OVERLAP, + "text_splitter": request.app.state.config.TEXT_SPLITTER, + "chunk_size": request.app.state.config.CHUNK_SIZE, + "chunk_overlap": request.app.state.config.CHUNK_OVERLAP, }, "file": { - "max_size": app.state.config.FILE_MAX_SIZE, - "max_count": app.state.config.FILE_MAX_COUNT, + "max_size": request.app.state.config.FILE_MAX_SIZE, + "max_count": request.app.state.config.FILE_MAX_COUNT, }, "youtube": { - "language": app.state.config.YOUTUBE_LOADER_LANGUAGE, - "translation": app.state.YOUTUBE_LOADER_TRANSLATION, - "proxy_url": app.state.config.YOUTUBE_LOADER_PROXY_URL, + "language": request.app.state.config.YOUTUBE_LOADER_LANGUAGE, + "translation": request.app.state.YOUTUBE_LOADER_TRANSLATION, + "proxy_url": request.app.state.config.YOUTUBE_LOADER_PROXY_URL, }, "web": { - "web_loader_ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, + "web_loader_ssl_verification": request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, "search": { - "enabled": app.state.config.ENABLE_RAG_WEB_SEARCH, - "engine": app.state.config.RAG_WEB_SEARCH_ENGINE, - "searxng_query_url": app.state.config.SEARXNG_QUERY_URL, - "google_pse_api_key": app.state.config.GOOGLE_PSE_API_KEY, - "google_pse_engine_id": app.state.config.GOOGLE_PSE_ENGINE_ID, - "brave_search_api_key": app.state.config.BRAVE_SEARCH_API_KEY, - "mojeek_search_api_key": app.state.config.MOJEEK_SEARCH_API_KEY, - "serpstack_api_key": app.state.config.SERPSTACK_API_KEY, - "serpstack_https": app.state.config.SERPSTACK_HTTPS, - "serper_api_key": app.state.config.SERPER_API_KEY, - "serply_api_key": app.state.config.SERPLY_API_KEY, - "tavily_api_key": app.state.config.TAVILY_API_KEY, - "searchapi_api_key": app.state.config.SEARCHAPI_API_KEY, - "seaarchapi_engine": app.state.config.SEARCHAPI_ENGINE, - "jina_api_key": app.state.config.JINA_API_KEY, - "bing_search_v7_endpoint": app.state.config.BING_SEARCH_V7_ENDPOINT, - "bing_search_v7_subscription_key": app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY, - "result_count": app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, - "concurrent_requests": app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS, + "enabled": request.app.state.config.ENABLE_RAG_WEB_SEARCH, + "engine": request.app.state.config.RAG_WEB_SEARCH_ENGINE, + "searxng_query_url": request.app.state.config.SEARXNG_QUERY_URL, + "google_pse_api_key": request.app.state.config.GOOGLE_PSE_API_KEY, + "google_pse_engine_id": request.app.state.config.GOOGLE_PSE_ENGINE_ID, + "brave_search_api_key": request.app.state.config.BRAVE_SEARCH_API_KEY, + "kagi_search_api_key": request.app.state.config.KAGI_SEARCH_API_KEY, + "mojeek_search_api_key": request.app.state.config.MOJEEK_SEARCH_API_KEY, + "serpstack_api_key": request.app.state.config.SERPSTACK_API_KEY, + "serpstack_https": request.app.state.config.SERPSTACK_HTTPS, + "serper_api_key": request.app.state.config.SERPER_API_KEY, + "serply_api_key": request.app.state.config.SERPLY_API_KEY, + "tavily_api_key": request.app.state.config.TAVILY_API_KEY, + "searchapi_api_key": request.app.state.config.SEARCHAPI_API_KEY, + "seaarchapi_engine": request.app.state.config.SEARCHAPI_ENGINE, + "jina_api_key": request.app.state.config.JINA_API_KEY, + "bing_search_v7_endpoint": request.app.state.config.BING_SEARCH_V7_ENDPOINT, + "bing_search_v7_subscription_key": request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY, + "result_count": request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, + "concurrent_requests": request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS, }, }, } @@ -531,6 +422,7 @@ class WebSearchConfig(BaseModel): google_pse_api_key: Optional[str] = None google_pse_engine_id: Optional[str] = None brave_search_api_key: Optional[str] = None + kagi_search_api_key: Optional[str] = None mojeek_search_api_key: Optional[str] = None serpstack_api_key: Optional[str] = None serpstack_https: Optional[bool] = None @@ -560,137 +452,159 @@ class ConfigUpdateForm(BaseModel): web: Optional[WebConfig] = None -@app.post("/config/update") -async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)): - app.state.config.PDF_EXTRACT_IMAGES = ( +@router.post("/config/update") +async def update_rag_config( + request: Request, form_data: ConfigUpdateForm, user=Depends(get_admin_user) +): + request.app.state.config.PDF_EXTRACT_IMAGES = ( form_data.pdf_extract_images if form_data.pdf_extract_images is not None - else app.state.config.PDF_EXTRACT_IMAGES + else request.app.state.config.PDF_EXTRACT_IMAGES ) if form_data.file is not None: - app.state.config.FILE_MAX_SIZE = form_data.file.max_size - app.state.config.FILE_MAX_COUNT = form_data.file.max_count + request.app.state.config.FILE_MAX_SIZE = form_data.file.max_size + request.app.state.config.FILE_MAX_COUNT = form_data.file.max_count if form_data.content_extraction is not None: log.info(f"Updating text settings: {form_data.content_extraction}") - app.state.config.CONTENT_EXTRACTION_ENGINE = form_data.content_extraction.engine - app.state.config.TIKA_SERVER_URL = form_data.content_extraction.tika_server_url + request.app.state.config.CONTENT_EXTRACTION_ENGINE = ( + form_data.content_extraction.engine + ) + request.app.state.config.TIKA_SERVER_URL = ( + form_data.content_extraction.tika_server_url + ) if form_data.chunk is not None: - app.state.config.TEXT_SPLITTER = form_data.chunk.text_splitter - app.state.config.CHUNK_SIZE = form_data.chunk.chunk_size - app.state.config.CHUNK_OVERLAP = form_data.chunk.chunk_overlap + request.app.state.config.TEXT_SPLITTER = form_data.chunk.text_splitter + request.app.state.config.CHUNK_SIZE = form_data.chunk.chunk_size + request.app.state.config.CHUNK_OVERLAP = form_data.chunk.chunk_overlap if form_data.youtube is not None: - app.state.config.YOUTUBE_LOADER_LANGUAGE = form_data.youtube.language - app.state.config.YOUTUBE_LOADER_PROXY_URL = form_data.youtube.proxy_url - app.state.YOUTUBE_LOADER_TRANSLATION = form_data.youtube.translation + request.app.state.config.YOUTUBE_LOADER_LANGUAGE = form_data.youtube.language + request.app.state.config.YOUTUBE_LOADER_PROXY_URL = form_data.youtube.proxy_url + request.app.state.YOUTUBE_LOADER_TRANSLATION = form_data.youtube.translation if form_data.web is not None: - app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = ( + request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = ( # Note: When UI "Bypass SSL verification for Websites"=True then ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION=False form_data.web.web_loader_ssl_verification ) - app.state.config.ENABLE_RAG_WEB_SEARCH = form_data.web.search.enabled - app.state.config.RAG_WEB_SEARCH_ENGINE = form_data.web.search.engine - app.state.config.SEARXNG_QUERY_URL = form_data.web.search.searxng_query_url - app.state.config.GOOGLE_PSE_API_KEY = form_data.web.search.google_pse_api_key - app.state.config.GOOGLE_PSE_ENGINE_ID = ( + request.app.state.config.ENABLE_RAG_WEB_SEARCH = form_data.web.search.enabled + request.app.state.config.RAG_WEB_SEARCH_ENGINE = form_data.web.search.engine + request.app.state.config.SEARXNG_QUERY_URL = ( + form_data.web.search.searxng_query_url + ) + request.app.state.config.GOOGLE_PSE_API_KEY = ( + form_data.web.search.google_pse_api_key + ) + request.app.state.config.GOOGLE_PSE_ENGINE_ID = ( form_data.web.search.google_pse_engine_id ) - app.state.config.BRAVE_SEARCH_API_KEY = ( + request.app.state.config.BRAVE_SEARCH_API_KEY = ( form_data.web.search.brave_search_api_key ) - app.state.config.MOJEEK_SEARCH_API_KEY = ( + request.app.state.config.KAGI_SEARCH_API_KEY = ( + form_data.web.search.kagi_search_api_key + ) + request.app.state.config.MOJEEK_SEARCH_API_KEY = ( form_data.web.search.mojeek_search_api_key ) - app.state.config.SERPSTACK_API_KEY = form_data.web.search.serpstack_api_key - app.state.config.SERPSTACK_HTTPS = form_data.web.search.serpstack_https - app.state.config.SERPER_API_KEY = form_data.web.search.serper_api_key - app.state.config.SERPLY_API_KEY = form_data.web.search.serply_api_key - app.state.config.TAVILY_API_KEY = form_data.web.search.tavily_api_key - app.state.config.SEARCHAPI_API_KEY = form_data.web.search.searchapi_api_key - app.state.config.SEARCHAPI_ENGINE = form_data.web.search.searchapi_engine + request.app.state.config.SERPSTACK_API_KEY = ( + form_data.web.search.serpstack_api_key + ) + request.app.state.config.SERPSTACK_HTTPS = form_data.web.search.serpstack_https + request.app.state.config.SERPER_API_KEY = form_data.web.search.serper_api_key + request.app.state.config.SERPLY_API_KEY = form_data.web.search.serply_api_key + request.app.state.config.TAVILY_API_KEY = form_data.web.search.tavily_api_key + request.app.state.config.SEARCHAPI_API_KEY = ( + form_data.web.search.searchapi_api_key + ) + request.app.state.config.SEARCHAPI_ENGINE = ( + form_data.web.search.searchapi_engine + ) - app.state.config.JINA_API_KEY = form_data.web.search.jina_api_key - app.state.config.BING_SEARCH_V7_ENDPOINT = ( + request.app.state.config.JINA_API_KEY = form_data.web.search.jina_api_key + request.app.state.config.BING_SEARCH_V7_ENDPOINT = ( form_data.web.search.bing_search_v7_endpoint ) - app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY = ( + request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY = ( form_data.web.search.bing_search_v7_subscription_key ) - app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = form_data.web.search.result_count - app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = ( + request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = ( + form_data.web.search.result_count + ) + request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = ( form_data.web.search.concurrent_requests ) return { "status": True, - "pdf_extract_images": app.state.config.PDF_EXTRACT_IMAGES, + "pdf_extract_images": request.app.state.config.PDF_EXTRACT_IMAGES, "file": { - "max_size": app.state.config.FILE_MAX_SIZE, - "max_count": app.state.config.FILE_MAX_COUNT, + "max_size": request.app.state.config.FILE_MAX_SIZE, + "max_count": request.app.state.config.FILE_MAX_COUNT, }, "content_extraction": { - "engine": app.state.config.CONTENT_EXTRACTION_ENGINE, - "tika_server_url": app.state.config.TIKA_SERVER_URL, + "engine": request.app.state.config.CONTENT_EXTRACTION_ENGINE, + "tika_server_url": request.app.state.config.TIKA_SERVER_URL, }, "chunk": { - "text_splitter": app.state.config.TEXT_SPLITTER, - "chunk_size": app.state.config.CHUNK_SIZE, - "chunk_overlap": app.state.config.CHUNK_OVERLAP, + "text_splitter": request.app.state.config.TEXT_SPLITTER, + "chunk_size": request.app.state.config.CHUNK_SIZE, + "chunk_overlap": request.app.state.config.CHUNK_OVERLAP, }, "youtube": { - "language": app.state.config.YOUTUBE_LOADER_LANGUAGE, - "proxy_url": app.state.config.YOUTUBE_LOADER_PROXY_URL, - "translation": app.state.YOUTUBE_LOADER_TRANSLATION, + "language": request.app.state.config.YOUTUBE_LOADER_LANGUAGE, + "proxy_url": request.app.state.config.YOUTUBE_LOADER_PROXY_URL, + "translation": request.app.state.YOUTUBE_LOADER_TRANSLATION, }, "web": { - "web_loader_ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, + "web_loader_ssl_verification": request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, "search": { - "enabled": app.state.config.ENABLE_RAG_WEB_SEARCH, - "engine": app.state.config.RAG_WEB_SEARCH_ENGINE, - "searxng_query_url": app.state.config.SEARXNG_QUERY_URL, - "google_pse_api_key": app.state.config.GOOGLE_PSE_API_KEY, - "google_pse_engine_id": app.state.config.GOOGLE_PSE_ENGINE_ID, - "brave_search_api_key": app.state.config.BRAVE_SEARCH_API_KEY, - "mojeek_search_api_key": app.state.config.MOJEEK_SEARCH_API_KEY, - "serpstack_api_key": app.state.config.SERPSTACK_API_KEY, - "serpstack_https": app.state.config.SERPSTACK_HTTPS, - "serper_api_key": app.state.config.SERPER_API_KEY, - "serply_api_key": app.state.config.SERPLY_API_KEY, - "serachapi_api_key": app.state.config.SEARCHAPI_API_KEY, - "searchapi_engine": app.state.config.SEARCHAPI_ENGINE, - "tavily_api_key": app.state.config.TAVILY_API_KEY, - "jina_api_key": app.state.config.JINA_API_KEY, - "bing_search_v7_endpoint": app.state.config.BING_SEARCH_V7_ENDPOINT, - "bing_search_v7_subscription_key": app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY, - "result_count": app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, - "concurrent_requests": app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS, + "enabled": request.app.state.config.ENABLE_RAG_WEB_SEARCH, + "engine": request.app.state.config.RAG_WEB_SEARCH_ENGINE, + "searxng_query_url": request.app.state.config.SEARXNG_QUERY_URL, + "google_pse_api_key": request.app.state.config.GOOGLE_PSE_API_KEY, + "google_pse_engine_id": request.app.state.config.GOOGLE_PSE_ENGINE_ID, + "brave_search_api_key": request.app.state.config.BRAVE_SEARCH_API_KEY, + "kagi_search_api_key": request.app.state.config.KAGI_SEARCH_API_KEY, + "mojeek_search_api_key": request.app.state.config.MOJEEK_SEARCH_API_KEY, + "serpstack_api_key": request.app.state.config.SERPSTACK_API_KEY, + "serpstack_https": request.app.state.config.SERPSTACK_HTTPS, + "serper_api_key": request.app.state.config.SERPER_API_KEY, + "serply_api_key": request.app.state.config.SERPLY_API_KEY, + "serachapi_api_key": request.app.state.config.SEARCHAPI_API_KEY, + "searchapi_engine": request.app.state.config.SEARCHAPI_ENGINE, + "tavily_api_key": request.app.state.config.TAVILY_API_KEY, + "jina_api_key": request.app.state.config.JINA_API_KEY, + "bing_search_v7_endpoint": request.app.state.config.BING_SEARCH_V7_ENDPOINT, + "bing_search_v7_subscription_key": request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY, + "result_count": request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, + "concurrent_requests": request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS, }, }, } -@app.get("/template") -async def get_rag_template(user=Depends(get_verified_user)): +@router.get("/template") +async def get_rag_template(request: Request, user=Depends(get_verified_user)): return { "status": True, - "template": app.state.config.RAG_TEMPLATE, + "template": request.app.state.config.RAG_TEMPLATE, } -@app.get("/query/settings") -async def get_query_settings(user=Depends(get_admin_user)): +@router.get("/query/settings") +async def get_query_settings(request: Request, user=Depends(get_admin_user)): return { "status": True, - "template": app.state.config.RAG_TEMPLATE, - "k": app.state.config.TOP_K, - "r": app.state.config.RELEVANCE_THRESHOLD, - "hybrid": app.state.config.ENABLE_RAG_HYBRID_SEARCH, + "template": request.app.state.config.RAG_TEMPLATE, + "k": request.app.state.config.TOP_K, + "r": request.app.state.config.RELEVANCE_THRESHOLD, + "hybrid": request.app.state.config.ENABLE_RAG_HYBRID_SEARCH, } @@ -701,24 +615,24 @@ class QuerySettingsForm(BaseModel): hybrid: Optional[bool] = None -@app.post("/query/settings/update") +@router.post("/query/settings/update") async def update_query_settings( - form_data: QuerySettingsForm, user=Depends(get_admin_user) + request: Request, form_data: QuerySettingsForm, user=Depends(get_admin_user) ): - app.state.config.RAG_TEMPLATE = form_data.template - app.state.config.TOP_K = form_data.k if form_data.k else 4 - app.state.config.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0 + request.app.state.config.RAG_TEMPLATE = form_data.template + request.app.state.config.TOP_K = form_data.k if form_data.k else 4 + request.app.state.config.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0 - app.state.config.ENABLE_RAG_HYBRID_SEARCH = ( + request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = ( form_data.hybrid if form_data.hybrid else False ) return { "status": True, - "template": app.state.config.RAG_TEMPLATE, - "k": app.state.config.TOP_K, - "r": app.state.config.RELEVANCE_THRESHOLD, - "hybrid": app.state.config.ENABLE_RAG_HYBRID_SEARCH, + "template": request.app.state.config.RAG_TEMPLATE, + "k": request.app.state.config.TOP_K, + "r": request.app.state.config.RELEVANCE_THRESHOLD, + "hybrid": request.app.state.config.ENABLE_RAG_HYBRID_SEARCH, } @@ -729,24 +643,8 @@ async def update_query_settings( #################################### -def _get_docs_info(docs: list[Document]) -> str: - docs_info = set() - - # Trying to select relevant metadata identifying the document. - for doc in docs: - metadata = getattr(doc, "metadata", {}) - doc_name = metadata.get("name", "") - if not doc_name: - doc_name = metadata.get("title", "") - if not doc_name: - doc_name = metadata.get("source", "") - if doc_name: - docs_info.add(doc_name) - - return ", ".join(docs_info) - - def save_docs_to_vector_db( + request: Request, docs, collection_name, metadata: Optional[dict] = None, @@ -754,6 +652,22 @@ def save_docs_to_vector_db( split: bool = True, add: bool = False, ) -> bool: + def _get_docs_info(docs: list[Document]) -> str: + docs_info = set() + + # Trying to select relevant metadata identifying the document. + for doc in docs: + metadata = getattr(doc, "metadata", {}) + doc_name = metadata.get("name", "") + if not doc_name: + doc_name = metadata.get("title", "") + if not doc_name: + doc_name = metadata.get("source", "") + if doc_name: + docs_info.add(doc_name) + + return ", ".join(docs_info) + log.info( f"save_docs_to_vector_db: document {_get_docs_info(docs)} {collection_name}" ) @@ -772,22 +686,22 @@ def save_docs_to_vector_db( raise ValueError(ERROR_MESSAGES.DUPLICATE_CONTENT) if split: - if app.state.config.TEXT_SPLITTER in ["", "character"]: + if request.app.state.config.TEXT_SPLITTER in ["", "character"]: text_splitter = RecursiveCharacterTextSplitter( - chunk_size=app.state.config.CHUNK_SIZE, - chunk_overlap=app.state.config.CHUNK_OVERLAP, + chunk_size=request.app.state.config.CHUNK_SIZE, + chunk_overlap=request.app.state.config.CHUNK_OVERLAP, add_start_index=True, ) - elif app.state.config.TEXT_SPLITTER == "token": + elif request.app.state.config.TEXT_SPLITTER == "token": log.info( - f"Using token text splitter: {app.state.config.TIKTOKEN_ENCODING_NAME}" + f"Using token text splitter: {request.app.state.config.TIKTOKEN_ENCODING_NAME}" ) - tiktoken.get_encoding(str(app.state.config.TIKTOKEN_ENCODING_NAME)) + tiktoken.get_encoding(str(request.app.state.config.TIKTOKEN_ENCODING_NAME)) text_splitter = TokenTextSplitter( - encoding_name=str(app.state.config.TIKTOKEN_ENCODING_NAME), - chunk_size=app.state.config.CHUNK_SIZE, - chunk_overlap=app.state.config.CHUNK_OVERLAP, + encoding_name=str(request.app.state.config.TIKTOKEN_ENCODING_NAME), + chunk_size=request.app.state.config.CHUNK_SIZE, + chunk_overlap=request.app.state.config.CHUNK_OVERLAP, add_start_index=True, ) else: @@ -805,8 +719,8 @@ def save_docs_to_vector_db( **(metadata if metadata else {}), "embedding_config": json.dumps( { - "engine": app.state.config.RAG_EMBEDDING_ENGINE, - "model": app.state.config.RAG_EMBEDDING_MODEL, + "engine": request.app.state.config.RAG_EMBEDDING_ENGINE, + "model": request.app.state.config.RAG_EMBEDDING_MODEL, } ), } @@ -835,20 +749,20 @@ def save_docs_to_vector_db( log.info(f"adding to collection {collection_name}") embedding_function = get_embedding_function( - app.state.config.RAG_EMBEDDING_ENGINE, - app.state.config.RAG_EMBEDDING_MODEL, - app.state.sentence_transformer_ef, + request.app.state.config.RAG_EMBEDDING_ENGINE, + request.app.state.config.RAG_EMBEDDING_MODEL, + request.app.state.ef, ( - app.state.config.OPENAI_API_BASE_URL - if app.state.config.RAG_EMBEDDING_ENGINE == "openai" - else app.state.config.OLLAMA_BASE_URL + request.app.state.config.RAG_OPENAI_API_BASE_URL + if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai" + else request.app.state.config.RAG_OLLAMA_BASE_URL ), ( - app.state.config.OPENAI_API_KEY - if app.state.config.RAG_EMBEDDING_ENGINE == "openai" - else app.state.config.OLLAMA_API_KEY + request.app.state.config.RAG_OPENAI_API_KEY + if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai" + else request.app.state.config.RAG_OLLAMA_API_KEY ), - app.state.config.RAG_EMBEDDING_BATCH_SIZE, + request.app.state.config.RAG_EMBEDDING_BATCH_SIZE, ) embeddings = embedding_function( @@ -882,8 +796,9 @@ class ProcessFileForm(BaseModel): collection_name: Optional[str] = None -@app.post("/process/file") +@router.post("/process/file") def process_file( + request: Request, form_data: ProcessFileForm, user=Depends(get_verified_user), ): @@ -953,9 +868,9 @@ def process_file( if file_path: file_path = Storage.get_file(file_path) loader = Loader( - engine=app.state.config.CONTENT_EXTRACTION_ENGINE, - TIKA_SERVER_URL=app.state.config.TIKA_SERVER_URL, - PDF_EXTRACT_IMAGES=app.state.config.PDF_EXTRACT_IMAGES, + engine=request.app.state.config.CONTENT_EXTRACTION_ENGINE, + TIKA_SERVER_URL=request.app.state.config.TIKA_SERVER_URL, + PDF_EXTRACT_IMAGES=request.app.state.config.PDF_EXTRACT_IMAGES, ) docs = loader.load( file.filename, file.meta.get("content_type"), file_path @@ -1000,6 +915,7 @@ def process_file( try: result = save_docs_to_vector_db( + request, docs=docs, collection_name=collection_name, metadata={ @@ -1046,8 +962,9 @@ class ProcessTextForm(BaseModel): collection_name: Optional[str] = None -@app.post("/process/text") +@router.post("/process/text") def process_text( + request: Request, form_data: ProcessTextForm, user=Depends(get_verified_user), ): @@ -1064,8 +981,7 @@ def process_text( text_content = form_data.content log.debug(f"text_content: {text_content}") - result = save_docs_to_vector_db(docs, collection_name) - + result = save_docs_to_vector_db(request, docs, collection_name) if result: return { "status": True, @@ -1079,8 +995,10 @@ def process_text( ) -@app.post("/process/youtube") -def process_youtube_video(form_data: ProcessUrlForm, user=Depends(get_verified_user)): +@router.post("/process/youtube") +def process_youtube_video( + request: Request, form_data: ProcessUrlForm, user=Depends(get_verified_user) +): try: collection_name = form_data.collection_name if not collection_name: @@ -1088,14 +1006,15 @@ def process_youtube_video(form_data: ProcessUrlForm, user=Depends(get_verified_u loader = YoutubeLoader( form_data.url, - language=app.state.config.YOUTUBE_LOADER_LANGUAGE, - proxy_url=app.state.config.YOUTUBE_LOADER_PROXY_URL, + language=request.app.state.config.YOUTUBE_LOADER_LANGUAGE, + proxy_url=request.app.state.config.YOUTUBE_LOADER_PROXY_URL, ) docs = loader.load() content = " ".join([doc.page_content for doc in docs]) log.debug(f"text_content: {content}") - save_docs_to_vector_db(docs, collection_name, overwrite=True) + + save_docs_to_vector_db(request, docs, collection_name, overwrite=True) return { "status": True, @@ -1118,8 +1037,10 @@ def process_youtube_video(form_data: ProcessUrlForm, user=Depends(get_verified_u ) -@app.post("/process/web") -def process_web(form_data: ProcessUrlForm, user=Depends(get_verified_user)): +@router.post("/process/web") +def process_web( + request: Request, form_data: ProcessUrlForm, user=Depends(get_verified_user) +): try: collection_name = form_data.collection_name if not collection_name: @@ -1127,13 +1048,14 @@ def process_web(form_data: ProcessUrlForm, user=Depends(get_verified_user)): loader = get_web_loader( form_data.url, - verify_ssl=app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, - requests_per_second=app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS, + verify_ssl=request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, + requests_per_second=request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS, ) docs = loader.load() content = " ".join([doc.page_content for doc in docs]) + log.debug(f"text_content: {content}") - save_docs_to_vector_db(docs, collection_name, overwrite=True) + save_docs_to_vector_db(request, docs, collection_name, overwrite=True) return { "status": True, @@ -1156,12 +1078,13 @@ def process_web(form_data: ProcessUrlForm, user=Depends(get_verified_user)): ) -def search_web(engine: str, query: str) -> list[SearchResult]: +def search_web(request: Request, engine: str, query: str) -> list[SearchResult]: """Search the web using a search engine and return the results as a list of SearchResult objects. Will look for a search engine API key in environment variables in the following order: - SEARXNG_QUERY_URL - GOOGLE_PSE_API_KEY + GOOGLE_PSE_ENGINE_ID - BRAVE_SEARCH_API_KEY + - KAGI_SEARCH_API_KEY - MOJEEK_SEARCH_API_KEY - SERPSTACK_API_KEY - SERPER_API_KEY @@ -1174,140 +1097,151 @@ def search_web(engine: str, query: str) -> list[SearchResult]: # TODO: add playwright to search the web if engine == "searxng": - if app.state.config.SEARXNG_QUERY_URL: + if request.app.state.config.SEARXNG_QUERY_URL: return search_searxng( - app.state.config.SEARXNG_QUERY_URL, + request.app.state.config.SEARXNG_QUERY_URL, query, - app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, - app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, + request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, + request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, ) else: raise Exception("No SEARXNG_QUERY_URL found in environment variables") elif engine == "google_pse": if ( - app.state.config.GOOGLE_PSE_API_KEY - and app.state.config.GOOGLE_PSE_ENGINE_ID + request.app.state.config.GOOGLE_PSE_API_KEY + and request.app.state.config.GOOGLE_PSE_ENGINE_ID ): return search_google_pse( - app.state.config.GOOGLE_PSE_API_KEY, - app.state.config.GOOGLE_PSE_ENGINE_ID, + request.app.state.config.GOOGLE_PSE_API_KEY, + request.app.state.config.GOOGLE_PSE_ENGINE_ID, query, - app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, - app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, + request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, + request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, ) else: raise Exception( "No GOOGLE_PSE_API_KEY or GOOGLE_PSE_ENGINE_ID found in environment variables" ) elif engine == "brave": - if app.state.config.BRAVE_SEARCH_API_KEY: + if request.app.state.config.BRAVE_SEARCH_API_KEY: return search_brave( - app.state.config.BRAVE_SEARCH_API_KEY, + request.app.state.config.BRAVE_SEARCH_API_KEY, query, - app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, - app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, + request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, + request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, ) else: raise Exception("No BRAVE_SEARCH_API_KEY found in environment variables") - elif engine == "mojeek": - if app.state.config.MOJEEK_SEARCH_API_KEY: - return search_mojeek( - app.state.config.MOJEEK_SEARCH_API_KEY, + elif engine == "kagi": + if request.app.state.config.KAGI_SEARCH_API_KEY: + return search_kagi( + request.app.state.config.KAGI_SEARCH_API_KEY, query, - app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, - app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, + request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, + request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, + ) + else: + raise Exception("No KAGI_SEARCH_API_KEY found in environment variables") + elif engine == "mojeek": + if request.app.state.config.MOJEEK_SEARCH_API_KEY: + return search_mojeek( + request.app.state.config.MOJEEK_SEARCH_API_KEY, + query, + request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, + request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, ) else: raise Exception("No MOJEEK_SEARCH_API_KEY found in environment variables") elif engine == "serpstack": - if app.state.config.SERPSTACK_API_KEY: + if request.app.state.config.SERPSTACK_API_KEY: return search_serpstack( - app.state.config.SERPSTACK_API_KEY, + request.app.state.config.SERPSTACK_API_KEY, query, - app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, - app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, - https_enabled=app.state.config.SERPSTACK_HTTPS, + request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, + request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, + https_enabled=request.app.state.config.SERPSTACK_HTTPS, ) else: raise Exception("No SERPSTACK_API_KEY found in environment variables") elif engine == "serper": - if app.state.config.SERPER_API_KEY: + if request.app.state.config.SERPER_API_KEY: return search_serper( - app.state.config.SERPER_API_KEY, + request.app.state.config.SERPER_API_KEY, query, - app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, - app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, + request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, + request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, ) else: raise Exception("No SERPER_API_KEY found in environment variables") elif engine == "serply": - if app.state.config.SERPLY_API_KEY: + if request.app.state.config.SERPLY_API_KEY: return search_serply( - app.state.config.SERPLY_API_KEY, + request.app.state.config.SERPLY_API_KEY, query, - app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, - app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, + request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, + request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, ) else: raise Exception("No SERPLY_API_KEY found in environment variables") elif engine == "duckduckgo": return search_duckduckgo( query, - app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, - app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, + request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, + request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, ) elif engine == "tavily": - if app.state.config.TAVILY_API_KEY: + if request.app.state.config.TAVILY_API_KEY: return search_tavily( - app.state.config.TAVILY_API_KEY, + request.app.state.config.TAVILY_API_KEY, query, - app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, + request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, ) else: raise Exception("No TAVILY_API_KEY found in environment variables") elif engine == "searchapi": - if app.state.config.SEARCHAPI_API_KEY: + if request.app.state.config.SEARCHAPI_API_KEY: return search_searchapi( - app.state.config.SEARCHAPI_API_KEY, - app.state.config.SEARCHAPI_ENGINE, + request.app.state.config.SEARCHAPI_API_KEY, + request.app.state.config.SEARCHAPI_ENGINE, query, - app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, - app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, + request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, + request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, ) else: raise Exception("No SEARCHAPI_API_KEY found in environment variables") elif engine == "jina": return search_jina( - app.state.config.JINA_API_KEY, + request.app.state.config.JINA_API_KEY, query, - app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, + request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, ) elif engine == "bing": return search_bing( - app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY, - app.state.config.BING_SEARCH_V7_ENDPOINT, + request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY, + request.app.state.config.BING_SEARCH_V7_ENDPOINT, str(DEFAULT_LOCALE), query, - app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, - app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, + request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, + request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, ) else: raise Exception("No search engine API key found in environment variables") -@app.post("/process/web/search") -def process_web_search(form_data: SearchForm, user=Depends(get_verified_user)): +@router.post("/process/web/search") +def process_web_search( + request: Request, form_data: SearchForm, user=Depends(get_verified_user) +): try: logging.info( - f"trying to web search with {app.state.config.RAG_WEB_SEARCH_ENGINE, form_data.query}" + f"trying to web search with {request.app.state.config.RAG_WEB_SEARCH_ENGINE, form_data.query}" ) web_results = search_web( - app.state.config.RAG_WEB_SEARCH_ENGINE, form_data.query + request, request.app.state.config.RAG_WEB_SEARCH_ENGINE, form_data.query ) except Exception as e: log.exception(e) - print(e) raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.WEB_SEARCH_ERROR(e), @@ -1316,18 +1250,19 @@ def process_web_search(form_data: SearchForm, user=Depends(get_verified_user)): try: collection_name = form_data.collection_name if collection_name == "": - collection_name = calculate_sha256_string(form_data.query)[:63] + collection_name = f"web-search-{calculate_sha256_string(form_data.query)}"[ + :63 + ] urls = [result.link for result in web_results] - loader = get_web_loader( - urls, - verify_ssl=app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, - requests_per_second=app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS, + urls=urls, + verify_ssl=request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, + requests_per_second=request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS, ) docs = loader.aload() - save_docs_to_vector_db(docs, collection_name, overwrite=True) + save_docs_to_vector_db(request, docs, collection_name, overwrite=True) return { "status": True, @@ -1350,29 +1285,31 @@ class QueryDocForm(BaseModel): hybrid: Optional[bool] = None -@app.post("/query/doc") +@router.post("/query/doc") def query_doc_handler( + request: Request, form_data: QueryDocForm, user=Depends(get_verified_user), ): try: - if app.state.config.ENABLE_RAG_HYBRID_SEARCH: + if request.app.state.config.ENABLE_RAG_HYBRID_SEARCH: return query_doc_with_hybrid_search( collection_name=form_data.collection_name, query=form_data.query, - embedding_function=app.state.EMBEDDING_FUNCTION, - k=form_data.k if form_data.k else app.state.config.TOP_K, - reranking_function=app.state.sentence_transformer_rf, + embedding_function=request.app.state.EMBEDDING_FUNCTION, + k=form_data.k if form_data.k else request.app.state.config.TOP_K, + reranking_function=request.app.state.rf, r=( - form_data.r if form_data.r else app.state.config.RELEVANCE_THRESHOLD + form_data.r + if form_data.r + else request.app.state.config.RELEVANCE_THRESHOLD ), ) else: return query_doc( collection_name=form_data.collection_name, - query=form_data.query, - embedding_function=app.state.EMBEDDING_FUNCTION, - k=form_data.k if form_data.k else app.state.config.TOP_K, + query_embedding=request.app.state.EMBEDDING_FUNCTION(form_data.query), + k=form_data.k if form_data.k else request.app.state.config.TOP_K, ) except Exception as e: log.exception(e) @@ -1390,29 +1327,32 @@ class QueryCollectionsForm(BaseModel): hybrid: Optional[bool] = None -@app.post("/query/collection") +@router.post("/query/collection") def query_collection_handler( + request: Request, form_data: QueryCollectionsForm, user=Depends(get_verified_user), ): try: - if app.state.config.ENABLE_RAG_HYBRID_SEARCH: + if request.app.state.config.ENABLE_RAG_HYBRID_SEARCH: return query_collection_with_hybrid_search( collection_names=form_data.collection_names, queries=[form_data.query], - embedding_function=app.state.EMBEDDING_FUNCTION, - k=form_data.k if form_data.k else app.state.config.TOP_K, - reranking_function=app.state.sentence_transformer_rf, + embedding_function=request.app.state.EMBEDDING_FUNCTION, + k=form_data.k if form_data.k else request.app.state.config.TOP_K, + reranking_function=request.app.state.rf, r=( - form_data.r if form_data.r else app.state.config.RELEVANCE_THRESHOLD + form_data.r + if form_data.r + else request.app.state.config.RELEVANCE_THRESHOLD ), ) else: return query_collection( collection_names=form_data.collection_names, queries=[form_data.query], - embedding_function=app.state.EMBEDDING_FUNCTION, - k=form_data.k if form_data.k else app.state.config.TOP_K, + embedding_function=request.app.state.EMBEDDING_FUNCTION, + k=form_data.k if form_data.k else request.app.state.config.TOP_K, ) except Exception as e: @@ -1435,7 +1375,7 @@ class DeleteForm(BaseModel): file_id: str -@app.post("/delete") +@router.post("/delete") def delete_entries_from_collection(form_data: DeleteForm, user=Depends(get_admin_user)): try: if VECTOR_DB_CLIENT.has_collection(collection_name=form_data.collection_name): @@ -1454,13 +1394,13 @@ def delete_entries_from_collection(form_data: DeleteForm, user=Depends(get_admin return {"status": False} -@app.post("/reset/db") +@router.post("/reset/db") def reset_vector_db(user=Depends(get_admin_user)): VECTOR_DB_CLIENT.reset() Knowledges.delete_all_knowledge() -@app.post("/reset/uploads") +@router.post("/reset/uploads") def reset_upload_dir(user=Depends(get_admin_user)) -> bool: folder = f"{UPLOAD_DIR}" try: @@ -1485,10 +1425,6 @@ def reset_upload_dir(user=Depends(get_admin_user)) -> bool: if ENV == "dev": - @app.get("/ef") - async def get_embeddings(): - return {"result": app.state.EMBEDDING_FUNCTION("hello world")} - - @app.get("/ef/{text}") - async def get_embeddings_text(text: str): - return {"result": app.state.EMBEDDING_FUNCTION(text)} + @router.get("/ef/{text}") + async def get_embeddings(request: Request, text: Optional[str] = "Hello World!"): + return {"result": request.app.state.EMBEDDING_FUNCTION(text)} diff --git a/backend/open_webui/routers/tasks.py b/backend/open_webui/routers/tasks.py new file mode 100644 index 000000000..a2a6cdc92 --- /dev/null +++ b/backend/open_webui/routers/tasks.py @@ -0,0 +1,512 @@ +from fastapi import APIRouter, Depends, HTTPException, Response, status, Request +from fastapi.responses import JSONResponse, RedirectResponse + +from pydantic import BaseModel +from typing import Optional +import logging + +from open_webui.utils.chat import generate_chat_completion +from open_webui.utils.task import ( + title_generation_template, + query_generation_template, + autocomplete_generation_template, + tags_generation_template, + emoji_generation_template, + moa_response_generation_template, +) +from open_webui.utils.auth import get_admin_user, get_verified_user +from open_webui.constants import TASKS + +from open_webui.routers.pipelines import process_pipeline_inlet_filter +from open_webui.utils.task import get_task_model_id + +from open_webui.config import ( + DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE, + DEFAULT_TAGS_GENERATION_PROMPT_TEMPLATE, + DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE, + DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE, + DEFAULT_EMOJI_GENERATION_PROMPT_TEMPLATE, + DEFAULT_MOA_GENERATION_PROMPT_TEMPLATE, +) +from open_webui.env import SRC_LOG_LEVELS + + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["MODELS"]) + +router = APIRouter() + + +################################## +# +# Task Endpoints +# +################################## + + +@router.get("/config") +async def get_task_config(request: Request, user=Depends(get_verified_user)): + return { + "TASK_MODEL": request.app.state.config.TASK_MODEL, + "TASK_MODEL_EXTERNAL": request.app.state.config.TASK_MODEL_EXTERNAL, + "TITLE_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE, + "ENABLE_AUTOCOMPLETE_GENERATION": request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION, + "AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH, + "TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE, + "ENABLE_TAGS_GENERATION": request.app.state.config.ENABLE_TAGS_GENERATION, + "ENABLE_SEARCH_QUERY_GENERATION": request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION, + "ENABLE_RETRIEVAL_QUERY_GENERATION": request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION, + "QUERY_GENERATION_PROMPT_TEMPLATE": request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE, + "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, + } + + +class TaskConfigForm(BaseModel): + TASK_MODEL: Optional[str] + TASK_MODEL_EXTERNAL: Optional[str] + TITLE_GENERATION_PROMPT_TEMPLATE: str + ENABLE_AUTOCOMPLETE_GENERATION: bool + AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH: int + TAGS_GENERATION_PROMPT_TEMPLATE: str + ENABLE_TAGS_GENERATION: bool + ENABLE_SEARCH_QUERY_GENERATION: bool + ENABLE_RETRIEVAL_QUERY_GENERATION: bool + QUERY_GENERATION_PROMPT_TEMPLATE: str + TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE: str + + +@router.post("/config/update") +async def update_task_config( + request: Request, form_data: TaskConfigForm, user=Depends(get_admin_user) +): + request.app.state.config.TASK_MODEL = form_data.TASK_MODEL + request.app.state.config.TASK_MODEL_EXTERNAL = form_data.TASK_MODEL_EXTERNAL + request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = ( + form_data.TITLE_GENERATION_PROMPT_TEMPLATE + ) + + request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION = ( + form_data.ENABLE_AUTOCOMPLETE_GENERATION + ) + request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH = ( + form_data.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH + ) + + request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = ( + form_data.TAGS_GENERATION_PROMPT_TEMPLATE + ) + request.app.state.config.ENABLE_TAGS_GENERATION = form_data.ENABLE_TAGS_GENERATION + request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION = ( + form_data.ENABLE_SEARCH_QUERY_GENERATION + ) + request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION = ( + form_data.ENABLE_RETRIEVAL_QUERY_GENERATION + ) + + request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE = ( + form_data.QUERY_GENERATION_PROMPT_TEMPLATE + ) + request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = ( + form_data.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE + ) + + return { + "TASK_MODEL": request.app.state.config.TASK_MODEL, + "TASK_MODEL_EXTERNAL": request.app.state.config.TASK_MODEL_EXTERNAL, + "TITLE_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE, + "ENABLE_AUTOCOMPLETE_GENERATION": request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION, + "AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH, + "TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE, + "ENABLE_TAGS_GENERATION": request.app.state.config.ENABLE_TAGS_GENERATION, + "ENABLE_SEARCH_QUERY_GENERATION": request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION, + "ENABLE_RETRIEVAL_QUERY_GENERATION": request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION, + "QUERY_GENERATION_PROMPT_TEMPLATE": request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE, + "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, + } + + +@router.post("/title/completions") +async def generate_title( + request: Request, form_data: dict, user=Depends(get_verified_user) +): + models = request.app.state.MODELS + + model_id = form_data["model"] + if model_id not in models: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Model not found", + ) + + # Check if the user has a custom task model + # If the user has a custom task model, use that model + task_model_id = get_task_model_id( + model_id, + request.app.state.config.TASK_MODEL, + request.app.state.config.TASK_MODEL_EXTERNAL, + models, + ) + + log.debug( + f"generating chat title using model {task_model_id} for user {user.email} " + ) + + if request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE != "": + template = request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE + else: + template = DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE + + content = title_generation_template( + template, + form_data["messages"], + { + "name": user.name, + "location": user.info.get("location") if user.info else None, + }, + ) + + payload = { + "model": task_model_id, + "messages": [{"role": "user", "content": content}], + "stream": False, + **( + {"max_tokens": 50} + if models[task_model_id]["owned_by"] == "ollama" + else { + "max_completion_tokens": 50, + } + ), + "metadata": { + "task": str(TASKS.TITLE_GENERATION), + "task_body": form_data, + "chat_id": form_data.get("chat_id", None), + }, + } + + try: + return await generate_chat_completion(request, form_data=payload, user=user) + except Exception as e: + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"detail": str(e)}, + ) + + +@router.post("/tags/completions") +async def generate_chat_tags( + request: Request, form_data: dict, user=Depends(get_verified_user) +): + + if not request.app.state.config.ENABLE_TAGS_GENERATION: + return JSONResponse( + status_code=status.HTTP_200_OK, + content={"detail": "Tags generation is disabled"}, + ) + + models = request.app.state.MODELS + + model_id = form_data["model"] + if model_id not in models: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Model not found", + ) + + # Check if the user has a custom task model + # If the user has a custom task model, use that model + task_model_id = get_task_model_id( + model_id, + request.app.state.config.TASK_MODEL, + request.app.state.config.TASK_MODEL_EXTERNAL, + models, + ) + + log.debug( + f"generating chat tags using model {task_model_id} for user {user.email} " + ) + + if request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE != "": + template = request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE + else: + template = DEFAULT_TAGS_GENERATION_PROMPT_TEMPLATE + + content = tags_generation_template( + template, form_data["messages"], {"name": user.name} + ) + + payload = { + "model": task_model_id, + "messages": [{"role": "user", "content": content}], + "stream": False, + "metadata": { + "task": str(TASKS.TAGS_GENERATION), + "task_body": form_data, + "chat_id": form_data.get("chat_id", None), + }, + } + + try: + return await generate_chat_completion(request, form_data=payload, user=user) + except Exception as e: + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"detail": str(e)}, + ) + + +@router.post("/queries/completions") +async def generate_queries( + request: Request, form_data: dict, user=Depends(get_verified_user) +): + + type = form_data.get("type") + if type == "web_search": + if not request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Search query generation is disabled", + ) + elif type == "retrieval": + if not request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Query generation is disabled", + ) + + models = request.app.state.MODELS + + model_id = form_data["model"] + if model_id not in models: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Model not found", + ) + + # Check if the user has a custom task model + # If the user has a custom task model, use that model + task_model_id = get_task_model_id( + model_id, + request.app.state.config.TASK_MODEL, + request.app.state.config.TASK_MODEL_EXTERNAL, + models, + ) + + log.debug( + f"generating {type} queries using model {task_model_id} for user {user.email}" + ) + + if (request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE).strip() != "": + template = request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE + else: + template = DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE + + content = query_generation_template( + template, form_data["messages"], {"name": user.name} + ) + + payload = { + "model": task_model_id, + "messages": [{"role": "user", "content": content}], + "stream": False, + "metadata": { + "task": str(TASKS.QUERY_GENERATION), + "task_body": form_data, + "chat_id": form_data.get("chat_id", None), + }, + } + + try: + return await generate_chat_completion(request, form_data=payload, user=user) + except Exception as e: + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"detail": str(e)}, + ) + + +@router.post("/auto/completions") +async def generate_autocompletion( + request: Request, form_data: dict, user=Depends(get_verified_user) +): + if not request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Autocompletion generation is disabled", + ) + + type = form_data.get("type") + prompt = form_data.get("prompt") + messages = form_data.get("messages") + + if request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH > 0: + if ( + len(prompt) + > request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH + ): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Input prompt exceeds maximum length of {request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH}", + ) + + models = request.app.state.MODELS + + model_id = form_data["model"] + if model_id not in models: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Model not found", + ) + + # Check if the user has a custom task model + # If the user has a custom task model, use that model + task_model_id = get_task_model_id( + model_id, + request.app.state.config.TASK_MODEL, + request.app.state.config.TASK_MODEL_EXTERNAL, + models, + ) + + log.debug( + f"generating autocompletion using model {task_model_id} for user {user.email}" + ) + + if (request.app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE).strip() != "": + template = request.app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE + else: + template = DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE + + content = autocomplete_generation_template( + template, prompt, messages, type, {"name": user.name} + ) + + payload = { + "model": task_model_id, + "messages": [{"role": "user", "content": content}], + "stream": False, + "metadata": { + "task": str(TASKS.AUTOCOMPLETE_GENERATION), + "task_body": form_data, + "chat_id": form_data.get("chat_id", None), + }, + } + + try: + return await generate_chat_completion(request, form_data=payload, user=user) + except Exception as e: + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"detail": str(e)}, + ) + + +@router.post("/emoji/completions") +async def generate_emoji( + request: Request, form_data: dict, user=Depends(get_verified_user) +): + + models = request.app.state.MODELS + + model_id = form_data["model"] + if model_id not in models: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Model not found", + ) + + # Check if the user has a custom task model + # If the user has a custom task model, use that model + task_model_id = get_task_model_id( + model_id, + request.app.state.config.TASK_MODEL, + request.app.state.config.TASK_MODEL_EXTERNAL, + models, + ) + + log.debug(f"generating emoji using model {task_model_id} for user {user.email} ") + + template = DEFAULT_EMOJI_GENERATION_PROMPT_TEMPLATE + + content = emoji_generation_template( + template, + form_data["prompt"], + { + "name": user.name, + "location": user.info.get("location") if user.info else None, + }, + ) + + payload = { + "model": task_model_id, + "messages": [{"role": "user", "content": content}], + "stream": False, + **( + {"max_tokens": 4} + if models[task_model_id]["owned_by"] == "ollama" + else { + "max_completion_tokens": 4, + } + ), + "chat_id": form_data.get("chat_id", None), + "metadata": {"task": str(TASKS.EMOJI_GENERATION), "task_body": form_data}, + } + + try: + return await generate_chat_completion(request, form_data=payload, user=user) + except Exception as e: + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"detail": str(e)}, + ) + + +@router.post("/moa/completions") +async def generate_moa_response( + request: Request, form_data: dict, user=Depends(get_verified_user) +): + + models = request.app.state.MODELS + model_id = form_data["model"] + + if model_id not in models: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Model not found", + ) + + # Check if the user has a custom task model + # If the user has a custom task model, use that model + task_model_id = get_task_model_id( + model_id, + request.app.state.config.TASK_MODEL, + request.app.state.config.TASK_MODEL_EXTERNAL, + models, + ) + + log.debug(f"generating MOA model {task_model_id} for user {user.email} ") + + template = DEFAULT_MOA_GENERATION_PROMPT_TEMPLATE + + content = moa_response_generation_template( + template, + form_data["prompt"], + form_data["responses"], + ) + + payload = { + "model": task_model_id, + "messages": [{"role": "user", "content": content}], + "stream": form_data.get("stream", False), + "chat_id": form_data.get("chat_id", None), + "metadata": { + "task": str(TASKS.MOA_RESPONSE_GENERATION), + "task_body": form_data, + }, + } + + try: + return await generate_chat_completion(request, form_data=payload, user=user) + except Exception as e: + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"detail": str(e)}, + ) diff --git a/backend/open_webui/apps/webui/routers/tools.py b/backend/open_webui/routers/tools.py similarity index 98% rename from backend/open_webui/apps/webui/routers/tools.py rename to backend/open_webui/routers/tools.py index d0523ddac..9e95ebe5a 100644 --- a/backend/open_webui/apps/webui/routers/tools.py +++ b/backend/open_webui/routers/tools.py @@ -1,19 +1,19 @@ from pathlib import Path from typing import Optional -from open_webui.apps.webui.models.tools import ( +from open_webui.models.tools import ( ToolForm, ToolModel, ToolResponse, ToolUserResponse, Tools, ) -from open_webui.apps.webui.utils import load_tools_module_by_id, replace_imports +from open_webui.utils.plugin import load_tools_module_by_id, replace_imports from open_webui.config import CACHE_DIR from open_webui.constants import ERROR_MESSAGES from fastapi import APIRouter, Depends, HTTPException, Request, status from open_webui.utils.tools import get_tools_specs -from open_webui.utils.utils import get_admin_user, get_verified_user +from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.utils.access_control import has_access, has_permission diff --git a/backend/open_webui/apps/webui/routers/users.py b/backend/open_webui/routers/users.py similarity index 97% rename from backend/open_webui/apps/webui/routers/users.py rename to backend/open_webui/routers/users.py index b6b91a5c3..1206d56f2 100644 --- a/backend/open_webui/apps/webui/routers/users.py +++ b/backend/open_webui/routers/users.py @@ -1,9 +1,9 @@ import logging from typing import Optional -from open_webui.apps.webui.models.auths import Auths -from open_webui.apps.webui.models.chats import Chats -from open_webui.apps.webui.models.users import ( +from open_webui.models.auths import Auths +from open_webui.models.chats import Chats +from open_webui.models.users import ( UserModel, UserRoleUpdateForm, Users, @@ -14,7 +14,7 @@ from open_webui.constants import ERROR_MESSAGES from open_webui.env import SRC_LOG_LEVELS from fastapi import APIRouter, Depends, HTTPException, Request, status from pydantic import BaseModel -from open_webui.utils.utils import get_admin_user, get_password_hash, get_verified_user +from open_webui.utils.auth import get_admin_user, get_password_hash, get_verified_user log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MODELS"]) diff --git a/backend/open_webui/apps/webui/routers/utils.py b/backend/open_webui/routers/utils.py similarity index 93% rename from backend/open_webui/apps/webui/routers/utils.py rename to backend/open_webui/routers/utils.py index 0ab0f6b15..ea73e9759 100644 --- a/backend/open_webui/apps/webui/routers/utils.py +++ b/backend/open_webui/routers/utils.py @@ -1,7 +1,7 @@ import black import markdown -from open_webui.apps.webui.models.chats import ChatTitleMessagesForm +from open_webui.models.chats import ChatTitleMessagesForm from open_webui.config import DATA_DIR, ENABLE_ADMIN_EXPORT from open_webui.constants import ERROR_MESSAGES from fastapi import APIRouter, Depends, HTTPException, Response, status @@ -9,7 +9,7 @@ from pydantic import BaseModel from starlette.responses import FileResponse from open_webui.utils.misc import get_gravatar_url from open_webui.utils.pdf_generator import PDFGenerator -from open_webui.utils.utils import get_admin_user +from open_webui.utils.auth import get_admin_user router = APIRouter() @@ -76,7 +76,7 @@ async def download_db(user=Depends(get_admin_user)): status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED, ) - from open_webui.apps.webui.internal.db import engine + from open_webui.internal.db import engine if engine.name != "sqlite": raise HTTPException( diff --git a/backend/open_webui/apps/socket/main.py b/backend/open_webui/socket/main.py similarity index 96% rename from backend/open_webui/apps/socket/main.py rename to backend/open_webui/socket/main.py index 5c284f18d..8343be666 100644 --- a/backend/open_webui/apps/socket/main.py +++ b/backend/open_webui/socket/main.py @@ -1,19 +1,17 @@ -# TODO: move socket to webui app - import asyncio import socketio import logging import sys import time -from open_webui.apps.webui.models.users import Users +from open_webui.models.users import Users from open_webui.env import ( ENABLE_WEBSOCKET_SUPPORT, WEBSOCKET_MANAGER, WEBSOCKET_REDIS_URL, ) -from open_webui.utils.utils import decode_token -from open_webui.apps.socket.utils import RedisDict +from open_webui.utils.auth import decode_token +from open_webui.socket.utils import RedisDict from open_webui.env import ( GLOBAL_LOG_LEVEL, @@ -173,6 +171,11 @@ async def user_count(sid): await sio.emit("user-count", {"count": len(USER_POOL.items())}) +@sio.on("chat") +async def chat(sid, data): + print("chat", sid, SESSION_POOL[sid], data) + + @sio.event async def disconnect(sid): if sid in SESSION_POOL: diff --git a/backend/open_webui/apps/socket/utils.py b/backend/open_webui/socket/utils.py similarity index 100% rename from backend/open_webui/apps/socket/utils.py rename to backend/open_webui/socket/utils.py diff --git a/backend/open_webui/static/assets/pdf-style.css b/backend/open_webui/static/assets/pdf-style.css index db9ac83dd..85c36271c 100644 --- a/backend/open_webui/static/assets/pdf-style.css +++ b/backend/open_webui/static/assets/pdf-style.css @@ -26,7 +26,7 @@ html { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', 'NotoSans', 'NotoSansJP', 'NotoSansKR', - 'NotoSansSC', 'STSong-Light', 'MSung-Light', 'HeiseiMin-W3', 'HYSMyeongJo-Medium', Roboto, + 'NotoSansSC', 'Twemoji', 'STSong-Light', 'MSung-Light', 'HeiseiMin-W3', 'HYSMyeongJo-Medium', Roboto, 'Helvetica Neue', Arial, sans-serif; font-size: 14px; /* Default font size */ line-height: 1.5; @@ -40,7 +40,7 @@ html { body { margin: 0; - color: #212529; + padding: 0; background-color: #fff; width: auto; } diff --git a/backend/open_webui/static/fonts/Twemoji.ttf b/backend/open_webui/static/fonts/Twemoji.ttf new file mode 100644 index 000000000..281d356d9 Binary files /dev/null and b/backend/open_webui/static/fonts/Twemoji.ttf differ diff --git a/backend/open_webui/test/apps/webui/routers/test_auths.py b/backend/open_webui/test/apps/webui/routers/test_auths.py index bc14fb8dd..f0f69e26d 100644 --- a/backend/open_webui/test/apps/webui/routers/test_auths.py +++ b/backend/open_webui/test/apps/webui/routers/test_auths.py @@ -7,8 +7,8 @@ class TestAuths(AbstractPostgresTest): def setup_class(cls): super().setup_class() - from open_webui.apps.webui.models.auths import Auths - from open_webui.apps.webui.models.users import Users + from open_webui.models.auths import Auths + from open_webui.models.users import Users cls.users = Users cls.auths = Auths @@ -26,7 +26,7 @@ class TestAuths(AbstractPostgresTest): } def test_update_profile(self): - from open_webui.utils.utils import get_password_hash + from open_webui.utils.auth import get_password_hash user = self.auths.insert_new_auth( email="john.doe@openwebui.com", @@ -47,7 +47,7 @@ class TestAuths(AbstractPostgresTest): assert db_user.profile_image_url == "/user2.png" def test_update_password(self): - from open_webui.utils.utils import get_password_hash + from open_webui.utils.auth import get_password_hash user = self.auths.insert_new_auth( email="john.doe@openwebui.com", @@ -74,7 +74,7 @@ class TestAuths(AbstractPostgresTest): assert new_auth is not None def test_signin(self): - from open_webui.utils.utils import get_password_hash + from open_webui.utils.auth import get_password_hash user = self.auths.insert_new_auth( email="john.doe@openwebui.com", diff --git a/backend/open_webui/test/apps/webui/routers/test_chats.py b/backend/open_webui/test/apps/webui/routers/test_chats.py index 935316fd8..a36a01fb1 100644 --- a/backend/open_webui/test/apps/webui/routers/test_chats.py +++ b/backend/open_webui/test/apps/webui/routers/test_chats.py @@ -12,7 +12,7 @@ class TestChats(AbstractPostgresTest): def setup_method(self): super().setup_method() - from open_webui.apps.webui.models.chats import ChatForm, Chats + from open_webui.models.chats import ChatForm, Chats self.chats = Chats self.chats.insert_new_chat( @@ -88,7 +88,7 @@ class TestChats(AbstractPostgresTest): def test_get_user_archived_chats(self): self.chats.archive_all_chats_by_user_id("2") - from open_webui.apps.webui.internal.db import Session + from open_webui.internal.db import Session Session.commit() with mock_webui_user(id="2"): diff --git a/backend/open_webui/test/apps/webui/routers/test_models.py b/backend/open_webui/test/apps/webui/routers/test_models.py index 1d52658b8..c16ca9d07 100644 --- a/backend/open_webui/test/apps/webui/routers/test_models.py +++ b/backend/open_webui/test/apps/webui/routers/test_models.py @@ -7,7 +7,7 @@ class TestModels(AbstractPostgresTest): def setup_class(cls): super().setup_class() - from open_webui.apps.webui.models.models import Model + from open_webui.models.models import Model cls.models = Model diff --git a/backend/open_webui/test/apps/webui/routers/test_users.py b/backend/open_webui/test/apps/webui/routers/test_users.py index 6facf7055..1a58ab147 100644 --- a/backend/open_webui/test/apps/webui/routers/test_users.py +++ b/backend/open_webui/test/apps/webui/routers/test_users.py @@ -25,7 +25,7 @@ class TestUsers(AbstractPostgresTest): def setup_class(cls): super().setup_class() - from open_webui.apps.webui.models.users import Users + from open_webui.models.users import Users cls.users = Users diff --git a/backend/open_webui/test/util/abstract_integration_test.py b/backend/open_webui/test/util/abstract_integration_test.py index 2814731e0..e8492befb 100644 --- a/backend/open_webui/test/util/abstract_integration_test.py +++ b/backend/open_webui/test/util/abstract_integration_test.py @@ -115,7 +115,7 @@ class AbstractPostgresTest(AbstractIntegrationTest): pytest.fail(f"Could not setup test environment: {ex}") def _check_db_connection(self): - from open_webui.apps.webui.internal.db import Session + from open_webui.internal.db import Session retries = 10 while retries > 0: @@ -139,7 +139,7 @@ class AbstractPostgresTest(AbstractIntegrationTest): cls.docker_client.containers.get(cls.DOCKER_CONTAINER_NAME).remove(force=True) def teardown_method(self): - from open_webui.apps.webui.internal.db import Session + from open_webui.internal.db import Session # rollback everything not yet committed Session.commit() diff --git a/backend/open_webui/test/util/mock_user.py b/backend/open_webui/test/util/mock_user.py index 96456a2c8..7ce64dffa 100644 --- a/backend/open_webui/test/util/mock_user.py +++ b/backend/open_webui/test/util/mock_user.py @@ -5,7 +5,7 @@ from fastapi import FastAPI @contextmanager def mock_webui_user(**kwargs): - from open_webui.apps.webui.main import app + from open_webui.routers.webui import app with mock_user(app, **kwargs): yield @@ -13,13 +13,13 @@ def mock_webui_user(**kwargs): @contextmanager def mock_user(app: FastAPI, **kwargs): - from open_webui.utils.utils import ( + from open_webui.utils.auth import ( get_current_user, get_verified_user, get_admin_user, get_current_user_by_api_key, ) - from open_webui.apps.webui.models.users import User + from open_webui.models.users import User def create_user(): user_parameters = { diff --git a/backend/open_webui/utils/access_control.py b/backend/open_webui/utils/access_control.py index 270b28bcc..3b3e75a8b 100644 --- a/backend/open_webui/utils/access_control.py +++ b/backend/open_webui/utils/access_control.py @@ -1,5 +1,5 @@ from typing import Optional, Union, List, Dict, Any -from open_webui.apps.webui.models.groups import Groups +from open_webui.models.groups import Groups import json diff --git a/backend/open_webui/utils/utils.py b/backend/open_webui/utils/auth.py similarity index 98% rename from backend/open_webui/utils/utils.py rename to backend/open_webui/utils/auth.py index cde953102..e1a0ca671 100644 --- a/backend/open_webui/utils/utils.py +++ b/backend/open_webui/utils/auth.py @@ -5,7 +5,7 @@ import jwt from datetime import UTC, datetime, timedelta from typing import Optional, Union, List, Dict -from open_webui.apps.webui.models.users import Users +from open_webui.models.users import Users from open_webui.constants import ERROR_MESSAGES from open_webui.env import WEBUI_SECRET_KEY diff --git a/backend/open_webui/utils/chat.py b/backend/open_webui/utils/chat.py new file mode 100644 index 000000000..56904d1d8 --- /dev/null +++ b/backend/open_webui/utils/chat.py @@ -0,0 +1,374 @@ +import time +import logging +import sys + +from aiocache import cached +from typing import Any, Optional +import random +import json +import inspect + +from fastapi import Request +from starlette.responses import Response, StreamingResponse + + +from open_webui.models.users import UserModel + +from open_webui.socket.main import ( + get_event_call, + get_event_emitter, +) +from open_webui.functions import generate_function_chat_completion + +from open_webui.routers.openai import ( + generate_chat_completion as generate_openai_chat_completion, +) + +from open_webui.routers.ollama import ( + generate_chat_completion as generate_ollama_chat_completion, +) + +from open_webui.routers.pipelines import ( + process_pipeline_inlet_filter, + process_pipeline_outlet_filter, +) + +from open_webui.models.functions import Functions +from open_webui.models.models import Models + + +from open_webui.utils.plugin import load_function_module_by_id +from open_webui.utils.models import get_all_models, check_model_access +from open_webui.utils.payload import convert_payload_openai_to_ollama +from open_webui.utils.response import ( + convert_response_ollama_to_openai, + convert_streaming_response_ollama_to_openai, +) + +from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL, BYPASS_MODEL_ACCESS_CONTROL + + +logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL) +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["MAIN"]) + + +async def generate_chat_completion( + request: Request, + form_data: dict, + user: Any, + bypass_filter: bool = False, +): + if BYPASS_MODEL_ACCESS_CONTROL: + bypass_filter = True + + models = request.app.state.MODELS + + model_id = form_data["model"] + if model_id not in models: + raise Exception("Model not found") + + # Process the form_data through the pipeline + try: + form_data = process_pipeline_inlet_filter(request, form_data, user, models) + except Exception as e: + raise e + + model = models[model_id] + + # Check if user has access to the model + if not bypass_filter and user.role == "user": + try: + check_model_access(user, model) + except Exception as e: + raise e + + if model["owned_by"] == "arena": + model_ids = model.get("info", {}).get("meta", {}).get("model_ids") + filter_mode = model.get("info", {}).get("meta", {}).get("filter_mode") + if model_ids and filter_mode == "exclude": + model_ids = [ + model["id"] + for model in await get_all_models(request) + if model.get("owned_by") != "arena" and model["id"] not in model_ids + ] + + selected_model_id = None + if isinstance(model_ids, list) and model_ids: + selected_model_id = random.choice(model_ids) + else: + model_ids = [ + model["id"] + for model in await get_all_models(request) + if model.get("owned_by") != "arena" + ] + selected_model_id = random.choice(model_ids) + + form_data["model"] = selected_model_id + + if form_data.get("stream") == True: + + async def stream_wrapper(stream): + yield f"data: {json.dumps({'selected_model_id': selected_model_id})}\n\n" + async for chunk in stream: + yield chunk + + response = await generate_chat_completion( + form_data, user, bypass_filter=True + ) + return StreamingResponse( + stream_wrapper(response.body_iterator), media_type="text/event-stream" + ) + else: + return { + **(await generate_chat_completion(form_data, user, bypass_filter=True)), + "selected_model_id": selected_model_id, + } + + if model.get("pipe"): + # Below does not require bypass_filter because this is the only route the uses this function and it is already bypassing the filter + return await generate_function_chat_completion( + form_data, user=user, models=models + ) + if model["owned_by"] == "ollama": + # Using /ollama/api/chat endpoint + form_data = convert_payload_openai_to_ollama(form_data) + response = await generate_ollama_chat_completion( + request=request, form_data=form_data, user=user, bypass_filter=bypass_filter + ) + if form_data.get("stream"): + response.headers["content-type"] = "text/event-stream" + return StreamingResponse( + convert_streaming_response_ollama_to_openai(response), + headers=dict(response.headers), + ) + else: + return convert_response_ollama_to_openai(response) + else: + return await generate_openai_chat_completion( + request=request, form_data=form_data, user=user, bypass_filter=bypass_filter + ) + + +async def chat_completed(request: Request, form_data: dict, user: Any): + await get_all_models(request) + models = request.app.state.MODELS + + data = form_data + model_id = data["model"] + if model_id not in models: + raise Exception("Model not found") + + model = models[model_id] + + try: + data = process_pipeline_outlet_filter(request, data, user, models) + except Exception as e: + return Exception(f"Error: {e}") + + __event_emitter__ = get_event_emitter( + { + "chat_id": data["chat_id"], + "message_id": data["id"], + "session_id": data["session_id"], + } + ) + + __event_call__ = get_event_call( + { + "chat_id": data["chat_id"], + "message_id": data["id"], + "session_id": data["session_id"], + } + ) + + def get_priority(function_id): + function = Functions.get_function_by_id(function_id) + if function is not None and hasattr(function, "valves"): + # TODO: Fix FunctionModel to include vavles + return (function.valves if function.valves else {}).get("priority", 0) + return 0 + + filter_ids = [function.id for function in Functions.get_global_filter_functions()] + if "info" in model and "meta" in model["info"]: + filter_ids.extend(model["info"]["meta"].get("filterIds", [])) + filter_ids = list(set(filter_ids)) + + enabled_filter_ids = [ + function.id + for function in Functions.get_functions_by_type("filter", active_only=True) + ] + filter_ids = [ + filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids + ] + + # Sort filter_ids by priority, using the get_priority function + filter_ids.sort(key=get_priority) + + for filter_id in filter_ids: + filter = Functions.get_function_by_id(filter_id) + if not filter: + continue + + if filter_id in request.app.state.FUNCTIONS: + function_module = request.app.state.FUNCTIONS[filter_id] + else: + function_module, _, _ = load_function_module_by_id(filter_id) + request.app.state.FUNCTIONS[filter_id] = function_module + + if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): + valves = Functions.get_function_valves_by_id(filter_id) + function_module.valves = function_module.Valves( + **(valves if valves else {}) + ) + + if not hasattr(function_module, "outlet"): + continue + try: + outlet = function_module.outlet + + # Get the signature of the function + sig = inspect.signature(outlet) + params = {"body": data} + + # Extra parameters to be passed to the function + extra_params = { + "__model__": model, + "__id__": filter_id, + "__event_emitter__": __event_emitter__, + "__event_call__": __event_call__, + "__request__": request, + } + + # Add extra params in contained in function signature + for key, value in extra_params.items(): + if key in sig.parameters: + params[key] = value + + if "__user__" in sig.parameters: + __user__ = { + "id": user.id, + "email": user.email, + "name": user.name, + "role": user.role, + } + + try: + if hasattr(function_module, "UserValves"): + __user__["valves"] = function_module.UserValves( + **Functions.get_user_valves_by_id_and_user_id( + filter_id, user.id + ) + ) + except Exception as e: + print(e) + + params = {**params, "__user__": __user__} + + if inspect.iscoroutinefunction(outlet): + data = await outlet(**params) + else: + data = outlet(**params) + + except Exception as e: + return Exception(f"Error: {e}") + + return data + + +async def chat_action(request: Request, action_id: str, form_data: dict, user: Any): + if "." in action_id: + action_id, sub_action_id = action_id.split(".") + else: + sub_action_id = None + + action = Functions.get_function_by_id(action_id) + if not action: + raise Exception(f"Action not found: {action_id}") + + await get_all_models(request) + models = request.app.state.MODELS + + data = form_data + model_id = data["model"] + + if model_id not in models: + raise Exception("Model not found") + model = models[model_id] + + __event_emitter__ = get_event_emitter( + { + "chat_id": data["chat_id"], + "message_id": data["id"], + "session_id": data["session_id"], + } + ) + __event_call__ = get_event_call( + { + "chat_id": data["chat_id"], + "message_id": data["id"], + "session_id": data["session_id"], + } + ) + + if action_id in request.app.state.FUNCTIONS: + function_module = request.app.state.FUNCTIONS[action_id] + else: + function_module, _, _ = load_function_module_by_id(action_id) + request.app.state.FUNCTIONS[action_id] = function_module + + if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): + valves = Functions.get_function_valves_by_id(action_id) + function_module.valves = function_module.Valves(**(valves if valves else {})) + + if hasattr(function_module, "action"): + try: + action = function_module.action + + # Get the signature of the function + sig = inspect.signature(action) + params = {"body": data} + + # Extra parameters to be passed to the function + extra_params = { + "__model__": model, + "__id__": sub_action_id if sub_action_id is not None else action_id, + "__event_emitter__": __event_emitter__, + "__event_call__": __event_call__, + "__request__": request, + } + + # Add extra params in contained in function signature + for key, value in extra_params.items(): + if key in sig.parameters: + params[key] = value + + if "__user__" in sig.parameters: + __user__ = { + "id": user.id, + "email": user.email, + "name": user.name, + "role": user.role, + } + + try: + if hasattr(function_module, "UserValves"): + __user__["valves"] = function_module.UserValves( + **Functions.get_user_valves_by_id_and_user_id( + action_id, user.id + ) + ) + except Exception as e: + print(e) + + params = {**params, "__user__": __user__} + + if inspect.iscoroutinefunction(action): + data = await action(**params) + else: + data = action(**params) + + except Exception as e: + return Exception(f"Error: {e}") + + return data diff --git a/backend/open_webui/apps/images/utils/comfyui.py b/backend/open_webui/utils/images/comfyui.py similarity index 100% rename from backend/open_webui/apps/images/utils/comfyui.py rename to backend/open_webui/utils/images/comfyui.py diff --git a/backend/open_webui/utils/logo.png b/backend/open_webui/utils/logo.png deleted file mode 100644 index 519af1db6..000000000 Binary files a/backend/open_webui/utils/logo.png and /dev/null differ diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py new file mode 100644 index 000000000..1d2bc2b99 --- /dev/null +++ b/backend/open_webui/utils/middleware.py @@ -0,0 +1,508 @@ +import time +import logging +import sys + +from aiocache import cached +from typing import Any, Optional +import random +import json +import inspect + +from fastapi import Request +from starlette.responses import Response, StreamingResponse + + +from open_webui.socket.main import ( + get_event_call, + get_event_emitter, +) +from open_webui.routers.tasks import generate_queries + + +from open_webui.models.users import UserModel +from open_webui.models.functions import Functions +from open_webui.models.models import Models + +from open_webui.retrieval.utils import get_sources_from_files + + +from open_webui.utils.chat import generate_chat_completion +from open_webui.utils.task import ( + get_task_model_id, + rag_template, + tools_function_calling_generation_template, +) +from open_webui.utils.misc import ( + add_or_update_system_message, + get_last_user_message, + prepend_to_first_user_message_content, +) +from open_webui.utils.tools import get_tools +from open_webui.utils.plugin import load_function_module_by_id + + +from open_webui.config import DEFAULT_TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE +from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL, BYPASS_MODEL_ACCESS_CONTROL +from open_webui.constants import TASKS + + +logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL) +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["MAIN"]) + + +async def chat_completion_filter_functions_handler(request, body, model, extra_params): + skip_files = None + + def get_filter_function_ids(model): + def get_priority(function_id): + function = Functions.get_function_by_id(function_id) + if function is not None and hasattr(function, "valves"): + # TODO: Fix FunctionModel + return (function.valves if function.valves else {}).get("priority", 0) + return 0 + + filter_ids = [ + function.id for function in Functions.get_global_filter_functions() + ] + if "info" in model and "meta" in model["info"]: + filter_ids.extend(model["info"]["meta"].get("filterIds", [])) + filter_ids = list(set(filter_ids)) + + enabled_filter_ids = [ + function.id + for function in Functions.get_functions_by_type("filter", active_only=True) + ] + + filter_ids = [ + filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids + ] + + filter_ids.sort(key=get_priority) + return filter_ids + + filter_ids = get_filter_function_ids(model) + for filter_id in filter_ids: + filter = Functions.get_function_by_id(filter_id) + if not filter: + continue + + if filter_id in request.app.state.FUNCTIONS: + function_module = request.app.state.FUNCTIONS[filter_id] + else: + function_module, _, _ = load_function_module_by_id(filter_id) + request.app.state.FUNCTIONS[filter_id] = function_module + + # Check if the function has a file_handler variable + if hasattr(function_module, "file_handler"): + skip_files = function_module.file_handler + + # Apply valves to the function + if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): + valves = Functions.get_function_valves_by_id(filter_id) + function_module.valves = function_module.Valves( + **(valves if valves else {}) + ) + + if hasattr(function_module, "inlet"): + try: + inlet = function_module.inlet + + # Create a dictionary of parameters to be passed to the function + params = {"body": body} | { + k: v + for k, v in { + **extra_params, + "__model__": model, + "__id__": filter_id, + }.items() + if k in inspect.signature(inlet).parameters + } + + if "__user__" in params and hasattr(function_module, "UserValves"): + try: + params["__user__"]["valves"] = function_module.UserValves( + **Functions.get_user_valves_by_id_and_user_id( + filter_id, params["__user__"]["id"] + ) + ) + except Exception as e: + print(e) + + if inspect.iscoroutinefunction(inlet): + body = await inlet(**params) + else: + body = inlet(**params) + + except Exception as e: + print(f"Error: {e}") + raise e + + if skip_files and "files" in body.get("metadata", {}): + del body["metadata"]["files"] + + return body, {} + + +async def chat_completion_tools_handler( + request: Request, body: dict, user: UserModel, models, extra_params: dict +) -> tuple[dict, dict]: + async def get_content_from_response(response) -> Optional[str]: + content = None + if hasattr(response, "body_iterator"): + async for chunk in response.body_iterator: + data = json.loads(chunk.decode("utf-8")) + content = data["choices"][0]["message"]["content"] + + # Cleanup any remaining background tasks if necessary + if response.background is not None: + await response.background() + else: + content = response["choices"][0]["message"]["content"] + return content + + def get_tools_function_calling_payload(messages, task_model_id, content): + user_message = get_last_user_message(messages) + history = "\n".join( + f"{message['role'].upper()}: \"\"\"{message['content']}\"\"\"" + for message in messages[::-1][:4] + ) + + prompt = f"History:\n{history}\nQuery: {user_message}" + + return { + "model": task_model_id, + "messages": [ + {"role": "system", "content": content}, + {"role": "user", "content": f"Query: {prompt}"}, + ], + "stream": False, + "metadata": {"task": str(TASKS.FUNCTION_CALLING)}, + } + + # If tool_ids field is present, call the functions + metadata = body.get("metadata", {}) + + tool_ids = metadata.get("tool_ids", None) + log.debug(f"{tool_ids=}") + if not tool_ids: + return body, {} + + skip_files = False + sources = [] + + task_model_id = get_task_model_id( + body["model"], + request.app.state.config.TASK_MODEL, + request.app.state.config.TASK_MODEL_EXTERNAL, + models, + ) + tools = get_tools( + request, + tool_ids, + user, + { + **extra_params, + "__model__": models[task_model_id], + "__messages__": body["messages"], + "__files__": metadata.get("files", []), + }, + ) + log.info(f"{tools=}") + + specs = [tool["spec"] for tool in tools.values()] + tools_specs = json.dumps(specs) + + if request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE != "": + template = request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE + else: + template = DEFAULT_TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE + + tools_function_calling_prompt = tools_function_calling_generation_template( + template, tools_specs + ) + log.info(f"{tools_function_calling_prompt=}") + payload = get_tools_function_calling_payload( + body["messages"], task_model_id, tools_function_calling_prompt + ) + + try: + response = await generate_chat_completion(request, form_data=payload, user=user) + log.debug(f"{response=}") + content = await get_content_from_response(response) + log.debug(f"{content=}") + + if not content: + return body, {} + + try: + content = content[content.find("{") : content.rfind("}") + 1] + if not content: + raise Exception("No JSON object found in the response") + + result = json.loads(content) + + tool_function_name = result.get("name", None) + if tool_function_name not in tools: + return body, {} + + tool_function_params = result.get("parameters", {}) + + try: + required_params = ( + tools[tool_function_name] + .get("spec", {}) + .get("parameters", {}) + .get("required", []) + ) + tool_function = tools[tool_function_name]["callable"] + tool_function_params = { + k: v + for k, v in tool_function_params.items() + if k in required_params + } + tool_output = await tool_function(**tool_function_params) + + except Exception as e: + tool_output = str(e) + + if isinstance(tool_output, str): + if tools[tool_function_name]["citation"]: + sources.append( + { + "source": { + "name": f"TOOL:{tools[tool_function_name]['toolkit_id']}/{tool_function_name}" + }, + "document": [tool_output], + "metadata": [ + { + "source": f"TOOL:{tools[tool_function_name]['toolkit_id']}/{tool_function_name}" + } + ], + } + ) + else: + sources.append( + { + "source": {}, + "document": [tool_output], + "metadata": [ + { + "source": f"TOOL:{tools[tool_function_name]['toolkit_id']}/{tool_function_name}" + } + ], + } + ) + + if tools[tool_function_name]["file_handler"]: + skip_files = True + + except Exception as e: + log.exception(f"Error: {e}") + content = None + except Exception as e: + log.exception(f"Error: {e}") + content = None + + log.debug(f"tool_contexts: {sources}") + + if skip_files and "files" in body.get("metadata", {}): + del body["metadata"]["files"] + + return body, {"sources": sources} + + +async def chat_completion_files_handler( + request: Request, body: dict, user: UserModel +) -> tuple[dict, dict[str, list]]: + sources = [] + + if files := body.get("metadata", {}).get("files", None): + try: + queries_response = await generate_queries( + { + "model": body["model"], + "messages": body["messages"], + "type": "retrieval", + }, + user, + ) + queries_response = queries_response["choices"][0]["message"]["content"] + + try: + bracket_start = queries_response.find("{") + bracket_end = queries_response.rfind("}") + 1 + + if bracket_start == -1 or bracket_end == -1: + raise Exception("No JSON object found in the response") + + queries_response = queries_response[bracket_start:bracket_end] + queries_response = json.loads(queries_response) + except Exception as e: + queries_response = {"queries": [queries_response]} + + queries = queries_response.get("queries", []) + except Exception as e: + queries = [] + + if len(queries) == 0: + queries = [get_last_user_message(body["messages"])] + + sources = get_sources_from_files( + files=files, + queries=queries, + embedding_function=request.app.state.EMBEDDING_FUNCTION, + k=request.app.state.config.TOP_K, + reranking_function=request.app.state.rf, + r=request.app.state.config.RELEVANCE_THRESHOLD, + hybrid_search=request.app.state.config.ENABLE_RAG_HYBRID_SEARCH, + ) + + log.debug(f"rag_contexts:sources: {sources}") + return body, {"sources": sources} + + +async def process_chat_payload(request, form_data, user, model): + metadata = { + "chat_id": form_data.pop("chat_id", None), + "message_id": form_data.pop("id", None), + "session_id": form_data.pop("session_id", None), + "tool_ids": form_data.get("tool_ids", None), + "files": form_data.get("files", None), + } + form_data["metadata"] = metadata + + extra_params = { + "__event_emitter__": get_event_emitter(metadata), + "__event_call__": get_event_call(metadata), + "__user__": { + "id": user.id, + "email": user.email, + "name": user.name, + "role": user.role, + }, + "__metadata__": metadata, + "__request__": request, + } + + # Initialize events to store additional event to be sent to the client + # Initialize contexts and citation + models = request.app.state.MODELS + events = [] + sources = [] + + try: + form_data, flags = await chat_completion_filter_functions_handler( + request, form_data, model, extra_params + ) + except Exception as e: + return Exception(f"Error: {e}") + + tool_ids = form_data.pop("tool_ids", None) + files = form_data.pop("files", None) + + metadata = { + **metadata, + "tool_ids": tool_ids, + "files": files, + } + form_data["metadata"] = metadata + + try: + form_data, flags = await chat_completion_tools_handler( + request, form_data, user, models, extra_params + ) + sources.extend(flags.get("sources", [])) + except Exception as e: + log.exception(e) + + try: + form_data, flags = await chat_completion_files_handler(request, form_data, user) + sources.extend(flags.get("sources", [])) + except Exception as e: + log.exception(e) + + # If context is not empty, insert it into the messages + if len(sources) > 0: + context_string = "" + for source_idx, source in enumerate(sources): + source_id = source.get("source", {}).get("name", "") + + if "document" in source: + for doc_idx, doc_context in enumerate(source["document"]): + metadata = source.get("metadata") + doc_source_id = None + + if metadata: + doc_source_id = metadata[doc_idx].get("source", source_id) + + if source_id: + context_string += f"{doc_source_id if doc_source_id is not None else source_id}{doc_context}\n" + else: + # If there is no source_id, then do not include the source_id tag + context_string += f"{doc_context}\n" + + context_string = context_string.strip() + prompt = get_last_user_message(form_data["messages"]) + + if prompt is None: + raise Exception("No user message found") + if ( + request.app.state.config.RELEVANCE_THRESHOLD == 0 + and context_string.strip() == "" + ): + log.debug( + f"With a 0 relevancy threshold for RAG, the context cannot be empty" + ) + + # Workaround for Ollama 2.0+ system prompt issue + # TODO: replace with add_or_update_system_message + if model["owned_by"] == "ollama": + form_data["messages"] = prepend_to_first_user_message_content( + rag_template( + request.app.state.config.RAG_TEMPLATE, context_string, prompt + ), + form_data["messages"], + ) + else: + form_data["messages"] = add_or_update_system_message( + rag_template( + request.app.state.config.RAG_TEMPLATE, context_string, prompt + ), + form_data["messages"], + ) + + # If there are citations, add them to the data_items + sources = [source for source in sources if source.get("source", {}).get("name", "")] + + if len(sources) > 0: + events.append({"sources": sources}) + + return form_data, events + + +async def process_chat_response(response, events): + if not isinstance(response, StreamingResponse): + return response + + content_type = response.headers["Content-Type"] + is_openai = "text/event-stream" in content_type + is_ollama = "application/x-ndjson" in content_type + + if not is_openai and not is_ollama: + return response + + async def stream_wrapper(original_generator, events): + def wrap_item(item): + return f"data: {item}\n\n" if is_openai else f"{item}\n" + + for event in events: + yield wrap_item(json.dumps(event)) + + async for data in original_generator: + yield data + + return StreamingResponse( + stream_wrapper(response.body_iterator, events), + headers=dict(response.headers), + ) diff --git a/backend/open_webui/utils/misc.py b/backend/open_webui/utils/misc.py index a5af492ba..aba696f60 100644 --- a/backend/open_webui/utils/misc.py +++ b/backend/open_webui/utils/misc.py @@ -106,7 +106,7 @@ def openai_chat_message_template(model: str): def openai_chat_chunk_message_template( - model: str, message: Optional[str] = None + model: str, message: Optional[str] = None, usage: Optional[dict] = None ) -> dict: template = openai_chat_message_template(model) template["object"] = "chat.completion.chunk" @@ -114,17 +114,23 @@ def openai_chat_chunk_message_template( template["choices"][0]["delta"] = {"content": message} else: template["choices"][0]["finish_reason"] = "stop" + + if usage: + template["usage"] = usage return template def openai_chat_completion_message_template( - model: str, message: Optional[str] = None + model: str, message: Optional[str] = None, usage: Optional[dict] = None ) -> dict: template = openai_chat_message_template(model) template["object"] = "chat.completion" if message is not None: template["choices"][0]["message"] = {"content": message, "role": "assistant"} template["choices"][0]["finish_reason"] = "stop" + + if usage: + template["usage"] = usage return template diff --git a/backend/open_webui/utils/models.py b/backend/open_webui/utils/models.py new file mode 100644 index 000000000..b9a4f07a3 --- /dev/null +++ b/backend/open_webui/utils/models.py @@ -0,0 +1,246 @@ +import time +import logging +import sys + +from aiocache import cached +from fastapi import Request + +from open_webui.routers import openai, ollama +from open_webui.functions import get_function_models + + +from open_webui.models.functions import Functions +from open_webui.models.models import Models + + +from open_webui.utils.plugin import load_function_module_by_id +from open_webui.utils.access_control import has_access + + +from open_webui.config import ( + DEFAULT_ARENA_MODEL, +) + +from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL + + +logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL) +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["MAIN"]) + + +async def get_all_base_models(request: Request): + function_models = [] + openai_models = [] + ollama_models = [] + + if request.app.state.config.ENABLE_OPENAI_API: + openai_models = await openai.get_all_models(request) + openai_models = openai_models["data"] + + if request.app.state.config.ENABLE_OLLAMA_API: + ollama_models = await ollama.get_all_models(request) + ollama_models = [ + { + "id": model["model"], + "name": model["name"], + "object": "model", + "created": int(time.time()), + "owned_by": "ollama", + "ollama": model, + } + for model in ollama_models["models"] + ] + + function_models = await get_function_models(request) + models = function_models + openai_models + ollama_models + + return models + + +@cached(ttl=3) +async def get_all_models(request): + models = await get_all_base_models(request) + + # If there are no models, return an empty list + if len(models) == 0: + return [] + + # Add arena models + if request.app.state.config.ENABLE_EVALUATION_ARENA_MODELS: + arena_models = [] + if len(request.app.state.config.EVALUATION_ARENA_MODELS) > 0: + arena_models = [ + { + "id": model["id"], + "name": model["name"], + "info": { + "meta": model["meta"], + }, + "object": "model", + "created": int(time.time()), + "owned_by": "arena", + "arena": True, + } + for model in request.app.state.config.EVALUATION_ARENA_MODELS + ] + else: + # Add default arena model + arena_models = [ + { + "id": DEFAULT_ARENA_MODEL["id"], + "name": DEFAULT_ARENA_MODEL["name"], + "info": { + "meta": DEFAULT_ARENA_MODEL["meta"], + }, + "object": "model", + "created": int(time.time()), + "owned_by": "arena", + "arena": True, + } + ] + models = models + arena_models + + global_action_ids = [ + function.id for function in Functions.get_global_action_functions() + ] + enabled_action_ids = [ + function.id + for function in Functions.get_functions_by_type("action", active_only=True) + ] + + custom_models = Models.get_all_models() + for custom_model in custom_models: + if custom_model.base_model_id is None: + for model in models: + if ( + custom_model.id == model["id"] + or custom_model.id == model["id"].split(":")[0] + ): + if custom_model.is_active: + model["name"] = custom_model.name + model["info"] = custom_model.model_dump() + + action_ids = [] + if "info" in model and "meta" in model["info"]: + action_ids.extend( + model["info"]["meta"].get("actionIds", []) + ) + + model["action_ids"] = action_ids + else: + models.remove(model) + + elif custom_model.is_active and ( + custom_model.id not in [model["id"] for model in models] + ): + owned_by = "openai" + pipe = None + action_ids = [] + + for model in models: + if ( + custom_model.base_model_id == model["id"] + or custom_model.base_model_id == model["id"].split(":")[0] + ): + owned_by = model["owned_by"] + if "pipe" in model: + pipe = model["pipe"] + break + + if custom_model.meta: + meta = custom_model.meta.model_dump() + if "actionIds" in meta: + action_ids.extend(meta["actionIds"]) + + models.append( + { + "id": f"{custom_model.id}", + "name": custom_model.name, + "object": "model", + "created": custom_model.created_at, + "owned_by": owned_by, + "info": custom_model.model_dump(), + "preset": True, + **({"pipe": pipe} if pipe is not None else {}), + "action_ids": action_ids, + } + ) + + # Process action_ids to get the actions + def get_action_items_from_module(function, module): + actions = [] + if hasattr(module, "actions"): + actions = module.actions + return [ + { + "id": f"{function.id}.{action['id']}", + "name": action.get("name", f"{function.name} ({action['id']})"), + "description": function.meta.description, + "icon_url": action.get( + "icon_url", function.meta.manifest.get("icon_url", None) + ), + } + for action in actions + ] + else: + return [ + { + "id": function.id, + "name": function.name, + "description": function.meta.description, + "icon_url": function.meta.manifest.get("icon_url", None), + } + ] + + def get_function_module_by_id(function_id): + if function_id in request.app.state.FUNCTIONS: + function_module = request.app.state.FUNCTIONS[function_id] + else: + function_module, _, _ = load_function_module_by_id(function_id) + request.app.state.FUNCTIONS[function_id] = function_module + + for model in models: + action_ids = [ + action_id + for action_id in list(set(model.pop("action_ids", []) + global_action_ids)) + if action_id in enabled_action_ids + ] + + model["actions"] = [] + for action_id in action_ids: + action_function = Functions.get_function_by_id(action_id) + if action_function is None: + raise Exception(f"Action not found: {action_id}") + + function_module = get_function_module_by_id(action_id) + model["actions"].extend( + get_action_items_from_module(action_function, function_module) + ) + log.debug(f"get_all_models() returned {len(models)} models") + + request.app.state.MODELS = {model["id"]: model for model in models} + return models + + +def check_model_access(user, model): + if model.get("arena"): + if not has_access( + user.id, + type="read", + access_control=model.get("info", {}) + .get("meta", {}) + .get("access_control", {}), + ): + raise Exception("Model not found") + else: + model_info = Models.get_model_by_id(model.get("id")) + if not model_info: + raise Exception("Model not found") + elif not ( + user.id == model_info.user_id + or has_access( + user.id, type="read", access_control=model_info.access_control + ) + ): + raise Exception("Model not found") diff --git a/backend/open_webui/utils/oauth.py b/backend/open_webui/utils/oauth.py index 722b1ea73..f0ab7a345 100644 --- a/backend/open_webui/utils/oauth.py +++ b/backend/open_webui/utils/oauth.py @@ -12,8 +12,8 @@ from fastapi import ( ) from starlette.responses import RedirectResponse -from open_webui.apps.webui.models.auths import Auths -from open_webui.apps.webui.models.users import Users +from open_webui.models.auths import Auths +from open_webui.models.users import Users from open_webui.config import ( DEFAULT_USER_ROLE, ENABLE_OAUTH_SIGNUP, @@ -26,6 +26,7 @@ from open_webui.config import ( OAUTH_USERNAME_CLAIM, OAUTH_ALLOWED_ROLES, OAUTH_ADMIN_ROLES, + OAUTH_ALLOWED_DOMAINS, WEBHOOK_URL, JWT_EXPIRES_IN, AppConfig, @@ -33,7 +34,7 @@ from open_webui.config import ( from open_webui.constants import ERROR_MESSAGES from open_webui.env import WEBUI_SESSION_COOKIE_SAME_SITE, WEBUI_SESSION_COOKIE_SECURE from open_webui.utils.misc import parse_duration -from open_webui.utils.utils import get_password_hash, create_token +from open_webui.utils.auth import get_password_hash, create_token from open_webui.utils.webhook import post_webhook log = logging.getLogger(__name__) @@ -49,6 +50,7 @@ auth_manager_config.OAUTH_PICTURE_CLAIM = OAUTH_PICTURE_CLAIM auth_manager_config.OAUTH_USERNAME_CLAIM = OAUTH_USERNAME_CLAIM auth_manager_config.OAUTH_ALLOWED_ROLES = OAUTH_ALLOWED_ROLES auth_manager_config.OAUTH_ADMIN_ROLES = OAUTH_ADMIN_ROLES +auth_manager_config.OAUTH_ALLOWED_DOMAINS = OAUTH_ALLOWED_DOMAINS auth_manager_config.WEBHOOK_URL = WEBHOOK_URL auth_manager_config.JWT_EXPIRES_IN = JWT_EXPIRES_IN @@ -156,6 +158,14 @@ class OAuthManager: if not email: log.warning(f"OAuth callback failed, email is missing: {user_data}") raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) + if ( + "*" not in auth_manager_config.OAUTH_ALLOWED_DOMAINS + and email.split("@")[-1] not in auth_manager_config.OAUTH_ALLOWED_DOMAINS + ): + log.warning( + f"OAuth callback failed, e-mail domain is not in the list of allowed domains: {user_data}" + ) + raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) # Check if the user exists user = Users.get_user_by_oauth_sub(provider_sub) @@ -253,9 +263,18 @@ class OAuthManager: secure=WEBUI_SESSION_COOKIE_SECURE, ) + if ENABLE_OAUTH_SIGNUP.value: + oauth_id_token = token.get("id_token") + response.set_cookie( + key="oauth_id_token", + value=oauth_id_token, + httponly=True, + samesite=WEBUI_SESSION_COOKIE_SAME_SITE, + secure=WEBUI_SESSION_COOKIE_SECURE, + ) # Redirect back to the frontend with the JWT token redirect_url = f"{request.base_url}auth#token={jwt_token}" - return RedirectResponse(url=redirect_url) + return RedirectResponse(url=redirect_url, headers=response.headers) oauth_manager = OAuthManager() diff --git a/backend/open_webui/utils/pdf_generator.py b/backend/open_webui/utils/pdf_generator.py index fb6cd57d5..bbaf42dbb 100644 --- a/backend/open_webui/utils/pdf_generator.py +++ b/backend/open_webui/utils/pdf_generator.py @@ -9,7 +9,7 @@ import site from fpdf import FPDF from open_webui.env import STATIC_DIR, FONTS_DIR -from open_webui.apps.webui.models.chats import ChatTitleMessagesForm +from open_webui.models.chats import ChatTitleMessagesForm class PDFGenerator: @@ -51,21 +51,25 @@ class PDFGenerator: # extends pymdownx extension to convert markdown to html. # - https://facelessuser.github.io/pymdown-extensions/usage_notes/ - html_content = markdown(content, extensions=["pymdownx.extra"]) + # html_content = markdown(content, extensions=["pymdownx.extra"]) html_message = f""" -
{date_str}
-
+
-

+

{role.title()} - {model} -

+ {model} + +
{date_str}
-
+                
+
+ +
{content} -
+
+
""" return html_message @@ -74,18 +78,15 @@ class PDFGenerator: return f""" - - + -
-
-

{self.form_data.title}

-
-
- {self.messages_html} -
+
+
+

{self.form_data.title}

+ {self.messages_html}
+
""" @@ -114,9 +115,12 @@ class PDFGenerator: pdf.add_font("NotoSansKR", "", f"{FONTS_DIR}/NotoSansKR-Regular.ttf") pdf.add_font("NotoSansJP", "", f"{FONTS_DIR}/NotoSansJP-Regular.ttf") pdf.add_font("NotoSansSC", "", f"{FONTS_DIR}/NotoSansSC-Regular.ttf") + pdf.add_font("Twemoji", "", f"{FONTS_DIR}/Twemoji.ttf") pdf.set_font("NotoSans", size=12) - pdf.set_fallback_fonts(["NotoSansKR", "NotoSansJP", "NotoSansSC"]) + pdf.set_fallback_fonts( + ["NotoSansKR", "NotoSansJP", "NotoSansSC", "Twemoji"] + ) pdf.set_auto_page_break(auto=True, margin=15) diff --git a/backend/open_webui/apps/webui/utils.py b/backend/open_webui/utils/plugin.py similarity index 98% rename from backend/open_webui/apps/webui/utils.py rename to backend/open_webui/utils/plugin.py index 054158b3e..17b86cea1 100644 --- a/backend/open_webui/apps/webui/utils.py +++ b/backend/open_webui/utils/plugin.py @@ -8,8 +8,8 @@ import tempfile import logging from open_webui.env import SRC_LOG_LEVELS -from open_webui.apps.webui.models.functions import Functions -from open_webui.apps.webui.models.tools import Tools +from open_webui.models.functions import Functions +from open_webui.models.tools import Tools log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MAIN"]) diff --git a/backend/open_webui/utils/response.py b/backend/open_webui/utils/response.py index b8501e92c..891016e43 100644 --- a/backend/open_webui/utils/response.py +++ b/backend/open_webui/utils/response.py @@ -21,8 +21,63 @@ async def convert_streaming_response_ollama_to_openai(ollama_streaming_response) message_content = data.get("message", {}).get("content", "") done = data.get("done", False) + usage = None + if done: + usage = { + "response_token/s": ( + round( + ( + ( + data.get("eval_count", 0) + / ((data.get("eval_duration", 0) / 1_000_000_000)) + ) + * 100 + ), + 2, + ) + if data.get("eval_duration", 0) > 0 + else "N/A" + ), + "prompt_token/s": ( + round( + ( + ( + data.get("prompt_eval_count", 0) + / ( + ( + data.get("prompt_eval_duration", 0) + / 1_000_000_000 + ) + ) + ) + * 100 + ), + 2, + ) + if data.get("prompt_eval_duration", 0) > 0 + else "N/A" + ), + "total_duration": round( + ((data.get("total_duration", 0) / 1_000_000) * 100), 2 + ), + "load_duration": round( + ((data.get("load_duration", 0) / 1_000_000) * 100), 2 + ), + "prompt_eval_count": data.get("prompt_eval_count", 0), + "prompt_eval_duration": round( + ((data.get("prompt_eval_duration", 0) / 1_000_000) * 100), 2 + ), + "eval_count": data.get("eval_count", 0), + "eval_duration": round( + ((data.get("eval_duration", 0) / 1_000_000) * 100), 2 + ), + "approximate_total": ( + lambda s: f"{s // 3600}h{(s % 3600) // 60}m{s % 60}s" + )((data.get("total_duration", 0) or 0) // 1_000_000_000), + } + data = openai_chat_chunk_message_template( - model, message_content if not done else None + model, message_content if not done else None, usage ) line = f"data: {json.dumps(data)}\n\n" diff --git a/backend/open_webui/utils/task.py b/backend/open_webui/utils/task.py index 604161a31..ebb7483ba 100644 --- a/backend/open_webui/utils/task.py +++ b/backend/open_webui/utils/task.py @@ -16,6 +16,22 @@ log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["RAG"]) +def get_task_model_id( + default_model_id: str, task_model: str, task_model_external: str, models +) -> str: + # Set the task model + task_model_id = default_model_id + # Check if the user has a custom task model and use that model + if models[task_model_id]["owned_by"] == "ollama": + if task_model and task_model in models: + task_model_id = task_model + else: + if task_model_external and task_model_external in models: + task_model_id = task_model_external + + return task_model_id + + def prompt_template( template: str, user_name: Optional[str] = None, user_location: Optional[str] = None ) -> str: diff --git a/backend/open_webui/utils/tools.py b/backend/open_webui/utils/tools.py index 60a9f942f..b6e13011d 100644 --- a/backend/open_webui/utils/tools.py +++ b/backend/open_webui/utils/tools.py @@ -4,11 +4,15 @@ import re from typing import Any, Awaitable, Callable, get_type_hints from functools import update_wrapper, partial -from langchain_core.utils.function_calling import convert_to_openai_function -from open_webui.apps.webui.models.tools import Tools -from open_webui.apps.webui.models.users import UserModel -from open_webui.apps.webui.utils import load_tools_module_by_id + +from fastapi import Request from pydantic import BaseModel, Field, create_model +from langchain_core.utils.function_calling import convert_to_openai_function + + +from open_webui.models.tools import Tools +from open_webui.models.users import UserModel +from open_webui.utils.plugin import load_tools_module_by_id log = logging.getLogger(__name__) @@ -32,7 +36,7 @@ def apply_extra_params_to_tool_function( # Mutation on extra_params def get_tools( - webui_app, tool_ids: list[str], user: UserModel, extra_params: dict + request: Request, tool_ids: list[str], user: UserModel, extra_params: dict ) -> dict[str, dict]: tools_dict = {} @@ -41,10 +45,10 @@ def get_tools( if tools is None: continue - module = webui_app.state.TOOLS.get(tool_id, None) + module = request.app.state.TOOLS.get(tool_id, None) if module is None: module, _ = load_tools_module_by_id(tool_id) - webui_app.state.TOOLS[tool_id] = module + request.app.state.TOOLS[tool_id] = module extra_params["__id__"] = tool_id if hasattr(module, "valves") and hasattr(module, "Valves"): diff --git a/package-lock.json b/package-lock.json index 020cd0f53..16542ed99 100644 --- a/package-lock.json +++ b/package-lock.json @@ -27,6 +27,7 @@ "async": "^3.2.5", "bits-ui": "^0.19.7", "codemirror": "^6.0.1", + "codemirror-lang-hcl": "^0.0.0-beta.2", "crc-32": "^1.2.2", "dayjs": "^1.11.10", "dompurify": "^3.1.6", @@ -4267,6 +4268,17 @@ "@codemirror/view": "^6.0.0" } }, + "node_modules/codemirror-lang-hcl": { + "version": "0.0.0-beta.2", + "resolved": "https://registry.npmjs.org/codemirror-lang-hcl/-/codemirror-lang-hcl-0.0.0-beta.2.tgz", + "integrity": "sha512-R3ew7Z2EYTdHTMXsWKBW9zxnLoLPYO+CrAa3dPZjXLrIR96Q3GR4cwJKF7zkSsujsnWgwRQZonyWpXYXfhQYuQ==", + "license": "MIT", + "dependencies": { + "@codemirror/language": "^6.0.0", + "@lezer/highlight": "^1.0.0", + "@lezer/lr": "^1.0.0" + } + }, "node_modules/coincident": { "version": "1.2.3", "resolved": "https://registry.npmjs.org/coincident/-/coincident-1.2.3.tgz", diff --git a/package.json b/package.json index c131e1f91..3b5911791 100644 --- a/package.json +++ b/package.json @@ -50,6 +50,7 @@ "type": "module", "dependencies": { "@codemirror/lang-javascript": "^6.2.2", + "codemirror-lang-hcl": "^0.0.0-beta.2", "@codemirror/lang-python": "^6.1.6", "@codemirror/language-data": "^6.5.1", "@codemirror/theme-one-dark": "^6.1.2", diff --git a/pyproject.toml b/pyproject.toml index 0554baa9e..de14a9fa1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -105,13 +105,14 @@ dependencies = [ "ldap3==2.9.1" ] readme = "README.md" -requires-python = ">= 3.11, < 3.12.0a1" +requires-python = ">= 3.11, < 3.13.0a1" dynamic = ["version"] classifiers = [ "Development Status :: 4 - Beta", "License :: OSI Approved :: MIT License", "Programming Language :: Python :: 3", "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", "Topic :: Communications :: Chat", "Topic :: Multimedia", ] diff --git a/src/lib/apis/index.ts b/src/lib/apis/index.ts index e76aa3c99..d06fbf3d7 100644 --- a/src/lib/apis/index.ts +++ b/src/lib/apis/index.ts @@ -110,7 +110,7 @@ export const chatAction = async (token: string, action_id: string, body: ChatAct export const getTaskConfig = async (token: string = '') => { let error = null; - const res = await fetch(`${WEBUI_BASE_URL}/api/task/config`, { + const res = await fetch(`${WEBUI_BASE_URL}/api/v1/tasks/config`, { method: 'GET', headers: { Accept: 'application/json', @@ -138,7 +138,7 @@ export const getTaskConfig = async (token: string = '') => { export const updateTaskConfig = async (token: string, config: object) => { let error = null; - const res = await fetch(`${WEBUI_BASE_URL}/api/task/config/update`, { + const res = await fetch(`${WEBUI_BASE_URL}/api/v1/tasks/config/update`, { method: 'POST', headers: { Accept: 'application/json', @@ -176,7 +176,7 @@ export const generateTitle = async ( ) => { let error = null; - const res = await fetch(`${WEBUI_BASE_URL}/api/task/title/completions`, { + const res = await fetch(`${WEBUI_BASE_URL}/api/v1/tasks/title/completions`, { method: 'POST', headers: { Accept: 'application/json', @@ -216,7 +216,7 @@ export const generateTags = async ( ) => { let error = null; - const res = await fetch(`${WEBUI_BASE_URL}/api/task/tags/completions`, { + const res = await fetch(`${WEBUI_BASE_URL}/api/v1/tasks/tags/completions`, { method: 'POST', headers: { Accept: 'application/json', @@ -288,7 +288,7 @@ export const generateEmoji = async ( ) => { let error = null; - const res = await fetch(`${WEBUI_BASE_URL}/api/task/emoji/completions`, { + const res = await fetch(`${WEBUI_BASE_URL}/api/v1/tasks/emoji/completions`, { method: 'POST', headers: { Accept: 'application/json', @@ -337,7 +337,7 @@ export const generateQueries = async ( ) => { let error = null; - const res = await fetch(`${WEBUI_BASE_URL}/api/task/queries/completions`, { + const res = await fetch(`${WEBUI_BASE_URL}/api/v1/tasks/queries/completions`, { method: 'POST', headers: { Accept: 'application/json', @@ -407,7 +407,7 @@ export const generateAutoCompletion = async ( const controller = new AbortController(); let error = null; - const res = await fetch(`${WEBUI_BASE_URL}/api/task/auto/completions`, { + const res = await fetch(`${WEBUI_BASE_URL}/api/v1/tasks/auto/completions`, { signal: controller.signal, method: 'POST', headers: { @@ -477,7 +477,7 @@ export const generateMoACompletion = async ( const controller = new AbortController(); let error = null; - const res = await fetch(`${WEBUI_BASE_URL}/api/task/moa/completions`, { + const res = await fetch(`${WEBUI_BASE_URL}/api/v1/tasks/moa/completions`, { signal: controller.signal, method: 'POST', headers: { @@ -507,7 +507,7 @@ export const generateMoACompletion = async ( export const getPipelinesList = async (token: string = '') => { let error = null; - const res = await fetch(`${WEBUI_BASE_URL}/api/pipelines/list`, { + const res = await fetch(`${WEBUI_BASE_URL}/api/v1/pipelines/list`, { method: 'GET', headers: { Accept: 'application/json', @@ -541,7 +541,7 @@ export const uploadPipeline = async (token: string, file: File, urlIdx: string) formData.append('file', file); formData.append('urlIdx', urlIdx); - const res = await fetch(`${WEBUI_BASE_URL}/api/pipelines/upload`, { + const res = await fetch(`${WEBUI_BASE_URL}/api/v1/pipelines/upload`, { method: 'POST', headers: { ...(token && { authorization: `Bearer ${token}` }) @@ -573,7 +573,7 @@ export const uploadPipeline = async (token: string, file: File, urlIdx: string) export const downloadPipeline = async (token: string, url: string, urlIdx: string) => { let error = null; - const res = await fetch(`${WEBUI_BASE_URL}/api/pipelines/add`, { + const res = await fetch(`${WEBUI_BASE_URL}/api/v1/pipelines/add`, { method: 'POST', headers: { Accept: 'application/json', @@ -609,7 +609,7 @@ export const downloadPipeline = async (token: string, url: string, urlIdx: strin export const deletePipeline = async (token: string, id: string, urlIdx: string) => { let error = null; - const res = await fetch(`${WEBUI_BASE_URL}/api/pipelines/delete`, { + const res = await fetch(`${WEBUI_BASE_URL}/api/v1/pipelines/delete`, { method: 'DELETE', headers: { Accept: 'application/json', @@ -650,7 +650,7 @@ export const getPipelines = async (token: string, urlIdx?: string) => { searchParams.append('urlIdx', urlIdx); } - const res = await fetch(`${WEBUI_BASE_URL}/api/pipelines?${searchParams.toString()}`, { + const res = await fetch(`${WEBUI_BASE_URL}/api/v1/pipelines?${searchParams.toString()}`, { method: 'GET', headers: { Accept: 'application/json', @@ -685,7 +685,7 @@ export const getPipelineValves = async (token: string, pipeline_id: string, urlI } const res = await fetch( - `${WEBUI_BASE_URL}/api/pipelines/${pipeline_id}/valves?${searchParams.toString()}`, + `${WEBUI_BASE_URL}/api/v1/pipelines/${pipeline_id}/valves?${searchParams.toString()}`, { method: 'GET', headers: { @@ -721,7 +721,7 @@ export const getPipelineValvesSpec = async (token: string, pipeline_id: string, } const res = await fetch( - `${WEBUI_BASE_URL}/api/pipelines/${pipeline_id}/valves/spec?${searchParams.toString()}`, + `${WEBUI_BASE_URL}/api/v1/pipelines/${pipeline_id}/valves/spec?${searchParams.toString()}`, { method: 'GET', headers: { @@ -762,7 +762,7 @@ export const updatePipelineValves = async ( } const res = await fetch( - `${WEBUI_BASE_URL}/api/pipelines/${pipeline_id}/valves/update?${searchParams.toString()}`, + `${WEBUI_BASE_URL}/api/v1/pipelines/${pipeline_id}/valves/update?${searchParams.toString()}`, { method: 'POST', headers: { diff --git a/src/lib/apis/streaming/index.ts b/src/lib/apis/streaming/index.ts index 54804385d..5617ce36c 100644 --- a/src/lib/apis/streaming/index.ts +++ b/src/lib/apis/streaming/index.ts @@ -77,10 +77,14 @@ async function* openAIStreamToIterator( continue; } + if (parsedData.usage) { + yield { done: false, value: '', usage: parsedData.usage }; + continue; + } + yield { done: false, value: parsedData.choices?.[0]?.delta?.content ?? '', - usage: parsedData.usage }; } catch (e) { console.error('Error extracting delta from SSE event:', e); @@ -98,10 +102,26 @@ async function* streamLargeDeltasAsRandomChunks( yield textStreamUpdate; return; } + + if (textStreamUpdate.error) { + yield textStreamUpdate; + continue; + } if (textStreamUpdate.sources) { yield textStreamUpdate; continue; } + if (textStreamUpdate.selectedModelId) { + yield textStreamUpdate; + continue; + } + if (textStreamUpdate.usage) { + yield textStreamUpdate; + continue; + } + + + let content = textStreamUpdate.value; if (content.length < 5) { yield { done: false, value: content }; diff --git a/src/lib/components/admin/Settings/Images.svelte b/src/lib/components/admin/Settings/Images.svelte index c76e192bf..b0492f24b 100644 --- a/src/lib/components/admin/Settings/Images.svelte +++ b/src/lib/components/admin/Settings/Images.svelte @@ -105,10 +105,15 @@ }; const updateConfigHandler = async () => { - const res = await updateConfig(localStorage.token, config).catch((error) => { - toast.error(error); - return null; - }); + const res = await updateConfig(localStorage.token, config) + .catch((error) => { + toast.error(error); + return null; + }) + .catch((error) => { + toast.error(error); + return null; + }); if (res) { config = res; diff --git a/src/lib/components/admin/Settings/Models.svelte b/src/lib/components/admin/Settings/Models.svelte index 7a3361682..f084de65a 100644 --- a/src/lib/components/admin/Settings/Models.svelte +++ b/src/lib/components/admin/Settings/Models.svelte @@ -137,7 +137,7 @@ }); - + {#if models !== null} {#if selectedModelId === null} diff --git a/src/lib/components/admin/Settings/Models/ConfigureModelsModal.svelte b/src/lib/components/admin/Settings/Models/ConfigureModelsModal.svelte index 4922b5b6f..23865c184 100644 --- a/src/lib/components/admin/Settings/Models/ConfigureModelsModal.svelte +++ b/src/lib/components/admin/Settings/Models/ConfigureModelsModal.svelte @@ -18,7 +18,7 @@ import Plus from '$lib/components/icons/Plus.svelte'; export let show = false; - export let init = () => {}; + export let initHandler = () => {}; let config = null; @@ -29,26 +29,11 @@ let loading = false; let showResetModal = false; - const submitHandler = async () => { - loading = true; + $: if (show) { + init(); + } - const res = await setModelsConfig(localStorage.token, { - DEFAULT_MODELS: defaultModelIds.join(','), - MODEL_ORDER_LIST: modelIds - }); - - if (res) { - toast.success($i18n.t('Models configuration saved successfully')); - init(); - show = false; - } else { - toast.error($i18n.t('Failed to save models configuration')); - } - - loading = false; - }; - - onMount(async () => { + const init = async () => { config = await getModelsConfig(localStorage.token); if (config?.DEFAULT_MODELS) { @@ -68,6 +53,28 @@ // Add remaining IDs not in MODEL_ORDER_LIST, sorted alphabetically ...allModelIds.filter((id) => !orderedSet.has(id)).sort((a, b) => a.localeCompare(b)) ]; + }; + const submitHandler = async () => { + loading = true; + + const res = await setModelsConfig(localStorage.token, { + DEFAULT_MODELS: defaultModelIds.join(','), + MODEL_ORDER_LIST: modelIds + }); + + if (res) { + toast.success($i18n.t('Models configuration saved successfully')); + initHandler(); + show = false; + } else { + toast.error($i18n.t('Failed to save models configuration')); + } + + loading = false; + }; + + onMount(async () => { + init(); }); @@ -79,7 +86,7 @@ const res = deleteAllModels(localStorage.token); if (res) { toast.success($i18n.t('All models deleted successfully')); - init(); + initHandler(); } }} /> diff --git a/src/lib/components/admin/Settings/WebSearch.svelte b/src/lib/components/admin/Settings/WebSearch.svelte index a3ccbec1d..58eb09da3 100644 --- a/src/lib/components/admin/Settings/WebSearch.svelte +++ b/src/lib/components/admin/Settings/WebSearch.svelte @@ -16,6 +16,7 @@ 'searxng', 'google_pse', 'brave', + 'kagi', 'mojeek', 'serpstack', 'serper', @@ -155,6 +156,17 @@ bind:value={webConfig.search.brave_search_api_key} />
+ {:else if webConfig.search.engine === 'kagi'} +
+
+ {$i18n.t('Kagi Search API Key')} +
+ + +
{:else if webConfig.search.engine === 'mojeek'}
diff --git a/src/lib/components/chat/Chat.svelte b/src/lib/components/chat/Chat.svelte index e6a653420..a55cbc87b 100644 --- a/src/lib/components/chat/Chat.svelte +++ b/src/lib/components/chat/Chat.svelte @@ -455,41 +455,43 @@ ////////////////////////// const initNewChat = async () => { - if (sessionStorage.selectedModels) { - selectedModels = JSON.parse(sessionStorage.selectedModels); - sessionStorage.removeItem('selectedModels'); - } else { - if ($page.url.searchParams.get('models')) { - selectedModels = $page.url.searchParams.get('models')?.split(','); - } else if ($page.url.searchParams.get('model')) { - const urlModels = $page.url.searchParams.get('model')?.split(','); + if ($page.url.searchParams.get('models')) { + selectedModels = $page.url.searchParams.get('models')?.split(','); + } else if ($page.url.searchParams.get('model')) { + const urlModels = $page.url.searchParams.get('model')?.split(','); - if (urlModels.length === 1) { - const m = $models.find((m) => m.id === urlModels[0]); - if (!m) { - const modelSelectorButton = document.getElementById('model-selector-0-button'); - if (modelSelectorButton) { - modelSelectorButton.click(); - await tick(); + if (urlModels.length === 1) { + const m = $models.find((m) => m.id === urlModels[0]); + if (!m) { + const modelSelectorButton = document.getElementById('model-selector-0-button'); + if (modelSelectorButton) { + modelSelectorButton.click(); + await tick(); - const modelSelectorInput = document.getElementById('model-search-input'); - if (modelSelectorInput) { - modelSelectorInput.focus(); - modelSelectorInput.value = urlModels[0]; - modelSelectorInput.dispatchEvent(new Event('input')); - } + const modelSelectorInput = document.getElementById('model-search-input'); + if (modelSelectorInput) { + modelSelectorInput.focus(); + modelSelectorInput.value = urlModels[0]; + modelSelectorInput.dispatchEvent(new Event('input')); } - } else { - selectedModels = urlModels; } } else { selectedModels = urlModels; } - } else if ($settings?.models) { - selectedModels = $settings?.models; - } else if ($config?.default_models) { - console.log($config?.default_models.split(',') ?? ''); - selectedModels = $config?.default_models.split(','); + } else { + selectedModels = urlModels; + } + } else { + if (sessionStorage.selectedModels) { + selectedModels = JSON.parse(sessionStorage.selectedModels); + sessionStorage.removeItem('selectedModels'); + } else { + if ($settings?.models) { + selectedModels = $settings?.models; + } else if ($config?.default_models) { + console.log($config?.default_models.split(',') ?? ''); + selectedModels = $config?.default_models.split(','); + } } } @@ -1056,11 +1058,14 @@ } let _response = null; - if (model?.owned_by === 'ollama') { - _response = await sendPromptOllama(model, prompt, responseMessageId, _chatId); - } else if (model) { - _response = await sendPromptOpenAI(model, prompt, responseMessageId, _chatId); - } + + // if (model?.owned_by === 'ollama') { + // _response = await sendPromptOllama(model, prompt, responseMessageId, _chatId); + // } else if (model) { + // } + + _response = await sendPromptOpenAI(model, prompt, responseMessageId, _chatId); + _responses.push(_response); if (chatEventEmitter) clearInterval(chatEventEmitter); @@ -1207,24 +1212,14 @@ $settings?.params?.stream_response ?? params?.stream_response ?? true; + const [res, controller] = await generateChatCompletion(localStorage.token, { stream: stream, model: model.id, messages: messagesBody, - options: { - ...{ ...($settings?.params ?? {}), ...params }, - stop: - (params?.stop ?? $settings?.params?.stop ?? undefined) - ? (params?.stop.split(',').map((token) => token.trim()) ?? $settings.params.stop).map( - (str) => decodeURIComponent(JSON.parse('"' + str.replace(/\"/g, '\\"') + '"')) - ) - : undefined, - num_predict: params?.max_tokens ?? $settings?.params?.max_tokens ?? undefined, - repeat_penalty: - params?.frequency_penalty ?? $settings?.params?.frequency_penalty ?? undefined - }, format: $settings.requestFormat ?? undefined, keep_alive: $settings.keepAlive ?? undefined, + tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined, files: files.length > 0 ? files : undefined, session_id: $socket?.id, @@ -1542,13 +1537,6 @@ { stream: stream, model: model.id, - ...(stream && (model.info?.meta?.capabilities?.usage ?? false) - ? { - stream_options: { - include_usage: true - } - } - : {}), messages: [ params?.system || $settings.system || (responseMessage?.userContext ?? null) ? { @@ -1593,23 +1581,36 @@ content: message?.merged?.content ?? message.content }) })), - seed: params?.seed ?? $settings?.params?.seed ?? undefined, - stop: - (params?.stop ?? $settings?.params?.stop ?? undefined) - ? (params?.stop.split(',').map((token) => token.trim()) ?? $settings.params.stop).map( - (str) => decodeURIComponent(JSON.parse('"' + str.replace(/\"/g, '\\"') + '"')) - ) - : undefined, - temperature: params?.temperature ?? $settings?.params?.temperature ?? undefined, - top_p: params?.top_p ?? $settings?.params?.top_p ?? undefined, - frequency_penalty: - params?.frequency_penalty ?? $settings?.params?.frequency_penalty ?? undefined, - max_tokens: params?.max_tokens ?? $settings?.params?.max_tokens ?? undefined, + + // params: { + // ...$settings?.params, + // ...params, + + // format: $settings.requestFormat ?? undefined, + // keep_alive: $settings.keepAlive ?? undefined, + // stop: + // (params?.stop ?? $settings?.params?.stop ?? undefined) + // ? ( + // params?.stop.split(',').map((token) => token.trim()) ?? $settings.params.stop + // ).map((str) => + // decodeURIComponent(JSON.parse('"' + str.replace(/\"/g, '\\"') + '"')) + // ) + // : undefined + // }, + tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined, files: files.length > 0 ? files : undefined, session_id: $socket?.id, chat_id: $chatId, - id: responseMessageId + id: responseMessageId, + + ...(stream && (model.info?.meta?.capabilities?.usage ?? false) + ? { + stream_options: { + include_usage: true + } + } + : {}) }, `${WEBUI_BASE_URL}/api` ); @@ -1636,6 +1637,7 @@ await handleOpenAIError(error, null, model, responseMessage); break; } + if (done || stopResponseFlag || _chatId !== $chatId) { responseMessage.done = true; history.messages[responseMessageId] = responseMessage; @@ -1648,7 +1650,7 @@ } if (usage) { - responseMessage.info = { ...usage, openai: true, usage }; + responseMessage.usage = usage; } if (selectedModelId) { diff --git a/src/lib/components/chat/MessageInput.svelte b/src/lib/components/chat/MessageInput.svelte index 296cc7939..800059055 100644 --- a/src/lib/components/chat/MessageInput.svelte +++ b/src/lib/components/chat/MessageInput.svelte @@ -781,7 +781,7 @@