diff --git a/CHANGELOG.md b/CHANGELOG.md
index 19686d028..46c1c644e 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -5,6 +5,34 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
+## [0.3.22] - 2024-09-19
+
+### Added
+
+- **⭐ Chat Overview**: Introducing a node-based interactive messages diagram for improved visualization of conversation flows.
+- **🔗 Multiple Vector DB Support**: Now supports multiple vector databases, including the newly added Milvus support. Community contributions for additional database support are highly encouraged!
+- **📡 Experimental Non-Stream Chat Completion**: Experimental feature allowing the use of OpenAI o1 models, which do not support streaming, ensuring more versatile model deployment.
+- **🔍 Experimental Colbert-AI Reranker Integration**: Added support for "jinaai/jina-colbert-v2" as a reranker, enhancing search relevance and accuracy. Note: it may not function at all on low-spec computers.
+- **🕸️ ENABLE_WEBSOCKET_SUPPORT**: Added environment variable for instances to ignore websocket upgrades, stabilizing connections on platforms with websocket issues.
+- **🔊 Azure Speech Service Integration**: Added support for Azure Speech services for Text-to-Speech (TTS).
+- **🎚️ Customizable Playback Speed**: Playback speed control is now available in Call mode settings, allowing users to adjust audio playback speed to their preferences.
+- **🧠 Enhanced Error Messaging**: System now displays helpful error messages directly to users during chat completion issues.
+- **📂 Save Model as Transparent PNG**: Model profile images are now saved as PNGs, supporting transparency and improving visual integration.
+- **📱 iPhone Compatibility Adjustments**: Added padding to accommodate the iPhone navigation bar, improving UI display on these devices.
+- **🔗 Secure Response Headers**: Implemented security response headers, bolstering web application security.
+- **🔧 Enhanced AUTOMATIC1111 Settings**: Users can now configure 'CFG Scale', 'Sampler', and 'Scheduler' parameters directly in the admin settings, enhancing workflow flexibility without source code modifications.
+- **🌍 i18n Updates**: Enhanced translations for Chinese, Ukrainian, Russian, and French, fostering a better localized experience.
+
+### Fixed
+
+- **🛠️ Chat Message Deletion**: Resolved issues with chat message deletion, ensuring a smoother user interaction and system stability.
+- **🔢 Ordered List Numbering**: Fixed the incorrect ordering in lists.
+
+### Changed
+
+- **🎨 Transparent Icon Handling**: Allowed model icons to be displayed on transparent backgrounds, improving UI aesthetics.
+- **📝 Improved RAG Template**: Enhanced Retrieval-Augmented Generation template, optimizing context handling and error checking for more precise operation.
+
## [0.3.21] - 2024-09-08
### Added
diff --git a/Dockerfile b/Dockerfile
index 8078bf0ea..c944f54e6 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -74,6 +74,10 @@ ENV RAG_EMBEDDING_MODEL="$USE_EMBEDDING_MODEL_DOCKER" \
## Hugging Face download cache ##
ENV HF_HOME="/app/backend/data/cache/embedding/models"
+
+## Torch Extensions ##
+# ENV TORCH_EXTENSIONS_DIR="/.cache/torch_extensions"
+
#### Other models ##########################################################
WORKDIR /app/backend
@@ -96,7 +100,7 @@ RUN chown -R $UID:$GID /app $HOME
RUN if [ "$USE_OLLAMA" = "true" ]; then \
apt-get update && \
# Install pandoc and netcat
- apt-get install -y --no-install-recommends pandoc netcat-openbsd curl && \
+ apt-get install -y --no-install-recommends git build-essential pandoc netcat-openbsd curl && \
apt-get install -y --no-install-recommends gcc python3-dev && \
# for RAG OCR
apt-get install -y --no-install-recommends ffmpeg libsm6 libxext6 && \
@@ -109,7 +113,7 @@ RUN if [ "$USE_OLLAMA" = "true" ]; then \
else \
apt-get update && \
# Install pandoc, netcat and gcc
- apt-get install -y --no-install-recommends pandoc gcc netcat-openbsd curl jq && \
+ apt-get install -y --no-install-recommends git build-essential pandoc gcc netcat-openbsd curl jq && \
apt-get install -y --no-install-recommends gcc python3-dev && \
# for RAG OCR
apt-get install -y --no-install-recommends ffmpeg libsm6 libxext6 && \
@@ -157,5 +161,6 @@ USER $UID:$GID
ARG BUILD_HASH
ENV WEBUI_BUILD_VERSION=${BUILD_HASH}
+ENV DOCKER true
CMD [ "bash", "start.sh"]
diff --git a/backend/open_webui/__init__.py b/backend/open_webui/__init__.py
index 30e83b198..167f0fb60 100644
--- a/backend/open_webui/__init__.py
+++ b/backend/open_webui/__init__.py
@@ -39,6 +39,19 @@ def serve(
"/usr/local/lib/python3.11/site-packages/nvidia/cudnn/lib",
]
)
+ try:
+ import torch
+
+ assert torch.cuda.is_available(), "CUDA not available"
+ typer.echo("CUDA seems to be working")
+ except Exception as e:
+ typer.echo(
+ "Error when testing CUDA but USE_CUDA_DOCKER is true. "
+ "Resetting USE_CUDA_DOCKER to false and removing "
+ f"LD_LIBRARY_PATH modifications: {e}"
+ )
+ os.environ["USE_CUDA_DOCKER"] = "false"
+ os.environ["LD_LIBRARY_PATH"] = ":".join(LD_LIBRARY_PATH)
import open_webui.main # we need set environment variables before importing main
uvicorn.run(open_webui.main.app, host=host, port=port, forwarded_allow_ips="*")
diff --git a/backend/open_webui/apps/audio/main.py b/backend/open_webui/apps/audio/main.py
index 1fc44b28f..0eee533bd 100644
--- a/backend/open_webui/apps/audio/main.py
+++ b/backend/open_webui/apps/audio/main.py
@@ -19,16 +19,18 @@ from open_webui.config import (
AUDIO_TTS_OPENAI_API_KEY,
AUDIO_TTS_SPLIT_ON,
AUDIO_TTS_VOICE,
+ AUDIO_TTS_AZURE_SPEECH_REGION,
+ AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT,
CACHE_DIR,
CORS_ALLOW_ORIGIN,
- DEVICE_TYPE,
WHISPER_MODEL,
WHISPER_MODEL_AUTO_UPDATE,
WHISPER_MODEL_DIR,
AppConfig,
)
+
from open_webui.constants import ERROR_MESSAGES
-from open_webui.env import SRC_LOG_LEVELS
+from open_webui.env import SRC_LOG_LEVELS, DEVICE_TYPE
from fastapi import Depends, FastAPI, File, HTTPException, Request, UploadFile, status
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse
@@ -62,6 +64,9 @@ app.state.config.TTS_VOICE = AUDIO_TTS_VOICE
app.state.config.TTS_API_KEY = AUDIO_TTS_API_KEY
app.state.config.TTS_SPLIT_ON = AUDIO_TTS_SPLIT_ON
+app.state.config.TTS_AZURE_SPEECH_REGION = AUDIO_TTS_AZURE_SPEECH_REGION
+app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT = AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT
+
# setting device type for whisper model
whisper_device_type = DEVICE_TYPE if DEVICE_TYPE and DEVICE_TYPE == "cuda" else "cpu"
log.info(f"whisper_device_type: {whisper_device_type}")
@@ -78,6 +83,8 @@ class TTSConfigForm(BaseModel):
MODEL: str
VOICE: str
SPLIT_ON: str
+ AZURE_SPEECH_REGION: str
+ AZURE_SPEECH_OUTPUT_FORMAT: str
class STTConfigForm(BaseModel):
@@ -130,6 +137,8 @@ async def get_audio_config(user=Depends(get_admin_user)):
"MODEL": app.state.config.TTS_MODEL,
"VOICE": app.state.config.TTS_VOICE,
"SPLIT_ON": app.state.config.TTS_SPLIT_ON,
+ "AZURE_SPEECH_REGION": app.state.config.TTS_AZURE_SPEECH_REGION,
+ "AZURE_SPEECH_OUTPUT_FORMAT": app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT,
},
"stt": {
"OPENAI_API_BASE_URL": app.state.config.STT_OPENAI_API_BASE_URL,
@@ -151,6 +160,10 @@ async def update_audio_config(
app.state.config.TTS_MODEL = form_data.tts.MODEL
app.state.config.TTS_VOICE = form_data.tts.VOICE
app.state.config.TTS_SPLIT_ON = form_data.tts.SPLIT_ON
+ app.state.config.TTS_AZURE_SPEECH_REGION = form_data.tts.AZURE_SPEECH_REGION
+ app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT = (
+ form_data.tts.AZURE_SPEECH_OUTPUT_FORMAT
+ )
app.state.config.STT_OPENAI_API_BASE_URL = form_data.stt.OPENAI_API_BASE_URL
app.state.config.STT_OPENAI_API_KEY = form_data.stt.OPENAI_API_KEY
@@ -166,6 +179,8 @@ async def update_audio_config(
"MODEL": app.state.config.TTS_MODEL,
"VOICE": app.state.config.TTS_VOICE,
"SPLIT_ON": app.state.config.TTS_SPLIT_ON,
+ "AZURE_SPEECH_REGION": app.state.config.TTS_AZURE_SPEECH_REGION,
+ "AZURE_SPEECH_OUTPUT_FORMAT": app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT,
},
"stt": {
"OPENAI_API_BASE_URL": app.state.config.STT_OPENAI_API_BASE_URL,
@@ -301,6 +316,42 @@ async def speech(request: Request, user=Depends(get_verified_user)):
detail=error_detail,
)
+ elif app.state.config.TTS_ENGINE == "azure":
+ payload = None
+ try:
+ payload = json.loads(body.decode("utf-8"))
+ except Exception as e:
+ log.exception(e)
+ raise HTTPException(status_code=400, detail="Invalid JSON payload")
+
+ region = app.state.config.TTS_AZURE_SPEECH_REGION
+ language = app.state.config.TTS_VOICE
+ locale = "-".join(app.state.config.TTS_VOICE.split("-")[:1])
+ output_format = app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT
+ url = f"https://{region}.tts.speech.microsoft.com/cognitiveservices/v1"
+
+ headers = {
+ "Ocp-Apim-Subscription-Key": app.state.config.TTS_API_KEY,
+ "Content-Type": "application/ssml+xml",
+ "X-Microsoft-OutputFormat": output_format,
+ }
+
+ data = f"""
+ {payload["input"]}
+ """
+
+ response = requests.post(url, headers=headers, data=data)
+
+ if response.status_code == 200:
+ with open(file_path, "wb") as f:
+ f.write(response.content)
+ return FileResponse(file_path)
+ else:
+ log.error(f"Error synthesizing speech - {response.reason}")
+ raise HTTPException(
+ status_code=500, detail=f"Error synthesizing speech - {response.reason}"
+ )
+
@app.post("/transcriptions")
def transcribe(
@@ -309,7 +360,7 @@ def transcribe(
):
log.info(f"file.content_type: {file.content_type}")
- if file.content_type not in ["audio/mpeg", "audio/wav"]:
+ if file.content_type not in ["audio/mpeg", "audio/wav", "audio/ogg"]:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.FILE_NOT_SUPPORTED,
@@ -443,7 +494,7 @@ def get_available_models() -> list[dict]:
try:
response = requests.get(
- "https://api.elevenlabs.io/v1/models", headers=headers
+ "https://api.elevenlabs.io/v1/models", headers=headers, timeout=5
)
response.raise_for_status()
models = response.json()
@@ -478,6 +529,21 @@ def get_available_voices() -> dict:
except Exception:
# Avoided @lru_cache with exception
pass
+ elif app.state.config.TTS_ENGINE == "azure":
+ try:
+ region = app.state.config.TTS_AZURE_SPEECH_REGION
+ url = f"https://{region}.tts.speech.microsoft.com/cognitiveservices/voices/list"
+ headers = {"Ocp-Apim-Subscription-Key": app.state.config.TTS_API_KEY}
+
+ response = requests.get(url, headers=headers)
+ response.raise_for_status()
+ voices = response.json()
+ for voice in voices:
+ ret[voice["ShortName"]] = (
+ f"{voice['DisplayName']} ({voice['ShortName']})"
+ )
+ except requests.RequestException as e:
+ log.error(f"Error fetching voices: {str(e)}")
return ret
diff --git a/backend/open_webui/apps/images/main.py b/backend/open_webui/apps/images/main.py
index 17afd645c..1074e2cb0 100644
--- a/backend/open_webui/apps/images/main.py
+++ b/backend/open_webui/apps/images/main.py
@@ -17,6 +17,9 @@ from open_webui.apps.images.utils.comfyui import (
from open_webui.config import (
AUTOMATIC1111_API_AUTH,
AUTOMATIC1111_BASE_URL,
+ AUTOMATIC1111_CFG_SCALE,
+ AUTOMATIC1111_SAMPLER,
+ AUTOMATIC1111_SCHEDULER,
CACHE_DIR,
COMFYUI_BASE_URL,
COMFYUI_WORKFLOW,
@@ -65,6 +68,9 @@ app.state.config.MODEL = IMAGE_GENERATION_MODEL
app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
app.state.config.AUTOMATIC1111_API_AUTH = AUTOMATIC1111_API_AUTH
+app.state.config.AUTOMATIC1111_CFG_SCALE = AUTOMATIC1111_CFG_SCALE
+app.state.config.AUTOMATIC1111_SAMPLER = AUTOMATIC1111_SAMPLER
+app.state.config.AUTOMATIC1111_SCHEDULER = AUTOMATIC1111_SCHEDULER
app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL
app.state.config.COMFYUI_WORKFLOW = COMFYUI_WORKFLOW
app.state.config.COMFYUI_WORKFLOW_NODES = COMFYUI_WORKFLOW_NODES
@@ -85,6 +91,9 @@ async def get_config(request: Request, user=Depends(get_admin_user)):
"automatic1111": {
"AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL,
"AUTOMATIC1111_API_AUTH": app.state.config.AUTOMATIC1111_API_AUTH,
+ "AUTOMATIC1111_CFG_SCALE": app.state.config.AUTOMATIC1111_CFG_SCALE,
+ "AUTOMATIC1111_SAMPLER": app.state.config.AUTOMATIC1111_SAMPLER,
+ "AUTOMATIC1111_SCHEDULER": app.state.config.AUTOMATIC1111_SCHEDULER,
},
"comfyui": {
"COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL,
@@ -102,6 +111,9 @@ class OpenAIConfigForm(BaseModel):
class Automatic1111ConfigForm(BaseModel):
AUTOMATIC1111_BASE_URL: str
AUTOMATIC1111_API_AUTH: str
+ AUTOMATIC1111_CFG_SCALE: Optional[str]
+ AUTOMATIC1111_SAMPLER: Optional[str]
+ AUTOMATIC1111_SCHEDULER: Optional[str]
class ComfyUIConfigForm(BaseModel):
@@ -133,6 +145,22 @@ async def update_config(form_data: ConfigForm, user=Depends(get_admin_user)):
form_data.automatic1111.AUTOMATIC1111_API_AUTH
)
+ app.state.config.AUTOMATIC1111_CFG_SCALE = (
+ float(form_data.automatic1111.AUTOMATIC1111_CFG_SCALE)
+ if form_data.automatic1111.AUTOMATIC1111_CFG_SCALE != ""
+ else None
+ )
+ app.state.config.AUTOMATIC1111_SAMPLER = (
+ form_data.automatic1111.AUTOMATIC1111_SAMPLER
+ if form_data.automatic1111.AUTOMATIC1111_SAMPLER != ""
+ else None
+ )
+ app.state.config.AUTOMATIC1111_SCHEDULER = (
+ form_data.automatic1111.AUTOMATIC1111_SCHEDULER
+ if form_data.automatic1111.AUTOMATIC1111_SCHEDULER != ""
+ else None
+ )
+
app.state.config.COMFYUI_BASE_URL = form_data.comfyui.COMFYUI_BASE_URL.strip("/")
app.state.config.COMFYUI_WORKFLOW = form_data.comfyui.COMFYUI_WORKFLOW
app.state.config.COMFYUI_WORKFLOW_NODES = form_data.comfyui.COMFYUI_WORKFLOW_NODES
@@ -147,6 +175,9 @@ async def update_config(form_data: ConfigForm, user=Depends(get_admin_user)):
"automatic1111": {
"AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL,
"AUTOMATIC1111_API_AUTH": app.state.config.AUTOMATIC1111_API_AUTH,
+ "AUTOMATIC1111_CFG_SCALE": app.state.config.AUTOMATIC1111_CFG_SCALE,
+ "AUTOMATIC1111_SAMPLER": app.state.config.AUTOMATIC1111_SAMPLER,
+ "AUTOMATIC1111_SCHEDULER": app.state.config.AUTOMATIC1111_SCHEDULER,
},
"comfyui": {
"COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL,
@@ -524,6 +555,15 @@ async def image_generations(
if form_data.negative_prompt is not None:
data["negative_prompt"] = form_data.negative_prompt
+ if app.state.config.AUTOMATIC1111_CFG_SCALE:
+ data["cfg_scale"] = app.state.config.AUTOMATIC1111_CFG_SCALE
+
+ if app.state.config.AUTOMATIC1111_SAMPLER:
+ data["sampler_name"] = app.state.config.AUTOMATIC1111_SAMPLER
+
+ if app.state.config.AUTOMATIC1111_SCHEDULER:
+ data["scheduler"] = app.state.config.AUTOMATIC1111_SCHEDULER
+
# Use asyncio.to_thread for the requests.post call
r = await asyncio.to_thread(
requests.post,
diff --git a/backend/open_webui/apps/ollama/main.py b/backend/open_webui/apps/ollama/main.py
index fe36010b7..6c639268e 100644
--- a/backend/open_webui/apps/ollama/main.py
+++ b/backend/open_webui/apps/ollama/main.py
@@ -545,6 +545,55 @@ class GenerateEmbeddingsForm(BaseModel):
@app.post("/api/embed")
@app.post("/api/embed/{url_idx}")
+async def generate_embeddings(
+ form_data: GenerateEmbeddingsForm,
+ url_idx: Optional[int] = None,
+ user=Depends(get_verified_user),
+):
+ if url_idx is None:
+ model = form_data.model
+
+ if ":" not in model:
+ model = f"{model}:latest"
+
+ if model in app.state.MODELS:
+ url_idx = random.choice(app.state.MODELS[model]["urls"])
+ else:
+ raise HTTPException(
+ status_code=400,
+ detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
+ )
+
+ url = app.state.config.OLLAMA_BASE_URLS[url_idx]
+ log.info(f"url: {url}")
+
+ r = requests.request(
+ method="POST",
+ url=f"{url}/api/embed",
+ headers={"Content-Type": "application/json"},
+ data=form_data.model_dump_json(exclude_none=True).encode(),
+ )
+ try:
+ r.raise_for_status()
+
+ return r.json()
+ except Exception as e:
+ log.exception(e)
+ error_detail = "Open WebUI: Server Connection Error"
+ if r is not None:
+ try:
+ res = r.json()
+ if "error" in res:
+ error_detail = f"Ollama: {res['error']}"
+ except Exception:
+ error_detail = f"Ollama: {e}"
+
+ raise HTTPException(
+ status_code=r.status_code if r else 500,
+ detail=error_detail,
+ )
+
+
@app.post("/api/embeddings")
@app.post("/api/embeddings/{url_idx}")
async def generate_embeddings(
@@ -571,7 +620,7 @@ async def generate_embeddings(
r = requests.request(
method="POST",
- url=f"{url}/api/embed",
+ url=f"{url}/api/embeddings",
headers={"Content-Type": "application/json"},
data=form_data.model_dump_json(exclude_none=True).encode(),
)
@@ -767,7 +816,10 @@ async def generate_chat_completion(
log.debug(payload)
return await post_streaming_url(
- f"{url}/api/chat", json.dumps(payload), content_type="application/x-ndjson"
+ f"{url}/api/chat",
+ json.dumps(payload),
+ stream=form_data.stream,
+ content_type="application/x-ndjson",
)
diff --git a/backend/open_webui/apps/openai/main.py b/backend/open_webui/apps/openai/main.py
index 23ea1cc8c..9a27c46a3 100644
--- a/backend/open_webui/apps/openai/main.py
+++ b/backend/open_webui/apps/openai/main.py
@@ -423,6 +423,7 @@ async def generate_chat_completion(
r = None
session = None
streaming = False
+ response = None
try:
session = aiohttp.ClientSession(
@@ -435,8 +436,6 @@ async def generate_chat_completion(
headers=headers,
)
- r.raise_for_status()
-
# Check if response is SSE
if "text/event-stream" in r.headers.get("Content-Type", ""):
streaming = True
@@ -449,19 +448,23 @@ async def generate_chat_completion(
),
)
else:
- response_data = await r.json()
- return response_data
+ try:
+ response = await r.json()
+ except Exception as e:
+ log.error(e)
+ response = await r.text()
+
+ r.raise_for_status()
+ return response
except Exception as e:
log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
- if r is not None:
- try:
- res = await r.json()
- print(res)
- if "error" in res:
- error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
- except Exception:
- error_detail = f"External: {e}"
+ if isinstance(response, dict):
+ if "error" in response:
+ error_detail = f"{response['error']['message'] if 'message' in response['error'] else response['error']}"
+ elif isinstance(response, str):
+ error_detail = response
+
raise HTTPException(status_code=r.status if r else 500, detail=error_detail)
finally:
if not streaming and session:
diff --git a/backend/open_webui/apps/rag/main.py b/backend/open_webui/apps/rag/main.py
index 6c064fe81..981a6fe5b 100644
--- a/backend/open_webui/apps/rag/main.py
+++ b/backend/open_webui/apps/rag/main.py
@@ -10,13 +10,21 @@ from datetime import datetime
from pathlib import Path
from typing import Iterator, Optional, Sequence, Union
+
+import numpy as np
+import torch
import requests
import validators
+
+from fastapi import Depends, FastAPI, File, Form, HTTPException, UploadFile, status
+from fastapi.middleware.cors import CORSMiddleware
+from pydantic import BaseModel
+
+from open_webui.apps.rag.search.main import SearchResult
from open_webui.apps.rag.search.brave import search_brave
from open_webui.apps.rag.search.duckduckgo import search_duckduckgo
from open_webui.apps.rag.search.google_pse import search_google_pse
from open_webui.apps.rag.search.jina_search import search_jina
-from open_webui.apps.rag.search.main import SearchResult
from open_webui.apps.rag.search.searchapi import search_searchapi
from open_webui.apps.rag.search.searxng import search_searxng
from open_webui.apps.rag.search.serper import search_serper
@@ -33,15 +41,12 @@ from open_webui.apps.rag.utils import (
)
from open_webui.apps.webui.models.documents import DocumentForm, Documents
from open_webui.apps.webui.models.files import Files
-from chromadb.utils.batch_utils import create_batches
from open_webui.config import (
BRAVE_SEARCH_API_KEY,
- CHROMA_CLIENT,
CHUNK_OVERLAP,
CHUNK_SIZE,
CONTENT_EXTRACTION_ENGINE,
CORS_ALLOW_ORIGIN,
- DEVICE_TYPE,
DOCS_DIR,
ENABLE_RAG_HYBRID_SEARCH,
ENABLE_RAG_LOCAL_WEB_FETCH,
@@ -64,6 +69,7 @@ from open_webui.config import (
RAG_RERANKING_MODEL,
RAG_RERANKING_MODEL_AUTO_UPDATE,
RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
+ DEFAULT_RAG_TEMPLATE,
RAG_TEMPLATE,
RAG_TOP_K,
RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
@@ -84,9 +90,16 @@ from open_webui.config import (
AppConfig,
)
from open_webui.constants import ERROR_MESSAGES
-from open_webui.env import SRC_LOG_LEVELS
-from fastapi import Depends, FastAPI, File, Form, HTTPException, UploadFile, status
-from fastapi.middleware.cors import CORSMiddleware
+from open_webui.env import SRC_LOG_LEVELS, DEVICE_TYPE, DOCKER
+from open_webui.utils.misc import (
+ calculate_sha256,
+ calculate_sha256_string,
+ extract_folders_after_data_docs,
+ sanitize_filename,
+)
+from open_webui.utils.utils import get_admin_user, get_verified_user
+from open_webui.apps.rag.vector.connector import VECTOR_DB_CLIENT
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import (
BSHTMLLoader,
@@ -105,14 +118,8 @@ from langchain_community.document_loaders import (
YoutubeLoader,
)
from langchain_core.documents import Document
-from pydantic import BaseModel
-from open_webui.utils.misc import (
- calculate_sha256,
- calculate_sha256_string,
- extract_folders_after_data_docs,
- sanitize_filename,
-)
-from open_webui.utils.utils import get_admin_user, get_verified_user
+from colbert.infra import ColBERTConfig
+from colbert.modeling.checkpoint import Checkpoint
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
@@ -143,13 +150,11 @@ app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE = RAG_EMBEDDING_OPENAI_BATCH_SI
app.state.config.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
app.state.config.RAG_TEMPLATE = RAG_TEMPLATE
-
app.state.config.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
app.state.config.OPENAI_API_KEY = RAG_OPENAI_API_KEY
app.state.config.PDF_EXTRACT_IMAGES = PDF_EXTRACT_IMAGES
-
app.state.config.YOUTUBE_LOADER_LANGUAGE = YOUTUBE_LOADER_LANGUAGE
app.state.YOUTUBE_LOADER_TRANSLATION = None
@@ -175,13 +180,13 @@ app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_
def update_embedding_model(
embedding_model: str,
- update_model: bool = False,
+ auto_update: bool = False,
):
if embedding_model and app.state.config.RAG_EMBEDDING_ENGINE == "":
import sentence_transformers
app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer(
- get_model_path(embedding_model, update_model),
+ get_model_path(embedding_model, auto_update),
device=DEVICE_TYPE,
trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
)
@@ -191,16 +196,108 @@ def update_embedding_model(
def update_reranking_model(
reranking_model: str,
- update_model: bool = False,
+ auto_update: bool = False,
):
if reranking_model:
- import sentence_transformers
+ if any(model in reranking_model for model in ["jinaai/jina-colbert-v2"]):
- app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder(
- get_model_path(reranking_model, update_model),
- device=DEVICE_TYPE,
- trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
- )
+ class ColBERT:
+ def __init__(self, name) -> None:
+ print("ColBERT: Loading model", name)
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
+
+ if DOCKER:
+ # This is a workaround for the issue with the docker container
+ # where the torch extension is not loaded properly
+ # and the following error is thrown:
+ # /root/.cache/torch_extensions/py311_cpu/segmented_maxsim_cpp/segmented_maxsim_cpp.so: cannot open shared object file: No such file or directory
+
+ lock_file = "/root/.cache/torch_extensions/py311_cpu/segmented_maxsim_cpp/lock"
+ if os.path.exists(lock_file):
+ os.remove(lock_file)
+
+ self.ckpt = Checkpoint(
+ name,
+ colbert_config=ColBERTConfig(model_name=name),
+ ).to(self.device)
+ pass
+
+ def calculate_similarity_scores(
+ self, query_embeddings, document_embeddings
+ ):
+
+ query_embeddings = query_embeddings.to(self.device)
+ document_embeddings = document_embeddings.to(self.device)
+
+ # Validate dimensions to ensure compatibility
+ if query_embeddings.dim() != 3:
+ raise ValueError(
+ f"Expected query embeddings to have 3 dimensions, but got {query_embeddings.dim()}."
+ )
+ if document_embeddings.dim() != 3:
+ raise ValueError(
+ f"Expected document embeddings to have 3 dimensions, but got {document_embeddings.dim()}."
+ )
+ if query_embeddings.size(0) not in [1, document_embeddings.size(0)]:
+ raise ValueError(
+ "There should be either one query or queries equal to the number of documents."
+ )
+
+ # Transpose the query embeddings to align for matrix multiplication
+ transposed_query_embeddings = query_embeddings.permute(0, 2, 1)
+ # Compute similarity scores using batch matrix multiplication
+ computed_scores = torch.matmul(
+ document_embeddings, transposed_query_embeddings
+ )
+ # Apply max pooling to extract the highest semantic similarity across each document's sequence
+ maximum_scores = torch.max(computed_scores, dim=1).values
+
+ # Sum up the maximum scores across features to get the overall document relevance scores
+ final_scores = maximum_scores.sum(dim=1)
+
+ normalized_scores = torch.softmax(final_scores, dim=0)
+
+ return normalized_scores.detach().cpu().numpy().astype(np.float32)
+
+ def predict(self, sentences):
+
+ query = sentences[0][0]
+ docs = [i[1] for i in sentences]
+
+ # Embedding the documents
+ embedded_docs = self.ckpt.docFromText(docs, bsize=32)[0]
+ # Embedding the queries
+ embedded_queries = self.ckpt.queryFromText([query], bsize=32)
+ embedded_query = embedded_queries[0]
+
+ # Calculate retrieval scores for the query against all documents
+ scores = self.calculate_similarity_scores(
+ embedded_query.unsqueeze(0), embedded_docs
+ )
+
+ return scores
+
+ try:
+ app.state.sentence_transformer_rf = ColBERT(
+ get_model_path(reranking_model, auto_update)
+ )
+ except Exception as e:
+ log.error(f"ColBERT: {e}")
+ app.state.sentence_transformer_rf = None
+ app.state.config.ENABLE_RAG_HYBRID_SEARCH = False
+ else:
+ import sentence_transformers
+
+ try:
+ app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder(
+ get_model_path(reranking_model, auto_update),
+ device=DEVICE_TYPE,
+ trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
+ )
+ except:
+ log.error("CrossEncoder error")
+ app.state.sentence_transformer_rf = None
+ app.state.config.ENABLE_RAG_HYBRID_SEARCH = False
else:
app.state.sentence_transformer_rf = None
@@ -593,7 +690,7 @@ async def update_query_settings(
form_data: QuerySettingsForm, user=Depends(get_admin_user)
):
app.state.config.RAG_TEMPLATE = (
- form_data.template if form_data.template else RAG_TEMPLATE
+ form_data.template if form_data.template != "" else DEFAULT_RAG_TEMPLATE
)
app.state.config.TOP_K = form_data.k if form_data.k else 4
app.state.config.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0
@@ -998,14 +1095,11 @@ def store_docs_in_vector_db(
try:
if overwrite:
- for collection in CHROMA_CLIENT.list_collections():
- if collection_name == collection.name:
- log.info(f"deleting existing collection {collection_name}")
- CHROMA_CLIENT.delete_collection(name=collection_name)
+ if VECTOR_DB_CLIENT.has_collection(collection_name=collection_name):
+ log.info(f"deleting existing collection {collection_name}")
+ VECTOR_DB_CLIENT.delete_collection(collection_name=collection_name)
- collection = CHROMA_CLIENT.create_collection(name=collection_name)
-
- embedding_func = get_embedding_function(
+ embedding_function = get_embedding_function(
app.state.config.RAG_EMBEDDING_ENGINE,
app.state.config.RAG_EMBEDDING_MODEL,
app.state.sentence_transformer_ef,
@@ -1014,17 +1108,18 @@ def store_docs_in_vector_db(
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
)
- embedding_texts = list(map(lambda x: x.replace("\n", " "), texts))
- embeddings = embedding_func(embedding_texts)
-
- for batch in create_batches(
- api=CHROMA_CLIENT,
- ids=[str(uuid.uuid4()) for _ in texts],
- metadatas=metadatas,
- embeddings=embeddings,
- documents=texts,
- ):
- collection.add(*batch)
+ VECTOR_DB_CLIENT.insert(
+ collection_name=collection_name,
+ items=[
+ {
+ "id": str(uuid.uuid4()),
+ "text": text,
+ "vector": embedding_function(text.replace("\n", " ")),
+ "metadata": metadatas[idx],
+ }
+ for idx, text in enumerate(texts)
+ ],
+ )
return True
except Exception as e:
@@ -1158,7 +1253,7 @@ def get_loader(filename: str, file_content_type: str, file_path: str):
elif (
file_content_type
== "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
- or file_ext in ["doc", "docx"]
+ or file_ext == "docx"
):
loader = Docx2txtLoader(file_path)
elif file_content_type in [
@@ -1396,7 +1491,7 @@ def scan_docs_dir(user=Depends(get_admin_user)):
@app.post("/reset/db")
def reset_vector_db(user=Depends(get_admin_user)):
- CHROMA_CLIENT.reset()
+ VECTOR_DB_CLIENT.reset()
@app.post("/reset/uploads")
@@ -1437,7 +1532,7 @@ def reset(user=Depends(get_admin_user)) -> bool:
log.error("Failed to delete %s. Reason: %s" % (file_path, e))
try:
- CHROMA_CLIENT.reset()
+ VECTOR_DB_CLIENT.reset()
except Exception as e:
log.exception(e)
diff --git a/backend/open_webui/apps/rag/utils.py b/backend/open_webui/apps/rag/utils.py
index 2bf8a02e4..73ccfad38 100644
--- a/backend/open_webui/apps/rag/utils.py
+++ b/backend/open_webui/apps/rag/utils.py
@@ -1,24 +1,68 @@
import logging
import os
+import uuid
from typing import Optional, Union
import requests
-from open_webui.apps.ollama.main import (
- GenerateEmbeddingsForm,
- generate_ollama_embeddings,
-)
-from open_webui.config import CHROMA_CLIENT
-from open_webui.env import SRC_LOG_LEVELS
+
from huggingface_hub import snapshot_download
from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriever
from langchain_community.retrievers import BM25Retriever
from langchain_core.documents import Document
+
+
+from open_webui.apps.ollama.main import (
+ GenerateEmbeddingsForm,
+ generate_ollama_embeddings,
+)
+from open_webui.apps.rag.vector.connector import VECTOR_DB_CLIENT
from open_webui.utils.misc import get_last_user_message
+from open_webui.env import SRC_LOG_LEVELS
+
+
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
+from typing import Any
+
+from langchain_core.callbacks import CallbackManagerForRetrieverRun
+from langchain_core.retrievers import BaseRetriever
+
+
+class VectorSearchRetriever(BaseRetriever):
+ collection_name: Any
+ embedding_function: Any
+ top_k: int
+
+ def _get_relevant_documents(
+ self,
+ query: str,
+ *,
+ run_manager: CallbackManagerForRetrieverRun,
+ ) -> list[Document]:
+ result = VECTOR_DB_CLIENT.search(
+ collection_name=self.collection_name,
+ vectors=[self.embedding_function(query)],
+ limit=self.top_k,
+ )
+
+ ids = result.ids[0]
+ metadatas = result.metadatas[0]
+ documents = result.documents[0]
+
+ results = []
+ for idx in range(len(ids)):
+ results.append(
+ Document(
+ metadata=metadatas[idx],
+ page_content=documents[idx],
+ )
+ )
+ return results
+
+
def query_doc(
collection_name: str,
query: str,
@@ -26,17 +70,18 @@ def query_doc(
k: int,
):
try:
- collection = CHROMA_CLIENT.get_collection(name=collection_name)
- query_embeddings = embedding_function(query)
-
- result = collection.query(
- query_embeddings=[query_embeddings],
- n_results=k,
+ result = VECTOR_DB_CLIENT.search(
+ collection_name=collection_name,
+ vectors=[embedding_function(query)],
+ limit=k,
)
+ print("result", result)
+
log.info(f"query_doc:result {result}")
return result
except Exception as e:
+ print(e)
raise e
@@ -47,27 +92,25 @@ def query_doc_with_hybrid_search(
k: int,
reranking_function,
r: float,
-):
+) -> dict:
try:
- collection = CHROMA_CLIENT.get_collection(name=collection_name)
- documents = collection.get() # get all documents
+ result = VECTOR_DB_CLIENT.get(collection_name=collection_name)
bm25_retriever = BM25Retriever.from_texts(
- texts=documents.get("documents"),
- metadatas=documents.get("metadatas"),
+ texts=result.documents[0],
+ metadatas=result.metadatas[0],
)
bm25_retriever.k = k
- chroma_retriever = ChromaRetriever(
- collection=collection,
+ vector_search_retriever = VectorSearchRetriever(
+ collection_name=collection_name,
embedding_function=embedding_function,
- top_n=k,
+ top_k=k,
)
ensemble_retriever = EnsembleRetriever(
- retrievers=[bm25_retriever, chroma_retriever], weights=[0.5, 0.5]
+ retrievers=[bm25_retriever, vector_search_retriever], weights=[0.5, 0.5]
)
-
compressor = RerankCompressor(
embedding_function=embedding_function,
top_n=k,
@@ -92,7 +135,9 @@ def query_doc_with_hybrid_search(
raise e
-def merge_and_sort_query_results(query_results, k, reverse=False):
+def merge_and_sort_query_results(
+ query_results: list[dict], k: int, reverse: bool = False
+) -> list[dict]:
# Initialize lists to store combined data
combined_distances = []
combined_documents = []
@@ -138,7 +183,7 @@ def query_collection(
query: str,
embedding_function,
k: int,
-):
+) -> dict:
results = []
for collection_name in collection_names:
if collection_name:
@@ -149,9 +194,9 @@ def query_collection(
k=k,
embedding_function=embedding_function,
)
- results.append(result)
- except Exception:
- pass
+ results.append(result.model_dump())
+ except Exception as e:
+ log.exception(f"Error when querying the collection: {e}")
else:
pass
@@ -165,8 +210,9 @@ def query_collection_with_hybrid_search(
k: int,
reranking_function,
r: float,
-):
+) -> dict:
results = []
+ error = False
for collection_name in collection_names:
try:
result = query_doc_with_hybrid_search(
@@ -178,14 +224,39 @@ def query_collection_with_hybrid_search(
r=r,
)
results.append(result)
- except Exception:
- pass
+ except Exception as e:
+ log.exception(
+ "Error when querying the collection with " f"hybrid_search: {e}"
+ )
+ error = True
+
+ if error:
+ raise Exception(
+ "Hybrid search failed for all collections. Using Non hybrid search as fallback."
+ )
+
return merge_and_sort_query_results(results, k=k, reverse=True)
def rag_template(template: str, context: str, query: str):
- template = template.replace("[context]", context)
- template = template.replace("[query]", query)
+ count = template.count("[context]")
+ assert "[context]" in template, "RAG template does not contain '[context]'"
+
+ if "" in context and "" in context:
+ log.debug(
+ "WARNING: Potential prompt injection attack: the RAG "
+ "context contains '' and ''. This might be "
+ "nothing, or the user might be trying to hack something."
+ )
+
+ if "[query]" in context:
+ query_placeholder = f"[query-{str(uuid.uuid4())}]"
+ template = template.replace("[query]", query_placeholder)
+ template = template.replace("[context]", context)
+ template = template.replace(query_placeholder, query)
+ else:
+ template = template.replace("[context]", context)
+ template = template.replace("[query]", query)
return template
@@ -262,19 +333,27 @@ def get_rag_context(
continue
try:
+ context = None
if file["type"] == "text":
context = file["content"]
else:
if hybrid_search:
- context = query_collection_with_hybrid_search(
- collection_names=collection_names,
- query=query,
- embedding_function=embedding_function,
- k=k,
- reranking_function=reranking_function,
- r=r,
- )
- else:
+ try:
+ context = query_collection_with_hybrid_search(
+ collection_names=collection_names,
+ query=query,
+ embedding_function=embedding_function,
+ k=k,
+ reranking_function=reranking_function,
+ r=r,
+ )
+ except Exception as e:
+ log.debug(
+ "Error when using hybrid search, using"
+ " non hybrid search as fallback."
+ )
+
+ if (not hybrid_search) or (context is None):
context = query_collection(
collection_names=collection_names,
query=query,
@@ -283,7 +362,6 @@ def get_rag_context(
)
except Exception as e:
log.exception(e)
- context = None
if context:
relevant_contexts.append({**context, "source": file})
@@ -391,51 +469,11 @@ def generate_openai_batch_embeddings(
return None
-from typing import Any
-
-from langchain_core.callbacks import CallbackManagerForRetrieverRun
-from langchain_core.retrievers import BaseRetriever
-
-
-class ChromaRetriever(BaseRetriever):
- collection: Any
- embedding_function: Any
- top_n: int
-
- def _get_relevant_documents(
- self,
- query: str,
- *,
- run_manager: CallbackManagerForRetrieverRun,
- ) -> list[Document]:
- query_embeddings = self.embedding_function(query)
-
- results = self.collection.query(
- query_embeddings=[query_embeddings],
- n_results=self.top_n,
- )
-
- ids = results["ids"][0]
- metadatas = results["metadatas"][0]
- documents = results["documents"][0]
-
- results = []
- for idx in range(len(ids)):
- results.append(
- Document(
- metadata=metadatas[idx],
- page_content=documents[idx],
- )
- )
- return results
-
-
import operator
from typing import Optional, Sequence
from langchain_core.callbacks import Callbacks
from langchain_core.documents import BaseDocumentCompressor, Document
-from langchain_core.pydantic_v1 import Extra
class RerankCompressor(BaseDocumentCompressor):
@@ -445,7 +483,7 @@ class RerankCompressor(BaseDocumentCompressor):
r_score: float
class Config:
- extra = Extra.forbid
+ extra = "forbid"
arbitrary_types_allowed = True
def compress_documents(
diff --git a/backend/open_webui/apps/rag/vector/connector.py b/backend/open_webui/apps/rag/vector/connector.py
new file mode 100644
index 000000000..073becdbe
--- /dev/null
+++ b/backend/open_webui/apps/rag/vector/connector.py
@@ -0,0 +1,10 @@
+from open_webui.apps.rag.vector.dbs.chroma import ChromaClient
+from open_webui.apps.rag.vector.dbs.milvus import MilvusClient
+
+
+from open_webui.config import VECTOR_DB
+
+if VECTOR_DB == "milvus":
+ VECTOR_DB_CLIENT = MilvusClient()
+else:
+ VECTOR_DB_CLIENT = ChromaClient()
diff --git a/backend/open_webui/apps/rag/vector/dbs/chroma.py b/backend/open_webui/apps/rag/vector/dbs/chroma.py
new file mode 100644
index 000000000..5f9420108
--- /dev/null
+++ b/backend/open_webui/apps/rag/vector/dbs/chroma.py
@@ -0,0 +1,122 @@
+import chromadb
+from chromadb import Settings
+from chromadb.utils.batch_utils import create_batches
+
+from typing import Optional
+
+from open_webui.apps.rag.vector.main import VectorItem, SearchResult, GetResult
+from open_webui.config import (
+ CHROMA_DATA_PATH,
+ CHROMA_HTTP_HOST,
+ CHROMA_HTTP_PORT,
+ CHROMA_HTTP_HEADERS,
+ CHROMA_HTTP_SSL,
+ CHROMA_TENANT,
+ CHROMA_DATABASE,
+)
+
+
+class ChromaClient:
+ def __init__(self):
+ if CHROMA_HTTP_HOST != "":
+ self.client = chromadb.HttpClient(
+ host=CHROMA_HTTP_HOST,
+ port=CHROMA_HTTP_PORT,
+ headers=CHROMA_HTTP_HEADERS,
+ ssl=CHROMA_HTTP_SSL,
+ tenant=CHROMA_TENANT,
+ database=CHROMA_DATABASE,
+ settings=Settings(allow_reset=True, anonymized_telemetry=False),
+ )
+ else:
+ self.client = chromadb.PersistentClient(
+ path=CHROMA_DATA_PATH,
+ settings=Settings(allow_reset=True, anonymized_telemetry=False),
+ tenant=CHROMA_TENANT,
+ database=CHROMA_DATABASE,
+ )
+
+ def has_collection(self, collection_name: str) -> bool:
+ # Check if the collection exists based on the collection name.
+ collections = self.client.list_collections()
+ return collection_name in [collection.name for collection in collections]
+
+ def delete_collection(self, collection_name: str):
+ # Delete the collection based on the collection name.
+ return self.client.delete_collection(name=collection_name)
+
+ def search(
+ self, collection_name: str, vectors: list[list[float | int]], limit: int
+ ) -> Optional[SearchResult]:
+ # Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
+ collection = self.client.get_collection(name=collection_name)
+ if collection:
+ result = collection.query(
+ query_embeddings=vectors,
+ n_results=limit,
+ )
+
+ return SearchResult(
+ **{
+ "ids": result["ids"],
+ "distances": result["distances"],
+ "documents": result["documents"],
+ "metadatas": result["metadatas"],
+ }
+ )
+ return None
+
+ def get(self, collection_name: str) -> Optional[GetResult]:
+ # Get all the items in the collection.
+ collection = self.client.get_collection(name=collection_name)
+ if collection:
+ result = collection.get()
+ return GetResult(
+ **{
+ "ids": [result["ids"]],
+ "documents": [result["documents"]],
+ "metadatas": [result["metadatas"]],
+ }
+ )
+ return None
+
+ def insert(self, collection_name: str, items: list[VectorItem]):
+ # Insert the items into the collection, if the collection does not exist, it will be created.
+ collection = self.client.get_or_create_collection(name=collection_name)
+
+ ids = [item["id"] for item in items]
+ documents = [item["text"] for item in items]
+ embeddings = [item["vector"] for item in items]
+ metadatas = [item["metadata"] for item in items]
+
+ for batch in create_batches(
+ api=self.client,
+ documents=documents,
+ embeddings=embeddings,
+ ids=ids,
+ metadatas=metadatas,
+ ):
+ collection.add(*batch)
+
+ def upsert(self, collection_name: str, items: list[VectorItem]):
+ # Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
+ collection = self.client.get_or_create_collection(name=collection_name)
+
+ ids = [item["id"] for item in items]
+ documents = [item["text"] for item in items]
+ embeddings = [item["vector"] for item in items]
+ metadatas = [item["metadata"] for item in items]
+
+ collection.upsert(
+ ids=ids, documents=documents, embeddings=embeddings, metadatas=metadatas
+ )
+
+ def delete(self, collection_name: str, ids: list[str]):
+ # Delete the items from the collection based on the ids.
+ collection = self.client.get_collection(name=collection_name)
+ if collection:
+ collection.delete(ids=ids)
+
+ def reset(self):
+ # Resets the database. This will delete all collections and item entries.
+ return self.client.reset()
diff --git a/backend/open_webui/apps/rag/vector/dbs/milvus.py b/backend/open_webui/apps/rag/vector/dbs/milvus.py
new file mode 100644
index 000000000..f205b9521
--- /dev/null
+++ b/backend/open_webui/apps/rag/vector/dbs/milvus.py
@@ -0,0 +1,205 @@
+from pymilvus import MilvusClient as Client
+from pymilvus import FieldSchema, DataType
+import json
+
+from typing import Optional
+
+from open_webui.apps.rag.vector.main import VectorItem, SearchResult, GetResult
+from open_webui.config import (
+ MILVUS_URI,
+)
+
+
+class MilvusClient:
+ def __init__(self):
+ self.collection_prefix = "open_webui"
+ self.client = Client(uri=MILVUS_URI)
+
+ def _result_to_get_result(self, result) -> GetResult:
+ print(result)
+
+ ids = []
+ documents = []
+ metadatas = []
+
+ for match in result:
+ _ids = []
+ _documents = []
+ _metadatas = []
+
+ for item in match:
+ _ids.append(item.get("id"))
+ _documents.append(item.get("data", {}).get("text"))
+ _metadatas.append(item.get("metadata"))
+
+ ids.append(_ids)
+ documents.append(_documents)
+ metadatas.append(_metadatas)
+
+ return GetResult(
+ **{
+ "ids": ids,
+ "documents": documents,
+ "metadatas": metadatas,
+ }
+ )
+
+ def _result_to_search_result(self, result) -> SearchResult:
+ print(result)
+
+ ids = []
+ distances = []
+ documents = []
+ metadatas = []
+
+ for match in result:
+ _ids = []
+ _distances = []
+ _documents = []
+ _metadatas = []
+
+ for item in match:
+ _ids.append(item.get("id"))
+ _distances.append(item.get("distance"))
+ _documents.append(item.get("entity", {}).get("data", {}).get("text"))
+ _metadatas.append(item.get("entity", {}).get("metadata"))
+
+ ids.append(_ids)
+ distances.append(_distances)
+ documents.append(_documents)
+ metadatas.append(_metadatas)
+
+ return SearchResult(
+ **{
+ "ids": ids,
+ "distances": distances,
+ "documents": documents,
+ "metadatas": metadatas,
+ }
+ )
+
+ def _create_collection(self, collection_name: str, dimension: int):
+ schema = self.client.create_schema(
+ auto_id=False,
+ enable_dynamic_field=True,
+ )
+ schema.add_field(
+ field_name="id",
+ datatype=DataType.VARCHAR,
+ is_primary=True,
+ max_length=65535,
+ )
+ schema.add_field(
+ field_name="vector",
+ datatype=DataType.FLOAT_VECTOR,
+ dim=dimension,
+ description="vector",
+ )
+ schema.add_field(field_name="data", datatype=DataType.JSON, description="data")
+ schema.add_field(
+ field_name="metadata", datatype=DataType.JSON, description="metadata"
+ )
+
+ index_params = self.client.prepare_index_params()
+ index_params.add_index(
+ field_name="vector", index_type="HNSW", metric_type="COSINE", params={}
+ )
+
+ self.client.create_collection(
+ collection_name=f"{self.collection_prefix}_{collection_name}",
+ schema=schema,
+ index_params=index_params,
+ )
+
+ def has_collection(self, collection_name: str) -> bool:
+ # Check if the collection exists based on the collection name.
+ return self.client.has_collection(
+ collection_name=f"{self.collection_prefix}_{collection_name}"
+ )
+
+ def delete_collection(self, collection_name: str):
+ # Delete the collection based on the collection name.
+ return self.client.drop_collection(
+ collection_name=f"{self.collection_prefix}_{collection_name}"
+ )
+
+ def search(
+ self, collection_name: str, vectors: list[list[float | int]], limit: int
+ ) -> Optional[SearchResult]:
+ # Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
+ result = self.client.search(
+ collection_name=f"{self.collection_prefix}_{collection_name}",
+ data=vectors,
+ limit=limit,
+ output_fields=["data", "metadata"],
+ )
+
+ return self._result_to_search_result(result)
+
+ def get(self, collection_name: str) -> Optional[GetResult]:
+ # Get all the items in the collection.
+ result = self.client.query(
+ collection_name=f"{self.collection_prefix}_{collection_name}",
+ filter='id != ""',
+ )
+ return self._result_to_get_result([result])
+
+ def insert(self, collection_name: str, items: list[VectorItem]):
+ # Insert the items into the collection, if the collection does not exist, it will be created.
+ if not self.client.has_collection(
+ collection_name=f"{self.collection_prefix}_{collection_name}"
+ ):
+ self._create_collection(
+ collection_name=collection_name, dimension=len(items[0]["vector"])
+ )
+
+ return self.client.insert(
+ collection_name=f"{self.collection_prefix}_{collection_name}",
+ data=[
+ {
+ "id": item["id"],
+ "vector": item["vector"],
+ "data": {"text": item["text"]},
+ "metadata": item["metadata"],
+ }
+ for item in items
+ ],
+ )
+
+ def upsert(self, collection_name: str, items: list[VectorItem]):
+ # Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
+ if not self.client.has_collection(
+ collection_name=f"{self.collection_prefix}_{collection_name}"
+ ):
+ self._create_collection(
+ collection_name=collection_name, dimension=len(items[0]["vector"])
+ )
+
+ return self.client.upsert(
+ collection_name=f"{self.collection_prefix}_{collection_name}",
+ data=[
+ {
+ "id": item["id"],
+ "vector": item["vector"],
+ "data": {"text": item["text"]},
+ "metadata": item["metadata"],
+ }
+ for item in items
+ ],
+ )
+
+ def delete(self, collection_name: str, ids: list[str]):
+ # Delete the items from the collection based on the ids.
+
+ return self.client.delete(
+ collection_name=f"{self.collection_prefix}_{collection_name}",
+ ids=ids,
+ )
+
+ def reset(self):
+ # Resets the database. This will delete all collections and item entries.
+
+ collection_names = self.client.list_collections()
+ for collection_name in collection_names:
+ if collection_name.startswith(self.collection_prefix):
+ self.client.drop_collection(collection_name=collection_name)
diff --git a/backend/open_webui/apps/rag/vector/main.py b/backend/open_webui/apps/rag/vector/main.py
new file mode 100644
index 000000000..f0cf0c038
--- /dev/null
+++ b/backend/open_webui/apps/rag/vector/main.py
@@ -0,0 +1,19 @@
+from pydantic import BaseModel
+from typing import Optional, List, Any
+
+
+class VectorItem(BaseModel):
+ id: str
+ text: str
+ vector: List[float | int]
+ metadata: Any
+
+
+class GetResult(BaseModel):
+ ids: Optional[List[List[str]]]
+ documents: Optional[List[List[str]]]
+ metadatas: Optional[List[List[Any]]]
+
+
+class SearchResult(GetResult):
+ distances: Optional[List[List[float | int]]]
diff --git a/backend/open_webui/apps/socket/main.py b/backend/open_webui/apps/socket/main.py
index 5985bc524..e41ef8412 100644
--- a/backend/open_webui/apps/socket/main.py
+++ b/backend/open_webui/apps/socket/main.py
@@ -2,9 +2,16 @@ import asyncio
import socketio
from open_webui.apps.webui.models.users import Users
+from open_webui.env import ENABLE_WEBSOCKET_SUPPORT
from open_webui.utils.utils import decode_token
-sio = socketio.AsyncServer(cors_allowed_origins=[], async_mode="asgi")
+sio = socketio.AsyncServer(
+ cors_allowed_origins=[],
+ async_mode="asgi",
+ transports=(["polling", "websocket"] if ENABLE_WEBSOCKET_SUPPORT else ["polling"]),
+ allow_upgrades=ENABLE_WEBSOCKET_SUPPORT,
+ always_connect=True,
+)
app = socketio.ASGIApp(sio, socketio_path="/ws/socket.io")
# Dictionary to maintain the user pool
@@ -32,7 +39,7 @@ async def connect(sid, environ, auth):
else:
USER_POOL[user.id] = [sid]
- print(f"user {user.name}({user.id}) connected with session ID {sid}")
+ # print(f"user {user.name}({user.id}) connected with session ID {sid}")
await sio.emit("user-count", {"count": len(set(USER_POOL))})
await sio.emit("usage", {"models": get_models_in_use()})
@@ -40,7 +47,7 @@ async def connect(sid, environ, auth):
@sio.on("user-join")
async def user_join(sid, data):
- print("user-join", sid, data)
+ # print("user-join", sid, data)
auth = data["auth"] if "auth" in data else None
if not auth or "token" not in auth:
@@ -60,7 +67,7 @@ async def user_join(sid, data):
else:
USER_POOL[user.id] = [sid]
- print(f"user {user.name}({user.id}) connected with session ID {sid}")
+ # print(f"user {user.name}({user.id}) connected with session ID {sid}")
await sio.emit("user-count", {"count": len(set(USER_POOL))})
@@ -109,7 +116,7 @@ async def remove_after_timeout(sid, model_id):
try:
await asyncio.sleep(TIMEOUT_DURATION)
if model_id in USAGE_POOL:
- print(USAGE_POOL[model_id]["sids"])
+ # print(USAGE_POOL[model_id]["sids"])
USAGE_POOL[model_id]["sids"].remove(sid)
USAGE_POOL[model_id]["sids"] = list(set(USAGE_POOL[model_id]["sids"]))
@@ -136,7 +143,8 @@ async def disconnect(sid):
await sio.emit("user-count", {"count": len(USER_POOL)})
else:
- print(f"Unknown session ID {sid} disconnected")
+ pass
+ # print(f"Unknown session ID {sid} disconnected")
def get_event_emitter(request_info):
diff --git a/backend/open_webui/apps/webui/routers/auths.py b/backend/open_webui/apps/webui/routers/auths.py
index 2366841e1..bfa460836 100644
--- a/backend/open_webui/apps/webui/routers/auths.py
+++ b/backend/open_webui/apps/webui/routers/auths.py
@@ -190,8 +190,8 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
async def signup(request: Request, response: Response, form_data: SignupForm):
if (
not request.app.state.config.ENABLE_SIGNUP
- and request.app.state.config.ENABLE_LOGIN_FORM
- and WEBUI_AUTH
+ or not request.app.state.config.ENABLE_LOGIN_FORM
+ or not WEBUI_AUTH
):
raise HTTPException(
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
diff --git a/backend/open_webui/apps/webui/routers/memories.py b/backend/open_webui/apps/webui/routers/memories.py
index 914b69e7e..d659833bc 100644
--- a/backend/open_webui/apps/webui/routers/memories.py
+++ b/backend/open_webui/apps/webui/routers/memories.py
@@ -1,12 +1,13 @@
+from fastapi import APIRouter, Depends, HTTPException, Request
+from pydantic import BaseModel
import logging
from typing import Optional
from open_webui.apps.webui.models.memories import Memories, MemoryModel
-from open_webui.config import CHROMA_CLIENT
-from open_webui.env import SRC_LOG_LEVELS
-from fastapi import APIRouter, Depends, HTTPException, Request
-from pydantic import BaseModel
+from open_webui.apps.rag.vector.connector import VECTOR_DB_CLIENT
from open_webui.utils.utils import get_verified_user
+from open_webui.env import SRC_LOG_LEVELS
+
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
@@ -49,14 +50,17 @@ async def add_memory(
user=Depends(get_verified_user),
):
memory = Memories.insert_new_memory(user.id, form_data.content)
- memory_embedding = request.app.state.EMBEDDING_FUNCTION(memory.content)
- collection = CHROMA_CLIENT.get_or_create_collection(name=f"user-memory-{user.id}")
- collection.upsert(
- documents=[memory.content],
- ids=[memory.id],
- embeddings=[memory_embedding],
- metadatas=[{"created_at": memory.created_at}],
+ VECTOR_DB_CLIENT.upsert(
+ collection_name=f"user-memory-{user.id}",
+ items=[
+ {
+ "id": memory.id,
+ "text": memory.content,
+ "vector": request.app.state.EMBEDDING_FUNCTION(memory.content),
+ "metadata": {"created_at": memory.created_at},
+ }
+ ],
)
return memory
@@ -76,12 +80,10 @@ class QueryMemoryForm(BaseModel):
async def query_memory(
request: Request, form_data: QueryMemoryForm, user=Depends(get_verified_user)
):
- query_embedding = request.app.state.EMBEDDING_FUNCTION(form_data.content)
- collection = CHROMA_CLIENT.get_or_create_collection(name=f"user-memory-{user.id}")
-
- results = collection.query(
- query_embeddings=[query_embedding],
- n_results=form_data.k, # how many results to return
+ results = VECTOR_DB_CLIENT.search(
+ collection_name=f"user-memory-{user.id}",
+ vectors=[request.app.state.EMBEDDING_FUNCTION(form_data.content)],
+ limit=form_data.k,
)
return results
@@ -94,17 +96,25 @@ async def query_memory(
async def reset_memory_from_vector_db(
request: Request, user=Depends(get_verified_user)
):
- CHROMA_CLIENT.delete_collection(f"user-memory-{user.id}")
- collection = CHROMA_CLIENT.get_or_create_collection(name=f"user-memory-{user.id}")
+ VECTOR_DB_CLIENT.delete_collection(f"user-memory-{user.id}")
memories = Memories.get_memories_by_user_id(user.id)
- for memory in memories:
- memory_embedding = request.app.state.EMBEDDING_FUNCTION(memory.content)
- collection.upsert(
- documents=[memory.content],
- ids=[memory.id],
- embeddings=[memory_embedding],
- )
+ VECTOR_DB_CLIENT.upsert(
+ collection_name=f"user-memory-{user.id}",
+ items=[
+ {
+ "id": memory.id,
+ "text": memory.content,
+ "vector": request.app.state.EMBEDDING_FUNCTION(memory.content),
+ "metadata": {
+ "created_at": memory.created_at,
+ "updated_at": memory.updated_at,
+ },
+ }
+ for memory in memories
+ ],
+ )
+
return True
@@ -119,7 +129,7 @@ async def delete_memory_by_user_id(user=Depends(get_verified_user)):
if result:
try:
- CHROMA_CLIENT.delete_collection(f"user-memory-{user.id}")
+ VECTOR_DB_CLIENT.delete_collection(f"user-memory-{user.id}")
except Exception as e:
log.error(e)
return True
@@ -144,16 +154,18 @@ async def update_memory_by_id(
raise HTTPException(status_code=404, detail="Memory not found")
if form_data.content is not None:
- memory_embedding = request.app.state.EMBEDDING_FUNCTION(form_data.content)
- collection = CHROMA_CLIENT.get_or_create_collection(
- name=f"user-memory-{user.id}"
- )
- collection.upsert(
- documents=[form_data.content],
- ids=[memory.id],
- embeddings=[memory_embedding],
- metadatas=[
- {"created_at": memory.created_at, "updated_at": memory.updated_at}
+ VECTOR_DB_CLIENT.upsert(
+ collection_name=f"user-memory-{user.id}",
+ items=[
+ {
+ "id": memory.id,
+ "text": memory.content,
+ "vector": request.app.state.EMBEDDING_FUNCTION(memory.content),
+ "metadata": {
+ "created_at": memory.created_at,
+ "updated_at": memory.updated_at,
+ },
+ }
],
)
@@ -170,10 +182,9 @@ async def delete_memory_by_id(memory_id: str, user=Depends(get_verified_user)):
result = Memories.delete_memory_by_id_and_user_id(memory_id, user.id)
if result:
- collection = CHROMA_CLIENT.get_or_create_collection(
- name=f"user-memory-{user.id}"
+ VECTOR_DB_CLIENT.delete(
+ collection_name=f"user-memory-{user.id}", ids=[memory_id]
)
- collection.delete(ids=[memory_id])
return True
return False
diff --git a/backend/open_webui/apps/webui/routers/models.py b/backend/open_webui/apps/webui/routers/models.py
index a99c65d76..a5cb2395e 100644
--- a/backend/open_webui/apps/webui/routers/models.py
+++ b/backend/open_webui/apps/webui/routers/models.py
@@ -18,8 +18,18 @@ router = APIRouter()
@router.get("/", response_model=list[ModelResponse])
-async def get_models(user=Depends(get_verified_user)):
- return Models.get_all_models()
+async def get_models(id: Optional[str] = None, user=Depends(get_verified_user)):
+ if id:
+ model = Models.get_model_by_id(id)
+ if model:
+ return [model]
+ else:
+ raise HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ detail=ERROR_MESSAGES.NOT_FOUND,
+ )
+ else:
+ return Models.get_all_models()
############################
@@ -50,24 +60,6 @@ async def add_new_model(
)
-############################
-# GetModelById
-############################
-
-
-@router.get("/", response_model=Optional[ModelModel])
-async def get_model_by_id(id: str, user=Depends(get_verified_user)):
- model = Models.get_model_by_id(id)
-
- if model:
- return model
- else:
- raise HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED,
- detail=ERROR_MESSAGES.NOT_FOUND,
- )
-
-
############################
# UpdateModelById
############################
diff --git a/backend/open_webui/apps/webui/utils.py b/backend/open_webui/apps/webui/utils.py
index 2d537af51..8bf48e400 100644
--- a/backend/open_webui/apps/webui/utils.py
+++ b/backend/open_webui/apps/webui/utils.py
@@ -4,7 +4,7 @@ import subprocess
import sys
from importlib import util
import types
-
+import tempfile
from open_webui.apps.webui.models.functions import Functions
from open_webui.apps.webui.models.tools import Tools
@@ -84,7 +84,15 @@ def load_toolkit_module_by_id(toolkit_id, content=None):
module = types.ModuleType(module_name)
sys.modules[module_name] = module
+ # Create a temporary file and use it to define `__file__` so
+ # that it works as expected from the module's perspective.
+ temp_file = tempfile.NamedTemporaryFile(delete=False)
+
try:
+ with open(temp_file.name, "w", encoding="utf-8") as f:
+ f.write(content)
+ module.__dict__["__file__"] = temp_file.name
+
# Executing the modified content in the created module's namespace
exec(content, module.__dict__)
frontmatter = extract_frontmatter(content)
@@ -96,9 +104,11 @@ def load_toolkit_module_by_id(toolkit_id, content=None):
else:
raise Exception("No Tools class found in the module")
except Exception as e:
- print(f"Error loading module: {toolkit_id}")
+ print(f"Error loading module: {toolkit_id}: {e}")
del sys.modules[module_name] # Clean up
raise e
+ finally:
+ os.unlink(temp_file.name)
def load_function_module_by_id(function_id, content=None):
@@ -118,7 +128,14 @@ def load_function_module_by_id(function_id, content=None):
module = types.ModuleType(module_name)
sys.modules[module_name] = module
+ # Create a temporary file and use it to define `__file__` so
+ # that it works as expected from the module's perspective.
+ temp_file = tempfile.NamedTemporaryFile(delete=False)
try:
+ with open(temp_file.name, "w", encoding="utf-8") as f:
+ f.write(content)
+ module.__dict__["__file__"] = temp_file.name
+
# Execute the modified content in the created module's namespace
exec(content, module.__dict__)
frontmatter = extract_frontmatter(content)
@@ -134,11 +151,13 @@ def load_function_module_by_id(function_id, content=None):
else:
raise Exception("No Function class found in the module")
except Exception as e:
- print(f"Error loading module: {function_id}")
+ print(f"Error loading module: {function_id}: {e}")
del sys.modules[module_name] # Cleanup by removing the module in case of error
Functions.update_function_by_id(function_id, {"is_active": False})
raise e
+ finally:
+ os.unlink(temp_file.name)
def install_frontmatter_requirements(requirements):
diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py
index 5ccb40d47..86d8a47a3 100644
--- a/backend/open_webui/config.py
+++ b/backend/open_webui/config.py
@@ -11,7 +11,6 @@ import chromadb
import requests
import yaml
from open_webui.apps.webui.internal.db import Base, get_db
-from chromadb import Settings
from open_webui.env import (
OPEN_WEBUI_DIR,
DATA_DIR,
@@ -540,40 +539,6 @@ Path(TOOLS_DIR).mkdir(parents=True, exist_ok=True)
FUNCTIONS_DIR = os.getenv("FUNCTIONS_DIR", f"{DATA_DIR}/functions")
Path(FUNCTIONS_DIR).mkdir(parents=True, exist_ok=True)
-
-####################################
-# LITELLM_CONFIG
-####################################
-
-
-def create_config_file(file_path):
- directory = os.path.dirname(file_path)
-
- # Check if directory exists, if not, create it
- if not os.path.exists(directory):
- os.makedirs(directory)
-
- # Data to write into the YAML file
- config_data = {
- "general_settings": {},
- "litellm_settings": {},
- "model_list": [],
- "router_settings": {},
- }
-
- # Write data to YAML file
- with open(file_path, "w") as file:
- yaml.dump(config_data, file)
-
-
-LITELLM_CONFIG_PATH = f"{DATA_DIR}/litellm/config.yaml"
-
-# if not os.path.exists(LITELLM_CONFIG_PATH):
-# log.info("Config file doesn't exist. Creating...")
-# create_config_file(LITELLM_CONFIG_PATH)
-# log.info("Config file created successfully.")
-
-
####################################
# OLLAMA_BASE_URL
####################################
@@ -923,25 +888,12 @@ TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = PersistentConfig(
####################################
-# RAG document content extraction
+# Vector Database
####################################
-CONTENT_EXTRACTION_ENGINE = PersistentConfig(
- "CONTENT_EXTRACTION_ENGINE",
- "rag.CONTENT_EXTRACTION_ENGINE",
- os.environ.get("CONTENT_EXTRACTION_ENGINE", "").lower(),
-)
-
-TIKA_SERVER_URL = PersistentConfig(
- "TIKA_SERVER_URL",
- "rag.tika_server_url",
- os.getenv("TIKA_SERVER_URL", "http://tika:9998"), # Default for sidecar deployment
-)
-
-####################################
-# RAG
-####################################
+VECTOR_DB = os.environ.get("VECTOR_DB", "chroma")
+# Chroma
CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db"
CHROMA_TENANT = os.environ.get("CHROMA_TENANT", chromadb.DEFAULT_TENANT)
CHROMA_DATABASE = os.environ.get("CHROMA_DATABASE", chromadb.DEFAULT_DATABASE)
@@ -958,8 +910,29 @@ else:
CHROMA_HTTP_SSL = os.environ.get("CHROMA_HTTP_SSL", "false").lower() == "true"
# this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (sentence-transformers/all-MiniLM-L6-v2)
+# Milvus
+
+MILVUS_URI = os.environ.get("MILVUS_URI", f"{DATA_DIR}/vector_db/milvus.db")
+
+####################################
+# RAG
+####################################
+
+# RAG Content Extraction
+CONTENT_EXTRACTION_ENGINE = PersistentConfig(
+ "CONTENT_EXTRACTION_ENGINE",
+ "rag.CONTENT_EXTRACTION_ENGINE",
+ os.environ.get("CONTENT_EXTRACTION_ENGINE", "").lower(),
+)
+
+TIKA_SERVER_URL = PersistentConfig(
+ "TIKA_SERVER_URL",
+ "rag.tika_server_url",
+ os.getenv("TIKA_SERVER_URL", "http://tika:9998"), # Default for sidecar deployment
+)
+
RAG_TOP_K = PersistentConfig(
- "RAG_TOP_K", "rag.top_k", int(os.environ.get("RAG_TOP_K", "5"))
+ "RAG_TOP_K", "rag.top_k", int(os.environ.get("RAG_TOP_K", "3"))
)
RAG_RELEVANCE_THRESHOLD = PersistentConfig(
"RAG_RELEVANCE_THRESHOLD",
@@ -1048,36 +1021,8 @@ RAG_RERANKING_MODEL_TRUST_REMOTE_CODE = (
os.environ.get("RAG_RERANKING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true"
)
-
-if CHROMA_HTTP_HOST != "":
- CHROMA_CLIENT = chromadb.HttpClient(
- host=CHROMA_HTTP_HOST,
- port=CHROMA_HTTP_PORT,
- headers=CHROMA_HTTP_HEADERS,
- ssl=CHROMA_HTTP_SSL,
- tenant=CHROMA_TENANT,
- database=CHROMA_DATABASE,
- settings=Settings(allow_reset=True, anonymized_telemetry=False),
- )
-else:
- CHROMA_CLIENT = chromadb.PersistentClient(
- path=CHROMA_DATA_PATH,
- settings=Settings(allow_reset=True, anonymized_telemetry=False),
- tenant=CHROMA_TENANT,
- database=CHROMA_DATABASE,
- )
-
-
-# device type embedding models - "cpu" (default), "cuda" (nvidia gpu required) or "mps" (apple silicon) - choosing this right can lead to better performance
-USE_CUDA = os.environ.get("USE_CUDA_DOCKER", "false")
-
-if USE_CUDA.lower() == "true":
- DEVICE_TYPE = "cuda"
-else:
- DEVICE_TYPE = "cpu"
-
CHUNK_SIZE = PersistentConfig(
- "CHUNK_SIZE", "rag.chunk_size", int(os.environ.get("CHUNK_SIZE", "1500"))
+ "CHUNK_SIZE", "rag.chunk_size", int(os.environ.get("CHUNK_SIZE", "1000"))
)
CHUNK_OVERLAP = PersistentConfig(
"CHUNK_OVERLAP",
@@ -1085,19 +1030,25 @@ CHUNK_OVERLAP = PersistentConfig(
int(os.environ.get("CHUNK_OVERLAP", "100")),
)
-DEFAULT_RAG_TEMPLATE = """Use the following context as your learned knowledge, inside XML tags.
+DEFAULT_RAG_TEMPLATE = """You are given a user query, some textual context and rules, all inside xml tags. You have to answer the query based on the context while respecting the rules.
+
- [context]
+[context]
-When answer to user:
-- If you don't know, just say that you don't know.
-- If you don't know when you are not sure, ask for clarification.
-Avoid mentioning that you obtained the information from the context.
-And answer according to the language of the user's question.
+
+- If you don't know, just say so.
+- If you are not sure, ask for clarification.
+- Answer in the same language as the user query.
+- If the context appears unreadable or of poor quality, tell the user then answer as best as you can.
+- If the answer is not in the context but you think you know the answer, explain that to the user then answer with your own knowledge.
+- Answer directly and without using xml tags.
+
-Given the context information, answer the query.
-Query: [query]"""
+
+[query]
+
+"""
RAG_TEMPLATE = PersistentConfig(
"RAG_TEMPLATE",
@@ -1267,6 +1218,37 @@ AUTOMATIC1111_API_AUTH = PersistentConfig(
os.getenv("AUTOMATIC1111_API_AUTH", ""),
)
+AUTOMATIC1111_CFG_SCALE = PersistentConfig(
+ "AUTOMATIC1111_CFG_SCALE",
+ "image_generation.automatic1111.cfg_scale",
+ (
+ float(os.environ.get("AUTOMATIC1111_CFG_SCALE"))
+ if os.environ.get("AUTOMATIC1111_CFG_SCALE")
+ else None
+ ),
+)
+
+
+AUTOMATIC1111_SAMPLER = PersistentConfig(
+ "AUTOMATIC1111_SAMPLERE",
+ "image_generation.automatic1111.sampler",
+ (
+ os.environ.get("AUTOMATIC1111_SAMPLER")
+ if os.environ.get("AUTOMATIC1111_SAMPLER")
+ else None
+ ),
+)
+
+AUTOMATIC1111_SCHEDULER = PersistentConfig(
+ "AUTOMATIC1111_SCHEDULER",
+ "image_generation.automatic1111.scheduler",
+ (
+ os.environ.get("AUTOMATIC1111_SCHEDULER")
+ if os.environ.get("AUTOMATIC1111_SCHEDULER")
+ else None
+ ),
+)
+
COMFYUI_BASE_URL = PersistentConfig(
"COMFYUI_BASE_URL",
"image_generation.comfyui.base_url",
@@ -1490,3 +1472,17 @@ AUDIO_TTS_SPLIT_ON = PersistentConfig(
"audio.tts.split_on",
os.getenv("AUDIO_TTS_SPLIT_ON", "punctuation"),
)
+
+AUDIO_TTS_AZURE_SPEECH_REGION = PersistentConfig(
+ "AUDIO_TTS_AZURE_SPEECH_REGION",
+ "audio.tts.azure.speech_region",
+ os.getenv("AUDIO_TTS_AZURE_SPEECH_REGION", "eastus"),
+)
+
+AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT = PersistentConfig(
+ "AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT",
+ "audio.tts.azure.speech_output_format",
+ os.getenv(
+ "AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT", "audio-24khz-160kbitrate-mono-mp3"
+ ),
+)
diff --git a/backend/open_webui/env.py b/backend/open_webui/env.py
index b716769c2..89422e57b 100644
--- a/backend/open_webui/env.py
+++ b/backend/open_webui/env.py
@@ -31,6 +31,28 @@ try:
except ImportError:
print("dotenv not installed, skipping...")
+DOCKER = os.environ.get("DOCKER", "False").lower() == "true"
+
+# device type embedding models - "cpu" (default), "cuda" (nvidia gpu required) or "mps" (apple silicon) - choosing this right can lead to better performance
+USE_CUDA = os.environ.get("USE_CUDA_DOCKER", "false")
+
+if USE_CUDA.lower() == "true":
+ try:
+ import torch
+
+ assert torch.cuda.is_available(), "CUDA not available"
+ DEVICE_TYPE = "cuda"
+ except Exception as e:
+ cuda_error = (
+ "Error when testing CUDA but USE_CUDA_DOCKER is true. "
+ f"Resetting USE_CUDA_DOCKER to false: {e}"
+ )
+ os.environ["USE_CUDA_DOCKER"] = "false"
+ USE_CUDA = "false"
+ DEVICE_TYPE = "cpu"
+else:
+ DEVICE_TYPE = "cpu"
+
####################################
# LOGGING
@@ -47,6 +69,9 @@ else:
log = logging.getLogger(__name__)
log.info(f"GLOBAL_LOG_LEVEL: {GLOBAL_LOG_LEVEL}")
+if "cuda_error" in locals():
+ log.exception(cuda_error)
+
log_sources = [
"AUDIO",
"COMFYUI",
@@ -273,3 +298,7 @@ WEBUI_SESSION_COOKIE_SECURE = os.environ.get(
if WEBUI_AUTH and WEBUI_SECRET_KEY == "":
raise ValueError(ERROR_MESSAGES.ENV_VAR_NOT_FOUND)
+
+ENABLE_WEBSOCKET_SUPPORT = (
+ os.environ.get("ENABLE_WEBSOCKET_SUPPORT", "True").lower() == "true"
+)
diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py
index 8914cb491..319d95165 100644
--- a/backend/open_webui/main.py
+++ b/backend/open_webui/main.py
@@ -109,6 +109,7 @@ from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.sessions import SessionMiddleware
from starlette.responses import RedirectResponse, Response, StreamingResponse
+from open_webui.utils.security_headers import SecurityHeadersMiddleware
from open_webui.utils.misc import (
add_or_update_system_message,
@@ -586,8 +587,17 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
if len(contexts) > 0:
context_string = "/n".join(contexts).strip()
prompt = get_last_user_message(body["messages"])
+
if prompt is None:
raise Exception("No user message found")
+ if (
+ rag_app.state.config.RELEVANCE_THRESHOLD == 0
+ and context_string.strip() == ""
+ ):
+ log.debug(
+ f"With a 0 relevancy threshold for RAG, the context cannot be empty"
+ )
+
# Workaround for Ollama 2.0+ system prompt issue
# TODO: replace with add_or_update_system_message
if model["owned_by"] == "ollama":
@@ -780,6 +790,8 @@ app.add_middleware(
allow_headers=["*"],
)
+app.add_middleware(SecurityHeadersMiddleware)
+
@app.middleware("http")
async def commit_session_after_request(request: Request, call_next):
@@ -812,6 +824,24 @@ async def update_embedding_function(request: Request, call_next):
return response
+@app.middleware("http")
+async def inspect_websocket(request: Request, call_next):
+ if (
+ "/ws/socket.io" in request.url.path
+ and request.query_params.get("transport") == "websocket"
+ ):
+ upgrade = (request.headers.get("Upgrade") or "").lower()
+ connection = (request.headers.get("Connection") or "").lower().split(",")
+ # Check that there's the correct headers for an upgrade, else reject the connection
+ # This is to work around this upstream issue: https://github.com/miguelgrinberg/python-engineio/issues/367
+ if upgrade != "websocket" or "upgrade" not in connection:
+ return JSONResponse(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ content={"detail": "Invalid WebSocket upgrade request"},
+ )
+ return await call_next(request)
+
+
app.mount("/ws", socket_app)
app.mount("/ollama", ollama_app)
@@ -1368,9 +1398,9 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
# Check if the user has a custom task model
# If the user has a custom task model, use that model
- model_id = get_task_model_id(model_id)
+ task_model_id = get_task_model_id(model_id)
- print(model_id)
+ print(task_model_id)
if app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE != "":
template = app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE
@@ -1397,10 +1427,16 @@ Prompt: {{prompt:middletruncate:8000}}"""
)
payload = {
- "model": model_id,
+ "model": task_model_id,
"messages": [{"role": "user", "content": content}],
"stream": False,
- "max_tokens": 50,
+ **(
+ {"max_tokens": 50}
+ if app.state.MODELS[task_model_id]["owned_by"] == "ollama"
+ else {
+ "max_completion_tokens": 50,
+ }
+ ),
"chat_id": form_data.get("chat_id", None),
"metadata": {"task": str(TASKS.TITLE_GENERATION)},
}
@@ -1445,9 +1481,8 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
# Check if the user has a custom task model
# If the user has a custom task model, use that model
- model_id = get_task_model_id(model_id)
-
- print(model_id)
+ task_model_id = get_task_model_id(model_id)
+ print(task_model_id)
if app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE != "":
template = app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
@@ -1469,10 +1504,16 @@ Search Query:"""
print("content", content)
payload = {
- "model": model_id,
+ "model": task_model_id,
"messages": [{"role": "user", "content": content}],
"stream": False,
- "max_tokens": 30,
+ **(
+ {"max_tokens": 30}
+ if app.state.MODELS[task_model_id]["owned_by"] == "ollama"
+ else {
+ "max_completion_tokens": 30,
+ }
+ ),
"metadata": {"task": str(TASKS.QUERY_GENERATION)},
}
@@ -1511,9 +1552,8 @@ async def generate_emoji(form_data: dict, user=Depends(get_verified_user)):
# Check if the user has a custom task model
# If the user has a custom task model, use that model
- model_id = get_task_model_id(model_id)
-
- print(model_id)
+ task_model_id = get_task_model_id(model_id)
+ print(task_model_id)
template = '''
Your task is to reflect the speaker's likely facial expression through a fitting emoji. Interpret emotions from the message and reflect their facial expression using fitting, diverse emojis (e.g., 😊, 😢, 😡, 😱).
@@ -1531,10 +1571,16 @@ Message: """{{prompt}}"""
)
payload = {
- "model": model_id,
+ "model": task_model_id,
"messages": [{"role": "user", "content": content}],
"stream": False,
- "max_tokens": 4,
+ **(
+ {"max_tokens": 4}
+ if app.state.MODELS[task_model_id]["owned_by"] == "ollama"
+ else {
+ "max_completion_tokens": 4,
+ }
+ ),
"chat_id": form_data.get("chat_id", None),
"metadata": {"task": str(TASKS.EMOJI_GENERATION)},
}
diff --git a/backend/open_webui/utils/payload.py b/backend/open_webui/utils/payload.py
index 227cca45f..b2654cd25 100644
--- a/backend/open_webui/utils/payload.py
+++ b/backend/open_webui/utils/payload.py
@@ -44,9 +44,9 @@ def apply_model_params_to_body(
def apply_model_params_to_body_openai(params: dict, form_data: dict) -> dict:
mappings = {
"temperature": float,
- "top_p": int,
+ "top_p": float,
"max_tokens": int,
- "frequency_penalty": int,
+ "frequency_penalty": float,
"seed": lambda x: x,
"stop": lambda x: [bytes(s, "utf-8").decode("unicode_escape") for s in x],
}
diff --git a/backend/open_webui/utils/security_headers.py b/backend/open_webui/utils/security_headers.py
new file mode 100644
index 000000000..69a464814
--- /dev/null
+++ b/backend/open_webui/utils/security_headers.py
@@ -0,0 +1,115 @@
+import re
+import os
+
+from fastapi import Request
+from starlette.middleware.base import BaseHTTPMiddleware
+from typing import Dict
+
+
+class SecurityHeadersMiddleware(BaseHTTPMiddleware):
+ async def dispatch(self, request: Request, call_next):
+ response = await call_next(request)
+ response.headers.update(set_security_headers())
+ return response
+
+
+def set_security_headers() -> Dict[str, str]:
+ """
+ Sets security headers based on environment variables.
+
+ This function reads specific environment variables and uses their values
+ to set corresponding security headers. The headers that can be set are:
+ - cache-control
+ - strict-transport-security
+ - referrer-policy
+ - x-content-type-options
+ - x-download-options
+ - x-frame-options
+ - x-permitted-cross-domain-policies
+
+ Each environment variable is associated with a specific setter function
+ that constructs the header. If the environment variable is set, the
+ corresponding header is added to the options dictionary.
+
+ Returns:
+ dict: A dictionary containing the security headers and their values.
+ """
+ options = {}
+ header_setters = {
+ "CACHE_CONTROL": set_cache_control,
+ "HSTS": set_hsts,
+ "REFERRER_POLICY": set_referrer,
+ "XCONTENT_TYPE": set_xcontent_type,
+ "XDOWNLOAD_OPTIONS": set_xdownload_options,
+ "XFRAME_OPTIONS": set_xframe,
+ "XPERMITTED_CROSS_DOMAIN_POLICIES": set_xpermitted_cross_domain_policies,
+ }
+
+ for env_var, setter in header_setters.items():
+ value = os.environ.get(env_var, None)
+ if value:
+ header = setter(value)
+ if header:
+ options.update(header)
+
+ return options
+
+
+# Set HTTP Strict Transport Security(HSTS) response header
+def set_hsts(value: str):
+ pattern = r"^max-age=(\d+)(;includeSubDomains)?(;preload)?$"
+ match = re.match(pattern, value, re.IGNORECASE)
+ if not match:
+ return "max-age=31536000;includeSubDomains"
+ return {"Strict-Transport-Security": value}
+
+
+# Set X-Frame-Options response header
+def set_xframe(value: str):
+ pattern = r"^(DENY|SAMEORIGIN)$"
+ match = re.match(pattern, value, re.IGNORECASE)
+ if not match:
+ value = "DENY"
+ return {"X-Frame-Options": value}
+
+
+# Set Referrer-Policy response header
+def set_referrer(value: str):
+ pattern = r"^(no-referrer|no-referrer-when-downgrade|origin|origin-when-cross-origin|same-origin|strict-origin|strict-origin-when-cross-origin|unsafe-url)$"
+ match = re.match(pattern, value, re.IGNORECASE)
+ if not match:
+ value = "no-referrer"
+ return {"Referrer-Policy": value}
+
+
+# Set Cache-Control response header
+def set_cache_control(value: str):
+ pattern = r"^(public|private|no-cache|no-store|must-revalidate|proxy-revalidate|max-age=\d+|s-maxage=\d+|no-transform|immutable)(,\s*(public|private|no-cache|no-store|must-revalidate|proxy-revalidate|max-age=\d+|s-maxage=\d+|no-transform|immutable))*$"
+ match = re.match(pattern, value, re.IGNORECASE)
+ if not match:
+ value = "no-store, max-age=0"
+
+ return {"Cache-Control": value}
+
+
+# Set X-Download-Options response header
+def set_xdownload_options(value: str):
+ if value != "noopen":
+ value = "noopen"
+ return {"X-Download-Options": value}
+
+
+# Set X-Content-Type-Options response header
+def set_xcontent_type(value: str):
+ if value != "nosniff":
+ value = "nosniff"
+ return {"X-Content-Type-Options": value}
+
+
+# Set X-Permitted-Cross-Domain-Policies response header
+def set_xpermitted_cross_domain_policies(value: str):
+ pattern = r"^(none|master-only|by-content-type|by-ftp-filename)$"
+ match = re.match(pattern, value, re.IGNORECASE)
+ if not match:
+ value = "none"
+ return {"X-Permitted-Cross-Domain-Policies": value}
diff --git a/backend/requirements.txt b/backend/requirements.txt
index 93720cc84..2554bb5f8 100644
--- a/backend/requirements.txt
+++ b/backend/requirements.txt
@@ -40,7 +40,12 @@ langchain-chroma==0.1.2
fake-useragent==1.5.1
chromadb==0.5.5
+pymilvus==2.4.6
+
sentence-transformers==3.0.1
+colbert-ai==0.2.21
+einops==0.8.0
+
pypdf==4.3.1
docx2txt==0.8
python-pptx==1.0.0
diff --git a/kubernetes/manifest/base/kustomization.yaml b/kubernetes/manifest/base/kustomization.yaml
new file mode 100644
index 000000000..61500f87c
--- /dev/null
+++ b/kubernetes/manifest/base/kustomization.yaml
@@ -0,0 +1,8 @@
+resources:
+ - open-webui.yaml
+ - ollama-service.yaml
+ - ollama-statefulset.yaml
+ - webui-deployment.yaml
+ - webui-service.yaml
+ - webui-ingress.yaml
+ - webui-pvc.yaml
diff --git a/kubernetes/manifest/gpu/kustomization.yaml b/kubernetes/manifest/gpu/kustomization.yaml
new file mode 100644
index 000000000..c0d39fbfa
--- /dev/null
+++ b/kubernetes/manifest/gpu/kustomization.yaml
@@ -0,0 +1,8 @@
+apiVersion: kustomize.config.k8s.io/v1beta1
+kind: Kustomization
+
+resources:
+ - ../base
+
+patches:
+- path: ollama-statefulset-gpu.yaml
diff --git a/kubernetes/manifest/patches/ollama-statefulset-gpu.yaml b/kubernetes/manifest/gpu/ollama-statefulset-gpu.yaml
similarity index 100%
rename from kubernetes/manifest/patches/ollama-statefulset-gpu.yaml
rename to kubernetes/manifest/gpu/ollama-statefulset-gpu.yaml
diff --git a/kubernetes/manifest/kustomization.yaml b/kubernetes/manifest/kustomization.yaml
deleted file mode 100644
index 907bff3e1..000000000
--- a/kubernetes/manifest/kustomization.yaml
+++ /dev/null
@@ -1,13 +0,0 @@
-resources:
-- base/open-webui.yaml
-- base/ollama-service.yaml
-- base/ollama-statefulset.yaml
-- base/webui-deployment.yaml
-- base/webui-service.yaml
-- base/webui-ingress.yaml
-- base/webui-pvc.yaml
-
-apiVersion: kustomize.config.k8s.io/v1beta1
-kind: Kustomization
-patches:
-- path: patches/ollama-statefulset-gpu.yaml
diff --git a/package-lock.json b/package-lock.json
index 9ec44ffa7..a0048a211 100644
--- a/package-lock.json
+++ b/package-lock.json
@@ -1,18 +1,19 @@
{
"name": "open-webui",
- "version": "0.3.21",
+ "version": "0.3.22",
"lockfileVersion": 3,
"requires": true,
"packages": {
"": {
"name": "open-webui",
- "version": "0.3.21",
+ "version": "0.3.22",
"dependencies": {
"@codemirror/lang-javascript": "^6.2.2",
"@codemirror/lang-python": "^6.1.6",
"@codemirror/theme-one-dark": "^6.1.2",
"@pyscript/core": "^0.4.32",
"@sveltejs/adapter-node": "^2.0.0",
+ "@xyflow/svelte": "^0.1.19",
"async": "^3.2.5",
"bits-ui": "^0.19.7",
"codemirror": "^6.0.1",
@@ -45,7 +46,6 @@
"@sveltejs/kit": "^2.5.20",
"@sveltejs/vite-plugin-svelte": "^3.1.1",
"@tailwindcss/typography": "^0.5.13",
- "@types/bun": "latest",
"@typescript-eslint/eslint-plugin": "^6.17.0",
"@typescript-eslint/parser": "^6.17.0",
"autoprefixer": "^10.4.16",
@@ -1368,6 +1368,14 @@
"resolved": "https://registry.npmjs.org/@socket.io/component-emitter/-/component-emitter-3.1.2.tgz",
"integrity": "sha512-9BCxFwvbGg/RsZK9tjXd8s4UcwR0MWeFQ1XEKIQVVvAGJyINdrqKMcTRyLoK8Rse1GjzLV9cwjWV1olXRWEXVA=="
},
+ "node_modules/@svelte-put/shortcut": {
+ "version": "3.1.1",
+ "resolved": "https://registry.npmjs.org/@svelte-put/shortcut/-/shortcut-3.1.1.tgz",
+ "integrity": "sha512-2L5EYTZXiaKvbEelVkg5znxqvfZGZai3m97+cAiUBhLZwXnGtviTDpHxOoZBsqz41szlfRMcamW/8o0+fbW3ZQ==",
+ "peerDependencies": {
+ "svelte": "^3.55.0 || ^4.0.0 || ^5.0.0"
+ }
+ },
"node_modules/@sveltejs/adapter-auto": {
"version": "3.2.2",
"resolved": "https://registry.npmjs.org/@sveltejs/adapter-auto/-/adapter-auto-3.2.2.tgz",
@@ -1494,20 +1502,32 @@
"tailwindcss": ">=3.0.0 || insiders"
}
},
- "node_modules/@types/bun": {
- "version": "1.0.10",
- "resolved": "https://registry.npmjs.org/@types/bun/-/bun-1.0.10.tgz",
- "integrity": "sha512-Jaz6YYAdm1u3NVlgSyEK+qGmrlLQ20sbWeEoXD64b9w6z/YKYNWlfaphu+xF2Kiy5Tpykm5Q9jIquLegwXx4ng==",
- "dev": true,
- "dependencies": {
- "bun-types": "1.0.33"
- }
- },
"node_modules/@types/cookie": {
"version": "0.6.0",
"resolved": "https://registry.npmjs.org/@types/cookie/-/cookie-0.6.0.tgz",
"integrity": "sha512-4Kh9a6B2bQciAhf7FSuMRRkUWecJgJu9nPnx3yzpsfXX/c50REIqpHY4C82bXP90qrLtXtkDxTZosYO3UpOwlA=="
},
+ "node_modules/@types/d3-color": {
+ "version": "3.1.3",
+ "resolved": "https://registry.npmjs.org/@types/d3-color/-/d3-color-3.1.3.tgz",
+ "integrity": "sha512-iO90scth9WAbmgv7ogoq57O9YpKmFBbmoEoCHDB2xMBY0+/KVrqAaCDyCE16dUspeOvIxFFRI+0sEtqDqy2b4A=="
+ },
+ "node_modules/@types/d3-drag": {
+ "version": "3.0.7",
+ "resolved": "https://registry.npmjs.org/@types/d3-drag/-/d3-drag-3.0.7.tgz",
+ "integrity": "sha512-HE3jVKlzU9AaMazNufooRJ5ZpWmLIoc90A37WU2JMmeq28w1FQqCZswHZ3xR+SuxYftzHq6WU6KJHvqxKzTxxQ==",
+ "dependencies": {
+ "@types/d3-selection": "*"
+ }
+ },
+ "node_modules/@types/d3-interpolate": {
+ "version": "3.0.4",
+ "resolved": "https://registry.npmjs.org/@types/d3-interpolate/-/d3-interpolate-3.0.4.tgz",
+ "integrity": "sha512-mgLPETlrpVV1YRJIglr4Ez47g7Yxjl1lj7YKsiMCb27VJH9W8NVM6Bb9d8kkpG/uAQS5AmbA48q2IAolKKo1MA==",
+ "dependencies": {
+ "@types/d3-color": "*"
+ }
+ },
"node_modules/@types/d3-scale": {
"version": "4.0.8",
"resolved": "https://registry.npmjs.org/@types/d3-scale/-/d3-scale-4.0.8.tgz",
@@ -1521,11 +1541,33 @@
"resolved": "https://registry.npmjs.org/@types/d3-scale-chromatic/-/d3-scale-chromatic-3.0.3.tgz",
"integrity": "sha512-laXM4+1o5ImZv3RpFAsTRn3TEkzqkytiOY0Dz0sq5cnd1dtNlk6sHLon4OvqaiJb28T0S/TdsBI3Sjsy+keJrw=="
},
+ "node_modules/@types/d3-selection": {
+ "version": "3.0.10",
+ "resolved": "https://registry.npmjs.org/@types/d3-selection/-/d3-selection-3.0.10.tgz",
+ "integrity": "sha512-cuHoUgS/V3hLdjJOLTT691+G2QoqAjCVLmr4kJXR4ha56w1Zdu8UUQ5TxLRqudgNjwXeQxKMq4j+lyf9sWuslg=="
+ },
"node_modules/@types/d3-time": {
"version": "3.0.3",
"resolved": "https://registry.npmjs.org/@types/d3-time/-/d3-time-3.0.3.tgz",
"integrity": "sha512-2p6olUZ4w3s+07q3Tm2dbiMZy5pCDfYwtLXXHUnVzXgQlZ/OyPtUz6OL382BkOuGlLXqfT+wqv8Fw2v8/0geBw=="
},
+ "node_modules/@types/d3-transition": {
+ "version": "3.0.8",
+ "resolved": "https://registry.npmjs.org/@types/d3-transition/-/d3-transition-3.0.8.tgz",
+ "integrity": "sha512-ew63aJfQ/ms7QQ4X7pk5NxQ9fZH/z+i24ZfJ6tJSfqxJMrYLiK01EAs2/Rtw/JreGUsS3pLPNV644qXFGnoZNQ==",
+ "dependencies": {
+ "@types/d3-selection": "*"
+ }
+ },
+ "node_modules/@types/d3-zoom": {
+ "version": "3.0.8",
+ "resolved": "https://registry.npmjs.org/@types/d3-zoom/-/d3-zoom-3.0.8.tgz",
+ "integrity": "sha512-iqMC4/YlFCSlO8+2Ii1GGGliCAY4XdeG748w5vQUbevlbDu0zSjH/+jojorQVBK/se0j6DUFNPBGSqD3YWYnDw==",
+ "dependencies": {
+ "@types/d3-interpolate": "*",
+ "@types/d3-selection": "*"
+ }
+ },
"node_modules/@types/debug": {
"version": "4.1.12",
"resolved": "https://registry.npmjs.org/@types/debug/-/debug-4.1.12.tgz",
@@ -1568,7 +1610,7 @@
"version": "20.11.30",
"resolved": "https://registry.npmjs.org/@types/node/-/node-20.11.30.tgz",
"integrity": "sha512-dHM6ZxwlmuZaRmUPfv1p+KrdD1Dci04FbdEm/9wEMouFqxYoFl5aMkt0VMAUtYRQDyYvD41WJLukhq/ha3YuTw==",
- "devOptional": true,
+ "optional": true,
"dependencies": {
"undici-types": "~5.26.4"
}
@@ -1613,15 +1655,6 @@
"resolved": "https://registry.npmjs.org/@types/unist/-/unist-2.0.10.tgz",
"integrity": "sha512-IfYcSBWE3hLpBg8+X2SEa8LVkJdJEkT2Ese2aaLs3ptGdVtABxndrMaxuFlQ1qdFf9Q5rDvDpxI3WwgvKFAsQA=="
},
- "node_modules/@types/ws": {
- "version": "8.5.10",
- "resolved": "https://registry.npmjs.org/@types/ws/-/ws-8.5.10.tgz",
- "integrity": "sha512-vmQSUcfalpIq0R9q7uTo2lXs6eGIpt9wtnLdMv9LVpIjCA/+ufZRozlVoVelIYixx1ugCBKDhn89vnsEGOCx9A==",
- "dev": true,
- "dependencies": {
- "@types/node": "*"
- }
- },
"node_modules/@types/yauzl": {
"version": "2.10.3",
"resolved": "https://registry.npmjs.org/@types/yauzl/-/yauzl-2.10.3.tgz",
@@ -1942,6 +1975,33 @@
"resolved": "https://registry.npmjs.org/@webreflection/fetch/-/fetch-0.1.5.tgz",
"integrity": "sha512-zCcqCJoNLvdeF41asAK71XPlwSPieeRDsE09albBunJEksuYPYNillKNQjf8p5BqSoTKTuKrW3lUm3MNodUC4g=="
},
+ "node_modules/@xyflow/svelte": {
+ "version": "0.1.19",
+ "resolved": "https://registry.npmjs.org/@xyflow/svelte/-/svelte-0.1.19.tgz",
+ "integrity": "sha512-yW5w5aI+Yqkob4kLQpVDo/ZmX+E9Pw7459kqwLfv4YG4N1NYXrsDRh9cyph/rapbuDnPi6zqK5E8LKrgaCQC0w==",
+ "dependencies": {
+ "@svelte-put/shortcut": "^3.1.0",
+ "@xyflow/system": "0.0.42",
+ "classcat": "^5.0.4"
+ },
+ "peerDependencies": {
+ "svelte": "^3.0.0 || ^4.0.0"
+ }
+ },
+ "node_modules/@xyflow/system": {
+ "version": "0.0.42",
+ "resolved": "https://registry.npmjs.org/@xyflow/system/-/system-0.0.42.tgz",
+ "integrity": "sha512-kWYj+Y0GOct0jKYTdyRMNOLPxGNbb2TYvPg2gTmJnZ31DOOMkL5uRBLX825DR2gOACDu+i5FHLxPJUPf/eGOJw==",
+ "dependencies": {
+ "@types/d3-drag": "^3.0.7",
+ "@types/d3-selection": "^3.0.10",
+ "@types/d3-transition": "^3.0.8",
+ "@types/d3-zoom": "^3.0.8",
+ "d3-drag": "^3.0.0",
+ "d3-selection": "^3.0.0",
+ "d3-zoom": "^3.0.0"
+ }
+ },
"node_modules/acorn": {
"version": "8.11.3",
"resolved": "https://registry.npmjs.org/acorn/-/acorn-8.11.3.tgz",
@@ -2533,16 +2593,6 @@
"url": "https://github.com/sponsors/sindresorhus"
}
},
- "node_modules/bun-types": {
- "version": "1.0.33",
- "resolved": "https://registry.npmjs.org/bun-types/-/bun-types-1.0.33.tgz",
- "integrity": "sha512-L5tBIf9g6rBBkvshqysi5NoLQ9NnhSPU1pfJ9FzqoSfofYdyac3WLUnOIuQ+M5za/sooVUOP2ko+E6Tco0OLIA==",
- "dev": true,
- "dependencies": {
- "@types/node": "~20.11.3",
- "@types/ws": "~8.5.10"
- }
- },
"node_modules/cac": {
"version": "6.7.14",
"resolved": "https://registry.npmjs.org/cac/-/cac-6.7.14.tgz",
@@ -2777,6 +2827,11 @@
"node": ">=8"
}
},
+ "node_modules/classcat": {
+ "version": "5.0.5",
+ "resolved": "https://registry.npmjs.org/classcat/-/classcat-5.0.5.tgz",
+ "integrity": "sha512-JhZUT7JFcQy/EzW605k/ktHtncoo9vnyW/2GspNYwFlN1C/WmjuV/xtS04e9SOkL2sTdw0VAZ2UGCcQ9lR6p6w=="
+ },
"node_modules/clean-stack": {
"version": "2.2.0",
"resolved": "https://registry.npmjs.org/clean-stack/-/clean-stack-2.2.0.tgz",
@@ -7094,9 +7149,9 @@
}
},
"node_modules/picocolors": {
- "version": "1.0.1",
- "resolved": "https://registry.npmjs.org/picocolors/-/picocolors-1.0.1.tgz",
- "integrity": "sha512-anP1Z8qwhkbmu7MFP5iTt+wQKXgwzf7zTyGlcdzabySa9vd0Xt392U0rVmz9poOaBj0uHJKyyo9/upk0HrEQew=="
+ "version": "1.1.0",
+ "resolved": "https://registry.npmjs.org/picocolors/-/picocolors-1.1.0.tgz",
+ "integrity": "sha512-TQ92mBOW0l3LeMeyLV6mzy/kWr8lkd/hp3mTg7wYK7zJhuBStmGMBG0BdeDZS/dZx1IukaX6Bk11zcln25o1Aw=="
},
"node_modules/picomatch": {
"version": "2.3.1",
@@ -7162,9 +7217,9 @@
}
},
"node_modules/postcss": {
- "version": "8.4.41",
- "resolved": "https://registry.npmjs.org/postcss/-/postcss-8.4.41.tgz",
- "integrity": "sha512-TesUflQ0WKZqAvg52PWL6kHgLKP6xB6heTOdoYM0Wt2UHyxNa4K25EZZMgKns3BH1RLVbZCREPpLY0rhnNoHVQ==",
+ "version": "8.4.47",
+ "resolved": "https://registry.npmjs.org/postcss/-/postcss-8.4.47.tgz",
+ "integrity": "sha512-56rxCq7G/XfB4EkXq9Egn5GCqugWvDFjafDOThIdMBsI15iqPqR5r15TfSr1YPYeEI19YeaXMCbY6u88Y76GLQ==",
"funding": [
{
"type": "opencollective",
@@ -7181,8 +7236,8 @@
],
"dependencies": {
"nanoid": "^3.3.7",
- "picocolors": "^1.0.1",
- "source-map-js": "^1.2.0"
+ "picocolors": "^1.1.0",
+ "source-map-js": "^1.2.1"
},
"engines": {
"node": "^10 || ^12 || >=14"
@@ -8195,9 +8250,9 @@
"integrity": "sha512-FJF5jgdfvoKn1MAKSdGs33bIqLi3LmsgVTliuX6iITj834F+JRQZN90Z93yql8h0K2t0RwDPBmxwlbZfDcxNZA=="
},
"node_modules/source-map-js": {
- "version": "1.2.0",
- "resolved": "https://registry.npmjs.org/source-map-js/-/source-map-js-1.2.0.tgz",
- "integrity": "sha512-itJW8lvSA0TXEphiRoawsCksnlf8SyvmFzIhltqAHluXd88pkCd+cXJVHTDwdCr0IzwptSm035IHQktUu1QUMg==",
+ "version": "1.2.1",
+ "resolved": "https://registry.npmjs.org/source-map-js/-/source-map-js-1.2.1.tgz",
+ "integrity": "sha512-UXWMKhLOwVKb728IUtQPXxfYU+usdybtUrK/8uGE8CQMvrhOpwvzDBwj0QhSL7MQc7vIsISBG8VQ8+IDQxpfQA==",
"engines": {
"node": ">=0.10.0"
}
@@ -9112,7 +9167,7 @@
"version": "5.26.5",
"resolved": "https://registry.npmjs.org/undici-types/-/undici-types-5.26.5.tgz",
"integrity": "sha512-JlCMO+ehdEIKqlFxk6IfVoAUVmgz7cU7zD/h9XZ0qzeosSHmUJVOzSQvvYSYWXkFXC+IfLKSIffhv0sVZup6pA==",
- "devOptional": true
+ "optional": true
},
"node_modules/unist-util-stringify-position": {
"version": "3.0.3",
@@ -9329,13 +9384,13 @@
}
},
"node_modules/vite": {
- "version": "5.4.0",
- "resolved": "https://registry.npmjs.org/vite/-/vite-5.4.0.tgz",
- "integrity": "sha512-5xokfMX0PIiwCMCMb9ZJcMyh5wbBun0zUzKib+L65vAZ8GY9ePZMXxFrHbr/Kyll2+LSCY7xtERPpxkBDKngwg==",
+ "version": "5.4.6",
+ "resolved": "https://registry.npmjs.org/vite/-/vite-5.4.6.tgz",
+ "integrity": "sha512-IeL5f8OO5nylsgzd9tq4qD2QqI0k2CQLGrWD0rCN0EQJZpBK5vJAx0I+GDkMOXxQX/OfFHMuLIx6ddAxGX/k+Q==",
"dependencies": {
"esbuild": "^0.21.3",
- "postcss": "^8.4.40",
- "rollup": "^4.13.0"
+ "postcss": "^8.4.43",
+ "rollup": "^4.20.0"
},
"bin": {
"vite": "bin/vite.js"
diff --git a/package.json b/package.json
index 08c101384..371507789 100644
--- a/package.json
+++ b/package.json
@@ -1,6 +1,6 @@
{
"name": "open-webui",
- "version": "0.3.21",
+ "version": "0.3.22",
"private": true,
"scripts": {
"dev": "npm run pyodide:fetch && vite dev --host",
@@ -25,7 +25,6 @@
"@sveltejs/kit": "^2.5.20",
"@sveltejs/vite-plugin-svelte": "^3.1.1",
"@tailwindcss/typography": "^0.5.13",
- "@types/bun": "latest",
"@typescript-eslint/eslint-plugin": "^6.17.0",
"@typescript-eslint/parser": "^6.17.0",
"autoprefixer": "^10.4.16",
@@ -54,6 +53,7 @@
"@codemirror/theme-one-dark": "^6.1.2",
"@pyscript/core": "^0.4.32",
"@sveltejs/adapter-node": "^2.0.0",
+ "@xyflow/svelte": "^0.1.19",
"async": "^3.2.5",
"bits-ui": "^0.19.7",
"codemirror": "^6.0.1",
diff --git a/pyproject.toml b/pyproject.toml
index 057ef1475..b2558e4d1 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -47,7 +47,12 @@ dependencies = [
"fake-useragent==1.5.1",
"chromadb==0.5.5",
+ "pymilvus==2.4.6",
+
"sentence-transformers==3.0.1",
+ "colbert-ai==0.2.21",
+ "einops==0.8.0",
+
"pypdf==4.3.1",
"docx2txt==0.8",
"python-pptx==1.0.0",
diff --git a/src/app.css b/src/app.css
index a421d90ae..65103b55a 100644
--- a/src/app.css
+++ b/src/app.css
@@ -50,21 +50,6 @@ iframe {
@apply rounded-lg;
}
-ol > li {
- counter-increment: list-number;
- display: block;
- margin-bottom: 0;
- margin-top: 0;
- min-height: 28px;
-}
-
-.prose ol > li::before {
- content: counters(list-number, '.') '.';
- padding-right: 0.5rem;
- color: var(--tw-prose-counters);
- font-weight: 400;
-}
-
li p {
display: inline;
}
@@ -171,3 +156,20 @@ input[type='number'] {
font-weight: 600;
@apply rounded-md dark:bg-gray-800 bg-gray-100 mx-0.5;
}
+
+.svelte-flow {
+ background-color: transparent !important;
+}
+
+.svelte-flow__edge > path {
+ stroke-width: 0.5;
+}
+
+.svelte-flow__edge.animated > path {
+ stroke-width: 2;
+ @apply stroke-gray-600 dark:stroke-gray-500;
+}
+
+.bg-gray-950-90 {
+ background-color: rgba(var(--color-gray-950, #0d0d0d), 0.9);
+}
diff --git a/src/app.html b/src/app.html
index 59fd7c5ed..d7f4513e7 100644
--- a/src/app.html
+++ b/src/app.html
@@ -4,7 +4,11 @@
-
+
+
// On page load or when changing themes, best to add inline in `head` to avoid FOUC
(() => {
+ const metaThemeColorTag = document.querySelector('meta[name="theme-color"]');
+ const prefersDarkTheme = window.matchMedia('(prefers-color-scheme: dark)').matches;
+
if (!localStorage?.theme) {
localStorage.theme = 'system';
}
- if (localStorage?.theme && localStorage?.theme.includes('oled')) {
+ if (localStorage.theme === 'system') {
+ document.documentElement.classList.add(prefersDarkTheme ? 'dark' : 'light');
+ metaThemeColorTag.setAttribute('content', prefersDarkTheme ? '#171717' : '#ffffff');
+ } else if (localStorage.theme === 'oled-dark') {
document.documentElement.style.setProperty('--color-gray-800', '#101010');
document.documentElement.style.setProperty('--color-gray-850', '#050505');
document.documentElement.style.setProperty('--color-gray-900', '#000000');
document.documentElement.style.setProperty('--color-gray-950', '#000000');
document.documentElement.classList.add('dark');
- } else if (
- localStorage.theme === 'light' ||
- (!('theme' in localStorage) && window.matchMedia('(prefers-color-scheme: light)').matches)
- ) {
+ metaThemeColorTag.setAttribute('content', '#000000');
+ } else if (localStorage.theme === 'light') {
document.documentElement.classList.add('light');
- } else if (localStorage.theme && localStorage.theme !== 'system') {
- localStorage.theme.split(' ').forEach((e) => {
- document.documentElement.classList.add(e);
- });
- } else if (localStorage.theme && localStorage.theme === 'system') {
- systemTheme = window.matchMedia('(prefers-color-scheme: dark)').matches;
- document.documentElement.classList.add(systemTheme ? 'dark' : 'light');
- } else if (localStorage.theme && localStorage.theme === 'her') {
+ metaThemeColorTag.setAttribute('content', '#ffffff');
+ } else if (localStorage.theme === 'her') {
document.documentElement.classList.add('dark');
document.documentElement.classList.add('her');
+ metaThemeColorTag.setAttribute('content', '#983724');
} else {
document.documentElement.classList.add('dark');
+ metaThemeColorTag.setAttribute('content', '#171717');
}
window.matchMedia('(prefers-color-scheme: dark)').addListener((e) => {
@@ -57,9 +61,11 @@
if (e.matches) {
document.documentElement.classList.add('dark');
document.documentElement.classList.remove('light');
+ metaThemeColorTag.setAttribute('content', '#171717');
} else {
document.documentElement.classList.add('light');
document.documentElement.classList.remove('dark');
+ metaThemeColorTag.setAttribute('content', '#ffffff');
}
}
});
diff --git a/src/lib/components/admin/Settings/Audio.svelte b/src/lib/components/admin/Settings/Audio.svelte
index 1c114c9dd..040bc5e1a 100644
--- a/src/lib/components/admin/Settings/Audio.svelte
+++ b/src/lib/components/admin/Settings/Audio.svelte
@@ -31,6 +31,8 @@
let TTS_MODEL = '';
let TTS_VOICE = '';
let TTS_SPLIT_ON: TTS_RESPONSE_SPLIT = TTS_RESPONSE_SPLIT.PUNCTUATION;
+ let TTS_AZURE_SPEECH_REGION = '';
+ let TTS_AZURE_SPEECH_OUTPUT_FORMAT = '';
let STT_OPENAI_API_BASE_URL = '';
let STT_OPENAI_API_KEY = '';
@@ -87,7 +89,9 @@
ENGINE: TTS_ENGINE,
MODEL: TTS_MODEL,
VOICE: TTS_VOICE,
- SPLIT_ON: TTS_SPLIT_ON
+ SPLIT_ON: TTS_SPLIT_ON,
+ AZURE_SPEECH_REGION: TTS_AZURE_SPEECH_REGION,
+ AZURE_SPEECH_OUTPUT_FORMAT: TTS_AZURE_SPEECH_OUTPUT_FORMAT
},
stt: {
OPENAI_API_BASE_URL: STT_OPENAI_API_BASE_URL,
@@ -120,6 +124,9 @@
TTS_SPLIT_ON = res.tts.SPLIT_ON || TTS_RESPONSE_SPLIT.PUNCTUATION;
+ TTS_AZURE_SPEECH_OUTPUT_FORMAT = res.tts.AZURE_SPEECH_OUTPUT_FORMAT;
+ TTS_AZURE_SPEECH_REGION = res.tts.AZURE_SPEECH_REGION;
+
STT_OPENAI_API_BASE_URL = res.stt.OPENAI_API_BASE_URL;
STT_OPENAI_API_KEY = res.stt.OPENAI_API_KEY;
@@ -224,6 +231,7 @@
+
@@ -252,6 +260,23 @@
/>
+ {:else if TTS_ENGINE === 'azure'}
+
{/if}
@@ -359,6 +384,49 @@
+ {:else if TTS_ENGINE === 'azure'}
+
+
+
{$i18n.t('TTS Voice')}
+
+
+
+
+
+
+
+
+
+
{/if}
diff --git a/src/lib/components/admin/Settings/Connections.svelte b/src/lib/components/admin/Settings/Connections.svelte
index fe71e4816..97a760110 100644
--- a/src/lib/components/admin/Settings/Connections.svelte
+++ b/src/lib/components/admin/Settings/Connections.svelte
@@ -150,18 +150,20 @@
})()
]);
- OPENAI_API_BASE_URLS.forEach(async (url, idx) => {
- const res = await getOpenAIModels(localStorage.token, idx);
- if (res.pipelines) {
- pipelineUrls[url] = true;
- }
- });
-
const ollamaConfig = await getOllamaConfig(localStorage.token);
const openaiConfig = await getOpenAIConfig(localStorage.token);
ENABLE_OPENAI_API = openaiConfig.ENABLE_OPENAI_API;
ENABLE_OLLAMA_API = ollamaConfig.ENABLE_OLLAMA_API;
+
+ if (ENABLE_OPENAI_API) {
+ OPENAI_API_BASE_URLS.forEach(async (url, idx) => {
+ const res = await getOpenAIModels(localStorage.token, idx);
+ if (res.pipelines) {
+ pipelineUrls[url] = true;
+ }
+ });
+ }
}
});
diff --git a/src/lib/components/admin/Settings/Documents.svelte b/src/lib/components/admin/Settings/Documents.svelte
index 7f935a08a..e06edce9d 100644
--- a/src/lib/components/admin/Settings/Documents.svelte
+++ b/src/lib/components/admin/Settings/Documents.svelte
@@ -732,11 +732,17 @@
{$i18n.t('RAG Template')}
-
+
+
+
diff --git a/src/lib/components/admin/Settings/Images.svelte b/src/lib/components/admin/Settings/Images.svelte
index 91ce4f280..09b3c77d0 100644
--- a/src/lib/components/admin/Settings/Images.svelte
+++ b/src/lib/components/admin/Settings/Images.svelte
@@ -27,6 +27,43 @@
let models = null;
+ let samplers = [
+ 'DPM++ 2M',
+ 'DPM++ SDE',
+ 'DPM++ 2M SDE',
+ 'DPM++ 2M SDE Heun',
+ 'DPM++ 2S a',
+ 'DPM++ 3M SDE',
+ 'Euler a',
+ 'Euler',
+ 'LMS',
+ 'Heun',
+ 'DPM2',
+ 'DPM2 a',
+ 'DPM fast',
+ 'DPM adaptive',
+ 'Restart',
+ 'DDIM',
+ 'DDIM CFG++',
+ 'PLMS',
+ 'UniPC'
+ ];
+
+ let schedulers = [
+ 'Automatic',
+ 'Uniform',
+ 'Karras',
+ 'Exponential',
+ 'Polyexponential',
+ 'SGM Uniform',
+ 'KL Optimal',
+ 'Align Your Steps',
+ 'Simple',
+ 'Normal',
+ 'DDIM',
+ 'Beta'
+ ];
+
let requiredWorkflowNodes = [
{
type: 'prompt',
@@ -326,6 +363,66 @@
+
+
+
+
{$i18n.t('Set Sampler')}
+
+
+
+
+
+
+
+
+
+
+
+
+
{$i18n.t('Set Scheduler')}
+
+
+
+
+
+
+
+
+
+
+
+
+
{$i18n.t('Set CFG Scale')}
+
+
{:else if config?.engine === 'comfyui'}
{$i18n.t('ComfyUI Base URL')}
diff --git a/src/lib/components/admin/Settings/Interface.svelte b/src/lib/components/admin/Settings/Interface.svelte
index 295f3d208..4ccdc80de 100644
--- a/src/lib/components/admin/Settings/Interface.svelte
+++ b/src/lib/components/admin/Settings/Interface.svelte
@@ -307,9 +307,10 @@
/>
-
diff --git a/src/lib/components/chat/Chat.svelte b/src/lib/components/chat/Chat.svelte
index c84e8d84b..e0ec62b52 100644
--- a/src/lib/components/chat/Chat.svelte
+++ b/src/lib/components/chat/Chat.svelte
@@ -23,6 +23,7 @@
banners,
user,
socket,
+ showControls,
showCallOverlay,
currentChatPage,
temporaryChatEnabled
@@ -70,7 +71,6 @@
let loaded = false;
const eventTarget = new EventTarget();
- let showControls = false;
let stopResponseFlag = false;
let autoScroll = true;
let processing = '';
@@ -115,13 +115,14 @@
$: if (history.currentId !== null) {
let _messages = [];
-
let currentMessage = history.messages[history.currentId];
- while (currentMessage !== null) {
+ while (currentMessage) {
_messages.unshift({ ...currentMessage });
currentMessage =
currentMessage.parentId !== null ? history.messages[currentMessage.parentId] : null;
}
+
+ // This is most likely causing the performance issue
messages = _messages;
} else {
messages = [];
@@ -143,6 +144,28 @@
})();
}
+ const showMessage = async (message) => {
+ let _messageId = JSON.parse(JSON.stringify(message.id));
+
+ let messageChildrenIds = history.messages[_messageId].childrenIds;
+
+ while (messageChildrenIds.length !== 0) {
+ _messageId = messageChildrenIds.at(-1);
+ messageChildrenIds = history.messages[_messageId].childrenIds;
+ }
+
+ history.currentId = _messageId;
+
+ await tick();
+ await tick();
+ await tick();
+
+ const messageElement = document.getElementById(`message-${message.id}`);
+ if (messageElement) {
+ messageElement.scrollIntoView({ behavior: 'smooth' });
+ }
+ };
+
const chatEventHandler = async (event, cb) => {
if (event.chat_id === $chatId) {
await tick();
@@ -860,8 +883,9 @@
await tick();
+ const stream = $settings?.streamResponse ?? true;
const [res, controller] = await generateChatCompletion(localStorage.token, {
- stream: true,
+ stream: stream,
model: model.id,
messages: messagesBody,
options: {
@@ -886,142 +910,162 @@
});
if (res && res.ok) {
- console.log('controller', controller);
+ if (!stream) {
+ const response = await res.json();
+ console.log(response);
- const reader = res.body
- .pipeThrough(new TextDecoderStream())
- .pipeThrough(splitStream('\n'))
- .getReader();
+ responseMessage.content = response.message.content;
+ responseMessage.info = {
+ eval_count: response.eval_count,
+ eval_duration: response.eval_duration,
+ load_duration: response.load_duration,
+ prompt_eval_count: response.prompt_eval_count,
+ prompt_eval_duration: response.prompt_eval_duration,
+ total_duration: response.total_duration
+ };
+ responseMessage.done = true;
+ } else {
+ console.log('controller', controller);
- while (true) {
- const { value, done } = await reader.read();
- if (done || stopResponseFlag || _chatId !== $chatId) {
- responseMessage.done = true;
- messages = messages;
+ const reader = res.body
+ .pipeThrough(new TextDecoderStream())
+ .pipeThrough(splitStream('\n'))
+ .getReader();
- if (stopResponseFlag) {
- controller.abort('User: Stop Response');
- } else {
- const messages = createMessagesList(responseMessageId);
- await chatCompletedHandler(_chatId, model.id, responseMessageId, messages);
+ while (true) {
+ const { value, done } = await reader.read();
+ if (done || stopResponseFlag || _chatId !== $chatId) {
+ responseMessage.done = true;
+ messages = messages;
+
+ if (stopResponseFlag) {
+ controller.abort('User: Stop Response');
+ }
+
+ _response = responseMessage.content;
+ break;
}
- _response = responseMessage.content;
- break;
- }
+ try {
+ let lines = value.split('\n');
- try {
- let lines = value.split('\n');
+ for (const line of lines) {
+ if (line !== '') {
+ console.log(line);
+ let data = JSON.parse(line);
- for (const line of lines) {
- if (line !== '') {
- console.log(line);
- let data = JSON.parse(line);
-
- if ('citations' in data) {
- responseMessage.citations = data.citations;
- // Only remove status if it was initially set
- if (model?.info?.meta?.knowledge ?? false) {
- responseMessage.statusHistory = responseMessage.statusHistory.filter(
- (status) => status.action !== 'knowledge_search'
- );
- }
- continue;
- }
-
- if ('detail' in data) {
- throw data;
- }
-
- if (data.done == false) {
- if (responseMessage.content == '' && data.message.content == '\n') {
- continue;
- } else {
- responseMessage.content += data.message.content;
-
- if (navigator.vibrate && ($settings?.hapticFeedback ?? false)) {
- navigator.vibrate(5);
- }
-
- const messageContentParts = getMessageContentParts(
- responseMessage.content,
- $config?.audio?.tts?.split_on ?? 'punctuation'
- );
- messageContentParts.pop();
-
- // dispatch only last sentence and make sure it hasn't been dispatched before
- if (
- messageContentParts.length > 0 &&
- messageContentParts[messageContentParts.length - 1] !==
- responseMessage.lastSentence
- ) {
- responseMessage.lastSentence =
- messageContentParts[messageContentParts.length - 1];
- eventTarget.dispatchEvent(
- new CustomEvent('chat', {
- detail: {
- id: responseMessageId,
- content: messageContentParts[messageContentParts.length - 1]
- }
- })
+ if ('citations' in data) {
+ responseMessage.citations = data.citations;
+ // Only remove status if it was initially set
+ if (model?.info?.meta?.knowledge ?? false) {
+ responseMessage.statusHistory = responseMessage.statusHistory.filter(
+ (status) => status.action !== 'knowledge_search'
);
}
-
- messages = messages;
+ continue;
}
- } else {
- responseMessage.done = true;
- if (responseMessage.content == '') {
- responseMessage.error = {
- code: 400,
- content: `Oops! No text generated from Ollama, Please try again.`
+ if ('detail' in data) {
+ throw data;
+ }
+
+ if (data.done == false) {
+ if (responseMessage.content == '' && data.message.content == '\n') {
+ continue;
+ } else {
+ responseMessage.content += data.message.content;
+
+ if (navigator.vibrate && ($settings?.hapticFeedback ?? false)) {
+ navigator.vibrate(5);
+ }
+
+ const messageContentParts = getMessageContentParts(
+ responseMessage.content,
+ $config?.audio?.tts?.split_on ?? 'punctuation'
+ );
+ messageContentParts.pop();
+
+ // dispatch only last sentence and make sure it hasn't been dispatched before
+ if (
+ messageContentParts.length > 0 &&
+ messageContentParts[messageContentParts.length - 1] !==
+ responseMessage.lastSentence
+ ) {
+ responseMessage.lastSentence =
+ messageContentParts[messageContentParts.length - 1];
+ eventTarget.dispatchEvent(
+ new CustomEvent('chat', {
+ detail: {
+ id: responseMessageId,
+ content: messageContentParts[messageContentParts.length - 1]
+ }
+ })
+ );
+ }
+
+ messages = messages;
+ }
+ } else {
+ responseMessage.done = true;
+
+ if (responseMessage.content == '') {
+ responseMessage.error = {
+ code: 400,
+ content: `Oops! No text generated from Ollama, Please try again.`
+ };
+ }
+
+ responseMessage.context = data.context ?? null;
+ responseMessage.info = {
+ total_duration: data.total_duration,
+ load_duration: data.load_duration,
+ sample_count: data.sample_count,
+ sample_duration: data.sample_duration,
+ prompt_eval_count: data.prompt_eval_count,
+ prompt_eval_duration: data.prompt_eval_duration,
+ eval_count: data.eval_count,
+ eval_duration: data.eval_duration
};
- }
+ messages = messages;
- responseMessage.context = data.context ?? null;
- responseMessage.info = {
- total_duration: data.total_duration,
- load_duration: data.load_duration,
- sample_count: data.sample_count,
- sample_duration: data.sample_duration,
- prompt_eval_count: data.prompt_eval_count,
- prompt_eval_duration: data.prompt_eval_duration,
- eval_count: data.eval_count,
- eval_duration: data.eval_duration
- };
- messages = messages;
+ if ($settings.notificationEnabled && !document.hasFocus()) {
+ const notification = new Notification(`${model.id}`, {
+ body: responseMessage.content,
+ icon: `${WEBUI_BASE_URL}/static/favicon.png`
+ });
+ }
- if ($settings.notificationEnabled && !document.hasFocus()) {
- const notification = new Notification(`${model.id}`, {
- body: responseMessage.content,
- icon: `${WEBUI_BASE_URL}/static/favicon.png`
- });
- }
+ if ($settings?.responseAutoCopy ?? false) {
+ copyToClipboard(responseMessage.content);
+ }
- if ($settings?.responseAutoCopy ?? false) {
- copyToClipboard(responseMessage.content);
- }
-
- if ($settings.responseAutoPlayback && !$showCallOverlay) {
- await tick();
- document.getElementById(`speak-button-${responseMessage.id}`)?.click();
+ if ($settings.responseAutoPlayback && !$showCallOverlay) {
+ await tick();
+ document.getElementById(`speak-button-${responseMessage.id}`)?.click();
+ }
}
}
}
+ } catch (error) {
+ console.log(error);
+ if ('detail' in error) {
+ toast.error(error.detail);
+ }
+ break;
}
- } catch (error) {
- console.log(error);
- if ('detail' in error) {
- toast.error(error.detail);
- }
- break;
- }
- if (autoScroll) {
- scrollToBottom();
+ if (autoScroll) {
+ scrollToBottom();
+ }
}
}
+
+ await chatCompletedHandler(
+ _chatId,
+ model.id,
+ responseMessageId,
+ createMessagesList(responseMessageId)
+ );
} else {
if (res !== null) {
const error = await res.json();
@@ -1133,17 +1177,19 @@
await tick();
try {
+ const stream = $settings?.streamResponse ?? true;
const [res, controller] = await generateOpenAIChatCompletion(
localStorage.token,
{
- stream: true,
+ stream: stream,
model: model.id,
- stream_options:
- (model.info?.meta?.capabilities?.usage ?? false)
- ? {
+ ...(stream && (model.info?.meta?.capabilities?.usage ?? false)
+ ? {
+ stream_options: {
include_usage: true
}
- : undefined,
+ }
+ : {}),
messages: [
params?.system || $settings.system || (responseMessage?.userContext ?? null)
? {
@@ -1221,85 +1267,95 @@
scrollToBottom();
if (res && res.ok && res.body) {
- const textStream = await createOpenAITextStream(res.body, $settings.splitLargeChunks);
+ if (!stream) {
+ const response = await res.json();
+ console.log(response);
- for await (const update of textStream) {
- const { value, done, citations, error, usage } = update;
- if (error) {
- await handleOpenAIError(error, null, model, responseMessage);
- break;
- }
- if (done || stopResponseFlag || _chatId !== $chatId) {
- responseMessage.done = true;
- messages = messages;
+ responseMessage.content = response.choices[0].message.content;
+ responseMessage.info = { ...response.usage, openai: true };
+ responseMessage.done = true;
+ } else {
+ const textStream = await createOpenAITextStream(res.body, $settings.splitLargeChunks);
- if (stopResponseFlag) {
- controller.abort('User: Stop Response');
+ for await (const update of textStream) {
+ const { value, done, citations, error, usage } = update;
+ if (error) {
+ await handleOpenAIError(error, null, model, responseMessage);
+ break;
+ }
+ if (done || stopResponseFlag || _chatId !== $chatId) {
+ responseMessage.done = true;
+ messages = messages;
+
+ if (stopResponseFlag) {
+ controller.abort('User: Stop Response');
+ }
+ _response = responseMessage.content;
+ break;
+ }
+
+ if (usage) {
+ responseMessage.info = { ...usage, openai: true };
+ }
+
+ if (citations) {
+ responseMessage.citations = citations;
+ // Only remove status if it was initially set
+ if (model?.info?.meta?.knowledge ?? false) {
+ responseMessage.statusHistory = responseMessage.statusHistory.filter(
+ (status) => status.action !== 'knowledge_search'
+ );
+ }
+ continue;
+ }
+
+ if (responseMessage.content == '' && value == '\n') {
+ continue;
} else {
- const messages = createMessagesList(responseMessageId);
+ responseMessage.content += value;
- await chatCompletedHandler(_chatId, model.id, responseMessageId, messages);
- }
+ if (navigator.vibrate && ($settings?.hapticFeedback ?? false)) {
+ navigator.vibrate(5);
+ }
- _response = responseMessage.content;
-
- break;
- }
-
- if (usage) {
- responseMessage.info = { ...usage, openai: true };
- }
-
- if (citations) {
- responseMessage.citations = citations;
- // Only remove status if it was initially set
- if (model?.info?.meta?.knowledge ?? false) {
- responseMessage.statusHistory = responseMessage.statusHistory.filter(
- (status) => status.action !== 'knowledge_search'
+ const messageContentParts = getMessageContentParts(
+ responseMessage.content,
+ $config?.audio?.tts?.split_on ?? 'punctuation'
);
- }
- continue;
- }
+ messageContentParts.pop();
- if (responseMessage.content == '' && value == '\n') {
- continue;
- } else {
- responseMessage.content += value;
+ // dispatch only last sentence and make sure it hasn't been dispatched before
+ if (
+ messageContentParts.length > 0 &&
+ messageContentParts[messageContentParts.length - 1] !== responseMessage.lastSentence
+ ) {
+ responseMessage.lastSentence = messageContentParts[messageContentParts.length - 1];
+ eventTarget.dispatchEvent(
+ new CustomEvent('chat', {
+ detail: {
+ id: responseMessageId,
+ content: messageContentParts[messageContentParts.length - 1]
+ }
+ })
+ );
+ }
- if (navigator.vibrate && ($settings?.hapticFeedback ?? false)) {
- navigator.vibrate(5);
+ messages = messages;
}
- const messageContentParts = getMessageContentParts(
- responseMessage.content,
- $config?.audio?.tts?.split_on ?? 'punctuation'
- );
- messageContentParts.pop();
-
- // dispatch only last sentence and make sure it hasn't been dispatched before
- if (
- messageContentParts.length > 0 &&
- messageContentParts[messageContentParts.length - 1] !== responseMessage.lastSentence
- ) {
- responseMessage.lastSentence = messageContentParts[messageContentParts.length - 1];
- eventTarget.dispatchEvent(
- new CustomEvent('chat', {
- detail: {
- id: responseMessageId,
- content: messageContentParts[messageContentParts.length - 1]
- }
- })
- );
+ if (autoScroll) {
+ scrollToBottom();
}
-
- messages = messages;
- }
-
- if (autoScroll) {
- scrollToBottom();
}
}
+ await chatCompletedHandler(
+ _chatId,
+ model.id,
+ responseMessageId,
+ createMessagesList(responseMessageId)
+ );
+
if ($settings.notificationEnabled && !document.hasFocus()) {
const notification = new Notification(`${model.id}`, {
body: responseMessage.content,
@@ -1703,7 +1759,6 @@
{title}
bind:selectedModels
bind:showModelSelector
- bind:showControls
shareEnabled={messages.length > 0}
{chat}
{initNewChat}
@@ -1713,7 +1768,7 @@