diff --git a/CHANGELOG.md b/CHANGELOG.md index 19686d028..46c1c644e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,34 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.3.22] - 2024-09-19 + +### Added + +- **⭐ Chat Overview**: Introducing a node-based interactive messages diagram for improved visualization of conversation flows. +- **🔗 Multiple Vector DB Support**: Now supports multiple vector databases, including the newly added Milvus support. Community contributions for additional database support are highly encouraged! +- **📡 Experimental Non-Stream Chat Completion**: Experimental feature allowing the use of OpenAI o1 models, which do not support streaming, ensuring more versatile model deployment. +- **🔍 Experimental Colbert-AI Reranker Integration**: Added support for "jinaai/jina-colbert-v2" as a reranker, enhancing search relevance and accuracy. Note: it may not function at all on low-spec computers. +- **🕸️ ENABLE_WEBSOCKET_SUPPORT**: Added environment variable for instances to ignore websocket upgrades, stabilizing connections on platforms with websocket issues. +- **🔊 Azure Speech Service Integration**: Added support for Azure Speech services for Text-to-Speech (TTS). +- **🎚️ Customizable Playback Speed**: Playback speed control is now available in Call mode settings, allowing users to adjust audio playback speed to their preferences. +- **🧠 Enhanced Error Messaging**: System now displays helpful error messages directly to users during chat completion issues. +- **📂 Save Model as Transparent PNG**: Model profile images are now saved as PNGs, supporting transparency and improving visual integration. +- **📱 iPhone Compatibility Adjustments**: Added padding to accommodate the iPhone navigation bar, improving UI display on these devices. +- **🔗 Secure Response Headers**: Implemented security response headers, bolstering web application security. +- **🔧 Enhanced AUTOMATIC1111 Settings**: Users can now configure 'CFG Scale', 'Sampler', and 'Scheduler' parameters directly in the admin settings, enhancing workflow flexibility without source code modifications. +- **🌍 i18n Updates**: Enhanced translations for Chinese, Ukrainian, Russian, and French, fostering a better localized experience. + +### Fixed + +- **🛠️ Chat Message Deletion**: Resolved issues with chat message deletion, ensuring a smoother user interaction and system stability. +- **🔢 Ordered List Numbering**: Fixed the incorrect ordering in lists. + +### Changed + +- **🎨 Transparent Icon Handling**: Allowed model icons to be displayed on transparent backgrounds, improving UI aesthetics. +- **📝 Improved RAG Template**: Enhanced Retrieval-Augmented Generation template, optimizing context handling and error checking for more precise operation. + ## [0.3.21] - 2024-09-08 ### Added diff --git a/Dockerfile b/Dockerfile index 8078bf0ea..c944f54e6 100644 --- a/Dockerfile +++ b/Dockerfile @@ -74,6 +74,10 @@ ENV RAG_EMBEDDING_MODEL="$USE_EMBEDDING_MODEL_DOCKER" \ ## Hugging Face download cache ## ENV HF_HOME="/app/backend/data/cache/embedding/models" + +## Torch Extensions ## +# ENV TORCH_EXTENSIONS_DIR="/.cache/torch_extensions" + #### Other models ########################################################## WORKDIR /app/backend @@ -96,7 +100,7 @@ RUN chown -R $UID:$GID /app $HOME RUN if [ "$USE_OLLAMA" = "true" ]; then \ apt-get update && \ # Install pandoc and netcat - apt-get install -y --no-install-recommends pandoc netcat-openbsd curl && \ + apt-get install -y --no-install-recommends git build-essential pandoc netcat-openbsd curl && \ apt-get install -y --no-install-recommends gcc python3-dev && \ # for RAG OCR apt-get install -y --no-install-recommends ffmpeg libsm6 libxext6 && \ @@ -109,7 +113,7 @@ RUN if [ "$USE_OLLAMA" = "true" ]; then \ else \ apt-get update && \ # Install pandoc, netcat and gcc - apt-get install -y --no-install-recommends pandoc gcc netcat-openbsd curl jq && \ + apt-get install -y --no-install-recommends git build-essential pandoc gcc netcat-openbsd curl jq && \ apt-get install -y --no-install-recommends gcc python3-dev && \ # for RAG OCR apt-get install -y --no-install-recommends ffmpeg libsm6 libxext6 && \ @@ -157,5 +161,6 @@ USER $UID:$GID ARG BUILD_HASH ENV WEBUI_BUILD_VERSION=${BUILD_HASH} +ENV DOCKER true CMD [ "bash", "start.sh"] diff --git a/backend/open_webui/__init__.py b/backend/open_webui/__init__.py index 30e83b198..167f0fb60 100644 --- a/backend/open_webui/__init__.py +++ b/backend/open_webui/__init__.py @@ -39,6 +39,19 @@ def serve( "/usr/local/lib/python3.11/site-packages/nvidia/cudnn/lib", ] ) + try: + import torch + + assert torch.cuda.is_available(), "CUDA not available" + typer.echo("CUDA seems to be working") + except Exception as e: + typer.echo( + "Error when testing CUDA but USE_CUDA_DOCKER is true. " + "Resetting USE_CUDA_DOCKER to false and removing " + f"LD_LIBRARY_PATH modifications: {e}" + ) + os.environ["USE_CUDA_DOCKER"] = "false" + os.environ["LD_LIBRARY_PATH"] = ":".join(LD_LIBRARY_PATH) import open_webui.main # we need set environment variables before importing main uvicorn.run(open_webui.main.app, host=host, port=port, forwarded_allow_ips="*") diff --git a/backend/open_webui/apps/audio/main.py b/backend/open_webui/apps/audio/main.py index 1fc44b28f..0eee533bd 100644 --- a/backend/open_webui/apps/audio/main.py +++ b/backend/open_webui/apps/audio/main.py @@ -19,16 +19,18 @@ from open_webui.config import ( AUDIO_TTS_OPENAI_API_KEY, AUDIO_TTS_SPLIT_ON, AUDIO_TTS_VOICE, + AUDIO_TTS_AZURE_SPEECH_REGION, + AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT, CACHE_DIR, CORS_ALLOW_ORIGIN, - DEVICE_TYPE, WHISPER_MODEL, WHISPER_MODEL_AUTO_UPDATE, WHISPER_MODEL_DIR, AppConfig, ) + from open_webui.constants import ERROR_MESSAGES -from open_webui.env import SRC_LOG_LEVELS +from open_webui.env import SRC_LOG_LEVELS, DEVICE_TYPE from fastapi import Depends, FastAPI, File, HTTPException, Request, UploadFile, status from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import FileResponse @@ -62,6 +64,9 @@ app.state.config.TTS_VOICE = AUDIO_TTS_VOICE app.state.config.TTS_API_KEY = AUDIO_TTS_API_KEY app.state.config.TTS_SPLIT_ON = AUDIO_TTS_SPLIT_ON +app.state.config.TTS_AZURE_SPEECH_REGION = AUDIO_TTS_AZURE_SPEECH_REGION +app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT = AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT + # setting device type for whisper model whisper_device_type = DEVICE_TYPE if DEVICE_TYPE and DEVICE_TYPE == "cuda" else "cpu" log.info(f"whisper_device_type: {whisper_device_type}") @@ -78,6 +83,8 @@ class TTSConfigForm(BaseModel): MODEL: str VOICE: str SPLIT_ON: str + AZURE_SPEECH_REGION: str + AZURE_SPEECH_OUTPUT_FORMAT: str class STTConfigForm(BaseModel): @@ -130,6 +137,8 @@ async def get_audio_config(user=Depends(get_admin_user)): "MODEL": app.state.config.TTS_MODEL, "VOICE": app.state.config.TTS_VOICE, "SPLIT_ON": app.state.config.TTS_SPLIT_ON, + "AZURE_SPEECH_REGION": app.state.config.TTS_AZURE_SPEECH_REGION, + "AZURE_SPEECH_OUTPUT_FORMAT": app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT, }, "stt": { "OPENAI_API_BASE_URL": app.state.config.STT_OPENAI_API_BASE_URL, @@ -151,6 +160,10 @@ async def update_audio_config( app.state.config.TTS_MODEL = form_data.tts.MODEL app.state.config.TTS_VOICE = form_data.tts.VOICE app.state.config.TTS_SPLIT_ON = form_data.tts.SPLIT_ON + app.state.config.TTS_AZURE_SPEECH_REGION = form_data.tts.AZURE_SPEECH_REGION + app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT = ( + form_data.tts.AZURE_SPEECH_OUTPUT_FORMAT + ) app.state.config.STT_OPENAI_API_BASE_URL = form_data.stt.OPENAI_API_BASE_URL app.state.config.STT_OPENAI_API_KEY = form_data.stt.OPENAI_API_KEY @@ -166,6 +179,8 @@ async def update_audio_config( "MODEL": app.state.config.TTS_MODEL, "VOICE": app.state.config.TTS_VOICE, "SPLIT_ON": app.state.config.TTS_SPLIT_ON, + "AZURE_SPEECH_REGION": app.state.config.TTS_AZURE_SPEECH_REGION, + "AZURE_SPEECH_OUTPUT_FORMAT": app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT, }, "stt": { "OPENAI_API_BASE_URL": app.state.config.STT_OPENAI_API_BASE_URL, @@ -301,6 +316,42 @@ async def speech(request: Request, user=Depends(get_verified_user)): detail=error_detail, ) + elif app.state.config.TTS_ENGINE == "azure": + payload = None + try: + payload = json.loads(body.decode("utf-8")) + except Exception as e: + log.exception(e) + raise HTTPException(status_code=400, detail="Invalid JSON payload") + + region = app.state.config.TTS_AZURE_SPEECH_REGION + language = app.state.config.TTS_VOICE + locale = "-".join(app.state.config.TTS_VOICE.split("-")[:1]) + output_format = app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT + url = f"https://{region}.tts.speech.microsoft.com/cognitiveservices/v1" + + headers = { + "Ocp-Apim-Subscription-Key": app.state.config.TTS_API_KEY, + "Content-Type": "application/ssml+xml", + "X-Microsoft-OutputFormat": output_format, + } + + data = f""" + {payload["input"]} + """ + + response = requests.post(url, headers=headers, data=data) + + if response.status_code == 200: + with open(file_path, "wb") as f: + f.write(response.content) + return FileResponse(file_path) + else: + log.error(f"Error synthesizing speech - {response.reason}") + raise HTTPException( + status_code=500, detail=f"Error synthesizing speech - {response.reason}" + ) + @app.post("/transcriptions") def transcribe( @@ -309,7 +360,7 @@ def transcribe( ): log.info(f"file.content_type: {file.content_type}") - if file.content_type not in ["audio/mpeg", "audio/wav"]: + if file.content_type not in ["audio/mpeg", "audio/wav", "audio/ogg"]: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.FILE_NOT_SUPPORTED, @@ -443,7 +494,7 @@ def get_available_models() -> list[dict]: try: response = requests.get( - "https://api.elevenlabs.io/v1/models", headers=headers + "https://api.elevenlabs.io/v1/models", headers=headers, timeout=5 ) response.raise_for_status() models = response.json() @@ -478,6 +529,21 @@ def get_available_voices() -> dict: except Exception: # Avoided @lru_cache with exception pass + elif app.state.config.TTS_ENGINE == "azure": + try: + region = app.state.config.TTS_AZURE_SPEECH_REGION + url = f"https://{region}.tts.speech.microsoft.com/cognitiveservices/voices/list" + headers = {"Ocp-Apim-Subscription-Key": app.state.config.TTS_API_KEY} + + response = requests.get(url, headers=headers) + response.raise_for_status() + voices = response.json() + for voice in voices: + ret[voice["ShortName"]] = ( + f"{voice['DisplayName']} ({voice['ShortName']})" + ) + except requests.RequestException as e: + log.error(f"Error fetching voices: {str(e)}") return ret diff --git a/backend/open_webui/apps/images/main.py b/backend/open_webui/apps/images/main.py index 17afd645c..1074e2cb0 100644 --- a/backend/open_webui/apps/images/main.py +++ b/backend/open_webui/apps/images/main.py @@ -17,6 +17,9 @@ from open_webui.apps.images.utils.comfyui import ( from open_webui.config import ( AUTOMATIC1111_API_AUTH, AUTOMATIC1111_BASE_URL, + AUTOMATIC1111_CFG_SCALE, + AUTOMATIC1111_SAMPLER, + AUTOMATIC1111_SCHEDULER, CACHE_DIR, COMFYUI_BASE_URL, COMFYUI_WORKFLOW, @@ -65,6 +68,9 @@ app.state.config.MODEL = IMAGE_GENERATION_MODEL app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL app.state.config.AUTOMATIC1111_API_AUTH = AUTOMATIC1111_API_AUTH +app.state.config.AUTOMATIC1111_CFG_SCALE = AUTOMATIC1111_CFG_SCALE +app.state.config.AUTOMATIC1111_SAMPLER = AUTOMATIC1111_SAMPLER +app.state.config.AUTOMATIC1111_SCHEDULER = AUTOMATIC1111_SCHEDULER app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL app.state.config.COMFYUI_WORKFLOW = COMFYUI_WORKFLOW app.state.config.COMFYUI_WORKFLOW_NODES = COMFYUI_WORKFLOW_NODES @@ -85,6 +91,9 @@ async def get_config(request: Request, user=Depends(get_admin_user)): "automatic1111": { "AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL, "AUTOMATIC1111_API_AUTH": app.state.config.AUTOMATIC1111_API_AUTH, + "AUTOMATIC1111_CFG_SCALE": app.state.config.AUTOMATIC1111_CFG_SCALE, + "AUTOMATIC1111_SAMPLER": app.state.config.AUTOMATIC1111_SAMPLER, + "AUTOMATIC1111_SCHEDULER": app.state.config.AUTOMATIC1111_SCHEDULER, }, "comfyui": { "COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL, @@ -102,6 +111,9 @@ class OpenAIConfigForm(BaseModel): class Automatic1111ConfigForm(BaseModel): AUTOMATIC1111_BASE_URL: str AUTOMATIC1111_API_AUTH: str + AUTOMATIC1111_CFG_SCALE: Optional[str] + AUTOMATIC1111_SAMPLER: Optional[str] + AUTOMATIC1111_SCHEDULER: Optional[str] class ComfyUIConfigForm(BaseModel): @@ -133,6 +145,22 @@ async def update_config(form_data: ConfigForm, user=Depends(get_admin_user)): form_data.automatic1111.AUTOMATIC1111_API_AUTH ) + app.state.config.AUTOMATIC1111_CFG_SCALE = ( + float(form_data.automatic1111.AUTOMATIC1111_CFG_SCALE) + if form_data.automatic1111.AUTOMATIC1111_CFG_SCALE != "" + else None + ) + app.state.config.AUTOMATIC1111_SAMPLER = ( + form_data.automatic1111.AUTOMATIC1111_SAMPLER + if form_data.automatic1111.AUTOMATIC1111_SAMPLER != "" + else None + ) + app.state.config.AUTOMATIC1111_SCHEDULER = ( + form_data.automatic1111.AUTOMATIC1111_SCHEDULER + if form_data.automatic1111.AUTOMATIC1111_SCHEDULER != "" + else None + ) + app.state.config.COMFYUI_BASE_URL = form_data.comfyui.COMFYUI_BASE_URL.strip("/") app.state.config.COMFYUI_WORKFLOW = form_data.comfyui.COMFYUI_WORKFLOW app.state.config.COMFYUI_WORKFLOW_NODES = form_data.comfyui.COMFYUI_WORKFLOW_NODES @@ -147,6 +175,9 @@ async def update_config(form_data: ConfigForm, user=Depends(get_admin_user)): "automatic1111": { "AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL, "AUTOMATIC1111_API_AUTH": app.state.config.AUTOMATIC1111_API_AUTH, + "AUTOMATIC1111_CFG_SCALE": app.state.config.AUTOMATIC1111_CFG_SCALE, + "AUTOMATIC1111_SAMPLER": app.state.config.AUTOMATIC1111_SAMPLER, + "AUTOMATIC1111_SCHEDULER": app.state.config.AUTOMATIC1111_SCHEDULER, }, "comfyui": { "COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL, @@ -524,6 +555,15 @@ async def image_generations( if form_data.negative_prompt is not None: data["negative_prompt"] = form_data.negative_prompt + if app.state.config.AUTOMATIC1111_CFG_SCALE: + data["cfg_scale"] = app.state.config.AUTOMATIC1111_CFG_SCALE + + if app.state.config.AUTOMATIC1111_SAMPLER: + data["sampler_name"] = app.state.config.AUTOMATIC1111_SAMPLER + + if app.state.config.AUTOMATIC1111_SCHEDULER: + data["scheduler"] = app.state.config.AUTOMATIC1111_SCHEDULER + # Use asyncio.to_thread for the requests.post call r = await asyncio.to_thread( requests.post, diff --git a/backend/open_webui/apps/ollama/main.py b/backend/open_webui/apps/ollama/main.py index fe36010b7..6c639268e 100644 --- a/backend/open_webui/apps/ollama/main.py +++ b/backend/open_webui/apps/ollama/main.py @@ -545,6 +545,55 @@ class GenerateEmbeddingsForm(BaseModel): @app.post("/api/embed") @app.post("/api/embed/{url_idx}") +async def generate_embeddings( + form_data: GenerateEmbeddingsForm, + url_idx: Optional[int] = None, + user=Depends(get_verified_user), +): + if url_idx is None: + model = form_data.model + + if ":" not in model: + model = f"{model}:latest" + + if model in app.state.MODELS: + url_idx = random.choice(app.state.MODELS[model]["urls"]) + else: + raise HTTPException( + status_code=400, + detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), + ) + + url = app.state.config.OLLAMA_BASE_URLS[url_idx] + log.info(f"url: {url}") + + r = requests.request( + method="POST", + url=f"{url}/api/embed", + headers={"Content-Type": "application/json"}, + data=form_data.model_dump_json(exclude_none=True).encode(), + ) + try: + r.raise_for_status() + + return r.json() + except Exception as e: + log.exception(e) + error_detail = "Open WebUI: Server Connection Error" + if r is not None: + try: + res = r.json() + if "error" in res: + error_detail = f"Ollama: {res['error']}" + except Exception: + error_detail = f"Ollama: {e}" + + raise HTTPException( + status_code=r.status_code if r else 500, + detail=error_detail, + ) + + @app.post("/api/embeddings") @app.post("/api/embeddings/{url_idx}") async def generate_embeddings( @@ -571,7 +620,7 @@ async def generate_embeddings( r = requests.request( method="POST", - url=f"{url}/api/embed", + url=f"{url}/api/embeddings", headers={"Content-Type": "application/json"}, data=form_data.model_dump_json(exclude_none=True).encode(), ) @@ -767,7 +816,10 @@ async def generate_chat_completion( log.debug(payload) return await post_streaming_url( - f"{url}/api/chat", json.dumps(payload), content_type="application/x-ndjson" + f"{url}/api/chat", + json.dumps(payload), + stream=form_data.stream, + content_type="application/x-ndjson", ) diff --git a/backend/open_webui/apps/openai/main.py b/backend/open_webui/apps/openai/main.py index 23ea1cc8c..9a27c46a3 100644 --- a/backend/open_webui/apps/openai/main.py +++ b/backend/open_webui/apps/openai/main.py @@ -423,6 +423,7 @@ async def generate_chat_completion( r = None session = None streaming = False + response = None try: session = aiohttp.ClientSession( @@ -435,8 +436,6 @@ async def generate_chat_completion( headers=headers, ) - r.raise_for_status() - # Check if response is SSE if "text/event-stream" in r.headers.get("Content-Type", ""): streaming = True @@ -449,19 +448,23 @@ async def generate_chat_completion( ), ) else: - response_data = await r.json() - return response_data + try: + response = await r.json() + except Exception as e: + log.error(e) + response = await r.text() + + r.raise_for_status() + return response except Exception as e: log.exception(e) error_detail = "Open WebUI: Server Connection Error" - if r is not None: - try: - res = await r.json() - print(res) - if "error" in res: - error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}" - except Exception: - error_detail = f"External: {e}" + if isinstance(response, dict): + if "error" in response: + error_detail = f"{response['error']['message'] if 'message' in response['error'] else response['error']}" + elif isinstance(response, str): + error_detail = response + raise HTTPException(status_code=r.status if r else 500, detail=error_detail) finally: if not streaming and session: diff --git a/backend/open_webui/apps/rag/main.py b/backend/open_webui/apps/rag/main.py index 6c064fe81..981a6fe5b 100644 --- a/backend/open_webui/apps/rag/main.py +++ b/backend/open_webui/apps/rag/main.py @@ -10,13 +10,21 @@ from datetime import datetime from pathlib import Path from typing import Iterator, Optional, Sequence, Union + +import numpy as np +import torch import requests import validators + +from fastapi import Depends, FastAPI, File, Form, HTTPException, UploadFile, status +from fastapi.middleware.cors import CORSMiddleware +from pydantic import BaseModel + +from open_webui.apps.rag.search.main import SearchResult from open_webui.apps.rag.search.brave import search_brave from open_webui.apps.rag.search.duckduckgo import search_duckduckgo from open_webui.apps.rag.search.google_pse import search_google_pse from open_webui.apps.rag.search.jina_search import search_jina -from open_webui.apps.rag.search.main import SearchResult from open_webui.apps.rag.search.searchapi import search_searchapi from open_webui.apps.rag.search.searxng import search_searxng from open_webui.apps.rag.search.serper import search_serper @@ -33,15 +41,12 @@ from open_webui.apps.rag.utils import ( ) from open_webui.apps.webui.models.documents import DocumentForm, Documents from open_webui.apps.webui.models.files import Files -from chromadb.utils.batch_utils import create_batches from open_webui.config import ( BRAVE_SEARCH_API_KEY, - CHROMA_CLIENT, CHUNK_OVERLAP, CHUNK_SIZE, CONTENT_EXTRACTION_ENGINE, CORS_ALLOW_ORIGIN, - DEVICE_TYPE, DOCS_DIR, ENABLE_RAG_HYBRID_SEARCH, ENABLE_RAG_LOCAL_WEB_FETCH, @@ -64,6 +69,7 @@ from open_webui.config import ( RAG_RERANKING_MODEL, RAG_RERANKING_MODEL_AUTO_UPDATE, RAG_RERANKING_MODEL_TRUST_REMOTE_CODE, + DEFAULT_RAG_TEMPLATE, RAG_TEMPLATE, RAG_TOP_K, RAG_WEB_SEARCH_CONCURRENT_REQUESTS, @@ -84,9 +90,16 @@ from open_webui.config import ( AppConfig, ) from open_webui.constants import ERROR_MESSAGES -from open_webui.env import SRC_LOG_LEVELS -from fastapi import Depends, FastAPI, File, Form, HTTPException, UploadFile, status -from fastapi.middleware.cors import CORSMiddleware +from open_webui.env import SRC_LOG_LEVELS, DEVICE_TYPE, DOCKER +from open_webui.utils.misc import ( + calculate_sha256, + calculate_sha256_string, + extract_folders_after_data_docs, + sanitize_filename, +) +from open_webui.utils.utils import get_admin_user, get_verified_user +from open_webui.apps.rag.vector.connector import VECTOR_DB_CLIENT + from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.document_loaders import ( BSHTMLLoader, @@ -105,14 +118,8 @@ from langchain_community.document_loaders import ( YoutubeLoader, ) from langchain_core.documents import Document -from pydantic import BaseModel -from open_webui.utils.misc import ( - calculate_sha256, - calculate_sha256_string, - extract_folders_after_data_docs, - sanitize_filename, -) -from open_webui.utils.utils import get_admin_user, get_verified_user +from colbert.infra import ColBERTConfig +from colbert.modeling.checkpoint import Checkpoint log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["RAG"]) @@ -143,13 +150,11 @@ app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE = RAG_EMBEDDING_OPENAI_BATCH_SI app.state.config.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL app.state.config.RAG_TEMPLATE = RAG_TEMPLATE - app.state.config.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL app.state.config.OPENAI_API_KEY = RAG_OPENAI_API_KEY app.state.config.PDF_EXTRACT_IMAGES = PDF_EXTRACT_IMAGES - app.state.config.YOUTUBE_LOADER_LANGUAGE = YOUTUBE_LOADER_LANGUAGE app.state.YOUTUBE_LOADER_TRANSLATION = None @@ -175,13 +180,13 @@ app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_ def update_embedding_model( embedding_model: str, - update_model: bool = False, + auto_update: bool = False, ): if embedding_model and app.state.config.RAG_EMBEDDING_ENGINE == "": import sentence_transformers app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer( - get_model_path(embedding_model, update_model), + get_model_path(embedding_model, auto_update), device=DEVICE_TYPE, trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE, ) @@ -191,16 +196,108 @@ def update_embedding_model( def update_reranking_model( reranking_model: str, - update_model: bool = False, + auto_update: bool = False, ): if reranking_model: - import sentence_transformers + if any(model in reranking_model for model in ["jinaai/jina-colbert-v2"]): - app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder( - get_model_path(reranking_model, update_model), - device=DEVICE_TYPE, - trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE, - ) + class ColBERT: + def __init__(self, name) -> None: + print("ColBERT: Loading model", name) + self.device = "cuda" if torch.cuda.is_available() else "cpu" + + if DOCKER: + # This is a workaround for the issue with the docker container + # where the torch extension is not loaded properly + # and the following error is thrown: + # /root/.cache/torch_extensions/py311_cpu/segmented_maxsim_cpp/segmented_maxsim_cpp.so: cannot open shared object file: No such file or directory + + lock_file = "/root/.cache/torch_extensions/py311_cpu/segmented_maxsim_cpp/lock" + if os.path.exists(lock_file): + os.remove(lock_file) + + self.ckpt = Checkpoint( + name, + colbert_config=ColBERTConfig(model_name=name), + ).to(self.device) + pass + + def calculate_similarity_scores( + self, query_embeddings, document_embeddings + ): + + query_embeddings = query_embeddings.to(self.device) + document_embeddings = document_embeddings.to(self.device) + + # Validate dimensions to ensure compatibility + if query_embeddings.dim() != 3: + raise ValueError( + f"Expected query embeddings to have 3 dimensions, but got {query_embeddings.dim()}." + ) + if document_embeddings.dim() != 3: + raise ValueError( + f"Expected document embeddings to have 3 dimensions, but got {document_embeddings.dim()}." + ) + if query_embeddings.size(0) not in [1, document_embeddings.size(0)]: + raise ValueError( + "There should be either one query or queries equal to the number of documents." + ) + + # Transpose the query embeddings to align for matrix multiplication + transposed_query_embeddings = query_embeddings.permute(0, 2, 1) + # Compute similarity scores using batch matrix multiplication + computed_scores = torch.matmul( + document_embeddings, transposed_query_embeddings + ) + # Apply max pooling to extract the highest semantic similarity across each document's sequence + maximum_scores = torch.max(computed_scores, dim=1).values + + # Sum up the maximum scores across features to get the overall document relevance scores + final_scores = maximum_scores.sum(dim=1) + + normalized_scores = torch.softmax(final_scores, dim=0) + + return normalized_scores.detach().cpu().numpy().astype(np.float32) + + def predict(self, sentences): + + query = sentences[0][0] + docs = [i[1] for i in sentences] + + # Embedding the documents + embedded_docs = self.ckpt.docFromText(docs, bsize=32)[0] + # Embedding the queries + embedded_queries = self.ckpt.queryFromText([query], bsize=32) + embedded_query = embedded_queries[0] + + # Calculate retrieval scores for the query against all documents + scores = self.calculate_similarity_scores( + embedded_query.unsqueeze(0), embedded_docs + ) + + return scores + + try: + app.state.sentence_transformer_rf = ColBERT( + get_model_path(reranking_model, auto_update) + ) + except Exception as e: + log.error(f"ColBERT: {e}") + app.state.sentence_transformer_rf = None + app.state.config.ENABLE_RAG_HYBRID_SEARCH = False + else: + import sentence_transformers + + try: + app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder( + get_model_path(reranking_model, auto_update), + device=DEVICE_TYPE, + trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE, + ) + except: + log.error("CrossEncoder error") + app.state.sentence_transformer_rf = None + app.state.config.ENABLE_RAG_HYBRID_SEARCH = False else: app.state.sentence_transformer_rf = None @@ -593,7 +690,7 @@ async def update_query_settings( form_data: QuerySettingsForm, user=Depends(get_admin_user) ): app.state.config.RAG_TEMPLATE = ( - form_data.template if form_data.template else RAG_TEMPLATE + form_data.template if form_data.template != "" else DEFAULT_RAG_TEMPLATE ) app.state.config.TOP_K = form_data.k if form_data.k else 4 app.state.config.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0 @@ -998,14 +1095,11 @@ def store_docs_in_vector_db( try: if overwrite: - for collection in CHROMA_CLIENT.list_collections(): - if collection_name == collection.name: - log.info(f"deleting existing collection {collection_name}") - CHROMA_CLIENT.delete_collection(name=collection_name) + if VECTOR_DB_CLIENT.has_collection(collection_name=collection_name): + log.info(f"deleting existing collection {collection_name}") + VECTOR_DB_CLIENT.delete_collection(collection_name=collection_name) - collection = CHROMA_CLIENT.create_collection(name=collection_name) - - embedding_func = get_embedding_function( + embedding_function = get_embedding_function( app.state.config.RAG_EMBEDDING_ENGINE, app.state.config.RAG_EMBEDDING_MODEL, app.state.sentence_transformer_ef, @@ -1014,17 +1108,18 @@ def store_docs_in_vector_db( app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE, ) - embedding_texts = list(map(lambda x: x.replace("\n", " "), texts)) - embeddings = embedding_func(embedding_texts) - - for batch in create_batches( - api=CHROMA_CLIENT, - ids=[str(uuid.uuid4()) for _ in texts], - metadatas=metadatas, - embeddings=embeddings, - documents=texts, - ): - collection.add(*batch) + VECTOR_DB_CLIENT.insert( + collection_name=collection_name, + items=[ + { + "id": str(uuid.uuid4()), + "text": text, + "vector": embedding_function(text.replace("\n", " ")), + "metadata": metadatas[idx], + } + for idx, text in enumerate(texts) + ], + ) return True except Exception as e: @@ -1158,7 +1253,7 @@ def get_loader(filename: str, file_content_type: str, file_path: str): elif ( file_content_type == "application/vnd.openxmlformats-officedocument.wordprocessingml.document" - or file_ext in ["doc", "docx"] + or file_ext == "docx" ): loader = Docx2txtLoader(file_path) elif file_content_type in [ @@ -1396,7 +1491,7 @@ def scan_docs_dir(user=Depends(get_admin_user)): @app.post("/reset/db") def reset_vector_db(user=Depends(get_admin_user)): - CHROMA_CLIENT.reset() + VECTOR_DB_CLIENT.reset() @app.post("/reset/uploads") @@ -1437,7 +1532,7 @@ def reset(user=Depends(get_admin_user)) -> bool: log.error("Failed to delete %s. Reason: %s" % (file_path, e)) try: - CHROMA_CLIENT.reset() + VECTOR_DB_CLIENT.reset() except Exception as e: log.exception(e) diff --git a/backend/open_webui/apps/rag/utils.py b/backend/open_webui/apps/rag/utils.py index 2bf8a02e4..73ccfad38 100644 --- a/backend/open_webui/apps/rag/utils.py +++ b/backend/open_webui/apps/rag/utils.py @@ -1,24 +1,68 @@ import logging import os +import uuid from typing import Optional, Union import requests -from open_webui.apps.ollama.main import ( - GenerateEmbeddingsForm, - generate_ollama_embeddings, -) -from open_webui.config import CHROMA_CLIENT -from open_webui.env import SRC_LOG_LEVELS + from huggingface_hub import snapshot_download from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriever from langchain_community.retrievers import BM25Retriever from langchain_core.documents import Document + + +from open_webui.apps.ollama.main import ( + GenerateEmbeddingsForm, + generate_ollama_embeddings, +) +from open_webui.apps.rag.vector.connector import VECTOR_DB_CLIENT from open_webui.utils.misc import get_last_user_message +from open_webui.env import SRC_LOG_LEVELS + + log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["RAG"]) +from typing import Any + +from langchain_core.callbacks import CallbackManagerForRetrieverRun +from langchain_core.retrievers import BaseRetriever + + +class VectorSearchRetriever(BaseRetriever): + collection_name: Any + embedding_function: Any + top_k: int + + def _get_relevant_documents( + self, + query: str, + *, + run_manager: CallbackManagerForRetrieverRun, + ) -> list[Document]: + result = VECTOR_DB_CLIENT.search( + collection_name=self.collection_name, + vectors=[self.embedding_function(query)], + limit=self.top_k, + ) + + ids = result.ids[0] + metadatas = result.metadatas[0] + documents = result.documents[0] + + results = [] + for idx in range(len(ids)): + results.append( + Document( + metadata=metadatas[idx], + page_content=documents[idx], + ) + ) + return results + + def query_doc( collection_name: str, query: str, @@ -26,17 +70,18 @@ def query_doc( k: int, ): try: - collection = CHROMA_CLIENT.get_collection(name=collection_name) - query_embeddings = embedding_function(query) - - result = collection.query( - query_embeddings=[query_embeddings], - n_results=k, + result = VECTOR_DB_CLIENT.search( + collection_name=collection_name, + vectors=[embedding_function(query)], + limit=k, ) + print("result", result) + log.info(f"query_doc:result {result}") return result except Exception as e: + print(e) raise e @@ -47,27 +92,25 @@ def query_doc_with_hybrid_search( k: int, reranking_function, r: float, -): +) -> dict: try: - collection = CHROMA_CLIENT.get_collection(name=collection_name) - documents = collection.get() # get all documents + result = VECTOR_DB_CLIENT.get(collection_name=collection_name) bm25_retriever = BM25Retriever.from_texts( - texts=documents.get("documents"), - metadatas=documents.get("metadatas"), + texts=result.documents[0], + metadatas=result.metadatas[0], ) bm25_retriever.k = k - chroma_retriever = ChromaRetriever( - collection=collection, + vector_search_retriever = VectorSearchRetriever( + collection_name=collection_name, embedding_function=embedding_function, - top_n=k, + top_k=k, ) ensemble_retriever = EnsembleRetriever( - retrievers=[bm25_retriever, chroma_retriever], weights=[0.5, 0.5] + retrievers=[bm25_retriever, vector_search_retriever], weights=[0.5, 0.5] ) - compressor = RerankCompressor( embedding_function=embedding_function, top_n=k, @@ -92,7 +135,9 @@ def query_doc_with_hybrid_search( raise e -def merge_and_sort_query_results(query_results, k, reverse=False): +def merge_and_sort_query_results( + query_results: list[dict], k: int, reverse: bool = False +) -> list[dict]: # Initialize lists to store combined data combined_distances = [] combined_documents = [] @@ -138,7 +183,7 @@ def query_collection( query: str, embedding_function, k: int, -): +) -> dict: results = [] for collection_name in collection_names: if collection_name: @@ -149,9 +194,9 @@ def query_collection( k=k, embedding_function=embedding_function, ) - results.append(result) - except Exception: - pass + results.append(result.model_dump()) + except Exception as e: + log.exception(f"Error when querying the collection: {e}") else: pass @@ -165,8 +210,9 @@ def query_collection_with_hybrid_search( k: int, reranking_function, r: float, -): +) -> dict: results = [] + error = False for collection_name in collection_names: try: result = query_doc_with_hybrid_search( @@ -178,14 +224,39 @@ def query_collection_with_hybrid_search( r=r, ) results.append(result) - except Exception: - pass + except Exception as e: + log.exception( + "Error when querying the collection with " f"hybrid_search: {e}" + ) + error = True + + if error: + raise Exception( + "Hybrid search failed for all collections. Using Non hybrid search as fallback." + ) + return merge_and_sort_query_results(results, k=k, reverse=True) def rag_template(template: str, context: str, query: str): - template = template.replace("[context]", context) - template = template.replace("[query]", query) + count = template.count("[context]") + assert "[context]" in template, "RAG template does not contain '[context]'" + + if "" in context and "" in context: + log.debug( + "WARNING: Potential prompt injection attack: the RAG " + "context contains '' and ''. This might be " + "nothing, or the user might be trying to hack something." + ) + + if "[query]" in context: + query_placeholder = f"[query-{str(uuid.uuid4())}]" + template = template.replace("[query]", query_placeholder) + template = template.replace("[context]", context) + template = template.replace(query_placeholder, query) + else: + template = template.replace("[context]", context) + template = template.replace("[query]", query) return template @@ -262,19 +333,27 @@ def get_rag_context( continue try: + context = None if file["type"] == "text": context = file["content"] else: if hybrid_search: - context = query_collection_with_hybrid_search( - collection_names=collection_names, - query=query, - embedding_function=embedding_function, - k=k, - reranking_function=reranking_function, - r=r, - ) - else: + try: + context = query_collection_with_hybrid_search( + collection_names=collection_names, + query=query, + embedding_function=embedding_function, + k=k, + reranking_function=reranking_function, + r=r, + ) + except Exception as e: + log.debug( + "Error when using hybrid search, using" + " non hybrid search as fallback." + ) + + if (not hybrid_search) or (context is None): context = query_collection( collection_names=collection_names, query=query, @@ -283,7 +362,6 @@ def get_rag_context( ) except Exception as e: log.exception(e) - context = None if context: relevant_contexts.append({**context, "source": file}) @@ -391,51 +469,11 @@ def generate_openai_batch_embeddings( return None -from typing import Any - -from langchain_core.callbacks import CallbackManagerForRetrieverRun -from langchain_core.retrievers import BaseRetriever - - -class ChromaRetriever(BaseRetriever): - collection: Any - embedding_function: Any - top_n: int - - def _get_relevant_documents( - self, - query: str, - *, - run_manager: CallbackManagerForRetrieverRun, - ) -> list[Document]: - query_embeddings = self.embedding_function(query) - - results = self.collection.query( - query_embeddings=[query_embeddings], - n_results=self.top_n, - ) - - ids = results["ids"][0] - metadatas = results["metadatas"][0] - documents = results["documents"][0] - - results = [] - for idx in range(len(ids)): - results.append( - Document( - metadata=metadatas[idx], - page_content=documents[idx], - ) - ) - return results - - import operator from typing import Optional, Sequence from langchain_core.callbacks import Callbacks from langchain_core.documents import BaseDocumentCompressor, Document -from langchain_core.pydantic_v1 import Extra class RerankCompressor(BaseDocumentCompressor): @@ -445,7 +483,7 @@ class RerankCompressor(BaseDocumentCompressor): r_score: float class Config: - extra = Extra.forbid + extra = "forbid" arbitrary_types_allowed = True def compress_documents( diff --git a/backend/open_webui/apps/rag/vector/connector.py b/backend/open_webui/apps/rag/vector/connector.py new file mode 100644 index 000000000..073becdbe --- /dev/null +++ b/backend/open_webui/apps/rag/vector/connector.py @@ -0,0 +1,10 @@ +from open_webui.apps.rag.vector.dbs.chroma import ChromaClient +from open_webui.apps.rag.vector.dbs.milvus import MilvusClient + + +from open_webui.config import VECTOR_DB + +if VECTOR_DB == "milvus": + VECTOR_DB_CLIENT = MilvusClient() +else: + VECTOR_DB_CLIENT = ChromaClient() diff --git a/backend/open_webui/apps/rag/vector/dbs/chroma.py b/backend/open_webui/apps/rag/vector/dbs/chroma.py new file mode 100644 index 000000000..5f9420108 --- /dev/null +++ b/backend/open_webui/apps/rag/vector/dbs/chroma.py @@ -0,0 +1,122 @@ +import chromadb +from chromadb import Settings +from chromadb.utils.batch_utils import create_batches + +from typing import Optional + +from open_webui.apps.rag.vector.main import VectorItem, SearchResult, GetResult +from open_webui.config import ( + CHROMA_DATA_PATH, + CHROMA_HTTP_HOST, + CHROMA_HTTP_PORT, + CHROMA_HTTP_HEADERS, + CHROMA_HTTP_SSL, + CHROMA_TENANT, + CHROMA_DATABASE, +) + + +class ChromaClient: + def __init__(self): + if CHROMA_HTTP_HOST != "": + self.client = chromadb.HttpClient( + host=CHROMA_HTTP_HOST, + port=CHROMA_HTTP_PORT, + headers=CHROMA_HTTP_HEADERS, + ssl=CHROMA_HTTP_SSL, + tenant=CHROMA_TENANT, + database=CHROMA_DATABASE, + settings=Settings(allow_reset=True, anonymized_telemetry=False), + ) + else: + self.client = chromadb.PersistentClient( + path=CHROMA_DATA_PATH, + settings=Settings(allow_reset=True, anonymized_telemetry=False), + tenant=CHROMA_TENANT, + database=CHROMA_DATABASE, + ) + + def has_collection(self, collection_name: str) -> bool: + # Check if the collection exists based on the collection name. + collections = self.client.list_collections() + return collection_name in [collection.name for collection in collections] + + def delete_collection(self, collection_name: str): + # Delete the collection based on the collection name. + return self.client.delete_collection(name=collection_name) + + def search( + self, collection_name: str, vectors: list[list[float | int]], limit: int + ) -> Optional[SearchResult]: + # Search for the nearest neighbor items based on the vectors and return 'limit' number of results. + collection = self.client.get_collection(name=collection_name) + if collection: + result = collection.query( + query_embeddings=vectors, + n_results=limit, + ) + + return SearchResult( + **{ + "ids": result["ids"], + "distances": result["distances"], + "documents": result["documents"], + "metadatas": result["metadatas"], + } + ) + return None + + def get(self, collection_name: str) -> Optional[GetResult]: + # Get all the items in the collection. + collection = self.client.get_collection(name=collection_name) + if collection: + result = collection.get() + return GetResult( + **{ + "ids": [result["ids"]], + "documents": [result["documents"]], + "metadatas": [result["metadatas"]], + } + ) + return None + + def insert(self, collection_name: str, items: list[VectorItem]): + # Insert the items into the collection, if the collection does not exist, it will be created. + collection = self.client.get_or_create_collection(name=collection_name) + + ids = [item["id"] for item in items] + documents = [item["text"] for item in items] + embeddings = [item["vector"] for item in items] + metadatas = [item["metadata"] for item in items] + + for batch in create_batches( + api=self.client, + documents=documents, + embeddings=embeddings, + ids=ids, + metadatas=metadatas, + ): + collection.add(*batch) + + def upsert(self, collection_name: str, items: list[VectorItem]): + # Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created. + collection = self.client.get_or_create_collection(name=collection_name) + + ids = [item["id"] for item in items] + documents = [item["text"] for item in items] + embeddings = [item["vector"] for item in items] + metadatas = [item["metadata"] for item in items] + + collection.upsert( + ids=ids, documents=documents, embeddings=embeddings, metadatas=metadatas + ) + + def delete(self, collection_name: str, ids: list[str]): + # Delete the items from the collection based on the ids. + collection = self.client.get_collection(name=collection_name) + if collection: + collection.delete(ids=ids) + + def reset(self): + # Resets the database. This will delete all collections and item entries. + return self.client.reset() diff --git a/backend/open_webui/apps/rag/vector/dbs/milvus.py b/backend/open_webui/apps/rag/vector/dbs/milvus.py new file mode 100644 index 000000000..f205b9521 --- /dev/null +++ b/backend/open_webui/apps/rag/vector/dbs/milvus.py @@ -0,0 +1,205 @@ +from pymilvus import MilvusClient as Client +from pymilvus import FieldSchema, DataType +import json + +from typing import Optional + +from open_webui.apps.rag.vector.main import VectorItem, SearchResult, GetResult +from open_webui.config import ( + MILVUS_URI, +) + + +class MilvusClient: + def __init__(self): + self.collection_prefix = "open_webui" + self.client = Client(uri=MILVUS_URI) + + def _result_to_get_result(self, result) -> GetResult: + print(result) + + ids = [] + documents = [] + metadatas = [] + + for match in result: + _ids = [] + _documents = [] + _metadatas = [] + + for item in match: + _ids.append(item.get("id")) + _documents.append(item.get("data", {}).get("text")) + _metadatas.append(item.get("metadata")) + + ids.append(_ids) + documents.append(_documents) + metadatas.append(_metadatas) + + return GetResult( + **{ + "ids": ids, + "documents": documents, + "metadatas": metadatas, + } + ) + + def _result_to_search_result(self, result) -> SearchResult: + print(result) + + ids = [] + distances = [] + documents = [] + metadatas = [] + + for match in result: + _ids = [] + _distances = [] + _documents = [] + _metadatas = [] + + for item in match: + _ids.append(item.get("id")) + _distances.append(item.get("distance")) + _documents.append(item.get("entity", {}).get("data", {}).get("text")) + _metadatas.append(item.get("entity", {}).get("metadata")) + + ids.append(_ids) + distances.append(_distances) + documents.append(_documents) + metadatas.append(_metadatas) + + return SearchResult( + **{ + "ids": ids, + "distances": distances, + "documents": documents, + "metadatas": metadatas, + } + ) + + def _create_collection(self, collection_name: str, dimension: int): + schema = self.client.create_schema( + auto_id=False, + enable_dynamic_field=True, + ) + schema.add_field( + field_name="id", + datatype=DataType.VARCHAR, + is_primary=True, + max_length=65535, + ) + schema.add_field( + field_name="vector", + datatype=DataType.FLOAT_VECTOR, + dim=dimension, + description="vector", + ) + schema.add_field(field_name="data", datatype=DataType.JSON, description="data") + schema.add_field( + field_name="metadata", datatype=DataType.JSON, description="metadata" + ) + + index_params = self.client.prepare_index_params() + index_params.add_index( + field_name="vector", index_type="HNSW", metric_type="COSINE", params={} + ) + + self.client.create_collection( + collection_name=f"{self.collection_prefix}_{collection_name}", + schema=schema, + index_params=index_params, + ) + + def has_collection(self, collection_name: str) -> bool: + # Check if the collection exists based on the collection name. + return self.client.has_collection( + collection_name=f"{self.collection_prefix}_{collection_name}" + ) + + def delete_collection(self, collection_name: str): + # Delete the collection based on the collection name. + return self.client.drop_collection( + collection_name=f"{self.collection_prefix}_{collection_name}" + ) + + def search( + self, collection_name: str, vectors: list[list[float | int]], limit: int + ) -> Optional[SearchResult]: + # Search for the nearest neighbor items based on the vectors and return 'limit' number of results. + result = self.client.search( + collection_name=f"{self.collection_prefix}_{collection_name}", + data=vectors, + limit=limit, + output_fields=["data", "metadata"], + ) + + return self._result_to_search_result(result) + + def get(self, collection_name: str) -> Optional[GetResult]: + # Get all the items in the collection. + result = self.client.query( + collection_name=f"{self.collection_prefix}_{collection_name}", + filter='id != ""', + ) + return self._result_to_get_result([result]) + + def insert(self, collection_name: str, items: list[VectorItem]): + # Insert the items into the collection, if the collection does not exist, it will be created. + if not self.client.has_collection( + collection_name=f"{self.collection_prefix}_{collection_name}" + ): + self._create_collection( + collection_name=collection_name, dimension=len(items[0]["vector"]) + ) + + return self.client.insert( + collection_name=f"{self.collection_prefix}_{collection_name}", + data=[ + { + "id": item["id"], + "vector": item["vector"], + "data": {"text": item["text"]}, + "metadata": item["metadata"], + } + for item in items + ], + ) + + def upsert(self, collection_name: str, items: list[VectorItem]): + # Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created. + if not self.client.has_collection( + collection_name=f"{self.collection_prefix}_{collection_name}" + ): + self._create_collection( + collection_name=collection_name, dimension=len(items[0]["vector"]) + ) + + return self.client.upsert( + collection_name=f"{self.collection_prefix}_{collection_name}", + data=[ + { + "id": item["id"], + "vector": item["vector"], + "data": {"text": item["text"]}, + "metadata": item["metadata"], + } + for item in items + ], + ) + + def delete(self, collection_name: str, ids: list[str]): + # Delete the items from the collection based on the ids. + + return self.client.delete( + collection_name=f"{self.collection_prefix}_{collection_name}", + ids=ids, + ) + + def reset(self): + # Resets the database. This will delete all collections and item entries. + + collection_names = self.client.list_collections() + for collection_name in collection_names: + if collection_name.startswith(self.collection_prefix): + self.client.drop_collection(collection_name=collection_name) diff --git a/backend/open_webui/apps/rag/vector/main.py b/backend/open_webui/apps/rag/vector/main.py new file mode 100644 index 000000000..f0cf0c038 --- /dev/null +++ b/backend/open_webui/apps/rag/vector/main.py @@ -0,0 +1,19 @@ +from pydantic import BaseModel +from typing import Optional, List, Any + + +class VectorItem(BaseModel): + id: str + text: str + vector: List[float | int] + metadata: Any + + +class GetResult(BaseModel): + ids: Optional[List[List[str]]] + documents: Optional[List[List[str]]] + metadatas: Optional[List[List[Any]]] + + +class SearchResult(GetResult): + distances: Optional[List[List[float | int]]] diff --git a/backend/open_webui/apps/socket/main.py b/backend/open_webui/apps/socket/main.py index 5985bc524..e41ef8412 100644 --- a/backend/open_webui/apps/socket/main.py +++ b/backend/open_webui/apps/socket/main.py @@ -2,9 +2,16 @@ import asyncio import socketio from open_webui.apps.webui.models.users import Users +from open_webui.env import ENABLE_WEBSOCKET_SUPPORT from open_webui.utils.utils import decode_token -sio = socketio.AsyncServer(cors_allowed_origins=[], async_mode="asgi") +sio = socketio.AsyncServer( + cors_allowed_origins=[], + async_mode="asgi", + transports=(["polling", "websocket"] if ENABLE_WEBSOCKET_SUPPORT else ["polling"]), + allow_upgrades=ENABLE_WEBSOCKET_SUPPORT, + always_connect=True, +) app = socketio.ASGIApp(sio, socketio_path="/ws/socket.io") # Dictionary to maintain the user pool @@ -32,7 +39,7 @@ async def connect(sid, environ, auth): else: USER_POOL[user.id] = [sid] - print(f"user {user.name}({user.id}) connected with session ID {sid}") + # print(f"user {user.name}({user.id}) connected with session ID {sid}") await sio.emit("user-count", {"count": len(set(USER_POOL))}) await sio.emit("usage", {"models": get_models_in_use()}) @@ -40,7 +47,7 @@ async def connect(sid, environ, auth): @sio.on("user-join") async def user_join(sid, data): - print("user-join", sid, data) + # print("user-join", sid, data) auth = data["auth"] if "auth" in data else None if not auth or "token" not in auth: @@ -60,7 +67,7 @@ async def user_join(sid, data): else: USER_POOL[user.id] = [sid] - print(f"user {user.name}({user.id}) connected with session ID {sid}") + # print(f"user {user.name}({user.id}) connected with session ID {sid}") await sio.emit("user-count", {"count": len(set(USER_POOL))}) @@ -109,7 +116,7 @@ async def remove_after_timeout(sid, model_id): try: await asyncio.sleep(TIMEOUT_DURATION) if model_id in USAGE_POOL: - print(USAGE_POOL[model_id]["sids"]) + # print(USAGE_POOL[model_id]["sids"]) USAGE_POOL[model_id]["sids"].remove(sid) USAGE_POOL[model_id]["sids"] = list(set(USAGE_POOL[model_id]["sids"])) @@ -136,7 +143,8 @@ async def disconnect(sid): await sio.emit("user-count", {"count": len(USER_POOL)}) else: - print(f"Unknown session ID {sid} disconnected") + pass + # print(f"Unknown session ID {sid} disconnected") def get_event_emitter(request_info): diff --git a/backend/open_webui/apps/webui/routers/auths.py b/backend/open_webui/apps/webui/routers/auths.py index 2366841e1..bfa460836 100644 --- a/backend/open_webui/apps/webui/routers/auths.py +++ b/backend/open_webui/apps/webui/routers/auths.py @@ -190,8 +190,8 @@ async def signin(request: Request, response: Response, form_data: SigninForm): async def signup(request: Request, response: Response, form_data: SignupForm): if ( not request.app.state.config.ENABLE_SIGNUP - and request.app.state.config.ENABLE_LOGIN_FORM - and WEBUI_AUTH + or not request.app.state.config.ENABLE_LOGIN_FORM + or not WEBUI_AUTH ): raise HTTPException( status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED diff --git a/backend/open_webui/apps/webui/routers/memories.py b/backend/open_webui/apps/webui/routers/memories.py index 914b69e7e..d659833bc 100644 --- a/backend/open_webui/apps/webui/routers/memories.py +++ b/backend/open_webui/apps/webui/routers/memories.py @@ -1,12 +1,13 @@ +from fastapi import APIRouter, Depends, HTTPException, Request +from pydantic import BaseModel import logging from typing import Optional from open_webui.apps.webui.models.memories import Memories, MemoryModel -from open_webui.config import CHROMA_CLIENT -from open_webui.env import SRC_LOG_LEVELS -from fastapi import APIRouter, Depends, HTTPException, Request -from pydantic import BaseModel +from open_webui.apps.rag.vector.connector import VECTOR_DB_CLIENT from open_webui.utils.utils import get_verified_user +from open_webui.env import SRC_LOG_LEVELS + log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MODELS"]) @@ -49,14 +50,17 @@ async def add_memory( user=Depends(get_verified_user), ): memory = Memories.insert_new_memory(user.id, form_data.content) - memory_embedding = request.app.state.EMBEDDING_FUNCTION(memory.content) - collection = CHROMA_CLIENT.get_or_create_collection(name=f"user-memory-{user.id}") - collection.upsert( - documents=[memory.content], - ids=[memory.id], - embeddings=[memory_embedding], - metadatas=[{"created_at": memory.created_at}], + VECTOR_DB_CLIENT.upsert( + collection_name=f"user-memory-{user.id}", + items=[ + { + "id": memory.id, + "text": memory.content, + "vector": request.app.state.EMBEDDING_FUNCTION(memory.content), + "metadata": {"created_at": memory.created_at}, + } + ], ) return memory @@ -76,12 +80,10 @@ class QueryMemoryForm(BaseModel): async def query_memory( request: Request, form_data: QueryMemoryForm, user=Depends(get_verified_user) ): - query_embedding = request.app.state.EMBEDDING_FUNCTION(form_data.content) - collection = CHROMA_CLIENT.get_or_create_collection(name=f"user-memory-{user.id}") - - results = collection.query( - query_embeddings=[query_embedding], - n_results=form_data.k, # how many results to return + results = VECTOR_DB_CLIENT.search( + collection_name=f"user-memory-{user.id}", + vectors=[request.app.state.EMBEDDING_FUNCTION(form_data.content)], + limit=form_data.k, ) return results @@ -94,17 +96,25 @@ async def query_memory( async def reset_memory_from_vector_db( request: Request, user=Depends(get_verified_user) ): - CHROMA_CLIENT.delete_collection(f"user-memory-{user.id}") - collection = CHROMA_CLIENT.get_or_create_collection(name=f"user-memory-{user.id}") + VECTOR_DB_CLIENT.delete_collection(f"user-memory-{user.id}") memories = Memories.get_memories_by_user_id(user.id) - for memory in memories: - memory_embedding = request.app.state.EMBEDDING_FUNCTION(memory.content) - collection.upsert( - documents=[memory.content], - ids=[memory.id], - embeddings=[memory_embedding], - ) + VECTOR_DB_CLIENT.upsert( + collection_name=f"user-memory-{user.id}", + items=[ + { + "id": memory.id, + "text": memory.content, + "vector": request.app.state.EMBEDDING_FUNCTION(memory.content), + "metadata": { + "created_at": memory.created_at, + "updated_at": memory.updated_at, + }, + } + for memory in memories + ], + ) + return True @@ -119,7 +129,7 @@ async def delete_memory_by_user_id(user=Depends(get_verified_user)): if result: try: - CHROMA_CLIENT.delete_collection(f"user-memory-{user.id}") + VECTOR_DB_CLIENT.delete_collection(f"user-memory-{user.id}") except Exception as e: log.error(e) return True @@ -144,16 +154,18 @@ async def update_memory_by_id( raise HTTPException(status_code=404, detail="Memory not found") if form_data.content is not None: - memory_embedding = request.app.state.EMBEDDING_FUNCTION(form_data.content) - collection = CHROMA_CLIENT.get_or_create_collection( - name=f"user-memory-{user.id}" - ) - collection.upsert( - documents=[form_data.content], - ids=[memory.id], - embeddings=[memory_embedding], - metadatas=[ - {"created_at": memory.created_at, "updated_at": memory.updated_at} + VECTOR_DB_CLIENT.upsert( + collection_name=f"user-memory-{user.id}", + items=[ + { + "id": memory.id, + "text": memory.content, + "vector": request.app.state.EMBEDDING_FUNCTION(memory.content), + "metadata": { + "created_at": memory.created_at, + "updated_at": memory.updated_at, + }, + } ], ) @@ -170,10 +182,9 @@ async def delete_memory_by_id(memory_id: str, user=Depends(get_verified_user)): result = Memories.delete_memory_by_id_and_user_id(memory_id, user.id) if result: - collection = CHROMA_CLIENT.get_or_create_collection( - name=f"user-memory-{user.id}" + VECTOR_DB_CLIENT.delete( + collection_name=f"user-memory-{user.id}", ids=[memory_id] ) - collection.delete(ids=[memory_id]) return True return False diff --git a/backend/open_webui/apps/webui/routers/models.py b/backend/open_webui/apps/webui/routers/models.py index a99c65d76..a5cb2395e 100644 --- a/backend/open_webui/apps/webui/routers/models.py +++ b/backend/open_webui/apps/webui/routers/models.py @@ -18,8 +18,18 @@ router = APIRouter() @router.get("/", response_model=list[ModelResponse]) -async def get_models(user=Depends(get_verified_user)): - return Models.get_all_models() +async def get_models(id: Optional[str] = None, user=Depends(get_verified_user)): + if id: + model = Models.get_model_by_id(id) + if model: + return [model] + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + else: + return Models.get_all_models() ############################ @@ -50,24 +60,6 @@ async def add_new_model( ) -############################ -# GetModelById -############################ - - -@router.get("/", response_model=Optional[ModelModel]) -async def get_model_by_id(id: str, user=Depends(get_verified_user)): - model = Models.get_model_by_id(id) - - if model: - return model - else: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail=ERROR_MESSAGES.NOT_FOUND, - ) - - ############################ # UpdateModelById ############################ diff --git a/backend/open_webui/apps/webui/utils.py b/backend/open_webui/apps/webui/utils.py index 2d537af51..8bf48e400 100644 --- a/backend/open_webui/apps/webui/utils.py +++ b/backend/open_webui/apps/webui/utils.py @@ -4,7 +4,7 @@ import subprocess import sys from importlib import util import types - +import tempfile from open_webui.apps.webui.models.functions import Functions from open_webui.apps.webui.models.tools import Tools @@ -84,7 +84,15 @@ def load_toolkit_module_by_id(toolkit_id, content=None): module = types.ModuleType(module_name) sys.modules[module_name] = module + # Create a temporary file and use it to define `__file__` so + # that it works as expected from the module's perspective. + temp_file = tempfile.NamedTemporaryFile(delete=False) + try: + with open(temp_file.name, "w", encoding="utf-8") as f: + f.write(content) + module.__dict__["__file__"] = temp_file.name + # Executing the modified content in the created module's namespace exec(content, module.__dict__) frontmatter = extract_frontmatter(content) @@ -96,9 +104,11 @@ def load_toolkit_module_by_id(toolkit_id, content=None): else: raise Exception("No Tools class found in the module") except Exception as e: - print(f"Error loading module: {toolkit_id}") + print(f"Error loading module: {toolkit_id}: {e}") del sys.modules[module_name] # Clean up raise e + finally: + os.unlink(temp_file.name) def load_function_module_by_id(function_id, content=None): @@ -118,7 +128,14 @@ def load_function_module_by_id(function_id, content=None): module = types.ModuleType(module_name) sys.modules[module_name] = module + # Create a temporary file and use it to define `__file__` so + # that it works as expected from the module's perspective. + temp_file = tempfile.NamedTemporaryFile(delete=False) try: + with open(temp_file.name, "w", encoding="utf-8") as f: + f.write(content) + module.__dict__["__file__"] = temp_file.name + # Execute the modified content in the created module's namespace exec(content, module.__dict__) frontmatter = extract_frontmatter(content) @@ -134,11 +151,13 @@ def load_function_module_by_id(function_id, content=None): else: raise Exception("No Function class found in the module") except Exception as e: - print(f"Error loading module: {function_id}") + print(f"Error loading module: {function_id}: {e}") del sys.modules[module_name] # Cleanup by removing the module in case of error Functions.update_function_by_id(function_id, {"is_active": False}) raise e + finally: + os.unlink(temp_file.name) def install_frontmatter_requirements(requirements): diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index 5ccb40d47..86d8a47a3 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -11,7 +11,6 @@ import chromadb import requests import yaml from open_webui.apps.webui.internal.db import Base, get_db -from chromadb import Settings from open_webui.env import ( OPEN_WEBUI_DIR, DATA_DIR, @@ -540,40 +539,6 @@ Path(TOOLS_DIR).mkdir(parents=True, exist_ok=True) FUNCTIONS_DIR = os.getenv("FUNCTIONS_DIR", f"{DATA_DIR}/functions") Path(FUNCTIONS_DIR).mkdir(parents=True, exist_ok=True) - -#################################### -# LITELLM_CONFIG -#################################### - - -def create_config_file(file_path): - directory = os.path.dirname(file_path) - - # Check if directory exists, if not, create it - if not os.path.exists(directory): - os.makedirs(directory) - - # Data to write into the YAML file - config_data = { - "general_settings": {}, - "litellm_settings": {}, - "model_list": [], - "router_settings": {}, - } - - # Write data to YAML file - with open(file_path, "w") as file: - yaml.dump(config_data, file) - - -LITELLM_CONFIG_PATH = f"{DATA_DIR}/litellm/config.yaml" - -# if not os.path.exists(LITELLM_CONFIG_PATH): -# log.info("Config file doesn't exist. Creating...") -# create_config_file(LITELLM_CONFIG_PATH) -# log.info("Config file created successfully.") - - #################################### # OLLAMA_BASE_URL #################################### @@ -923,25 +888,12 @@ TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = PersistentConfig( #################################### -# RAG document content extraction +# Vector Database #################################### -CONTENT_EXTRACTION_ENGINE = PersistentConfig( - "CONTENT_EXTRACTION_ENGINE", - "rag.CONTENT_EXTRACTION_ENGINE", - os.environ.get("CONTENT_EXTRACTION_ENGINE", "").lower(), -) - -TIKA_SERVER_URL = PersistentConfig( - "TIKA_SERVER_URL", - "rag.tika_server_url", - os.getenv("TIKA_SERVER_URL", "http://tika:9998"), # Default for sidecar deployment -) - -#################################### -# RAG -#################################### +VECTOR_DB = os.environ.get("VECTOR_DB", "chroma") +# Chroma CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db" CHROMA_TENANT = os.environ.get("CHROMA_TENANT", chromadb.DEFAULT_TENANT) CHROMA_DATABASE = os.environ.get("CHROMA_DATABASE", chromadb.DEFAULT_DATABASE) @@ -958,8 +910,29 @@ else: CHROMA_HTTP_SSL = os.environ.get("CHROMA_HTTP_SSL", "false").lower() == "true" # this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (sentence-transformers/all-MiniLM-L6-v2) +# Milvus + +MILVUS_URI = os.environ.get("MILVUS_URI", f"{DATA_DIR}/vector_db/milvus.db") + +#################################### +# RAG +#################################### + +# RAG Content Extraction +CONTENT_EXTRACTION_ENGINE = PersistentConfig( + "CONTENT_EXTRACTION_ENGINE", + "rag.CONTENT_EXTRACTION_ENGINE", + os.environ.get("CONTENT_EXTRACTION_ENGINE", "").lower(), +) + +TIKA_SERVER_URL = PersistentConfig( + "TIKA_SERVER_URL", + "rag.tika_server_url", + os.getenv("TIKA_SERVER_URL", "http://tika:9998"), # Default for sidecar deployment +) + RAG_TOP_K = PersistentConfig( - "RAG_TOP_K", "rag.top_k", int(os.environ.get("RAG_TOP_K", "5")) + "RAG_TOP_K", "rag.top_k", int(os.environ.get("RAG_TOP_K", "3")) ) RAG_RELEVANCE_THRESHOLD = PersistentConfig( "RAG_RELEVANCE_THRESHOLD", @@ -1048,36 +1021,8 @@ RAG_RERANKING_MODEL_TRUST_REMOTE_CODE = ( os.environ.get("RAG_RERANKING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true" ) - -if CHROMA_HTTP_HOST != "": - CHROMA_CLIENT = chromadb.HttpClient( - host=CHROMA_HTTP_HOST, - port=CHROMA_HTTP_PORT, - headers=CHROMA_HTTP_HEADERS, - ssl=CHROMA_HTTP_SSL, - tenant=CHROMA_TENANT, - database=CHROMA_DATABASE, - settings=Settings(allow_reset=True, anonymized_telemetry=False), - ) -else: - CHROMA_CLIENT = chromadb.PersistentClient( - path=CHROMA_DATA_PATH, - settings=Settings(allow_reset=True, anonymized_telemetry=False), - tenant=CHROMA_TENANT, - database=CHROMA_DATABASE, - ) - - -# device type embedding models - "cpu" (default), "cuda" (nvidia gpu required) or "mps" (apple silicon) - choosing this right can lead to better performance -USE_CUDA = os.environ.get("USE_CUDA_DOCKER", "false") - -if USE_CUDA.lower() == "true": - DEVICE_TYPE = "cuda" -else: - DEVICE_TYPE = "cpu" - CHUNK_SIZE = PersistentConfig( - "CHUNK_SIZE", "rag.chunk_size", int(os.environ.get("CHUNK_SIZE", "1500")) + "CHUNK_SIZE", "rag.chunk_size", int(os.environ.get("CHUNK_SIZE", "1000")) ) CHUNK_OVERLAP = PersistentConfig( "CHUNK_OVERLAP", @@ -1085,19 +1030,25 @@ CHUNK_OVERLAP = PersistentConfig( int(os.environ.get("CHUNK_OVERLAP", "100")), ) -DEFAULT_RAG_TEMPLATE = """Use the following context as your learned knowledge, inside XML tags. +DEFAULT_RAG_TEMPLATE = """You are given a user query, some textual context and rules, all inside xml tags. You have to answer the query based on the context while respecting the rules. + - [context] +[context] -When answer to user: -- If you don't know, just say that you don't know. -- If you don't know when you are not sure, ask for clarification. -Avoid mentioning that you obtained the information from the context. -And answer according to the language of the user's question. + +- If you don't know, just say so. +- If you are not sure, ask for clarification. +- Answer in the same language as the user query. +- If the context appears unreadable or of poor quality, tell the user then answer as best as you can. +- If the answer is not in the context but you think you know the answer, explain that to the user then answer with your own knowledge. +- Answer directly and without using xml tags. + -Given the context information, answer the query. -Query: [query]""" + +[query] + +""" RAG_TEMPLATE = PersistentConfig( "RAG_TEMPLATE", @@ -1267,6 +1218,37 @@ AUTOMATIC1111_API_AUTH = PersistentConfig( os.getenv("AUTOMATIC1111_API_AUTH", ""), ) +AUTOMATIC1111_CFG_SCALE = PersistentConfig( + "AUTOMATIC1111_CFG_SCALE", + "image_generation.automatic1111.cfg_scale", + ( + float(os.environ.get("AUTOMATIC1111_CFG_SCALE")) + if os.environ.get("AUTOMATIC1111_CFG_SCALE") + else None + ), +) + + +AUTOMATIC1111_SAMPLER = PersistentConfig( + "AUTOMATIC1111_SAMPLERE", + "image_generation.automatic1111.sampler", + ( + os.environ.get("AUTOMATIC1111_SAMPLER") + if os.environ.get("AUTOMATIC1111_SAMPLER") + else None + ), +) + +AUTOMATIC1111_SCHEDULER = PersistentConfig( + "AUTOMATIC1111_SCHEDULER", + "image_generation.automatic1111.scheduler", + ( + os.environ.get("AUTOMATIC1111_SCHEDULER") + if os.environ.get("AUTOMATIC1111_SCHEDULER") + else None + ), +) + COMFYUI_BASE_URL = PersistentConfig( "COMFYUI_BASE_URL", "image_generation.comfyui.base_url", @@ -1490,3 +1472,17 @@ AUDIO_TTS_SPLIT_ON = PersistentConfig( "audio.tts.split_on", os.getenv("AUDIO_TTS_SPLIT_ON", "punctuation"), ) + +AUDIO_TTS_AZURE_SPEECH_REGION = PersistentConfig( + "AUDIO_TTS_AZURE_SPEECH_REGION", + "audio.tts.azure.speech_region", + os.getenv("AUDIO_TTS_AZURE_SPEECH_REGION", "eastus"), +) + +AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT = PersistentConfig( + "AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT", + "audio.tts.azure.speech_output_format", + os.getenv( + "AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT", "audio-24khz-160kbitrate-mono-mp3" + ), +) diff --git a/backend/open_webui/env.py b/backend/open_webui/env.py index b716769c2..89422e57b 100644 --- a/backend/open_webui/env.py +++ b/backend/open_webui/env.py @@ -31,6 +31,28 @@ try: except ImportError: print("dotenv not installed, skipping...") +DOCKER = os.environ.get("DOCKER", "False").lower() == "true" + +# device type embedding models - "cpu" (default), "cuda" (nvidia gpu required) or "mps" (apple silicon) - choosing this right can lead to better performance +USE_CUDA = os.environ.get("USE_CUDA_DOCKER", "false") + +if USE_CUDA.lower() == "true": + try: + import torch + + assert torch.cuda.is_available(), "CUDA not available" + DEVICE_TYPE = "cuda" + except Exception as e: + cuda_error = ( + "Error when testing CUDA but USE_CUDA_DOCKER is true. " + f"Resetting USE_CUDA_DOCKER to false: {e}" + ) + os.environ["USE_CUDA_DOCKER"] = "false" + USE_CUDA = "false" + DEVICE_TYPE = "cpu" +else: + DEVICE_TYPE = "cpu" + #################################### # LOGGING @@ -47,6 +69,9 @@ else: log = logging.getLogger(__name__) log.info(f"GLOBAL_LOG_LEVEL: {GLOBAL_LOG_LEVEL}") +if "cuda_error" in locals(): + log.exception(cuda_error) + log_sources = [ "AUDIO", "COMFYUI", @@ -273,3 +298,7 @@ WEBUI_SESSION_COOKIE_SECURE = os.environ.get( if WEBUI_AUTH and WEBUI_SECRET_KEY == "": raise ValueError(ERROR_MESSAGES.ENV_VAR_NOT_FOUND) + +ENABLE_WEBSOCKET_SUPPORT = ( + os.environ.get("ENABLE_WEBSOCKET_SUPPORT", "True").lower() == "true" +) diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 8914cb491..319d95165 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -109,6 +109,7 @@ from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.sessions import SessionMiddleware from starlette.responses import RedirectResponse, Response, StreamingResponse +from open_webui.utils.security_headers import SecurityHeadersMiddleware from open_webui.utils.misc import ( add_or_update_system_message, @@ -586,8 +587,17 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): if len(contexts) > 0: context_string = "/n".join(contexts).strip() prompt = get_last_user_message(body["messages"]) + if prompt is None: raise Exception("No user message found") + if ( + rag_app.state.config.RELEVANCE_THRESHOLD == 0 + and context_string.strip() == "" + ): + log.debug( + f"With a 0 relevancy threshold for RAG, the context cannot be empty" + ) + # Workaround for Ollama 2.0+ system prompt issue # TODO: replace with add_or_update_system_message if model["owned_by"] == "ollama": @@ -780,6 +790,8 @@ app.add_middleware( allow_headers=["*"], ) +app.add_middleware(SecurityHeadersMiddleware) + @app.middleware("http") async def commit_session_after_request(request: Request, call_next): @@ -812,6 +824,24 @@ async def update_embedding_function(request: Request, call_next): return response +@app.middleware("http") +async def inspect_websocket(request: Request, call_next): + if ( + "/ws/socket.io" in request.url.path + and request.query_params.get("transport") == "websocket" + ): + upgrade = (request.headers.get("Upgrade") or "").lower() + connection = (request.headers.get("Connection") or "").lower().split(",") + # Check that there's the correct headers for an upgrade, else reject the connection + # This is to work around this upstream issue: https://github.com/miguelgrinberg/python-engineio/issues/367 + if upgrade != "websocket" or "upgrade" not in connection: + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"detail": "Invalid WebSocket upgrade request"}, + ) + return await call_next(request) + + app.mount("/ws", socket_app) app.mount("/ollama", ollama_app) @@ -1368,9 +1398,9 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)): # Check if the user has a custom task model # If the user has a custom task model, use that model - model_id = get_task_model_id(model_id) + task_model_id = get_task_model_id(model_id) - print(model_id) + print(task_model_id) if app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE != "": template = app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE @@ -1397,10 +1427,16 @@ Prompt: {{prompt:middletruncate:8000}}""" ) payload = { - "model": model_id, + "model": task_model_id, "messages": [{"role": "user", "content": content}], "stream": False, - "max_tokens": 50, + **( + {"max_tokens": 50} + if app.state.MODELS[task_model_id]["owned_by"] == "ollama" + else { + "max_completion_tokens": 50, + } + ), "chat_id": form_data.get("chat_id", None), "metadata": {"task": str(TASKS.TITLE_GENERATION)}, } @@ -1445,9 +1481,8 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user) # Check if the user has a custom task model # If the user has a custom task model, use that model - model_id = get_task_model_id(model_id) - - print(model_id) + task_model_id = get_task_model_id(model_id) + print(task_model_id) if app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE != "": template = app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE @@ -1469,10 +1504,16 @@ Search Query:""" print("content", content) payload = { - "model": model_id, + "model": task_model_id, "messages": [{"role": "user", "content": content}], "stream": False, - "max_tokens": 30, + **( + {"max_tokens": 30} + if app.state.MODELS[task_model_id]["owned_by"] == "ollama" + else { + "max_completion_tokens": 30, + } + ), "metadata": {"task": str(TASKS.QUERY_GENERATION)}, } @@ -1511,9 +1552,8 @@ async def generate_emoji(form_data: dict, user=Depends(get_verified_user)): # Check if the user has a custom task model # If the user has a custom task model, use that model - model_id = get_task_model_id(model_id) - - print(model_id) + task_model_id = get_task_model_id(model_id) + print(task_model_id) template = ''' Your task is to reflect the speaker's likely facial expression through a fitting emoji. Interpret emotions from the message and reflect their facial expression using fitting, diverse emojis (e.g., 😊, 😢, 😡, 😱). @@ -1531,10 +1571,16 @@ Message: """{{prompt}}""" ) payload = { - "model": model_id, + "model": task_model_id, "messages": [{"role": "user", "content": content}], "stream": False, - "max_tokens": 4, + **( + {"max_tokens": 4} + if app.state.MODELS[task_model_id]["owned_by"] == "ollama" + else { + "max_completion_tokens": 4, + } + ), "chat_id": form_data.get("chat_id", None), "metadata": {"task": str(TASKS.EMOJI_GENERATION)}, } diff --git a/backend/open_webui/utils/payload.py b/backend/open_webui/utils/payload.py index 227cca45f..b2654cd25 100644 --- a/backend/open_webui/utils/payload.py +++ b/backend/open_webui/utils/payload.py @@ -44,9 +44,9 @@ def apply_model_params_to_body( def apply_model_params_to_body_openai(params: dict, form_data: dict) -> dict: mappings = { "temperature": float, - "top_p": int, + "top_p": float, "max_tokens": int, - "frequency_penalty": int, + "frequency_penalty": float, "seed": lambda x: x, "stop": lambda x: [bytes(s, "utf-8").decode("unicode_escape") for s in x], } diff --git a/backend/open_webui/utils/security_headers.py b/backend/open_webui/utils/security_headers.py new file mode 100644 index 000000000..69a464814 --- /dev/null +++ b/backend/open_webui/utils/security_headers.py @@ -0,0 +1,115 @@ +import re +import os + +from fastapi import Request +from starlette.middleware.base import BaseHTTPMiddleware +from typing import Dict + + +class SecurityHeadersMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request: Request, call_next): + response = await call_next(request) + response.headers.update(set_security_headers()) + return response + + +def set_security_headers() -> Dict[str, str]: + """ + Sets security headers based on environment variables. + + This function reads specific environment variables and uses their values + to set corresponding security headers. The headers that can be set are: + - cache-control + - strict-transport-security + - referrer-policy + - x-content-type-options + - x-download-options + - x-frame-options + - x-permitted-cross-domain-policies + + Each environment variable is associated with a specific setter function + that constructs the header. If the environment variable is set, the + corresponding header is added to the options dictionary. + + Returns: + dict: A dictionary containing the security headers and their values. + """ + options = {} + header_setters = { + "CACHE_CONTROL": set_cache_control, + "HSTS": set_hsts, + "REFERRER_POLICY": set_referrer, + "XCONTENT_TYPE": set_xcontent_type, + "XDOWNLOAD_OPTIONS": set_xdownload_options, + "XFRAME_OPTIONS": set_xframe, + "XPERMITTED_CROSS_DOMAIN_POLICIES": set_xpermitted_cross_domain_policies, + } + + for env_var, setter in header_setters.items(): + value = os.environ.get(env_var, None) + if value: + header = setter(value) + if header: + options.update(header) + + return options + + +# Set HTTP Strict Transport Security(HSTS) response header +def set_hsts(value: str): + pattern = r"^max-age=(\d+)(;includeSubDomains)?(;preload)?$" + match = re.match(pattern, value, re.IGNORECASE) + if not match: + return "max-age=31536000;includeSubDomains" + return {"Strict-Transport-Security": value} + + +# Set X-Frame-Options response header +def set_xframe(value: str): + pattern = r"^(DENY|SAMEORIGIN)$" + match = re.match(pattern, value, re.IGNORECASE) + if not match: + value = "DENY" + return {"X-Frame-Options": value} + + +# Set Referrer-Policy response header +def set_referrer(value: str): + pattern = r"^(no-referrer|no-referrer-when-downgrade|origin|origin-when-cross-origin|same-origin|strict-origin|strict-origin-when-cross-origin|unsafe-url)$" + match = re.match(pattern, value, re.IGNORECASE) + if not match: + value = "no-referrer" + return {"Referrer-Policy": value} + + +# Set Cache-Control response header +def set_cache_control(value: str): + pattern = r"^(public|private|no-cache|no-store|must-revalidate|proxy-revalidate|max-age=\d+|s-maxage=\d+|no-transform|immutable)(,\s*(public|private|no-cache|no-store|must-revalidate|proxy-revalidate|max-age=\d+|s-maxage=\d+|no-transform|immutable))*$" + match = re.match(pattern, value, re.IGNORECASE) + if not match: + value = "no-store, max-age=0" + + return {"Cache-Control": value} + + +# Set X-Download-Options response header +def set_xdownload_options(value: str): + if value != "noopen": + value = "noopen" + return {"X-Download-Options": value} + + +# Set X-Content-Type-Options response header +def set_xcontent_type(value: str): + if value != "nosniff": + value = "nosniff" + return {"X-Content-Type-Options": value} + + +# Set X-Permitted-Cross-Domain-Policies response header +def set_xpermitted_cross_domain_policies(value: str): + pattern = r"^(none|master-only|by-content-type|by-ftp-filename)$" + match = re.match(pattern, value, re.IGNORECASE) + if not match: + value = "none" + return {"X-Permitted-Cross-Domain-Policies": value} diff --git a/backend/requirements.txt b/backend/requirements.txt index 93720cc84..2554bb5f8 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -40,7 +40,12 @@ langchain-chroma==0.1.2 fake-useragent==1.5.1 chromadb==0.5.5 +pymilvus==2.4.6 + sentence-transformers==3.0.1 +colbert-ai==0.2.21 +einops==0.8.0 + pypdf==4.3.1 docx2txt==0.8 python-pptx==1.0.0 diff --git a/kubernetes/manifest/base/kustomization.yaml b/kubernetes/manifest/base/kustomization.yaml new file mode 100644 index 000000000..61500f87c --- /dev/null +++ b/kubernetes/manifest/base/kustomization.yaml @@ -0,0 +1,8 @@ +resources: + - open-webui.yaml + - ollama-service.yaml + - ollama-statefulset.yaml + - webui-deployment.yaml + - webui-service.yaml + - webui-ingress.yaml + - webui-pvc.yaml diff --git a/kubernetes/manifest/gpu/kustomization.yaml b/kubernetes/manifest/gpu/kustomization.yaml new file mode 100644 index 000000000..c0d39fbfa --- /dev/null +++ b/kubernetes/manifest/gpu/kustomization.yaml @@ -0,0 +1,8 @@ +apiVersion: kustomize.config.k8s.io/v1beta1 +kind: Kustomization + +resources: + - ../base + +patches: +- path: ollama-statefulset-gpu.yaml diff --git a/kubernetes/manifest/patches/ollama-statefulset-gpu.yaml b/kubernetes/manifest/gpu/ollama-statefulset-gpu.yaml similarity index 100% rename from kubernetes/manifest/patches/ollama-statefulset-gpu.yaml rename to kubernetes/manifest/gpu/ollama-statefulset-gpu.yaml diff --git a/kubernetes/manifest/kustomization.yaml b/kubernetes/manifest/kustomization.yaml deleted file mode 100644 index 907bff3e1..000000000 --- a/kubernetes/manifest/kustomization.yaml +++ /dev/null @@ -1,13 +0,0 @@ -resources: -- base/open-webui.yaml -- base/ollama-service.yaml -- base/ollama-statefulset.yaml -- base/webui-deployment.yaml -- base/webui-service.yaml -- base/webui-ingress.yaml -- base/webui-pvc.yaml - -apiVersion: kustomize.config.k8s.io/v1beta1 -kind: Kustomization -patches: -- path: patches/ollama-statefulset-gpu.yaml diff --git a/package-lock.json b/package-lock.json index 9ec44ffa7..a0048a211 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,18 +1,19 @@ { "name": "open-webui", - "version": "0.3.21", + "version": "0.3.22", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "open-webui", - "version": "0.3.21", + "version": "0.3.22", "dependencies": { "@codemirror/lang-javascript": "^6.2.2", "@codemirror/lang-python": "^6.1.6", "@codemirror/theme-one-dark": "^6.1.2", "@pyscript/core": "^0.4.32", "@sveltejs/adapter-node": "^2.0.0", + "@xyflow/svelte": "^0.1.19", "async": "^3.2.5", "bits-ui": "^0.19.7", "codemirror": "^6.0.1", @@ -45,7 +46,6 @@ "@sveltejs/kit": "^2.5.20", "@sveltejs/vite-plugin-svelte": "^3.1.1", "@tailwindcss/typography": "^0.5.13", - "@types/bun": "latest", "@typescript-eslint/eslint-plugin": "^6.17.0", "@typescript-eslint/parser": "^6.17.0", "autoprefixer": "^10.4.16", @@ -1368,6 +1368,14 @@ "resolved": "https://registry.npmjs.org/@socket.io/component-emitter/-/component-emitter-3.1.2.tgz", "integrity": "sha512-9BCxFwvbGg/RsZK9tjXd8s4UcwR0MWeFQ1XEKIQVVvAGJyINdrqKMcTRyLoK8Rse1GjzLV9cwjWV1olXRWEXVA==" }, + "node_modules/@svelte-put/shortcut": { + "version": "3.1.1", + "resolved": "https://registry.npmjs.org/@svelte-put/shortcut/-/shortcut-3.1.1.tgz", + "integrity": "sha512-2L5EYTZXiaKvbEelVkg5znxqvfZGZai3m97+cAiUBhLZwXnGtviTDpHxOoZBsqz41szlfRMcamW/8o0+fbW3ZQ==", + "peerDependencies": { + "svelte": "^3.55.0 || ^4.0.0 || ^5.0.0" + } + }, "node_modules/@sveltejs/adapter-auto": { "version": "3.2.2", "resolved": "https://registry.npmjs.org/@sveltejs/adapter-auto/-/adapter-auto-3.2.2.tgz", @@ -1494,20 +1502,32 @@ "tailwindcss": ">=3.0.0 || insiders" } }, - "node_modules/@types/bun": { - "version": "1.0.10", - "resolved": "https://registry.npmjs.org/@types/bun/-/bun-1.0.10.tgz", - "integrity": "sha512-Jaz6YYAdm1u3NVlgSyEK+qGmrlLQ20sbWeEoXD64b9w6z/YKYNWlfaphu+xF2Kiy5Tpykm5Q9jIquLegwXx4ng==", - "dev": true, - "dependencies": { - "bun-types": "1.0.33" - } - }, "node_modules/@types/cookie": { "version": "0.6.0", "resolved": "https://registry.npmjs.org/@types/cookie/-/cookie-0.6.0.tgz", "integrity": "sha512-4Kh9a6B2bQciAhf7FSuMRRkUWecJgJu9nPnx3yzpsfXX/c50REIqpHY4C82bXP90qrLtXtkDxTZosYO3UpOwlA==" }, + "node_modules/@types/d3-color": { + "version": "3.1.3", + "resolved": "https://registry.npmjs.org/@types/d3-color/-/d3-color-3.1.3.tgz", + "integrity": "sha512-iO90scth9WAbmgv7ogoq57O9YpKmFBbmoEoCHDB2xMBY0+/KVrqAaCDyCE16dUspeOvIxFFRI+0sEtqDqy2b4A==" + }, + "node_modules/@types/d3-drag": { + "version": "3.0.7", + "resolved": "https://registry.npmjs.org/@types/d3-drag/-/d3-drag-3.0.7.tgz", + "integrity": "sha512-HE3jVKlzU9AaMazNufooRJ5ZpWmLIoc90A37WU2JMmeq28w1FQqCZswHZ3xR+SuxYftzHq6WU6KJHvqxKzTxxQ==", + "dependencies": { + "@types/d3-selection": "*" + } + }, + "node_modules/@types/d3-interpolate": { + "version": "3.0.4", + "resolved": "https://registry.npmjs.org/@types/d3-interpolate/-/d3-interpolate-3.0.4.tgz", + "integrity": "sha512-mgLPETlrpVV1YRJIglr4Ez47g7Yxjl1lj7YKsiMCb27VJH9W8NVM6Bb9d8kkpG/uAQS5AmbA48q2IAolKKo1MA==", + "dependencies": { + "@types/d3-color": "*" + } + }, "node_modules/@types/d3-scale": { "version": "4.0.8", "resolved": "https://registry.npmjs.org/@types/d3-scale/-/d3-scale-4.0.8.tgz", @@ -1521,11 +1541,33 @@ "resolved": "https://registry.npmjs.org/@types/d3-scale-chromatic/-/d3-scale-chromatic-3.0.3.tgz", "integrity": "sha512-laXM4+1o5ImZv3RpFAsTRn3TEkzqkytiOY0Dz0sq5cnd1dtNlk6sHLon4OvqaiJb28T0S/TdsBI3Sjsy+keJrw==" }, + "node_modules/@types/d3-selection": { + "version": "3.0.10", + "resolved": "https://registry.npmjs.org/@types/d3-selection/-/d3-selection-3.0.10.tgz", + "integrity": "sha512-cuHoUgS/V3hLdjJOLTT691+G2QoqAjCVLmr4kJXR4ha56w1Zdu8UUQ5TxLRqudgNjwXeQxKMq4j+lyf9sWuslg==" + }, "node_modules/@types/d3-time": { "version": "3.0.3", "resolved": "https://registry.npmjs.org/@types/d3-time/-/d3-time-3.0.3.tgz", "integrity": "sha512-2p6olUZ4w3s+07q3Tm2dbiMZy5pCDfYwtLXXHUnVzXgQlZ/OyPtUz6OL382BkOuGlLXqfT+wqv8Fw2v8/0geBw==" }, + "node_modules/@types/d3-transition": { + "version": "3.0.8", + "resolved": "https://registry.npmjs.org/@types/d3-transition/-/d3-transition-3.0.8.tgz", + "integrity": "sha512-ew63aJfQ/ms7QQ4X7pk5NxQ9fZH/z+i24ZfJ6tJSfqxJMrYLiK01EAs2/Rtw/JreGUsS3pLPNV644qXFGnoZNQ==", + "dependencies": { + "@types/d3-selection": "*" + } + }, + "node_modules/@types/d3-zoom": { + "version": "3.0.8", + "resolved": "https://registry.npmjs.org/@types/d3-zoom/-/d3-zoom-3.0.8.tgz", + "integrity": "sha512-iqMC4/YlFCSlO8+2Ii1GGGliCAY4XdeG748w5vQUbevlbDu0zSjH/+jojorQVBK/se0j6DUFNPBGSqD3YWYnDw==", + "dependencies": { + "@types/d3-interpolate": "*", + "@types/d3-selection": "*" + } + }, "node_modules/@types/debug": { "version": "4.1.12", "resolved": "https://registry.npmjs.org/@types/debug/-/debug-4.1.12.tgz", @@ -1568,7 +1610,7 @@ "version": "20.11.30", "resolved": "https://registry.npmjs.org/@types/node/-/node-20.11.30.tgz", "integrity": "sha512-dHM6ZxwlmuZaRmUPfv1p+KrdD1Dci04FbdEm/9wEMouFqxYoFl5aMkt0VMAUtYRQDyYvD41WJLukhq/ha3YuTw==", - "devOptional": true, + "optional": true, "dependencies": { "undici-types": "~5.26.4" } @@ -1613,15 +1655,6 @@ "resolved": "https://registry.npmjs.org/@types/unist/-/unist-2.0.10.tgz", "integrity": "sha512-IfYcSBWE3hLpBg8+X2SEa8LVkJdJEkT2Ese2aaLs3ptGdVtABxndrMaxuFlQ1qdFf9Q5rDvDpxI3WwgvKFAsQA==" }, - "node_modules/@types/ws": { - "version": "8.5.10", - "resolved": "https://registry.npmjs.org/@types/ws/-/ws-8.5.10.tgz", - "integrity": "sha512-vmQSUcfalpIq0R9q7uTo2lXs6eGIpt9wtnLdMv9LVpIjCA/+ufZRozlVoVelIYixx1ugCBKDhn89vnsEGOCx9A==", - "dev": true, - "dependencies": { - "@types/node": "*" - } - }, "node_modules/@types/yauzl": { "version": "2.10.3", "resolved": "https://registry.npmjs.org/@types/yauzl/-/yauzl-2.10.3.tgz", @@ -1942,6 +1975,33 @@ "resolved": "https://registry.npmjs.org/@webreflection/fetch/-/fetch-0.1.5.tgz", "integrity": "sha512-zCcqCJoNLvdeF41asAK71XPlwSPieeRDsE09albBunJEksuYPYNillKNQjf8p5BqSoTKTuKrW3lUm3MNodUC4g==" }, + "node_modules/@xyflow/svelte": { + "version": "0.1.19", + "resolved": "https://registry.npmjs.org/@xyflow/svelte/-/svelte-0.1.19.tgz", + "integrity": "sha512-yW5w5aI+Yqkob4kLQpVDo/ZmX+E9Pw7459kqwLfv4YG4N1NYXrsDRh9cyph/rapbuDnPi6zqK5E8LKrgaCQC0w==", + "dependencies": { + "@svelte-put/shortcut": "^3.1.0", + "@xyflow/system": "0.0.42", + "classcat": "^5.0.4" + }, + "peerDependencies": { + "svelte": "^3.0.0 || ^4.0.0" + } + }, + "node_modules/@xyflow/system": { + "version": "0.0.42", + "resolved": "https://registry.npmjs.org/@xyflow/system/-/system-0.0.42.tgz", + "integrity": "sha512-kWYj+Y0GOct0jKYTdyRMNOLPxGNbb2TYvPg2gTmJnZ31DOOMkL5uRBLX825DR2gOACDu+i5FHLxPJUPf/eGOJw==", + "dependencies": { + "@types/d3-drag": "^3.0.7", + "@types/d3-selection": "^3.0.10", + "@types/d3-transition": "^3.0.8", + "@types/d3-zoom": "^3.0.8", + "d3-drag": "^3.0.0", + "d3-selection": "^3.0.0", + "d3-zoom": "^3.0.0" + } + }, "node_modules/acorn": { "version": "8.11.3", "resolved": "https://registry.npmjs.org/acorn/-/acorn-8.11.3.tgz", @@ -2533,16 +2593,6 @@ "url": "https://github.com/sponsors/sindresorhus" } }, - "node_modules/bun-types": { - "version": "1.0.33", - "resolved": "https://registry.npmjs.org/bun-types/-/bun-types-1.0.33.tgz", - "integrity": "sha512-L5tBIf9g6rBBkvshqysi5NoLQ9NnhSPU1pfJ9FzqoSfofYdyac3WLUnOIuQ+M5za/sooVUOP2ko+E6Tco0OLIA==", - "dev": true, - "dependencies": { - "@types/node": "~20.11.3", - "@types/ws": "~8.5.10" - } - }, "node_modules/cac": { "version": "6.7.14", "resolved": "https://registry.npmjs.org/cac/-/cac-6.7.14.tgz", @@ -2777,6 +2827,11 @@ "node": ">=8" } }, + "node_modules/classcat": { + "version": "5.0.5", + "resolved": "https://registry.npmjs.org/classcat/-/classcat-5.0.5.tgz", + "integrity": "sha512-JhZUT7JFcQy/EzW605k/ktHtncoo9vnyW/2GspNYwFlN1C/WmjuV/xtS04e9SOkL2sTdw0VAZ2UGCcQ9lR6p6w==" + }, "node_modules/clean-stack": { "version": "2.2.0", "resolved": "https://registry.npmjs.org/clean-stack/-/clean-stack-2.2.0.tgz", @@ -7094,9 +7149,9 @@ } }, "node_modules/picocolors": { - "version": "1.0.1", - "resolved": "https://registry.npmjs.org/picocolors/-/picocolors-1.0.1.tgz", - "integrity": "sha512-anP1Z8qwhkbmu7MFP5iTt+wQKXgwzf7zTyGlcdzabySa9vd0Xt392U0rVmz9poOaBj0uHJKyyo9/upk0HrEQew==" + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/picocolors/-/picocolors-1.1.0.tgz", + "integrity": "sha512-TQ92mBOW0l3LeMeyLV6mzy/kWr8lkd/hp3mTg7wYK7zJhuBStmGMBG0BdeDZS/dZx1IukaX6Bk11zcln25o1Aw==" }, "node_modules/picomatch": { "version": "2.3.1", @@ -7162,9 +7217,9 @@ } }, "node_modules/postcss": { - "version": "8.4.41", - "resolved": "https://registry.npmjs.org/postcss/-/postcss-8.4.41.tgz", - "integrity": "sha512-TesUflQ0WKZqAvg52PWL6kHgLKP6xB6heTOdoYM0Wt2UHyxNa4K25EZZMgKns3BH1RLVbZCREPpLY0rhnNoHVQ==", + "version": "8.4.47", + "resolved": "https://registry.npmjs.org/postcss/-/postcss-8.4.47.tgz", + "integrity": "sha512-56rxCq7G/XfB4EkXq9Egn5GCqugWvDFjafDOThIdMBsI15iqPqR5r15TfSr1YPYeEI19YeaXMCbY6u88Y76GLQ==", "funding": [ { "type": "opencollective", @@ -7181,8 +7236,8 @@ ], "dependencies": { "nanoid": "^3.3.7", - "picocolors": "^1.0.1", - "source-map-js": "^1.2.0" + "picocolors": "^1.1.0", + "source-map-js": "^1.2.1" }, "engines": { "node": "^10 || ^12 || >=14" @@ -8195,9 +8250,9 @@ "integrity": "sha512-FJF5jgdfvoKn1MAKSdGs33bIqLi3LmsgVTliuX6iITj834F+JRQZN90Z93yql8h0K2t0RwDPBmxwlbZfDcxNZA==" }, "node_modules/source-map-js": { - "version": "1.2.0", - "resolved": "https://registry.npmjs.org/source-map-js/-/source-map-js-1.2.0.tgz", - "integrity": "sha512-itJW8lvSA0TXEphiRoawsCksnlf8SyvmFzIhltqAHluXd88pkCd+cXJVHTDwdCr0IzwptSm035IHQktUu1QUMg==", + "version": "1.2.1", + "resolved": "https://registry.npmjs.org/source-map-js/-/source-map-js-1.2.1.tgz", + "integrity": "sha512-UXWMKhLOwVKb728IUtQPXxfYU+usdybtUrK/8uGE8CQMvrhOpwvzDBwj0QhSL7MQc7vIsISBG8VQ8+IDQxpfQA==", "engines": { "node": ">=0.10.0" } @@ -9112,7 +9167,7 @@ "version": "5.26.5", "resolved": "https://registry.npmjs.org/undici-types/-/undici-types-5.26.5.tgz", "integrity": "sha512-JlCMO+ehdEIKqlFxk6IfVoAUVmgz7cU7zD/h9XZ0qzeosSHmUJVOzSQvvYSYWXkFXC+IfLKSIffhv0sVZup6pA==", - "devOptional": true + "optional": true }, "node_modules/unist-util-stringify-position": { "version": "3.0.3", @@ -9329,13 +9384,13 @@ } }, "node_modules/vite": { - "version": "5.4.0", - "resolved": "https://registry.npmjs.org/vite/-/vite-5.4.0.tgz", - "integrity": "sha512-5xokfMX0PIiwCMCMb9ZJcMyh5wbBun0zUzKib+L65vAZ8GY9ePZMXxFrHbr/Kyll2+LSCY7xtERPpxkBDKngwg==", + "version": "5.4.6", + "resolved": "https://registry.npmjs.org/vite/-/vite-5.4.6.tgz", + "integrity": "sha512-IeL5f8OO5nylsgzd9tq4qD2QqI0k2CQLGrWD0rCN0EQJZpBK5vJAx0I+GDkMOXxQX/OfFHMuLIx6ddAxGX/k+Q==", "dependencies": { "esbuild": "^0.21.3", - "postcss": "^8.4.40", - "rollup": "^4.13.0" + "postcss": "^8.4.43", + "rollup": "^4.20.0" }, "bin": { "vite": "bin/vite.js" diff --git a/package.json b/package.json index 08c101384..371507789 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "open-webui", - "version": "0.3.21", + "version": "0.3.22", "private": true, "scripts": { "dev": "npm run pyodide:fetch && vite dev --host", @@ -25,7 +25,6 @@ "@sveltejs/kit": "^2.5.20", "@sveltejs/vite-plugin-svelte": "^3.1.1", "@tailwindcss/typography": "^0.5.13", - "@types/bun": "latest", "@typescript-eslint/eslint-plugin": "^6.17.0", "@typescript-eslint/parser": "^6.17.0", "autoprefixer": "^10.4.16", @@ -54,6 +53,7 @@ "@codemirror/theme-one-dark": "^6.1.2", "@pyscript/core": "^0.4.32", "@sveltejs/adapter-node": "^2.0.0", + "@xyflow/svelte": "^0.1.19", "async": "^3.2.5", "bits-ui": "^0.19.7", "codemirror": "^6.0.1", diff --git a/pyproject.toml b/pyproject.toml index 057ef1475..b2558e4d1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,7 +47,12 @@ dependencies = [ "fake-useragent==1.5.1", "chromadb==0.5.5", + "pymilvus==2.4.6", + "sentence-transformers==3.0.1", + "colbert-ai==0.2.21", + "einops==0.8.0", + "pypdf==4.3.1", "docx2txt==0.8", "python-pptx==1.0.0", diff --git a/src/app.css b/src/app.css index a421d90ae..65103b55a 100644 --- a/src/app.css +++ b/src/app.css @@ -50,21 +50,6 @@ iframe { @apply rounded-lg; } -ol > li { - counter-increment: list-number; - display: block; - margin-bottom: 0; - margin-top: 0; - min-height: 28px; -} - -.prose ol > li::before { - content: counters(list-number, '.') '.'; - padding-right: 0.5rem; - color: var(--tw-prose-counters); - font-weight: 400; -} - li p { display: inline; } @@ -171,3 +156,20 @@ input[type='number'] { font-weight: 600; @apply rounded-md dark:bg-gray-800 bg-gray-100 mx-0.5; } + +.svelte-flow { + background-color: transparent !important; +} + +.svelte-flow__edge > path { + stroke-width: 0.5; +} + +.svelte-flow__edge.animated > path { + stroke-width: 2; + @apply stroke-gray-600 dark:stroke-gray-500; +} + +.bg-gray-950-90 { + background-color: rgba(var(--color-gray-950, #0d0d0d), 0.9); +} diff --git a/src/app.html b/src/app.html index 59fd7c5ed..d7f4513e7 100644 --- a/src/app.html +++ b/src/app.html @@ -4,7 +4,11 @@ - + + // On page load or when changing themes, best to add inline in `head` to avoid FOUC (() => { + const metaThemeColorTag = document.querySelector('meta[name="theme-color"]'); + const prefersDarkTheme = window.matchMedia('(prefers-color-scheme: dark)').matches; + if (!localStorage?.theme) { localStorage.theme = 'system'; } - if (localStorage?.theme && localStorage?.theme.includes('oled')) { + if (localStorage.theme === 'system') { + document.documentElement.classList.add(prefersDarkTheme ? 'dark' : 'light'); + metaThemeColorTag.setAttribute('content', prefersDarkTheme ? '#171717' : '#ffffff'); + } else if (localStorage.theme === 'oled-dark') { document.documentElement.style.setProperty('--color-gray-800', '#101010'); document.documentElement.style.setProperty('--color-gray-850', '#050505'); document.documentElement.style.setProperty('--color-gray-900', '#000000'); document.documentElement.style.setProperty('--color-gray-950', '#000000'); document.documentElement.classList.add('dark'); - } else if ( - localStorage.theme === 'light' || - (!('theme' in localStorage) && window.matchMedia('(prefers-color-scheme: light)').matches) - ) { + metaThemeColorTag.setAttribute('content', '#000000'); + } else if (localStorage.theme === 'light') { document.documentElement.classList.add('light'); - } else if (localStorage.theme && localStorage.theme !== 'system') { - localStorage.theme.split(' ').forEach((e) => { - document.documentElement.classList.add(e); - }); - } else if (localStorage.theme && localStorage.theme === 'system') { - systemTheme = window.matchMedia('(prefers-color-scheme: dark)').matches; - document.documentElement.classList.add(systemTheme ? 'dark' : 'light'); - } else if (localStorage.theme && localStorage.theme === 'her') { + metaThemeColorTag.setAttribute('content', '#ffffff'); + } else if (localStorage.theme === 'her') { document.documentElement.classList.add('dark'); document.documentElement.classList.add('her'); + metaThemeColorTag.setAttribute('content', '#983724'); } else { document.documentElement.classList.add('dark'); + metaThemeColorTag.setAttribute('content', '#171717'); } window.matchMedia('(prefers-color-scheme: dark)').addListener((e) => { @@ -57,9 +61,11 @@ if (e.matches) { document.documentElement.classList.add('dark'); document.documentElement.classList.remove('light'); + metaThemeColorTag.setAttribute('content', '#171717'); } else { document.documentElement.classList.add('light'); document.documentElement.classList.remove('dark'); + metaThemeColorTag.setAttribute('content', '#ffffff'); } } }); diff --git a/src/lib/components/admin/Settings/Audio.svelte b/src/lib/components/admin/Settings/Audio.svelte index 1c114c9dd..040bc5e1a 100644 --- a/src/lib/components/admin/Settings/Audio.svelte +++ b/src/lib/components/admin/Settings/Audio.svelte @@ -31,6 +31,8 @@ let TTS_MODEL = ''; let TTS_VOICE = ''; let TTS_SPLIT_ON: TTS_RESPONSE_SPLIT = TTS_RESPONSE_SPLIT.PUNCTUATION; + let TTS_AZURE_SPEECH_REGION = ''; + let TTS_AZURE_SPEECH_OUTPUT_FORMAT = ''; let STT_OPENAI_API_BASE_URL = ''; let STT_OPENAI_API_KEY = ''; @@ -87,7 +89,9 @@ ENGINE: TTS_ENGINE, MODEL: TTS_MODEL, VOICE: TTS_VOICE, - SPLIT_ON: TTS_SPLIT_ON + SPLIT_ON: TTS_SPLIT_ON, + AZURE_SPEECH_REGION: TTS_AZURE_SPEECH_REGION, + AZURE_SPEECH_OUTPUT_FORMAT: TTS_AZURE_SPEECH_OUTPUT_FORMAT }, stt: { OPENAI_API_BASE_URL: STT_OPENAI_API_BASE_URL, @@ -120,6 +124,9 @@ TTS_SPLIT_ON = res.tts.SPLIT_ON || TTS_RESPONSE_SPLIT.PUNCTUATION; + TTS_AZURE_SPEECH_OUTPUT_FORMAT = res.tts.AZURE_SPEECH_OUTPUT_FORMAT; + TTS_AZURE_SPEECH_REGION = res.tts.AZURE_SPEECH_REGION; + STT_OPENAI_API_BASE_URL = res.stt.OPENAI_API_BASE_URL; STT_OPENAI_API_KEY = res.stt.OPENAI_API_KEY; @@ -224,6 +231,7 @@ + @@ -252,6 +260,23 @@ /> + {:else if TTS_ENGINE === 'azure'} +
+
+ + +
+
{/if}
@@ -359,6 +384,49 @@ + {:else if TTS_ENGINE === 'azure'} +
+
+
{$i18n.t('TTS Voice')}
+
+
+ + + + {#each voices as voice} + + {/each} + +
+
+
+
+
+ {$i18n.t('Output format')} + + {$i18n.t('Available list')} + +
+
+
+ +
+
+
+
{/if}
diff --git a/src/lib/components/admin/Settings/Connections.svelte b/src/lib/components/admin/Settings/Connections.svelte index fe71e4816..97a760110 100644 --- a/src/lib/components/admin/Settings/Connections.svelte +++ b/src/lib/components/admin/Settings/Connections.svelte @@ -150,18 +150,20 @@ })() ]); - OPENAI_API_BASE_URLS.forEach(async (url, idx) => { - const res = await getOpenAIModels(localStorage.token, idx); - if (res.pipelines) { - pipelineUrls[url] = true; - } - }); - const ollamaConfig = await getOllamaConfig(localStorage.token); const openaiConfig = await getOpenAIConfig(localStorage.token); ENABLE_OPENAI_API = openaiConfig.ENABLE_OPENAI_API; ENABLE_OLLAMA_API = ollamaConfig.ENABLE_OLLAMA_API; + + if (ENABLE_OPENAI_API) { + OPENAI_API_BASE_URLS.forEach(async (url, idx) => { + const res = await getOpenAIModels(localStorage.token, idx); + if (res.pipelines) { + pipelineUrls[url] = true; + } + }); + } } }); diff --git a/src/lib/components/admin/Settings/Documents.svelte b/src/lib/components/admin/Settings/Documents.svelte index 7f935a08a..e06edce9d 100644 --- a/src/lib/components/admin/Settings/Documents.svelte +++ b/src/lib/components/admin/Settings/Documents.svelte @@ -732,11 +732,17 @@
{$i18n.t('RAG Template')}
-