diff --git a/.github/ISSUE_TEMPLATE/bug_report.yaml b/.github/ISSUE_TEMPLATE/bug_report.yaml index feecd16c7..a1ea5c8e2 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.yaml +++ b/.github/ISSUE_TEMPLATE/bug_report.yaml @@ -89,9 +89,20 @@ body: required: true - label: I have included the Docker container logs. required: true - - label: I have listed steps to reproduce the bug in detail. + - label: I have **provided every relevant configuration, setting, and environment variable used in my setup.** + required: true + - label: I have clearly **listed every relevant configuration, custom setting, environment variable, and command-line option that influences my setup** (such as Docker Compose overrides, .env values, browser settings, authentication configurations, etc). + required: true + - label: | + I have documented **step-by-step reproduction instructions that are precise, sequential, and leave nothing to interpretation**. My steps: + - Start with the initial platform/version/OS and dependencies used, + - Specify exact install/launch/configure commands, + - List URLs visited, user input (incl. example values/emails/passwords if needed), + - Describe all options and toggles enabled or changed, + - Include any files or environmental changes, + - Identify the expected and actual result at each stage, + - Ensure any reasonably skilled user can follow and hit the same issue. required: true - - type: textarea id: expected-behavior attributes: @@ -112,15 +123,25 @@ body: id: reproduction-steps attributes: label: Steps to Reproduce - description: Providing clear, step-by-step instructions helps us reproduce and fix the issue faster. If we can't reproduce it, we can't fix it. + description: | + Please provide a **very detailed, step-by-step guide** to reproduce the issue. Your instructions should be so clear and precise that anyone can follow them without guesswork. Include every relevant detail—settings, configuration options, exact commands used, values entered, and any prerequisites or environment variables. + **If full reproduction steps and all relevant settings are not provided, your issue may not be addressed.** + placeholder: | - 1. Go to '...' - 2. Click on '...' - 3. Scroll down to '...' - 4. See the error message '...' + Example (include every detail): + 1. Start with a clean Ubuntu 22.04 install. + 2. Install Docker v24.0.5 and start the service. + 3. Clone the Open WebUI repo (git clone ...). + 4. Use the Docker Compose file without modifications. + 5. Open browser Chrome 115.0 in incognito mode. + 6. Go to http://localhost:8080 and log in with user "test@example.com". + 7. Set the language to "English" and theme to "Dark". + 8. Attempt to connect to Ollama at "http://localhost:11434". + 9. Observe that the error message "Connection refused" appears at the top right. + + Please list each step carefully and include all relevant configuration, settings, and options. validations: required: true - - type: textarea id: logs-screenshots attributes: diff --git a/CHANGELOG.md b/CHANGELOG.md index d17f2cf2e..95a795d51 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,54 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.6.11] - 2025-05-27 + +### Added + +- 🟢 **Ollama Model Status Indicator in Model Selector**: Instantly see which Ollama models are currently loaded with a clear indicator in the model selector, helping you stay organized and optimize local model usage. +- 🗑️ **Unload Ollama Model Directly from Model Selector**: Easily release memory and resources by unloading any loaded Ollama model right in the model selector—streamline hardware management without switching pages. +- 🗣️ **User-Configurable Speech-to-Text Language Setting**: Improve transcription accuracy by letting individual users explicitly set their preferred STT language in their settings—ideal for multilingual teams and clear audio capture. +- ⚡ **Granular Audio Playback Speed Control**: Instead of just presets, you can now choose granular audio speed using a numeric input, giving you complete control over playback pace in transcriptions and media reviews. +- 📦 **GZip, Brotli, ZStd Compression Middleware**: Enjoy significantly faster page loads and reduced bandwidth usage with new server-side compression—giving users a snappier, more efficient experience. +- 🏷️ **Configurable Weight for BM25 in Hybrid Search**: Fine-tune search relevance by adjusting the weight for BM25 inside hybrid search from the UI, letting you tailor knowledge search results to your workflow. +- 🧪 **Bypass File Creation with CTRL + SHIFT + V**: When “Paste Large Text as File” is enabled, use CTRL + SHIFT + V to skip the file creation dialog and instantly upload text as a file—perfect for rapid document prep. +- 🌐 **Bypass Web Loader in Web Search**: Choose to bypass web content loading and use snippets directly in web search for faster, more reliable results when page loads are slow or blocked. +- 🚀 **Environment Variable: WEBUI_AUTH_TRUSTED_GROUPS_HEADER**: Now sync and manage user groups directly via trusted HTTP header, unlocking smoother single sign-on and identity integrations for organizations. +- 🏢 **Workspace Models Visibility Controls**: You can now hide workspace-level models from both the model selector and shared environments—keep your team focused and reduce clutter from rarely-used endpoints. +- 🛡️ **Copy Model Link**: You can now copy a direct link to any model—including those hidden from the selector—making sharing and onboarding others more seamless. +- 🔗 **Load Function Directly from URL**: Simplify custom function management—just paste any GitHub function URL into Open WebUI and import new functions in seconds. +- ⚙️ **Custom Name/Description for External Tool Servers**: Personalize and clarify external tool servers by assigning custom names and descriptions, making it easier to manage integrations in large-scale workspaces. +- 🌍 **Custom OpenAPI JSON URL Support for Tool Servers**: Supports specifying any custom OpenAPI JSON URL, unlocking more flexible integration with any backend for tool calls. +- 📊 **Source Field Now Displays in Non-Streaming Responses with Attachments**: When files or knowledge are attached, the "source" field now appears for all responses, even in non-streaming mode—enabling improved citation workflow. +- 🎛 **Pinned Chats**: Reduced payload size on pinned chat requests—leading to faster load times and less data usage, especially on busy warehouses. +- 🛠 **Import/Export Default Prompt Suggestions**: Enjoy one-click import/export of prompt suggestions, making it much easier to share, reuse, and manage best practices across teams or deployments. +- 🍰 **Banners Now Sortable from Admin Settings**: Quickly re-order or prioritize banners, letting you highlight the most critical info for your team. +- 🛠 **Advanced Chat Parameters—Clearer Ollama Support Labels**: Parameters and advanced settings now explicitly indicate if they are Ollama-specific, reducing confusion and improving setup accuracy. +- 🤏 **Scroll Bar Thumb Improved for Better Visibility**: Enhanced scrollbar styling makes navigation more accessible and visually intuitive. +- 🗄️ **Modal Redesign for Archived and User Chat Listings**: Clean, modern modal interface for browsing archived and user-specific chats makes locating conversations faster and more pleasant. +- 📝 **Add/Edit Memory Modal UX**: Memory modals are now larger and have resizable input fields, supporting easier editing of long or complex memory content. +- 🏆 **Translation & Localization Enhancements**: Major upgrades to Chinese (Simplified & Traditional), Korean, Russian, German, Danish, Finnish—not just fixing typos, but consistency, tone, and terminology for a more natural native-language experience. +- ⚡ **General Backend Stability & Security Enhancements**: Various backend refinements ensure a more resilient, reliable, and secure platform for smoother operation and peace of mind. + +### Fixed + +- 🖼️ **Image Generation with Allowed File Extensions Now Works Reliably**: Ensure seamless image generation even when strict file extension rules are set—no more blocked creative workflows due to technical hiccups. +- 🗂 **Remove Leading Dot for File Extension Check**: Fixed an issue where file validation failed because of a leading dot, making file uploads and knowledge management more robust. +- 🏷️ **Correct Local/External Model Classification**: The platform now accurately distinguishes between local and external models—preventing local models from showing up as external (and vice versa)—ensuring seamless setup, clarity, and management of your AI model endpoints. +- 📄 **External Document Loader Now Functions as Intended**: External document loaders are reliably invoked, ensuring smoother knowledge ingestion from external sources—expanding your RAG and knowledge workflows. +- 🎯 **Correct Handling of Toggle Filters**: Toggle filters are now robustly managed, preventing accidental auto-activation and ensuring user preferences are always respected. +- 🗃 **S3 Tagging Character Restrictions Fixed**: Tags for files in S3 now automatically meet Amazon’s allowed character set, avoiding upload errors and ensuring cross-cloud compatibility. +- 🛡️ **Authentication Now Uses Password Hash When Duplicate Emails Exist**: Ensures account security and prevents access issues if duplicate emails are present in your system. + +### Changed + +- 🧩 **Admin Settings: OAuth Redirects Now Use WEBUI_URL**: The OAuth redirect URL is now based on the explicitly set WEBUI_URL, ensuring single sign-on and identity provider integrations always send users to the correct frontend. + +### Removed + +- 💡 **Duplicate/Typo Component Removals**: Obsolete components have been cleaned up, reducing confusion and improving overall code quality for the team. +- 🚫 **Streaming Upsert in Pinecone Removed**: Removed streaming upsert references for better compatibility and future-proofing with latest Pinecone SDK updates. + ## [0.6.10] - 2025-05-19 ### Added diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index b1955b056..441c99efb 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -1928,6 +1928,11 @@ RAG_RELEVANCE_THRESHOLD = PersistentConfig( "rag.relevance_threshold", float(os.environ.get("RAG_RELEVANCE_THRESHOLD", "0.0")), ) +RAG_HYBRID_BM25_WEIGHT = PersistentConfig( + "RAG_HYBRID_BM25_WEIGHT", + "rag.hybrid_bm25_weight", + float(os.environ.get("RAG_HYBRID_BM25_WEIGHT", "0.5")), +) ENABLE_RAG_HYBRID_SEARCH = PersistentConfig( "ENABLE_RAG_HYBRID_SEARCH", @@ -2177,6 +2182,12 @@ BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL = PersistentConfig( ) +BYPASS_WEB_SEARCH_WEB_LOADER = PersistentConfig( + "BYPASS_WEB_SEARCH_WEB_LOADER", + "rag.web.search.bypass_web_loader", + os.getenv("BYPASS_WEB_SEARCH_WEB_LOADER", "False").lower() == "true", +) + WEB_SEARCH_RESULT_COUNT = PersistentConfig( "WEB_SEARCH_RESULT_COUNT", "rag.web.search.result_count", @@ -2202,6 +2213,7 @@ WEB_SEARCH_CONCURRENT_REQUESTS = PersistentConfig( int(os.getenv("WEB_SEARCH_CONCURRENT_REQUESTS", "10")), ) + WEB_LOADER_ENGINE = PersistentConfig( "WEB_LOADER_ENGINE", "rag.web.loader.engine", diff --git a/backend/open_webui/env.py b/backend/open_webui/env.py index 59557349e..fcfccaedf 100644 --- a/backend/open_webui/env.py +++ b/backend/open_webui/env.py @@ -349,6 +349,10 @@ WEBUI_AUTH_TRUSTED_EMAIL_HEADER = os.environ.get( "WEBUI_AUTH_TRUSTED_EMAIL_HEADER", None ) WEBUI_AUTH_TRUSTED_NAME_HEADER = os.environ.get("WEBUI_AUTH_TRUSTED_NAME_HEADER", None) +WEBUI_AUTH_TRUSTED_GROUPS_HEADER = os.environ.get( + "WEBUI_AUTH_TRUSTED_GROUPS_HEADER", None +) + BYPASS_MODEL_ACCESS_CONTROL = ( os.environ.get("BYPASS_MODEL_ACCESS_CONTROL", "False").lower() == "true" diff --git a/backend/open_webui/functions.py b/backend/open_webui/functions.py index 340b60ba4..aa7dbccf9 100644 --- a/backend/open_webui/functions.py +++ b/backend/open_webui/functions.py @@ -54,11 +54,8 @@ log.setLevel(SRC_LOG_LEVELS["MAIN"]) def get_function_module_by_id(request: Request, pipe_id: str): # Check if function is already loaded - if pipe_id not in request.app.state.FUNCTIONS: - function_module, _, _ = load_function_module_by_id(pipe_id) - request.app.state.FUNCTIONS[pipe_id] = function_module - else: - function_module = request.app.state.FUNCTIONS[pipe_id] + function_module, _, _ = load_function_module_by_id(pipe_id) + request.app.state.FUNCTIONS[pipe_id] = function_module if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): valves = Functions.get_function_valves_by_id(pipe_id) diff --git a/backend/open_webui/internal/wrappers.py b/backend/open_webui/internal/wrappers.py index ccc62b9a5..5cf352930 100644 --- a/backend/open_webui/internal/wrappers.py +++ b/backend/open_webui/internal/wrappers.py @@ -43,7 +43,7 @@ class ReconnectingPostgresqlDatabase(CustomReconnectMixin, PostgresqlDatabase): def register_connection(db_url): - db = connect(db_url, unquote_password=True) + db = connect(db_url, unquote_user=True, unquote_password=True) if isinstance(db, PostgresqlDatabase): # Enable autoconnect for SQLite databases, managed by Peewee db.autoconnect = True @@ -51,7 +51,7 @@ def register_connection(db_url): log.info("Connected to PostgreSQL database") # Get the connection details - connection = parse(db_url, unquote_password=True) + connection = parse(db_url, unquote_user=True, unquote_password=True) # Use our custom database class that supports reconnection db = ReconnectingPostgresqlDatabase(**connection) diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index a5aee4bb8..999993e84 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -40,6 +40,8 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, RedirectResponse from fastapi.staticfiles import StaticFiles +from starlette_compress import CompressMiddleware + from starlette.exceptions import HTTPException as StarletteHTTPException from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.sessions import SessionMiddleware @@ -196,7 +198,10 @@ from open_webui.config import ( RAG_RERANKING_MODEL_TRUST_REMOTE_CODE, RAG_EMBEDDING_ENGINE, RAG_EMBEDDING_BATCH_SIZE, + RAG_TOP_K, + RAG_TOP_K_RERANKER, RAG_RELEVANCE_THRESHOLD, + RAG_HYBRID_BM25_WEIGHT, RAG_ALLOWED_FILE_EXTENSIONS, RAG_FILE_MAX_COUNT, RAG_FILE_MAX_SIZE, @@ -217,8 +222,6 @@ from open_webui.config import ( DOCUMENT_INTELLIGENCE_ENDPOINT, DOCUMENT_INTELLIGENCE_KEY, MISTRAL_OCR_API_KEY, - RAG_TOP_K, - RAG_TOP_K_RERANKER, RAG_TEXT_SPLITTER, TIKTOKEN_ENCODING_NAME, PDF_EXTRACT_IMAGES, @@ -228,6 +231,7 @@ from open_webui.config import ( ENABLE_WEB_SEARCH, WEB_SEARCH_ENGINE, BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL, + BYPASS_WEB_SEARCH_WEB_LOADER, WEB_SEARCH_RESULT_COUNT, WEB_SEARCH_CONCURRENT_REQUESTS, WEB_SEARCH_TRUST_ENV, @@ -646,6 +650,7 @@ app.state.FUNCTIONS = {} app.state.config.TOP_K = RAG_TOP_K app.state.config.TOP_K_RERANKER = RAG_TOP_K_RERANKER app.state.config.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD +app.state.config.HYBRID_BM25_WEIGHT = RAG_HYBRID_BM25_WEIGHT app.state.config.ALLOWED_FILE_EXTENSIONS = RAG_ALLOWED_FILE_EXTENSIONS app.state.config.FILE_MAX_SIZE = RAG_FILE_MAX_SIZE app.state.config.FILE_MAX_COUNT = RAG_FILE_MAX_COUNT @@ -707,6 +712,7 @@ app.state.config.WEB_SEARCH_TRUST_ENV = WEB_SEARCH_TRUST_ENV app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL = ( BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL ) +app.state.config.BYPASS_WEB_SEARCH_WEB_LOADER = BYPASS_WEB_SEARCH_WEB_LOADER app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION = ENABLE_GOOGLE_DRIVE_INTEGRATION app.state.config.ENABLE_ONEDRIVE_INTEGRATION = ENABLE_ONEDRIVE_INTEGRATION @@ -959,6 +965,7 @@ class RedirectMiddleware(BaseHTTPMiddleware): # Add the middleware to the app +app.add_middleware(CompressMiddleware) app.add_middleware(RedirectMiddleware) app.add_middleware(SecurityHeadersMiddleware) diff --git a/backend/open_webui/models/auths.py b/backend/open_webui/models/auths.py index f07c36c73..3ad88bc11 100644 --- a/backend/open_webui/models/auths.py +++ b/backend/open_webui/models/auths.py @@ -129,12 +129,16 @@ class AuthsTable: def authenticate_user(self, email: str, password: str) -> Optional[UserModel]: log.info(f"authenticate_user: {email}") + + user = Users.get_user_by_email(email) + if not user: + return None + try: with get_db() as db: - auth = db.query(Auth).filter_by(email=email, active=True).first() + auth = db.query(Auth).filter_by(id=user.id, active=True).first() if auth: if verify_password(password, auth.password): - user = Users.get_user_by_id(auth.id) return user else: return None @@ -155,8 +159,8 @@ class AuthsTable: except Exception: return False - def authenticate_user_by_trusted_header(self, email: str) -> Optional[UserModel]: - log.info(f"authenticate_user_by_trusted_header: {email}") + def authenticate_user_by_email(self, email: str) -> Optional[UserModel]: + log.info(f"authenticate_user_by_email: {email}") try: with get_db() as db: auth = db.query(Auth).filter_by(email=email, active=True).first() diff --git a/backend/open_webui/models/chats.py b/backend/open_webui/models/chats.py index 4b4f37197..0ac53a023 100644 --- a/backend/open_webui/models/chats.py +++ b/backend/open_webui/models/chats.py @@ -377,22 +377,47 @@ class ChatTable: return False def get_archived_chat_list_by_user_id( - self, user_id: str, skip: int = 0, limit: int = 50 + self, + user_id: str, + filter: Optional[dict] = None, + skip: int = 0, + limit: int = 50, ) -> list[ChatModel]: + with get_db() as db: - all_chats = ( - db.query(Chat) - .filter_by(user_id=user_id, archived=True) - .order_by(Chat.updated_at.desc()) - # .limit(limit).offset(skip) - .all() - ) + query = db.query(Chat).filter_by(user_id=user_id, archived=True) + + if filter: + query_key = filter.get("query") + if query_key: + query = query.filter(Chat.title.ilike(f"%{query_key}%")) + + order_by = filter.get("order_by") + direction = filter.get("direction") + + if order_by and direction and getattr(Chat, order_by): + if direction.lower() == "asc": + query = query.order_by(getattr(Chat, order_by).asc()) + elif direction.lower() == "desc": + query = query.order_by(getattr(Chat, order_by).desc()) + else: + raise ValueError("Invalid direction for ordering") + else: + query = query.order_by(Chat.updated_at.desc()) + + if skip: + query = query.offset(skip) + if limit: + query = query.limit(limit) + + all_chats = query.all() return [ChatModel.model_validate(chat) for chat in all_chats] def get_chat_list_by_user_id( self, user_id: str, include_archived: bool = False, + filter: Optional[dict] = None, skip: int = 0, limit: int = 50, ) -> list[ChatModel]: @@ -401,7 +426,23 @@ class ChatTable: if not include_archived: query = query.filter_by(archived=False) - query = query.order_by(Chat.updated_at.desc()) + if filter: + query_key = filter.get("query") + if query_key: + query = query.filter(Chat.title.ilike(f"%{query_key}%")) + + order_by = filter.get("order_by") + direction = filter.get("direction") + + if order_by and direction and getattr(Chat, order_by): + if direction.lower() == "asc": + query = query.order_by(getattr(Chat, order_by).asc()) + elif direction.lower() == "desc": + query = query.order_by(getattr(Chat, order_by).desc()) + else: + raise ValueError("Invalid direction for ordering") + else: + query = query.order_by(Chat.updated_at.desc()) if skip: query = query.offset(skip) @@ -542,7 +583,9 @@ class ChatTable: search_text = search_text.lower().strip() if not search_text: - return self.get_chat_list_by_user_id(user_id, include_archived, skip, limit) + return self.get_chat_list_by_user_id( + user_id, include_archived, filter={}, skip=skip, limit=limit + ) search_text_words = search_text.split(" ") diff --git a/backend/open_webui/models/functions.py b/backend/open_webui/models/functions.py index 8cbfc5de7..e98771fa0 100644 --- a/backend/open_webui/models/functions.py +++ b/backend/open_webui/models/functions.py @@ -108,6 +108,54 @@ class FunctionsTable: log.exception(f"Error creating a new function: {e}") return None + def sync_functions( + self, user_id: str, functions: list[FunctionModel] + ) -> list[FunctionModel]: + # Synchronize functions for a user by updating existing ones, inserting new ones, and removing those that are no longer present. + try: + with get_db() as db: + # Get existing functions + existing_functions = db.query(Function).all() + existing_ids = {func.id for func in existing_functions} + + # Prepare a set of new function IDs + new_function_ids = {func.id for func in functions} + + # Update or insert functions + for func in functions: + if func.id in existing_ids: + db.query(Function).filter_by(id=func.id).update( + { + **func.model_dump(), + "user_id": user_id, + "updated_at": int(time.time()), + } + ) + else: + new_func = Function( + **{ + **func.model_dump(), + "user_id": user_id, + "updated_at": int(time.time()), + } + ) + db.add(new_func) + + # Remove functions that are no longer present + for func in existing_functions: + if func.id not in new_function_ids: + db.delete(func) + + db.commit() + + return [ + FunctionModel.model_validate(func) + for func in db.query(Function).all() + ] + except Exception as e: + log.exception(f"Error syncing functions for user {user_id}: {e}") + return [] + def get_function_by_id(self, id: str) -> Optional[FunctionModel]: try: with get_db() as db: diff --git a/backend/open_webui/models/groups.py b/backend/open_webui/models/groups.py index 763340fbc..df79284cf 100644 --- a/backend/open_webui/models/groups.py +++ b/backend/open_webui/models/groups.py @@ -207,5 +207,43 @@ class GroupTable: except Exception: return False + def sync_user_groups_by_group_names( + self, user_id: str, group_names: list[str] + ) -> bool: + with get_db() as db: + try: + groups = db.query(Group).filter(Group.name.in_(group_names)).all() + group_ids = [group.id for group in groups] + + # Remove user from groups not in the new list + existing_groups = self.get_groups_by_member_id(user_id) + + for group in existing_groups: + if group.id not in group_ids: + group.user_ids.remove(user_id) + db.query(Group).filter_by(id=group.id).update( + { + "user_ids": group.user_ids, + "updated_at": int(time.time()), + } + ) + + # Add user to new groups + for group in groups: + if user_id not in group.user_ids: + group.user_ids.append(user_id) + db.query(Group).filter_by(id=group.id).update( + { + "user_ids": group.user_ids, + "updated_at": int(time.time()), + } + ) + + db.commit() + return True + except Exception as e: + log.exception(e) + return False + Groups = GroupTable() diff --git a/backend/open_webui/retrieval/loaders/main.py b/backend/open_webui/retrieval/loaders/main.py index c5f0b4e5e..22397b3b4 100644 --- a/backend/open_webui/retrieval/loaders/main.py +++ b/backend/open_webui/retrieval/loaders/main.py @@ -226,7 +226,7 @@ class Loader: api_key=self.kwargs.get("EXTERNAL_DOCUMENT_LOADER_API_KEY"), mime_type=file_content_type, ) - if self.engine == "tika" and self.kwargs.get("TIKA_SERVER_URL"): + elif self.engine == "tika" and self.kwargs.get("TIKA_SERVER_URL"): if self._is_text_file(file_ext, file_content_type): loader = TextLoader(file_path, autodetect_encoding=True) else: diff --git a/backend/open_webui/retrieval/loaders/mistral.py b/backend/open_webui/retrieval/loaders/mistral.py index 8f3a960a2..67641d050 100644 --- a/backend/open_webui/retrieval/loaders/mistral.py +++ b/backend/open_webui/retrieval/loaders/mistral.py @@ -1,8 +1,12 @@ import requests +import aiohttp +import asyncio import logging import os import sys +import time from typing import List, Dict, Any +from contextlib import asynccontextmanager from langchain_core.documents import Document from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL @@ -14,18 +18,29 @@ log.setLevel(SRC_LOG_LEVELS["RAG"]) class MistralLoader: """ + Enhanced Mistral OCR loader with both sync and async support. Loads documents by processing them through the Mistral OCR API. """ BASE_API_URL = "https://api.mistral.ai/v1" - def __init__(self, api_key: str, file_path: str): + def __init__( + self, + api_key: str, + file_path: str, + timeout: int = 300, # 5 minutes default + max_retries: int = 3, + enable_debug_logging: bool = False, + ): """ - Initializes the loader. + Initializes the loader with enhanced features. Args: api_key: Your Mistral API key. file_path: The local path to the PDF file to process. + timeout: Request timeout in seconds. + max_retries: Maximum number of retry attempts. + enable_debug_logging: Enable detailed debug logs. """ if not api_key: raise ValueError("API key cannot be empty.") @@ -34,7 +49,23 @@ class MistralLoader: self.api_key = api_key self.file_path = file_path - self.headers = {"Authorization": f"Bearer {self.api_key}"} + self.timeout = timeout + self.max_retries = max_retries + self.debug = enable_debug_logging + + # Pre-compute file info for performance + self.file_name = os.path.basename(file_path) + self.file_size = os.path.getsize(file_path) + + self.headers = { + "Authorization": f"Bearer {self.api_key}", + "User-Agent": "OpenWebUI-MistralLoader/2.0", + } + + def _debug_log(self, message: str, *args) -> None: + """Conditional debug logging for performance.""" + if self.debug: + log.debug(message, *args) def _handle_response(self, response: requests.Response) -> Dict[str, Any]: """Checks response status and returns JSON content.""" @@ -54,24 +85,89 @@ class MistralLoader: log.error(f"JSON decode error: {json_err} - Response: {response.text}") raise # Re-raise after logging + async def _handle_response_async( + self, response: aiohttp.ClientResponse + ) -> Dict[str, Any]: + """Async version of response handling with better error info.""" + try: + response.raise_for_status() + + # Check content type + content_type = response.headers.get("content-type", "") + if "application/json" not in content_type: + if response.status == 204: + return {} + text = await response.text() + raise ValueError( + f"Unexpected content type: {content_type}, body: {text[:200]}..." + ) + + return await response.json() + + except aiohttp.ClientResponseError as e: + error_text = await response.text() if response else "No response" + log.error(f"HTTP {e.status}: {e.message} - Response: {error_text[:500]}") + raise + except aiohttp.ClientError as e: + log.error(f"Client error: {e}") + raise + except Exception as e: + log.error(f"Unexpected error processing response: {e}") + raise + + def _retry_request_sync(self, request_func, *args, **kwargs): + """Synchronous retry logic with exponential backoff.""" + for attempt in range(self.max_retries): + try: + return request_func(*args, **kwargs) + except (requests.exceptions.RequestException, Exception) as e: + if attempt == self.max_retries - 1: + raise + + wait_time = (2**attempt) + 0.5 + log.warning( + f"Request failed (attempt {attempt + 1}/{self.max_retries}): {e}. Retrying in {wait_time}s..." + ) + time.sleep(wait_time) + + async def _retry_request_async(self, request_func, *args, **kwargs): + """Async retry logic with exponential backoff.""" + for attempt in range(self.max_retries): + try: + return await request_func(*args, **kwargs) + except (aiohttp.ClientError, asyncio.TimeoutError) as e: + if attempt == self.max_retries - 1: + raise + + wait_time = (2**attempt) + 0.5 + log.warning( + f"Request failed (attempt {attempt + 1}/{self.max_retries}): {e}. Retrying in {wait_time}s..." + ) + await asyncio.sleep(wait_time) + def _upload_file(self) -> str: - """Uploads the file to Mistral for OCR processing.""" + """Uploads the file to Mistral for OCR processing (sync version).""" log.info("Uploading file to Mistral API") url = f"{self.BASE_API_URL}/files" file_name = os.path.basename(self.file_path) - try: + def upload_request(): with open(self.file_path, "rb") as f: files = {"file": (file_name, f, "application/pdf")} data = {"purpose": "ocr"} - upload_headers = self.headers.copy() # Avoid modifying self.headers - response = requests.post( - url, headers=upload_headers, files=files, data=data + url, + headers=self.headers, + files=files, + data=data, + timeout=self.timeout, ) - response_data = self._handle_response(response) + return self._handle_response(response) + + try: + response_data = self._retry_request_sync(upload_request) file_id = response_data.get("id") if not file_id: raise ValueError("File ID not found in upload response.") @@ -81,16 +177,66 @@ class MistralLoader: log.error(f"Failed to upload file: {e}") raise + async def _upload_file_async(self, session: aiohttp.ClientSession) -> str: + """Async file upload with streaming for better memory efficiency.""" + url = f"{self.BASE_API_URL}/files" + + async def upload_request(): + # Create multipart writer for streaming upload + writer = aiohttp.MultipartWriter("form-data") + + # Add purpose field + purpose_part = writer.append("ocr") + purpose_part.set_content_disposition("form-data", name="purpose") + + # Add file part with streaming + file_part = writer.append_payload( + aiohttp.streams.FilePayload( + self.file_path, + filename=self.file_name, + content_type="application/pdf", + ) + ) + file_part.set_content_disposition( + "form-data", name="file", filename=self.file_name + ) + + self._debug_log( + f"Uploading file: {self.file_name} ({self.file_size:,} bytes)" + ) + + async with session.post( + url, + data=writer, + headers=self.headers, + timeout=aiohttp.ClientTimeout(total=self.timeout), + ) as response: + return await self._handle_response_async(response) + + response_data = await self._retry_request_async(upload_request) + + file_id = response_data.get("id") + if not file_id: + raise ValueError("File ID not found in upload response.") + + log.info(f"File uploaded successfully. File ID: {file_id}") + return file_id + def _get_signed_url(self, file_id: str) -> str: - """Retrieves a temporary signed URL for the uploaded file.""" + """Retrieves a temporary signed URL for the uploaded file (sync version).""" log.info(f"Getting signed URL for file ID: {file_id}") url = f"{self.BASE_API_URL}/files/{file_id}/url" params = {"expiry": 1} signed_url_headers = {**self.headers, "Accept": "application/json"} + def url_request(): + response = requests.get( + url, headers=signed_url_headers, params=params, timeout=self.timeout + ) + return self._handle_response(response) + try: - response = requests.get(url, headers=signed_url_headers, params=params) - response_data = self._handle_response(response) + response_data = self._retry_request_sync(url_request) signed_url = response_data.get("url") if not signed_url: raise ValueError("Signed URL not found in response.") @@ -100,8 +246,36 @@ class MistralLoader: log.error(f"Failed to get signed URL: {e}") raise + async def _get_signed_url_async( + self, session: aiohttp.ClientSession, file_id: str + ) -> str: + """Async signed URL retrieval.""" + url = f"{self.BASE_API_URL}/files/{file_id}/url" + params = {"expiry": 1} + + headers = {**self.headers, "Accept": "application/json"} + + async def url_request(): + self._debug_log(f"Getting signed URL for file ID: {file_id}") + async with session.get( + url, + headers=headers, + params=params, + timeout=aiohttp.ClientTimeout(total=self.timeout), + ) as response: + return await self._handle_response_async(response) + + response_data = await self._retry_request_async(url_request) + + signed_url = response_data.get("url") + if not signed_url: + raise ValueError("Signed URL not found in response.") + + self._debug_log("Signed URL received successfully") + return signed_url + def _process_ocr(self, signed_url: str) -> Dict[str, Any]: - """Sends the signed URL to the OCR endpoint for processing.""" + """Sends the signed URL to the OCR endpoint for processing (sync version).""" log.info("Processing OCR via Mistral API") url = f"{self.BASE_API_URL}/ocr" ocr_headers = { @@ -118,43 +292,198 @@ class MistralLoader: "include_image_base64": False, } + def ocr_request(): + response = requests.post( + url, headers=ocr_headers, json=payload, timeout=self.timeout + ) + return self._handle_response(response) + try: - response = requests.post(url, headers=ocr_headers, json=payload) - ocr_response = self._handle_response(response) + ocr_response = self._retry_request_sync(ocr_request) log.info("OCR processing done.") - log.debug("OCR response: %s", ocr_response) + self._debug_log("OCR response: %s", ocr_response) return ocr_response except Exception as e: log.error(f"Failed during OCR processing: {e}") raise + async def _process_ocr_async( + self, session: aiohttp.ClientSession, signed_url: str + ) -> Dict[str, Any]: + """Async OCR processing with timing metrics.""" + url = f"{self.BASE_API_URL}/ocr" + + headers = { + **self.headers, + "Content-Type": "application/json", + "Accept": "application/json", + } + + payload = { + "model": "mistral-ocr-latest", + "document": { + "type": "document_url", + "document_url": signed_url, + }, + "include_image_base64": False, + } + + async def ocr_request(): + log.info("Starting OCR processing via Mistral API") + start_time = time.time() + + async with session.post( + url, + json=payload, + headers=headers, + timeout=aiohttp.ClientTimeout(total=self.timeout), + ) as response: + ocr_response = await self._handle_response_async(response) + + processing_time = time.time() - start_time + log.info(f"OCR processing completed in {processing_time:.2f}s") + + return ocr_response + + return await self._retry_request_async(ocr_request) + def _delete_file(self, file_id: str) -> None: - """Deletes the file from Mistral storage.""" + """Deletes the file from Mistral storage (sync version).""" log.info(f"Deleting uploaded file ID: {file_id}") url = f"{self.BASE_API_URL}/files/{file_id}" - # No specific Accept header needed, default or Authorization is usually sufficient try: - response = requests.delete(url, headers=self.headers) - delete_response = self._handle_response( - response - ) # Check status, ignore response body unless needed - log.info( - f"File deleted successfully: {delete_response}" - ) # Log the response if available + response = requests.delete(url, headers=self.headers, timeout=30) + delete_response = self._handle_response(response) + log.info(f"File deleted successfully: {delete_response}") except Exception as e: # Log error but don't necessarily halt execution if deletion fails log.error(f"Failed to delete file ID {file_id}: {e}") - # Depending on requirements, you might choose to raise the error here + + async def _delete_file_async( + self, session: aiohttp.ClientSession, file_id: str + ) -> None: + """Async file deletion with error tolerance.""" + try: + + async def delete_request(): + self._debug_log(f"Deleting file ID: {file_id}") + async with session.delete( + url=f"{self.BASE_API_URL}/files/{file_id}", + headers=self.headers, + timeout=aiohttp.ClientTimeout( + total=30 + ), # Shorter timeout for cleanup + ) as response: + return await self._handle_response_async(response) + + await self._retry_request_async(delete_request) + self._debug_log(f"File {file_id} deleted successfully") + + except Exception as e: + # Don't fail the entire process if cleanup fails + log.warning(f"Failed to delete file ID {file_id}: {e}") + + @asynccontextmanager + async def _get_session(self): + """Context manager for HTTP session with optimized settings.""" + connector = aiohttp.TCPConnector( + limit=10, # Total connection limit + limit_per_host=5, # Per-host connection limit + ttl_dns_cache=300, # DNS cache TTL + use_dns_cache=True, + keepalive_timeout=30, + enable_cleanup_closed=True, + ) + + async with aiohttp.ClientSession( + connector=connector, + timeout=aiohttp.ClientTimeout(total=self.timeout), + headers={"User-Agent": "OpenWebUI-MistralLoader/2.0"}, + ) as session: + yield session + + def _process_results(self, ocr_response: Dict[str, Any]) -> List[Document]: + """Process OCR results into Document objects with enhanced metadata.""" + pages_data = ocr_response.get("pages") + if not pages_data: + log.warning("No pages found in OCR response.") + return [ + Document( + page_content="No text content found", metadata={"error": "no_pages"} + ) + ] + + documents = [] + total_pages = len(pages_data) + skipped_pages = 0 + + for page_data in pages_data: + page_content = page_data.get("markdown") + page_index = page_data.get("index") # API uses 0-based index + + if page_content is not None and page_index is not None: + # Clean up content efficiently + cleaned_content = ( + page_content.strip() + if isinstance(page_content, str) + else str(page_content) + ) + + if cleaned_content: # Only add non-empty pages + documents.append( + Document( + page_content=cleaned_content, + metadata={ + "page": page_index, # 0-based index from API + "page_label": page_index + + 1, # 1-based label for convenience + "total_pages": total_pages, + "file_name": self.file_name, + "file_size": self.file_size, + "processing_engine": "mistral-ocr", + }, + ) + ) + else: + skipped_pages += 1 + self._debug_log(f"Skipping empty page {page_index}") + else: + skipped_pages += 1 + self._debug_log( + f"Skipping page due to missing 'markdown' or 'index'. Data: {page_data}" + ) + + if skipped_pages > 0: + log.info( + f"Processed {len(documents)} pages, skipped {skipped_pages} empty/invalid pages" + ) + + if not documents: + # Case where pages existed but none had valid markdown/index + log.warning( + "OCR response contained pages, but none had valid content/index." + ) + return [ + Document( + page_content="No valid text content found in document", + metadata={"error": "no_valid_pages", "total_pages": total_pages}, + ) + ] + + return documents def load(self) -> List[Document]: """ Executes the full OCR workflow: upload, get URL, process OCR, delete file. + Synchronous version for backward compatibility. Returns: A list of Document objects, one for each page processed. """ file_id = None + start_time = time.time() + try: # 1. Upload file file_id = self._upload_file() @@ -166,53 +495,30 @@ class MistralLoader: ocr_response = self._process_ocr(signed_url) # 4. Process results - pages_data = ocr_response.get("pages") - if not pages_data: - log.warning("No pages found in OCR response.") - return [Document(page_content="No text content found", metadata={})] + documents = self._process_results(ocr_response) - documents = [] - total_pages = len(pages_data) - for page_data in pages_data: - page_content = page_data.get("markdown") - page_index = page_data.get("index") # API uses 0-based index - - if page_content is not None and page_index is not None: - documents.append( - Document( - page_content=page_content, - metadata={ - "page": page_index, # 0-based index from API - "page_label": page_index - + 1, # 1-based label for convenience - "total_pages": total_pages, - # Add other relevant metadata from page_data if available/needed - # e.g., page_data.get('width'), page_data.get('height') - }, - ) - ) - else: - log.warning( - f"Skipping page due to missing 'markdown' or 'index'. Data: {page_data}" - ) - - if not documents: - # Case where pages existed but none had valid markdown/index - log.warning( - "OCR response contained pages, but none had valid content/index." - ) - return [ - Document( - page_content="No text content found in valid pages", metadata={} - ) - ] + total_time = time.time() - start_time + log.info( + f"Sync OCR workflow completed in {total_time:.2f}s, produced {len(documents)} documents" + ) return documents except Exception as e: - log.error(f"An error occurred during the loading process: {e}") - # Return an empty list or a specific error document on failure - return [Document(page_content=f"Error during processing: {e}", metadata={})] + total_time = time.time() - start_time + log.error( + f"An error occurred during the loading process after {total_time:.2f}s: {e}" + ) + # Return an error document on failure + return [ + Document( + page_content=f"Error during processing: {e}", + metadata={ + "error": "processing_failed", + "file_name": self.file_name, + }, + ) + ] finally: # 5. Delete file (attempt even if prior steps failed after upload) if file_id: @@ -223,3 +529,105 @@ class MistralLoader: log.error( f"Cleanup error: Could not delete file ID {file_id}. Reason: {del_e}" ) + + async def load_async(self) -> List[Document]: + """ + Asynchronous OCR workflow execution with optimized performance. + + Returns: + A list of Document objects, one for each page processed. + """ + file_id = None + start_time = time.time() + + try: + async with self._get_session() as session: + # 1. Upload file with streaming + file_id = await self._upload_file_async(session) + + # 2. Get signed URL + signed_url = await self._get_signed_url_async(session, file_id) + + # 3. Process OCR + ocr_response = await self._process_ocr_async(session, signed_url) + + # 4. Process results + documents = self._process_results(ocr_response) + + total_time = time.time() - start_time + log.info( + f"Async OCR workflow completed in {total_time:.2f}s, produced {len(documents)} documents" + ) + + return documents + + except Exception as e: + total_time = time.time() - start_time + log.error(f"Async OCR workflow failed after {total_time:.2f}s: {e}") + return [ + Document( + page_content=f"Error during OCR processing: {e}", + metadata={ + "error": "processing_failed", + "file_name": self.file_name, + }, + ) + ] + finally: + # 5. Cleanup - always attempt file deletion + if file_id: + try: + async with self._get_session() as session: + await self._delete_file_async(session, file_id) + except Exception as cleanup_error: + log.error(f"Cleanup failed for file ID {file_id}: {cleanup_error}") + + @staticmethod + async def load_multiple_async( + loaders: List["MistralLoader"], + ) -> List[List[Document]]: + """ + Process multiple files concurrently for maximum performance. + + Args: + loaders: List of MistralLoader instances + + Returns: + List of document lists, one for each loader + """ + if not loaders: + return [] + + log.info(f"Starting concurrent processing of {len(loaders)} files") + start_time = time.time() + + # Process all files concurrently + tasks = [loader.load_async() for loader in loaders] + results = await asyncio.gather(*tasks, return_exceptions=True) + + # Handle any exceptions in results + processed_results = [] + for i, result in enumerate(results): + if isinstance(result, Exception): + log.error(f"File {i} failed: {result}") + processed_results.append( + [ + Document( + page_content=f"Error processing file: {result}", + metadata={ + "error": "batch_processing_failed", + "file_index": i, + }, + ) + ] + ) + else: + processed_results.append(result) + + total_time = time.time() - start_time + total_docs = sum(len(docs) for docs in processed_results) + log.info( + f"Batch processing completed in {total_time:.2f}s, produced {total_docs} total documents" + ) + + return processed_results diff --git a/backend/open_webui/retrieval/models/base_reranker.py b/backend/open_webui/retrieval/models/base_reranker.py new file mode 100644 index 000000000..6be7a5649 --- /dev/null +++ b/backend/open_webui/retrieval/models/base_reranker.py @@ -0,0 +1,8 @@ +from abc import ABC, abstractmethod +from typing import Optional, List, Tuple + + +class BaseReranker(ABC): + @abstractmethod + def predict(self, sentences: List[Tuple[str, str]]) -> Optional[List[float]]: + pass diff --git a/backend/open_webui/retrieval/models/colbert.py b/backend/open_webui/retrieval/models/colbert.py index 5b7499fd1..7ec888437 100644 --- a/backend/open_webui/retrieval/models/colbert.py +++ b/backend/open_webui/retrieval/models/colbert.py @@ -7,11 +7,13 @@ from colbert.modeling.checkpoint import Checkpoint from open_webui.env import SRC_LOG_LEVELS +from open_webui.retrieval.models.base_reranker import BaseReranker + log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["RAG"]) -class ColBERT: +class ColBERT(BaseReranker): def __init__(self, name, **kwargs) -> None: log.info("ColBERT: Loading model", name) self.device = "cuda" if torch.cuda.is_available() else "cpu" diff --git a/backend/open_webui/retrieval/models/external.py b/backend/open_webui/retrieval/models/external.py index 187d66e38..5ebc3e52e 100644 --- a/backend/open_webui/retrieval/models/external.py +++ b/backend/open_webui/retrieval/models/external.py @@ -3,12 +3,14 @@ import requests from typing import Optional, List, Tuple from open_webui.env import SRC_LOG_LEVELS +from open_webui.retrieval.models.base_reranker import BaseReranker + log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["RAG"]) -class ExternalReranker: +class ExternalReranker(BaseReranker): def __init__( self, api_key: str, diff --git a/backend/open_webui/retrieval/utils.py b/backend/open_webui/retrieval/utils.py index a132d7201..97a89880c 100644 --- a/backend/open_webui/retrieval/utils.py +++ b/backend/open_webui/retrieval/utils.py @@ -116,6 +116,7 @@ def query_doc_with_hybrid_search( reranking_function, k_reranker: int, r: float, + hybrid_bm25_weight: float, ) -> dict: try: log.debug(f"query_doc_with_hybrid_search:doc {collection_name}") @@ -131,9 +132,20 @@ def query_doc_with_hybrid_search( top_k=k, ) - ensemble_retriever = EnsembleRetriever( - retrievers=[bm25_retriever, vector_search_retriever], weights=[0.5, 0.5] - ) + if hybrid_bm25_weight <= 0: + ensemble_retriever = EnsembleRetriever( + retrievers=[vector_search_retriever], weights=[1.0] + ) + elif hybrid_bm25_weight >= 1: + ensemble_retriever = EnsembleRetriever( + retrievers=[bm25_retriever], weights=[1.0] + ) + else: + ensemble_retriever = EnsembleRetriever( + retrievers=[bm25_retriever, vector_search_retriever], + weights=[hybrid_bm25_weight, 1.0 - hybrid_bm25_weight], + ) + compressor = RerankCompressor( embedding_function=embedding_function, top_n=k_reranker, @@ -313,6 +325,7 @@ def query_collection_with_hybrid_search( reranking_function, k_reranker: int, r: float, + hybrid_bm25_weight: float, ) -> dict: results = [] error = False @@ -346,6 +359,7 @@ def query_collection_with_hybrid_search( reranking_function=reranking_function, k_reranker=k_reranker, r=r, + hybrid_bm25_weight=hybrid_bm25_weight, ) return result, None except Exception as e: @@ -433,6 +447,7 @@ def get_sources_from_files( reranking_function, k_reranker, r, + hybrid_bm25_weight, hybrid_search, full_context=False, ): @@ -550,6 +565,7 @@ def get_sources_from_files( reranking_function=reranking_function, k_reranker=k_reranker, r=r, + hybrid_bm25_weight=hybrid_bm25_weight, ) except Exception as e: log.debug( diff --git a/backend/open_webui/retrieval/vector/dbs/pinecone.py b/backend/open_webui/retrieval/vector/dbs/pinecone.py index c921089b6..9f8abf460 100644 --- a/backend/open_webui/retrieval/vector/dbs/pinecone.py +++ b/backend/open_webui/retrieval/vector/dbs/pinecone.py @@ -1,13 +1,12 @@ from typing import Optional, List, Dict, Any, Union import logging import time # for measuring elapsed time -from pinecone import ServerlessSpec +from pinecone import Pinecone, ServerlessSpec import asyncio # for async upserts import functools # for partial binding in async tasks import concurrent.futures # for parallel batch upserts -from pinecone.grpc import PineconeGRPC # use gRPC client for faster upserts from open_webui.retrieval.vector.main import ( VectorDBBase, @@ -47,10 +46,8 @@ class PineconeClient(VectorDBBase): self.metric = PINECONE_METRIC self.cloud = PINECONE_CLOUD - # Initialize Pinecone gRPC client for improved performance - self.client = PineconeGRPC( - api_key=self.api_key, environment=self.environment, cloud=self.cloud - ) + # Initialize Pinecone client for improved performance + self.client = Pinecone(api_key=self.api_key) # Persistent executor for batch operations self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=5) @@ -147,8 +144,8 @@ class PineconeClient(VectorDBBase): metadatas = [] for match in matches: - metadata = match.get("metadata", {}) - ids.append(match["id"]) + metadata = getattr(match, "metadata", {}) or {} + ids.append(match.id if hasattr(match, "id") else match["id"]) documents.append(metadata.get("text", "")) metadatas.append(metadata) @@ -174,7 +171,8 @@ class PineconeClient(VectorDBBase): filter={"collection_name": collection_name_with_prefix}, include_metadata=False, ) - return len(response.matches) > 0 + matches = getattr(response, "matches", []) or [] + return len(matches) > 0 except Exception as e: log.exception( f"Error checking collection '{collection_name_with_prefix}': {e}" @@ -321,32 +319,6 @@ class PineconeClient(VectorDBBase): f"Successfully async upserted {len(points)} vectors in batches into '{collection_name_with_prefix}'" ) - def streaming_upsert(self, collection_name: str, items: List[VectorItem]) -> None: - """Perform a streaming upsert over gRPC for performance testing.""" - if not items: - log.warning("No items to upsert via streaming") - return - - collection_name_with_prefix = self._get_collection_name_with_prefix( - collection_name - ) - points = self._create_points(items, collection_name_with_prefix) - - # Open a streaming upsert channel - stream = self.index.streaming_upsert() - try: - for point in points: - # send each point over the stream - stream.send(point) - # close the stream to finalize - stream.close() - log.info( - f"Successfully streamed upsert of {len(points)} vectors into '{collection_name_with_prefix}'" - ) - except Exception as e: - log.error(f"Error during streaming upsert: {e}") - raise - def search( self, collection_name: str, vectors: List[List[Union[float, int]]], limit: int ) -> Optional[SearchResult]: @@ -374,7 +346,8 @@ class PineconeClient(VectorDBBase): filter={"collection_name": collection_name_with_prefix}, ) - if not query_response.matches: + matches = getattr(query_response, "matches", []) or [] + if not matches: # Return empty result if no matches return SearchResult( ids=[[]], @@ -384,13 +357,13 @@ class PineconeClient(VectorDBBase): ) # Convert to GetResult format - get_result = self._result_to_get_result(query_response.matches) + get_result = self._result_to_get_result(matches) # Calculate normalized distances based on metric distances = [ [ - self._normalize_distance(match.score) - for match in query_response.matches + self._normalize_distance(getattr(match, "score", 0.0)) + for match in matches ] ] @@ -432,7 +405,8 @@ class PineconeClient(VectorDBBase): include_metadata=True, ) - return self._result_to_get_result(query_response.matches) + matches = getattr(query_response, "matches", []) or [] + return self._result_to_get_result(matches) except Exception as e: log.error(f"Error querying collection '{collection_name}': {e}") @@ -456,7 +430,8 @@ class PineconeClient(VectorDBBase): filter={"collection_name": collection_name_with_prefix}, ) - return self._result_to_get_result(query_response.matches) + matches = getattr(query_response, "matches", []) or [] + return self._result_to_get_result(matches) except Exception as e: log.error(f"Error getting collection '{collection_name}': {e}") @@ -516,12 +491,12 @@ class PineconeClient(VectorDBBase): raise def close(self): - """Shut down the gRPC channel and thread pool.""" + """Shut down resources.""" try: - self.client.close() - log.info("Pinecone gRPC channel closed.") + # The new Pinecone client doesn't need explicit closing + pass except Exception as e: - log.warning(f"Failed to close Pinecone gRPC channel: {e}") + log.warning(f"Failed to clean up Pinecone resources: {e}") self._executor.shutdown(wait=True) def __enter__(self): diff --git a/backend/open_webui/retrieval/web/searchapi.py b/backend/open_webui/retrieval/web/searchapi.py index 38bc0b574..d7704638c 100644 --- a/backend/open_webui/retrieval/web/searchapi.py +++ b/backend/open_webui/retrieval/web/searchapi.py @@ -42,7 +42,9 @@ def search_searchapi( results = get_filtered_results(results, filter_list) return [ SearchResult( - link=result["link"], title=result["title"], snippet=result["snippet"] + link=result["link"], + title=result.get("title"), + snippet=result.get("snippet"), ) for result in results[:count] ] diff --git a/backend/open_webui/retrieval/web/serpapi.py b/backend/open_webui/retrieval/web/serpapi.py index 028b6bcfe..8762210bf 100644 --- a/backend/open_webui/retrieval/web/serpapi.py +++ b/backend/open_webui/retrieval/web/serpapi.py @@ -42,7 +42,9 @@ def search_serpapi( results = get_filtered_results(results, filter_list) return [ SearchResult( - link=result["link"], title=result["title"], snippet=result["snippet"] + link=result["link"], + title=result.get("title"), + snippet=result.get("snippet"), ) for result in results[:count] ] diff --git a/backend/open_webui/retrieval/web/utils.py b/backend/open_webui/retrieval/web/utils.py index b8ec538d3..5a90a86e0 100644 --- a/backend/open_webui/retrieval/web/utils.py +++ b/backend/open_webui/retrieval/web/utils.py @@ -517,7 +517,6 @@ class SafeWebBaseLoader(WebBaseLoader): async with session.get( url, **(self.requests_kwargs | kwargs), - ssl=AIOHTTP_CLIENT_SESSION_SSL, ) as response: if self.raise_for_status: response.raise_for_status() diff --git a/backend/open_webui/routers/audio.py b/backend/open_webui/routers/audio.py index a0f5af4fc..d337ece2e 100644 --- a/backend/open_webui/routers/audio.py +++ b/backend/open_webui/routers/audio.py @@ -8,6 +8,8 @@ from pathlib import Path from pydub import AudioSegment from pydub.silence import split_on_silence from concurrent.futures import ThreadPoolExecutor +from typing import Optional + import aiohttp import aiofiles @@ -18,6 +20,7 @@ from fastapi import ( Depends, FastAPI, File, + Form, HTTPException, Request, UploadFile, @@ -527,11 +530,13 @@ async def speech(request: Request, user=Depends(get_verified_user)): return FileResponse(file_path) -def transcription_handler(request, file_path): +def transcription_handler(request, file_path, metadata): filename = os.path.basename(file_path) file_dir = os.path.dirname(file_path) id = filename.split(".")[0] + metadata = metadata or {} + if request.app.state.config.STT_ENGINE == "": if request.app.state.faster_whisper_model is None: request.app.state.faster_whisper_model = set_faster_whisper_model( @@ -543,7 +548,7 @@ def transcription_handler(request, file_path): file_path, beam_size=5, vad_filter=request.app.state.config.WHISPER_VAD_FILTER, - language=WHISPER_LANGUAGE, + language=metadata.get("language") or WHISPER_LANGUAGE, ) log.info( "Detected language '%s' with probability %f" @@ -569,7 +574,14 @@ def transcription_handler(request, file_path): "Authorization": f"Bearer {request.app.state.config.STT_OPENAI_API_KEY}" }, files={"file": (filename, open(file_path, "rb"))}, - data={"model": request.app.state.config.STT_MODEL}, + data={ + "model": request.app.state.config.STT_MODEL, + **( + {"language": metadata.get("language")} + if metadata.get("language") + else {} + ), + }, ) r.raise_for_status() @@ -777,8 +789,8 @@ def transcription_handler(request, file_path): ) -def transcribe(request: Request, file_path): - log.info(f"transcribe: {file_path}") +def transcribe(request: Request, file_path: str, metadata: Optional[dict] = None): + log.info(f"transcribe: {file_path} {metadata}") if is_audio_conversion_required(file_path): file_path = convert_audio_to_mp3(file_path) @@ -804,7 +816,7 @@ def transcribe(request: Request, file_path): with ThreadPoolExecutor() as executor: # Submit tasks for each chunk_path futures = [ - executor.submit(transcription_handler, request, chunk_path) + executor.submit(transcription_handler, request, chunk_path, metadata) for chunk_path in chunk_paths ] # Gather results as they complete @@ -812,10 +824,9 @@ def transcribe(request: Request, file_path): try: results.append(future.result()) except Exception as transcribe_exc: - log.exception(f"Error transcribing chunk: {transcribe_exc}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Error during transcription.", + detail=f"Error transcribing chunk: {transcribe_exc}", ) finally: # Clean up only the temporary chunks, never the original file @@ -897,6 +908,7 @@ def split_audio(file_path, max_bytes, format="mp3", bitrate="32k"): def transcription( request: Request, file: UploadFile = File(...), + language: Optional[str] = Form(None), user=Depends(get_verified_user), ): log.info(f"file.content_type: {file.content_type}") @@ -926,7 +938,12 @@ def transcription( f.write(contents) try: - result = transcribe(request, file_path) + metadata = None + + if language: + metadata = {"language": language} + + result = transcribe(request, file_path, metadata) return { **result, diff --git a/backend/open_webui/routers/auths.py b/backend/open_webui/routers/auths.py index 793bdfd30..06e506228 100644 --- a/backend/open_webui/routers/auths.py +++ b/backend/open_webui/routers/auths.py @@ -19,12 +19,14 @@ from open_webui.models.auths import ( UserResponse, ) from open_webui.models.users import Users +from open_webui.models.groups import Groups from open_webui.constants import ERROR_MESSAGES, WEBHOOK_MESSAGES from open_webui.env import ( WEBUI_AUTH, WEBUI_AUTH_TRUSTED_EMAIL_HEADER, WEBUI_AUTH_TRUSTED_NAME_HEADER, + WEBUI_AUTH_TRUSTED_GROUPS_HEADER, WEBUI_AUTH_COOKIE_SAME_SITE, WEBUI_AUTH_COOKIE_SECURE, WEBUI_AUTH_SIGNOUT_REDIRECT_URL, @@ -299,7 +301,7 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm): 500, detail="Internal error occurred during LDAP user creation." ) - user = Auths.authenticate_user_by_trusted_header(email) + user = Auths.authenticate_user_by_email(email) if user: expires_delta = parse_duration(request.app.state.config.JWT_EXPIRES_IN) @@ -363,21 +365,29 @@ async def signin(request: Request, response: Response, form_data: SigninForm): if WEBUI_AUTH_TRUSTED_EMAIL_HEADER not in request.headers: raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_TRUSTED_HEADER) - trusted_email = request.headers[WEBUI_AUTH_TRUSTED_EMAIL_HEADER].lower() - trusted_name = trusted_email + email = request.headers[WEBUI_AUTH_TRUSTED_EMAIL_HEADER].lower() + name = email + if WEBUI_AUTH_TRUSTED_NAME_HEADER: - trusted_name = request.headers.get( - WEBUI_AUTH_TRUSTED_NAME_HEADER, trusted_email - ) - if not Users.get_user_by_email(trusted_email.lower()): + name = request.headers.get(WEBUI_AUTH_TRUSTED_NAME_HEADER, email) + + if not Users.get_user_by_email(email.lower()): await signup( request, response, - SignupForm( - email=trusted_email, password=str(uuid.uuid4()), name=trusted_name - ), + SignupForm(email=email, password=str(uuid.uuid4()), name=name), ) - user = Auths.authenticate_user_by_trusted_header(trusted_email) + + user = Auths.authenticate_user_by_email(email) + if WEBUI_AUTH_TRUSTED_GROUPS_HEADER and user and user.role != "admin": + group_names = request.headers.get( + WEBUI_AUTH_TRUSTED_GROUPS_HEADER, "" + ).split(",") + group_names = [name.strip() for name in group_names if name.strip()] + + if group_names: + Groups.sync_user_groups_by_group_names(user.id, group_names) + elif WEBUI_AUTH == False: admin_email = "admin@localhost" admin_password = "admin" diff --git a/backend/open_webui/routers/chats.py b/backend/open_webui/routers/chats.py index 6f00dd4d7..29b12ed67 100644 --- a/backend/open_webui/routers/chats.py +++ b/backend/open_webui/routers/chats.py @@ -76,17 +76,34 @@ async def delete_all_user_chats(request: Request, user=Depends(get_verified_user @router.get("/list/user/{user_id}", response_model=list[ChatTitleIdResponse]) async def get_user_chat_list_by_user_id( user_id: str, + page: Optional[int] = None, + query: Optional[str] = None, + order_by: Optional[str] = None, + direction: Optional[str] = None, user=Depends(get_admin_user), - skip: int = 0, - limit: int = 50, ): if not ENABLE_ADMIN_CHAT_ACCESS: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED, ) + + if page is None: + page = 1 + + limit = 60 + skip = (page - 1) * limit + + filter = {} + if query: + filter["query"] = query + if order_by: + filter["order_by"] = order_by + if direction: + filter["direction"] = direction + return Chats.get_chat_list_by_user_id( - user_id, include_archived=True, skip=skip, limit=limit + user_id, include_archived=True, filter=filter, skip=skip, limit=limit ) @@ -194,10 +211,10 @@ async def get_chats_by_folder_id(folder_id: str, user=Depends(get_verified_user) ############################ -@router.get("/pinned", response_model=list[ChatResponse]) +@router.get("/pinned", response_model=list[ChatTitleIdResponse]) async def get_user_pinned_chats(user=Depends(get_verified_user)): return [ - ChatResponse(**chat.model_dump()) + ChatTitleIdResponse(**chat.model_dump()) for chat in Chats.get_pinned_chats_by_user_id(user.id) ] @@ -267,9 +284,37 @@ async def get_all_user_chats_in_db(user=Depends(get_admin_user)): @router.get("/archived", response_model=list[ChatTitleIdResponse]) async def get_archived_session_user_chat_list( - user=Depends(get_verified_user), skip: int = 0, limit: int = 50 + page: Optional[int] = None, + query: Optional[str] = None, + order_by: Optional[str] = None, + direction: Optional[str] = None, + user=Depends(get_verified_user), ): - return Chats.get_archived_chat_list_by_user_id(user.id, skip, limit) + if page is None: + page = 1 + + limit = 60 + skip = (page - 1) * limit + + filter = {} + if query: + filter["query"] = query + if order_by: + filter["order_by"] = order_by + if direction: + filter["direction"] = direction + + chat_list = [ + ChatTitleIdResponse(**chat.model_dump()) + for chat in Chats.get_archived_chat_list_by_user_id( + user.id, + filter=filter, + skip=skip, + limit=limit, + ) + ] + + return chat_list ############################ diff --git a/backend/open_webui/routers/files.py b/backend/open_webui/routers/files.py index ad556d327..ba6758671 100644 --- a/backend/open_webui/routers/files.py +++ b/backend/open_webui/routers/files.py @@ -1,6 +1,7 @@ import logging import os import uuid +import json from fnmatch import fnmatch from pathlib import Path from typing import Optional @@ -10,6 +11,7 @@ from fastapi import ( APIRouter, Depends, File, + Form, HTTPException, Request, UploadFile, @@ -84,19 +86,32 @@ def has_access_to_file( def upload_file( request: Request, file: UploadFile = File(...), - user=Depends(get_verified_user), - file_metadata: dict = None, + metadata: Optional[dict | str] = Form(None), process: bool = Query(True), + internal: bool = False, + user=Depends(get_verified_user), ): log.info(f"file.content_type: {file.content_type}") - file_metadata = file_metadata if file_metadata else {} + if isinstance(metadata, str): + try: + metadata = json.loads(metadata) + except json.JSONDecodeError: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT("Invalid metadata format"), + ) + file_metadata = metadata if metadata else {} + try: unsanitized_filename = file.filename filename = os.path.basename(unsanitized_filename) file_extension = os.path.splitext(filename)[1] - if request.app.state.config.ALLOWED_FILE_EXTENSIONS: + # Remove the leading dot from the file extension + file_extension = file_extension[1:] if file_extension else "" + + if (not internal) and request.app.state.config.ALLOWED_FILE_EXTENSIONS: request.app.state.config.ALLOWED_FILE_EXTENSIONS = [ ext for ext in request.app.state.config.ALLOWED_FILE_EXTENSIONS if ext ] @@ -144,21 +159,16 @@ def upload_file( "video/webm" }: file_path = Storage.get_file(file_path) - result = transcribe(request, file_path) + result = transcribe(request, file_path, file_metadata) process_file( request, ProcessFileForm(file_id=id, content=result.get("text", "")), user=user, ) - elif file.content_type not in [ - "image/png", - "image/jpeg", - "image/gif", - "video/mp4", - "video/ogg", - "video/quicktime", - ]: + elif (not file.content_type.startswith(("image/", "video/"))) or ( + request.app.state.config.CONTENT_EXTRACTION_ENGINE == "external" + ): process_file(request, ProcessFileForm(file_id=id), user=user) else: log.info( @@ -189,7 +199,7 @@ def upload_file( log.exception(e) raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.DEFAULT(e), + detail=ERROR_MESSAGES.DEFAULT("Error uploading file"), ) diff --git a/backend/open_webui/routers/functions.py b/backend/open_webui/routers/functions.py index 206610138..2748fa95c 100644 --- a/backend/open_webui/routers/functions.py +++ b/backend/open_webui/routers/functions.py @@ -1,5 +1,8 @@ import os +import re + import logging +import aiohttp from pathlib import Path from typing import Optional @@ -15,6 +18,8 @@ from open_webui.constants import ERROR_MESSAGES from fastapi import APIRouter, Depends, HTTPException, Request, status from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.env import SRC_LOG_LEVELS +from pydantic import BaseModel, HttpUrl + log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MAIN"]) @@ -42,6 +47,97 @@ async def get_functions(user=Depends(get_admin_user)): return Functions.get_functions() +############################ +# LoadFunctionFromLink +############################ + + +class LoadUrlForm(BaseModel): + url: HttpUrl + + +def github_url_to_raw_url(url: str) -> str: + # Handle 'tree' (folder) URLs (add main.py at the end) + m1 = re.match(r"https://github\.com/([^/]+)/([^/]+)/tree/([^/]+)/(.*)", url) + if m1: + org, repo, branch, path = m1.groups() + return f"https://raw.githubusercontent.com/{org}/{repo}/refs/heads/{branch}/{path.rstrip('/')}/main.py" + + # Handle 'blob' (file) URLs + m2 = re.match(r"https://github\.com/([^/]+)/([^/]+)/blob/([^/]+)/(.*)", url) + if m2: + org, repo, branch, path = m2.groups() + return ( + f"https://raw.githubusercontent.com/{org}/{repo}/refs/heads/{branch}/{path}" + ) + + # No match; return as-is + return url + + +@router.post("/load/url", response_model=Optional[dict]) +async def load_function_from_url( + request: Request, form_data: LoadUrlForm, user=Depends(get_admin_user) +): + # NOTE: This is NOT a SSRF vulnerability: + # This endpoint is admin-only (see get_admin_user), meant for *trusted* internal use, + # and does NOT accept untrusted user input. Access is enforced by authentication. + + url = str(form_data.url) + if not url: + raise HTTPException(status_code=400, detail="Please enter a valid URL") + + url = github_url_to_raw_url(url) + url_parts = url.rstrip("/").split("/") + + file_name = url_parts[-1] + function_name = ( + file_name[:-3] + if ( + file_name.endswith(".py") + and (not file_name.startswith(("main.py", "index.py", "__init__.py"))) + ) + else url_parts[-2] if len(url_parts) > 1 else "function" + ) + + try: + async with aiohttp.ClientSession() as session: + async with session.get( + url, headers={"Content-Type": "application/json"} + ) as resp: + if resp.status != 200: + raise HTTPException( + status_code=resp.status, detail="Failed to fetch the function" + ) + data = await resp.text() + if not data: + raise HTTPException( + status_code=400, detail="No data received from the URL" + ) + return { + "name": function_name, + "content": data, + } + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error importing function: {e}") + + +############################ +# SyncFunctions +############################ + + +class SyncFunctionsForm(FunctionForm): + functions: list[FunctionModel] = [] + + +@router.post("/sync", response_model=Optional[FunctionModel]) +async def sync_functions( + request: Request, form_data: SyncFunctionsForm, user=Depends(get_admin_user) +): + return Functions.sync_functions(user.id, form_data.functions) + + ############################ # CreateNewFunction ############################ @@ -262,11 +358,8 @@ async def get_function_valves_spec_by_id( ): function = Functions.get_function_by_id(id) if function: - if id in request.app.state.FUNCTIONS: - function_module = request.app.state.FUNCTIONS[id] - else: - function_module, function_type, frontmatter = load_function_module_by_id(id) - request.app.state.FUNCTIONS[id] = function_module + function_module, function_type, frontmatter = load_function_module_by_id(id) + request.app.state.FUNCTIONS[id] = function_module if hasattr(function_module, "Valves"): Valves = function_module.Valves @@ -290,11 +383,8 @@ async def update_function_valves_by_id( ): function = Functions.get_function_by_id(id) if function: - if id in request.app.state.FUNCTIONS: - function_module = request.app.state.FUNCTIONS[id] - else: - function_module, function_type, frontmatter = load_function_module_by_id(id) - request.app.state.FUNCTIONS[id] = function_module + function_module, function_type, frontmatter = load_function_module_by_id(id) + request.app.state.FUNCTIONS[id] = function_module if hasattr(function_module, "Valves"): Valves = function_module.Valves @@ -353,11 +443,8 @@ async def get_function_user_valves_spec_by_id( ): function = Functions.get_function_by_id(id) if function: - if id in request.app.state.FUNCTIONS: - function_module = request.app.state.FUNCTIONS[id] - else: - function_module, function_type, frontmatter = load_function_module_by_id(id) - request.app.state.FUNCTIONS[id] = function_module + function_module, function_type, frontmatter = load_function_module_by_id(id) + request.app.state.FUNCTIONS[id] = function_module if hasattr(function_module, "UserValves"): UserValves = function_module.UserValves @@ -377,11 +464,8 @@ async def update_function_user_valves_by_id( function = Functions.get_function_by_id(id) if function: - if id in request.app.state.FUNCTIONS: - function_module = request.app.state.FUNCTIONS[id] - else: - function_module, function_type, frontmatter = load_function_module_by_id(id) - request.app.state.FUNCTIONS[id] = function_module + function_module, function_type, frontmatter = load_function_module_by_id(id) + request.app.state.FUNCTIONS[id] = function_module if hasattr(function_module, "UserValves"): UserValves = function_module.UserValves diff --git a/backend/open_webui/routers/images.py b/backend/open_webui/routers/images.py index b8bb110f5..c6d8e4186 100644 --- a/backend/open_webui/routers/images.py +++ b/backend/open_webui/routers/images.py @@ -333,10 +333,11 @@ def get_models(request: Request, user=Depends(get_verified_user)): return [ {"id": "dall-e-2", "name": "DALL·E 2"}, {"id": "dall-e-3", "name": "DALL·E 3"}, + {"id": "gpt-image-1", "name": "GPT-IMAGE 1"}, ] elif request.app.state.config.IMAGE_GENERATION_ENGINE == "gemini": return [ - {"id": "imagen-3-0-generate-002", "name": "imagen-3.0 generate-002"}, + {"id": "imagen-3.0-generate-002", "name": "imagen-3.0 generate-002"}, ] elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui": # TODO - get models from comfyui @@ -450,7 +451,7 @@ def load_url_image_data(url, headers=None): return None -def upload_image(request, image_metadata, image_data, content_type, user): +def upload_image(request, image_data, content_type, metadata, user): image_format = mimetypes.guess_extension(content_type) file = UploadFile( file=io.BytesIO(image_data), @@ -459,7 +460,7 @@ def upload_image(request, image_metadata, image_data, content_type, user): "content-type": content_type, }, ) - file_item = upload_file(request, file, user, file_metadata=image_metadata) + file_item = upload_file(request, file, metadata=metadata, internal=True, user=user) url = request.app.url_path_for("get_file_content_by_id", id=file_item.id) return url @@ -526,7 +527,7 @@ async def image_generations( else: image_data, content_type = load_b64_image_data(image["b64_json"]) - url = upload_image(request, data, image_data, content_type, user) + url = upload_image(request, image_data, content_type, data, user) images.append({"url": url}) return images @@ -560,7 +561,7 @@ async def image_generations( image_data, content_type = load_b64_image_data( image["bytesBase64Encoded"] ) - url = upload_image(request, data, image_data, content_type, user) + url = upload_image(request, image_data, content_type, data, user) images.append({"url": url}) return images @@ -611,9 +612,9 @@ async def image_generations( image_data, content_type = load_url_image_data(image["url"], headers) url = upload_image( request, - form_data.model_dump(exclude_none=True), image_data, content_type, + form_data.model_dump(exclude_none=True), user, ) images.append({"url": url}) @@ -664,9 +665,9 @@ async def image_generations( image_data, content_type = load_b64_image_data(image) url = upload_image( request, - {**data, "info": res["info"]}, image_data, content_type, + {**data, "info": res["info"]}, user, ) images.append({"url": url}) diff --git a/backend/open_webui/routers/ollama.py b/backend/open_webui/routers/ollama.py index 7c313ea97..1410831d7 100644 --- a/backend/open_webui/routers/ollama.py +++ b/backend/open_webui/routers/ollama.py @@ -9,6 +9,8 @@ import os import random import re import time +from datetime import datetime + from typing import Optional, Union from urllib.parse import urlparse import aiohttp @@ -300,6 +302,22 @@ async def update_config( } +def merge_ollama_models_lists(model_lists): + merged_models = {} + + for idx, model_list in enumerate(model_lists): + if model_list is not None: + for model in model_list: + id = model["model"] + if id not in merged_models: + model["urls"] = [idx] + merged_models[id] = model + else: + merged_models[id]["urls"].append(idx) + + return list(merged_models.values()) + + @cached(ttl=1) async def get_all_models(request: Request, user: UserModel = None): log.info("get_all_models()") @@ -364,23 +382,8 @@ async def get_all_models(request: Request, user: UserModel = None): if connection_type: model["connection_type"] = connection_type - def merge_models_lists(model_lists): - merged_models = {} - - for idx, model_list in enumerate(model_lists): - if model_list is not None: - for model in model_list: - id = model["model"] - if id not in merged_models: - model["urls"] = [idx] - merged_models[id] = model - else: - merged_models[id]["urls"].append(idx) - - return list(merged_models.values()) - models = { - "models": merge_models_lists( + "models": merge_ollama_models_lists( map( lambda response: response.get("models", []) if response else None, responses, @@ -388,6 +391,22 @@ async def get_all_models(request: Request, user: UserModel = None): ) } + try: + loaded_models = await get_ollama_loaded_models(request, user=user) + expires_map = { + m["name"]: m["expires_at"] + for m in loaded_models["models"] + if "expires_at" in m + } + + for m in models["models"]: + if m["name"] in expires_map: + # Parse ISO8601 datetime with offset, get unix timestamp as int + dt = datetime.fromisoformat(expires_map[m["name"]]) + m["expires_at"] = int(dt.timestamp()) + except Exception as e: + log.debug(f"Failed to get loaded models: {e}") + else: models = {"models": []} @@ -468,6 +487,68 @@ async def get_ollama_tags( return models +@router.get("/api/ps") +async def get_ollama_loaded_models(request: Request, user=Depends(get_admin_user)): + """ + List models that are currently loaded into Ollama memory, and which node they are loaded on. + """ + if request.app.state.config.ENABLE_OLLAMA_API: + request_tasks = [] + for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS): + if (str(idx) not in request.app.state.config.OLLAMA_API_CONFIGS) and ( + url not in request.app.state.config.OLLAMA_API_CONFIGS # Legacy support + ): + request_tasks.append(send_get_request(f"{url}/api/ps", user=user)) + else: + api_config = request.app.state.config.OLLAMA_API_CONFIGS.get( + str(idx), + request.app.state.config.OLLAMA_API_CONFIGS.get( + url, {} + ), # Legacy support + ) + + enable = api_config.get("enable", True) + key = api_config.get("key", None) + + if enable: + request_tasks.append( + send_get_request(f"{url}/api/ps", key, user=user) + ) + else: + request_tasks.append(asyncio.ensure_future(asyncio.sleep(0, None))) + + responses = await asyncio.gather(*request_tasks) + + for idx, response in enumerate(responses): + if response: + url = request.app.state.config.OLLAMA_BASE_URLS[idx] + api_config = request.app.state.config.OLLAMA_API_CONFIGS.get( + str(idx), + request.app.state.config.OLLAMA_API_CONFIGS.get( + url, {} + ), # Legacy support + ) + + prefix_id = api_config.get("prefix_id", None) + + for model in response.get("models", []): + if prefix_id: + model["model"] = f"{prefix_id}.{model['model']}" + + models = { + "models": merge_ollama_models_lists( + map( + lambda response: response.get("models", []) if response else None, + responses, + ) + ) + } + else: + models = {"models": []} + + return models + + @router.get("/api/version") @router.get("/api/version/{url_idx}") async def get_ollama_versions(request: Request, url_idx: Optional[int] = None): @@ -541,36 +622,74 @@ async def get_ollama_versions(request: Request, url_idx: Optional[int] = None): return {"version": False} -@router.get("/api/ps") -async def get_ollama_loaded_models(request: Request, user=Depends(get_verified_user)): - """ - List models that are currently loaded into Ollama memory, and which node they are loaded on. - """ - if request.app.state.config.ENABLE_OLLAMA_API: - request_tasks = [ - send_get_request( - f"{url}/api/ps", - request.app.state.config.OLLAMA_API_CONFIGS.get( - str(idx), - request.app.state.config.OLLAMA_API_CONFIGS.get( - url, {} - ), # Legacy support - ).get("key", None), - user=user, - ) - for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS) - ] - responses = await asyncio.gather(*request_tasks) - - return dict(zip(request.app.state.config.OLLAMA_BASE_URLS, responses)) - else: - return {} - - class ModelNameForm(BaseModel): name: str +@router.post("/api/unload") +async def unload_model( + request: Request, + form_data: ModelNameForm, + user=Depends(get_admin_user), +): + model_name = form_data.name + if not model_name: + raise HTTPException( + status_code=400, detail="Missing 'name' of model to unload." + ) + + # Refresh/load models if needed, get mapping from name to URLs + await get_all_models(request, user=user) + models = request.app.state.OLLAMA_MODELS + + # Canonicalize model name (if not supplied with version) + if ":" not in model_name: + model_name = f"{model_name}:latest" + + if model_name not in models: + raise HTTPException( + status_code=400, detail=ERROR_MESSAGES.MODEL_NOT_FOUND(model_name) + ) + url_indices = models[model_name]["urls"] + + # Send unload to ALL url_indices + results = [] + errors = [] + for idx in url_indices: + url = request.app.state.config.OLLAMA_BASE_URLS[idx] + api_config = request.app.state.config.OLLAMA_API_CONFIGS.get( + str(idx), request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}) + ) + key = get_api_key(idx, url, request.app.state.config.OLLAMA_API_CONFIGS) + + prefix_id = api_config.get("prefix_id", None) + if prefix_id and model_name.startswith(f"{prefix_id}."): + model_name = model_name[len(f"{prefix_id}.") :] + + payload = {"model": model_name, "keep_alive": 0, "prompt": ""} + + try: + res = await send_post_request( + url=f"{url}/api/generate", + payload=json.dumps(payload), + stream=False, + key=key, + user=user, + ) + results.append({"url_idx": idx, "success": True, "response": res}) + except Exception as e: + log.exception(f"Failed to unload model on node {idx}: {e}") + errors.append({"url_idx": idx, "success": False, "error": str(e)}) + + if len(errors) > 0: + raise HTTPException( + status_code=500, + detail=f"Failed to unload model on {len(errors)} nodes: {errors}", + ) + + return {"status": True} + + @router.post("/api/pull") @router.post("/api/pull/{url_idx}") async def pull_model( diff --git a/backend/open_webui/routers/retrieval.py b/backend/open_webui/routers/retrieval.py index 5cb47373f..98f79c7fe 100644 --- a/backend/open_webui/routers/retrieval.py +++ b/backend/open_webui/routers/retrieval.py @@ -349,6 +349,7 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)): "ENABLE_RAG_HYBRID_SEARCH": request.app.state.config.ENABLE_RAG_HYBRID_SEARCH, "TOP_K_RERANKER": request.app.state.config.TOP_K_RERANKER, "RELEVANCE_THRESHOLD": request.app.state.config.RELEVANCE_THRESHOLD, + "HYBRID_BM25_WEIGHT": request.app.state.config.HYBRID_BM25_WEIGHT, # Content extraction settings "CONTENT_EXTRACTION_ENGINE": request.app.state.config.CONTENT_EXTRACTION_ENGINE, "PDF_EXTRACT_IMAGES": request.app.state.config.PDF_EXTRACT_IMAGES, @@ -387,6 +388,7 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)): "WEB_SEARCH_CONCURRENT_REQUESTS": request.app.state.config.WEB_SEARCH_CONCURRENT_REQUESTS, "WEB_SEARCH_DOMAIN_FILTER_LIST": request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST, "BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL": request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL, + "BYPASS_WEB_SEARCH_WEB_LOADER": request.app.state.config.BYPASS_WEB_SEARCH_WEB_LOADER, "SEARXNG_QUERY_URL": request.app.state.config.SEARXNG_QUERY_URL, "YACY_QUERY_URL": request.app.state.config.YACY_QUERY_URL, "YACY_USERNAME": request.app.state.config.YACY_USERNAME, @@ -439,6 +441,7 @@ class WebConfig(BaseModel): WEB_SEARCH_CONCURRENT_REQUESTS: Optional[int] = None WEB_SEARCH_DOMAIN_FILTER_LIST: Optional[List[str]] = [] BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL: Optional[bool] = None + BYPASS_WEB_SEARCH_WEB_LOADER: Optional[bool] = None SEARXNG_QUERY_URL: Optional[str] = None YACY_QUERY_URL: Optional[str] = None YACY_USERNAME: Optional[str] = None @@ -492,6 +495,7 @@ class ConfigForm(BaseModel): ENABLE_RAG_HYBRID_SEARCH: Optional[bool] = None TOP_K_RERANKER: Optional[int] = None RELEVANCE_THRESHOLD: Optional[float] = None + HYBRID_BM25_WEIGHT: Optional[float] = None # Content extraction settings CONTENT_EXTRACTION_ENGINE: Optional[str] = None @@ -578,6 +582,11 @@ async def update_rag_config( if form_data.RELEVANCE_THRESHOLD is not None else request.app.state.config.RELEVANCE_THRESHOLD ) + request.app.state.config.HYBRID_BM25_WEIGHT = ( + form_data.HYBRID_BM25_WEIGHT + if form_data.HYBRID_BM25_WEIGHT is not None + else request.app.state.config.HYBRID_BM25_WEIGHT + ) # Content extraction settings request.app.state.config.CONTENT_EXTRACTION_ENGINE = ( @@ -751,6 +760,9 @@ async def update_rag_config( request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL = ( form_data.web.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL ) + request.app.state.config.BYPASS_WEB_SEARCH_WEB_LOADER = ( + form_data.web.BYPASS_WEB_SEARCH_WEB_LOADER + ) request.app.state.config.SEARXNG_QUERY_URL = form_data.web.SEARXNG_QUERY_URL request.app.state.config.YACY_QUERY_URL = form_data.web.YACY_QUERY_URL request.app.state.config.YACY_USERNAME = form_data.web.YACY_USERNAME @@ -837,6 +849,7 @@ async def update_rag_config( "ENABLE_RAG_HYBRID_SEARCH": request.app.state.config.ENABLE_RAG_HYBRID_SEARCH, "TOP_K_RERANKER": request.app.state.config.TOP_K_RERANKER, "RELEVANCE_THRESHOLD": request.app.state.config.RELEVANCE_THRESHOLD, + "HYBRID_BM25_WEIGHT": request.app.state.config.HYBRID_BM25_WEIGHT, # Content extraction settings "CONTENT_EXTRACTION_ENGINE": request.app.state.config.CONTENT_EXTRACTION_ENGINE, "PDF_EXTRACT_IMAGES": request.app.state.config.PDF_EXTRACT_IMAGES, @@ -875,6 +888,7 @@ async def update_rag_config( "WEB_SEARCH_CONCURRENT_REQUESTS": request.app.state.config.WEB_SEARCH_CONCURRENT_REQUESTS, "WEB_SEARCH_DOMAIN_FILTER_LIST": request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST, "BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL": request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL, + "BYPASS_WEB_SEARCH_WEB_LOADER": request.app.state.config.BYPASS_WEB_SEARCH_WEB_LOADER, "SEARXNG_QUERY_URL": request.app.state.config.SEARXNG_QUERY_URL, "YACY_QUERY_URL": request.app.state.config.YACY_QUERY_URL, "YACY_USERNAME": request.app.state.config.YACY_USERNAME, @@ -1678,13 +1692,29 @@ async def process_web_search( ) try: - loader = get_web_loader( - urls, - verify_ssl=request.app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION, - requests_per_second=request.app.state.config.WEB_SEARCH_CONCURRENT_REQUESTS, - trust_env=request.app.state.config.WEB_SEARCH_TRUST_ENV, - ) - docs = await loader.aload() + if request.app.state.config.BYPASS_WEB_SEARCH_WEB_LOADER: + docs = [ + Document( + page_content=result.snippet, + metadata={ + "source": result.link, + "title": result.title, + "snippet": result.snippet, + "link": result.link, + }, + ) + for result in search_results + if hasattr(result, "snippet") + ] + else: + loader = get_web_loader( + urls, + verify_ssl=request.app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION, + requests_per_second=request.app.state.config.WEB_SEARCH_CONCURRENT_REQUESTS, + trust_env=request.app.state.config.WEB_SEARCH_TRUST_ENV, + ) + docs = await loader.aload() + urls = [ doc.metadata.get("source") for doc in docs if doc.metadata.get("source") ] # only keep the urls returned by the loader @@ -1774,6 +1804,11 @@ def query_doc_handler( if form_data.r else request.app.state.config.RELEVANCE_THRESHOLD ), + hybrid_bm25_weight=( + form_data.hybrid_bm25_weight + if form_data.hybrid_bm25_weight + else request.app.state.config.HYBRID_BM25_WEIGHT + ), user=user, ) else: @@ -1825,6 +1860,11 @@ def query_collection_handler( if form_data.r else request.app.state.config.RELEVANCE_THRESHOLD ), + hybrid_bm25_weight=( + form_data.hybrid_bm25_weight + if form_data.hybrid_bm25_weight + else request.app.state.config.HYBRID_BM25_WEIGHT + ), ) else: return query_collection( diff --git a/backend/open_webui/routers/tools.py b/backend/open_webui/routers/tools.py index 318f61398..bd1ce8f62 100644 --- a/backend/open_webui/routers/tools.py +++ b/backend/open_webui/routers/tools.py @@ -51,11 +51,11 @@ async def get_tools(request: Request, user=Depends(get_verified_user)): **{ "id": f"server:{server['idx']}", "user_id": f"server:{server['idx']}", - "name": server["openapi"] + "name": server.get("openapi", {}) .get("info", {}) .get("title", "Tool Server"), "meta": { - "description": server["openapi"] + "description": server.get("openapi", {}) .get("info", {}) .get("description", ""), }, diff --git a/backend/open_webui/storage/provider.py b/backend/open_webui/storage/provider.py index 5c85f88bc..41a92fafe 100644 --- a/backend/open_webui/storage/provider.py +++ b/backend/open_webui/storage/provider.py @@ -2,6 +2,7 @@ import os import shutil import json import logging +import re from abc import ABC, abstractmethod from typing import BinaryIO, Tuple, Dict @@ -136,6 +137,11 @@ class S3StorageProvider(StorageProvider): self.bucket_name = S3_BUCKET_NAME self.key_prefix = S3_KEY_PREFIX if S3_KEY_PREFIX else "" + @staticmethod + def sanitize_tag_value(s: str) -> str: + """Only include S3 allowed characters.""" + return re.sub(r"[^a-zA-Z0-9 äöüÄÖÜß\+\-=\._:/@]", "", s) + def upload_file( self, file: BinaryIO, filename: str, tags: Dict[str, str] ) -> Tuple[bytes, str]: @@ -145,7 +151,15 @@ class S3StorageProvider(StorageProvider): try: self.s3_client.upload_file(file_path, self.bucket_name, s3_key) if S3_ENABLE_TAGGING and tags: - tagging = {"TagSet": [{"Key": k, "Value": v} for k, v in tags.items()]} + sanitized_tags = { + self.sanitize_tag_value(k): self.sanitize_tag_value(v) + for k, v in tags.items() + } + tagging = { + "TagSet": [ + {"Key": k, "Value": v} for k, v in sanitized_tags.items() + ] + } self.s3_client.put_object_tagging( Bucket=self.bucket_name, Key=s3_key, diff --git a/backend/open_webui/utils/chat.py b/backend/open_webui/utils/chat.py index ce86811d4..d846e35b6 100644 --- a/backend/open_webui/utils/chat.py +++ b/backend/open_webui/utils/chat.py @@ -392,11 +392,8 @@ async def chat_action(request: Request, action_id: str, form_data: dict, user: A } ) - if action_id in request.app.state.FUNCTIONS: - function_module = request.app.state.FUNCTIONS[action_id] - else: - function_module, _, _ = load_function_module_by_id(action_id) - request.app.state.FUNCTIONS[action_id] = function_module + function_module, _, _ = load_function_module_by_id(action_id) + request.app.state.FUNCTIONS[action_id] = function_module if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): valves = Functions.get_function_valves_by_id(action_id) diff --git a/backend/open_webui/utils/filter.py b/backend/open_webui/utils/filter.py index 02e504765..8a4a7ba49 100644 --- a/backend/open_webui/utils/filter.py +++ b/backend/open_webui/utils/filter.py @@ -13,11 +13,9 @@ def get_function_module(request, function_id): """ Get the function module by its ID. """ - if function_id in request.app.state.FUNCTIONS: - function_module = request.app.state.FUNCTIONS[function_id] - else: - function_module, _, _ = load_function_module_by_id(function_id) - request.app.state.FUNCTIONS[function_id] = function_module + + function_module, _, _ = load_function_module_by_id(function_id) + request.app.state.FUNCTIONS[function_id] = function_module return function_module @@ -39,14 +37,17 @@ def get_sorted_filter_ids(request, model: dict, enabled_filter_ids: list = None) for function in Functions.get_functions_by_type("filter", active_only=True) ] - for filter_id in active_filter_ids: + def get_active_status(filter_id): function_module = get_function_module(request, filter_id) - if getattr(function_module, "toggle", None) and ( - filter_id not in enabled_filter_ids - ): - active_filter_ids.remove(filter_id) - continue + if getattr(function_module, "toggle", None): + return filter_id in (enabled_filter_ids or []) + + return True + + active_filter_ids = [ + filter_id for filter_id in active_filter_ids if get_active_status(filter_id) + ] filter_ids = [fid for fid in filter_ids if fid in active_filter_ids] filter_ids.sort(key=get_priority) diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index c9095f931..ce6ae2aca 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -41,6 +41,7 @@ from open_webui.routers.pipelines import ( process_pipeline_inlet_filter, process_pipeline_outlet_filter, ) +from open_webui.routers.memories import query_memory, QueryMemoryForm from open_webui.utils.webhook import post_webhook @@ -251,7 +252,12 @@ async def chat_completion_tools_handler( "name": (f"TOOL:{tool_name}"), }, "document": [tool_result], - "metadata": [{"source": (f"TOOL:{tool_name}")}], + "metadata": [ + { + "source": (f"TOOL:{tool_name}"), + "parameters": tool_function_params, + } + ], } ) else: @@ -290,6 +296,38 @@ async def chat_completion_tools_handler( return body, {"sources": sources} +async def chat_memory_handler( + request: Request, form_data: dict, extra_params: dict, user +): + results = await query_memory( + request, + QueryMemoryForm( + **{"content": get_last_user_message(form_data["messages"]), "k": 3} + ), + user, + ) + + user_context = "" + if results and hasattr(results, "documents"): + if results.documents and len(results.documents) > 0: + for doc_idx, doc in enumerate(results.documents[0]): + created_at_date = "Unknown Date" + + if results.metadatas[0][doc_idx].get("created_at"): + created_at_timestamp = results.metadatas[0][doc_idx]["created_at"] + created_at_date = time.strftime( + "%Y-%m-%d", time.localtime(created_at_timestamp) + ) + + user_context += f"{doc_idx + 1}. [{created_at_date}] {doc}\n" + + form_data["messages"] = add_or_update_system_message( + f"User Context:\n{user_context}\n", form_data["messages"], append=True + ) + + return form_data + + async def chat_web_search_handler( request: Request, form_data: dict, extra_params: dict, user ): @@ -389,6 +427,7 @@ async def chat_web_search_handler( "name": ", ".join(queries), "type": "web_search", "urls": results["filenames"], + "queries": queries, } ) elif results.get("docs"): @@ -400,6 +439,7 @@ async def chat_web_search_handler( "name": ", ".join(queries), "type": "web_search", "urls": results["filenames"], + "queries": queries, } ) @@ -603,6 +643,7 @@ async def chat_completion_files_handler( reranking_function=request.app.state.rf, k_reranker=request.app.state.config.TOP_K_RERANKER, r=request.app.state.config.RELEVANCE_THRESHOLD, + hybrid_bm25_weight=request.app.state.config.HYBRID_BM25_WEIGHT, hybrid_search=request.app.state.config.ENABLE_RAG_HYBRID_SEARCH, full_context=request.app.state.config.RAG_FULL_CONTEXT, ), @@ -774,6 +815,11 @@ async def process_chat_payload(request, form_data, user, metadata, model): features = form_data.pop("features", None) if features: + if "memory" in features and features["memory"]: + form_data = await chat_memory_handler( + request, form_data, extra_params, user + ) + if "web_search" in features and features["web_search"]: form_data = await chat_web_search_handler( request, form_data, extra_params, user @@ -876,6 +922,7 @@ async def process_chat_payload(request, form_data, user, metadata, model): for doc_context, doc_meta in zip( source["document"], source["metadata"] ): + source_name = source.get("source", {}).get("name", None) citation_id = ( doc_meta.get("source", None) or source.get("source", {}).get("id", None) @@ -883,7 +930,11 @@ async def process_chat_payload(request, form_data, user, metadata, model): ) if citation_id not in citation_idx: citation_idx[citation_id] = len(citation_idx) + 1 - context_string += f'{doc_context}\n' + context_string += ( + f'{doc_context}\n" + ) context_string = context_string.strip() prompt = get_last_user_message(form_data["messages"]) @@ -950,7 +1001,7 @@ async def process_chat_response( message = message_map.get(metadata["message_id"]) if message_map else None if message: - message_list = get_message_list(message_map, message.get("id")) + message_list = get_message_list(message_map, metadata["message_id"]) # Remove details tags and files from the messages. # as get_message_list creates a new list, it does not affect @@ -967,7 +1018,7 @@ async def process_chat_response( if isinstance(content, str): content = re.sub( - r"]*>.*?<\/details>", + r"]*>.*?<\/details>|!\[.*?\]\(.*?\)", "", content, flags=re.S | re.I, @@ -975,7 +1026,10 @@ async def process_chat_response( messages.append( { - "role": message["role"], + **message, + "role": message.get( + "role", "assistant" + ), # Safe fallback for missing role "content": content, } ) @@ -1143,6 +1197,7 @@ async def process_chat_response( metadata["chat_id"], metadata["message_id"], { + "role": "assistant", "content": content, }, ) @@ -1165,8 +1220,34 @@ async def process_chat_response( await background_tasks_handler() + if events and isinstance(events, list) and isinstance(response, dict): + extra_response = {} + for event in events: + if isinstance(event, dict): + extra_response.update(event) + else: + extra_response[event] = True + + response = { + **extra_response, + **response, + } + return response else: + if events and isinstance(events, list) and isinstance(response, dict): + extra_response = {} + for event in events: + if isinstance(event, dict): + extra_response.update(event) + else: + extra_response[event] = True + + response = { + **extra_response, + **response, + } + return response # Non standard response diff --git a/backend/open_webui/utils/misc.py b/backend/open_webui/utils/misc.py index 98938dfea..602794f05 100644 --- a/backend/open_webui/utils/misc.py +++ b/backend/open_webui/utils/misc.py @@ -34,11 +34,15 @@ def get_message_list(messages, message_id): :return: List of ordered messages starting from the root to the given message """ + # Handle case where messages is None + if not messages: + return [] # Return empty list instead of None to prevent iteration errors + # Find the message by its id current_message = messages.get(message_id) if not current_message: - return None + return [] # Return empty list instead of None to prevent iteration errors # Reconstruct the chain by following the parentId links message_list = [] @@ -47,7 +51,7 @@ def get_message_list(messages, message_id): message_list.insert( 0, current_message ) # Insert the message at the beginning of the list - parent_id = current_message["parentId"] + parent_id = current_message.get("parentId") # Use .get() for safety current_message = messages.get(parent_id) if parent_id else None return message_list @@ -130,7 +134,9 @@ def prepend_to_first_user_message_content( return messages -def add_or_update_system_message(content: str, messages: list[dict]): +def add_or_update_system_message( + content: str, messages: list[dict], append: bool = False +): """ Adds a new system message at the beginning of the messages list or updates the existing system message at the beginning. @@ -141,7 +147,10 @@ def add_or_update_system_message(content: str, messages: list[dict]): """ if messages and messages[0].get("role") == "system": - messages[0]["content"] = f"{content}\n{messages[0]['content']}" + if append: + messages[0]["content"] = f"{messages[0]['content']}\n{content}" + else: + messages[0]["content"] = f"{content}\n{messages[0]['content']}" else: # Insert at the beginning messages.insert(0, {"role": "system", "content": content}) diff --git a/backend/open_webui/utils/models.py b/backend/open_webui/utils/models.py index 77ff0c932..684d2074e 100644 --- a/backend/open_webui/utils/models.py +++ b/backend/open_webui/utils/models.py @@ -239,11 +239,8 @@ async def get_all_models(request, user: UserModel = None): ] def get_function_module_by_id(function_id): - if function_id in request.app.state.FUNCTIONS: - function_module = request.app.state.FUNCTIONS[function_id] - else: - function_module, _, _ = load_function_module_by_id(function_id) - request.app.state.FUNCTIONS[function_id] = function_module + function_module, _, _ = load_function_module_by_id(function_id) + request.app.state.FUNCTIONS[function_id] = function_module return function_module for model in models: diff --git a/backend/open_webui/utils/oauth.py b/backend/open_webui/utils/oauth.py index f6004515f..de3355859 100644 --- a/backend/open_webui/utils/oauth.py +++ b/backend/open_webui/utils/oauth.py @@ -536,5 +536,10 @@ class OAuthManager: secure=WEBUI_AUTH_COOKIE_SECURE, ) # Redirect back to the frontend with the JWT token - redirect_url = f"{request.base_url}auth#token={jwt_token}" + + redirect_base_url = request.app.state.config.WEBUI_URL or request.base_url + if redirect_base_url.endswith("/"): + redirect_base_url = redirect_base_url[:-1] + redirect_url = f"{redirect_base_url}/auth#token={jwt_token}" + return RedirectResponse(url=redirect_url, headers=response.headers) diff --git a/backend/open_webui/utils/payload.py b/backend/open_webui/utils/payload.py index d43dfd789..599881c4d 100644 --- a/backend/open_webui/utils/payload.py +++ b/backend/open_webui/utils/payload.py @@ -57,6 +57,7 @@ def apply_model_params_to_body_openai(params: dict, form_data: dict) -> dict: mappings = { "temperature": float, "top_p": float, + "min_p": float, "max_tokens": int, "frequency_penalty": float, "presence_penalty": float, diff --git a/backend/open_webui/utils/task.py b/backend/open_webui/utils/task.py index 66bdb4b3e..95018eef1 100644 --- a/backend/open_webui/utils/task.py +++ b/backend/open_webui/utils/task.py @@ -22,7 +22,7 @@ def get_task_model_id( # Set the task model task_model_id = default_model_id # Check if the user has a custom task model and use that model - if models[task_model_id].get("owned_by") == "ollama": + if models[task_model_id].get("connection_type") == "local": if task_model and task_model in models: task_model_id = task_model else: diff --git a/backend/open_webui/utils/tools.py b/backend/open_webui/utils/tools.py index f0b37b605..0774522db 100644 --- a/backend/open_webui/utils/tools.py +++ b/backend/open_webui/utils/tools.py @@ -160,7 +160,7 @@ def get_tools( # TODO: Fix hack for OpenAI API # Some times breaks OpenAI but others don't. Leaving the comment for val in spec.get("parameters", {}).get("properties", {}).values(): - if val["type"] == "str": + if val.get("type") == "str": val["type"] = "string" # Remove internal reserved parameters (e.g. __id__, __user__) @@ -490,8 +490,19 @@ async def get_tool_servers_data( server_entries = [] for idx, server in enumerate(servers): if server.get("config", {}).get("enable"): - url_path = server.get("path", "openapi.json") - full_url = f"{server.get('url')}/{url_path}" + # Path (to OpenAPI spec URL) can be either a full URL or a path to append to the base URL + openapi_path = server.get("path", "openapi.json") + if "://" in openapi_path: + # If it contains "://", it's a full URL + full_url = openapi_path + else: + if not openapi_path.startswith("/"): + # Ensure the path starts with a slash + openapi_path = f"/{openapi_path}" + + full_url = f"{server.get('url')}{openapi_path}" + + info = server.get("info", {}) auth_type = server.get("auth_type", "bearer") token = None @@ -500,26 +511,37 @@ async def get_tool_servers_data( token = server.get("key", "") elif auth_type == "session": token = session_token - server_entries.append((idx, server, full_url, token)) + server_entries.append((idx, server, full_url, info, token)) # Create async tasks to fetch data - tasks = [get_tool_server_data(token, url) for (_, _, url, token) in server_entries] + tasks = [ + get_tool_server_data(token, url) for (_, _, url, _, token) in server_entries + ] # Execute tasks concurrently responses = await asyncio.gather(*tasks, return_exceptions=True) # Build final results with index and server metadata results = [] - for (idx, server, url, _), response in zip(server_entries, responses): + for (idx, server, url, info, _), response in zip(server_entries, responses): if isinstance(response, Exception): log.error(f"Failed to connect to {url} OpenAPI tool server") continue + openapi_data = response.get("openapi", {}) + + if info and isinstance(openapi_data, dict): + if "name" in info: + openapi_data["info"]["title"] = info.get("name", "Tool Server") + + if "description" in info: + openapi_data["info"]["description"] = info.get("description", "") + results.append( { "idx": idx, "url": server.get("url"), - "openapi": response.get("openapi"), + "openapi": openapi_data, "info": response.get("info"), "specs": response.get("specs"), } diff --git a/backend/requirements.txt b/backend/requirements.txt index 07dc09be6..9930cd3b6 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -12,10 +12,12 @@ aiohttp==3.11.11 async-timeout aiocache aiofiles +starlette-compress==1.6.0 + sqlalchemy==2.0.38 alembic==1.14.0 -peewee==3.17.9 +peewee==3.18.1 peewee-migrate==1.12.2 psycopg2-binary==2.9.9 pgvector==0.4.0 diff --git a/package-lock.json b/package-lock.json index 1c3bc8571..c12535a6d 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "open-webui", - "version": "0.6.10", + "version": "0.6.11", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "open-webui", - "version": "0.6.10", + "version": "0.6.11", "dependencies": { "@azure/msal-browser": "^4.5.0", "@codemirror/lang-javascript": "^6.2.2", @@ -22,6 +22,10 @@ "@tiptap/extension-code-block-lowlight": "^2.11.9", "@tiptap/extension-highlight": "^2.10.0", "@tiptap/extension-placeholder": "^2.10.0", + "@tiptap/extension-table": "^2.12.0", + "@tiptap/extension-table-cell": "^2.12.0", + "@tiptap/extension-table-header": "^2.12.0", + "@tiptap/extension-table-row": "^2.12.0", "@tiptap/extension-typography": "^2.10.0", "@tiptap/pm": "^2.11.7", "@tiptap/starter-kit": "^2.10.0", @@ -62,6 +66,7 @@ "prosemirror-schema-basic": "^1.2.3", "prosemirror-schema-list": "^1.5.1", "prosemirror-state": "^1.4.3", + "prosemirror-tables": "^1.7.1", "prosemirror-view": "^1.34.3", "pyodide": "^0.27.3", "socket.io-client": "^4.2.0", @@ -69,6 +74,7 @@ "svelte-sonner": "^0.3.19", "tippy.js": "^6.3.7", "turndown": "^7.2.0", + "turndown-plugin-gfm": "^1.0.2", "undici": "^7.3.0", "uuid": "^9.0.1", "vite-plugin-static-copy": "^2.2.0", @@ -3173,6 +3179,59 @@ "@tiptap/core": "^2.7.0" } }, + "node_modules/@tiptap/extension-table": { + "version": "2.12.0", + "resolved": "https://registry.npmjs.org/@tiptap/extension-table/-/extension-table-2.12.0.tgz", + "integrity": "sha512-tT3IbbBal0vPQ1Bc/3Xl+tmqqZQCYWxnycBPl/WZBqhd57DWzfJqRPESwCGUIJgjOtTnipy/ulvj0FxHi1j9JA==", + "license": "MIT", + "funding": { + "type": "github", + "url": "https://github.com/sponsors/ueberdosis" + }, + "peerDependencies": { + "@tiptap/core": "^2.7.0", + "@tiptap/pm": "^2.7.0" + } + }, + "node_modules/@tiptap/extension-table-cell": { + "version": "2.12.0", + "resolved": "https://registry.npmjs.org/@tiptap/extension-table-cell/-/extension-table-cell-2.12.0.tgz", + "integrity": "sha512-8i35uCkmkSiQxMiZ+DLgT/wj24P5U/Zo3jr1e0tMAAMG7sRO1MljjLmkpV8WCdBo0xoRqzkz4J7Nkq+DtzZv9Q==", + "license": "MIT", + "funding": { + "type": "github", + "url": "https://github.com/sponsors/ueberdosis" + }, + "peerDependencies": { + "@tiptap/core": "^2.7.0" + } + }, + "node_modules/@tiptap/extension-table-header": { + "version": "2.12.0", + "resolved": "https://registry.npmjs.org/@tiptap/extension-table-header/-/extension-table-header-2.12.0.tgz", + "integrity": "sha512-gRKEsy13KKLpg9RxyPeUGqh4BRFSJ2Bc2KQP1ldhef6CPRYHCbGycxXCVQ5aAb7Mhpo54L+AAkmAv1iMHUTflw==", + "license": "MIT", + "funding": { + "type": "github", + "url": "https://github.com/sponsors/ueberdosis" + }, + "peerDependencies": { + "@tiptap/core": "^2.7.0" + } + }, + "node_modules/@tiptap/extension-table-row": { + "version": "2.12.0", + "resolved": "https://registry.npmjs.org/@tiptap/extension-table-row/-/extension-table-row-2.12.0.tgz", + "integrity": "sha512-AEW/Zl9V0IoaYDBLMhF5lVl0xgoIJs3IuKCsIYxGDlxBfTVFC6PfQzvuy296CMjO5ZcZ0xalVipPV9ggsMRD+w==", + "license": "MIT", + "funding": { + "type": "github", + "url": "https://github.com/sponsors/ueberdosis" + }, + "peerDependencies": { + "@tiptap/core": "^2.7.0" + } + }, "node_modules/@tiptap/extension-text": { "version": "2.10.0", "resolved": "https://registry.npmjs.org/@tiptap/extension-text/-/extension-text-2.10.0.tgz", @@ -9809,16 +9868,16 @@ } }, "node_modules/prosemirror-tables": { - "version": "1.6.4", - "resolved": "https://registry.npmjs.org/prosemirror-tables/-/prosemirror-tables-1.6.4.tgz", - "integrity": "sha512-TkDY3Gw52gRFRfRn2f4wJv5WOgAOXLJA2CQJYIJ5+kdFbfj3acR4JUW6LX2e1hiEBiUwvEhzH5a3cZ5YSztpIA==", + "version": "1.7.1", + "resolved": "https://registry.npmjs.org/prosemirror-tables/-/prosemirror-tables-1.7.1.tgz", + "integrity": "sha512-eRQ97Bf+i9Eby99QbyAiyov43iOKgWa7QCGly+lrDt7efZ1v8NWolhXiB43hSDGIXT1UXgbs4KJN3a06FGpr1Q==", "license": "MIT", "dependencies": { "prosemirror-keymap": "^1.2.2", - "prosemirror-model": "^1.24.1", + "prosemirror-model": "^1.25.0", "prosemirror-state": "^1.4.3", - "prosemirror-transform": "^1.10.2", - "prosemirror-view": "^1.37.2" + "prosemirror-transform": "^1.10.3", + "prosemirror-view": "^1.39.1" } }, "node_modules/prosemirror-trailing-node": { @@ -9837,9 +9896,9 @@ } }, "node_modules/prosemirror-transform": { - "version": "1.10.2", - "resolved": "https://registry.npmjs.org/prosemirror-transform/-/prosemirror-transform-1.10.2.tgz", - "integrity": "sha512-2iUq0wv2iRoJO/zj5mv8uDUriOHWzXRnOTVgCzSXnktS/2iQRa3UUQwVlkBlYZFtygw6Nh1+X4mGqoYBINn5KQ==", + "version": "1.10.4", + "resolved": "https://registry.npmjs.org/prosemirror-transform/-/prosemirror-transform-1.10.4.tgz", + "integrity": "sha512-pwDy22nAnGqNR1feOQKHxoFkkUtepoFAd3r2hbEDsnf4wp57kKA36hXsB3njA9FtONBEwSDnDeCiJe+ItD+ykw==", "license": "MIT", "dependencies": { "prosemirror-model": "^1.21.0" @@ -11808,6 +11867,12 @@ "@mixmark-io/domino": "^2.2.0" } }, + "node_modules/turndown-plugin-gfm": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/turndown-plugin-gfm/-/turndown-plugin-gfm-1.0.2.tgz", + "integrity": "sha512-vwz9tfvF7XN/jE0dGoBei3FXWuvll78ohzCZQuOb+ZjWrs3a0XhQVomJEb2Qh4VHTPNRO4GPZh0V7VRbiWwkRg==", + "license": "MIT" + }, "node_modules/tweetnacl": { "version": "0.14.5", "resolved": "https://registry.npmjs.org/tweetnacl/-/tweetnacl-0.14.5.tgz", diff --git a/package.json b/package.json index 744315c3b..68eb3f0a0 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "open-webui", - "version": "0.6.10", + "version": "0.6.11", "private": true, "scripts": { "dev": "npm run pyodide:fetch && vite dev --host", @@ -66,6 +66,10 @@ "@tiptap/extension-code-block-lowlight": "^2.11.9", "@tiptap/extension-highlight": "^2.10.0", "@tiptap/extension-placeholder": "^2.10.0", + "@tiptap/extension-table": "^2.12.0", + "@tiptap/extension-table-cell": "^2.12.0", + "@tiptap/extension-table-header": "^2.12.0", + "@tiptap/extension-table-row": "^2.12.0", "@tiptap/extension-typography": "^2.10.0", "@tiptap/pm": "^2.11.7", "@tiptap/starter-kit": "^2.10.0", @@ -106,6 +110,7 @@ "prosemirror-schema-basic": "^1.2.3", "prosemirror-schema-list": "^1.5.1", "prosemirror-state": "^1.4.3", + "prosemirror-tables": "^1.7.1", "prosemirror-view": "^1.34.3", "pyodide": "^0.27.3", "socket.io-client": "^4.2.0", @@ -113,6 +118,7 @@ "svelte-sonner": "^0.3.19", "tippy.js": "^6.3.7", "turndown": "^7.2.0", + "turndown-plugin-gfm": "^1.0.2", "undici": "^7.3.0", "uuid": "^9.0.1", "vite-plugin-static-copy": "^2.2.0", diff --git a/pyproject.toml b/pyproject.toml index 01e6bd72c..51ea65890 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,9 +21,11 @@ dependencies = [ "aiocache", "aiofiles", + "starlette-compress==1.6.0", + "sqlalchemy==2.0.38", "alembic==1.14.0", - "peewee==3.17.9", + "peewee==3.18.1", "peewee-migrate==1.12.2", "psycopg2-binary==2.9.9", "pgvector==0.4.0", diff --git a/src/app.css b/src/app.css index 925b9c52d..ea0bd5fb0 100644 --- a/src/app.css +++ b/src/app.css @@ -103,7 +103,7 @@ li p { ::-webkit-scrollbar-thumb { --tw-border-opacity: 1; - background-color: rgba(236, 236, 236, 0.8); + background-color: rgba(215, 215, 215, 0.8); border-color: rgba(255, 255, 255, var(--tw-border-opacity)); border-radius: 9999px; border-width: 1px; @@ -111,12 +111,12 @@ li p { /* Dark theme scrollbar styles */ .dark ::-webkit-scrollbar-thumb { - background-color: rgba(42, 42, 42, 0.8); /* Darker color for dark theme */ + background-color: rgba(67, 67, 67, 0.8); /* Darker color for dark theme */ border-color: rgba(0, 0, 0, var(--tw-border-opacity)); } ::-webkit-scrollbar { - height: 0.8rem; + height: 0.6rem; width: 0.4rem; } @@ -412,3 +412,29 @@ input[type='number'] { .hljs-strong { font-weight: 700; } + +/* Table styling for tiptap editors */ +.tiptap table { + @apply w-full text-sm text-left text-gray-500 dark:text-gray-400 max-w-full; +} + +.tiptap thead { + @apply text-xs text-gray-700 uppercase bg-gray-50 dark:bg-gray-850 dark:text-gray-400 border-none; +} + +.tiptap th, +.tiptap td { + @apply px-3 py-1.5 border border-gray-100 dark:border-gray-850; +} + +.tiptap th { + @apply cursor-pointer text-left text-xs text-gray-700 dark:text-gray-400 font-semibold uppercase bg-gray-50 dark:bg-gray-850; +} + +.tiptap td { + @apply text-gray-900 dark:text-white w-max; +} + +.tiptap tr { + @apply bg-white dark:bg-gray-900 dark:border-gray-850 text-xs; +} diff --git a/src/lib/apis/audio/index.ts b/src/lib/apis/audio/index.ts index f6354da77..b2fed5739 100644 --- a/src/lib/apis/audio/index.ts +++ b/src/lib/apis/audio/index.ts @@ -64,9 +64,12 @@ export const updateAudioConfig = async (token: string, payload: OpenAIConfigForm return res; }; -export const transcribeAudio = async (token: string, file: File) => { +export const transcribeAudio = async (token: string, file: File, language?: string) => { const data = new FormData(); data.append('file', file); + if (language) { + data.append('language', language); + } let error = null; const res = await fetch(`${AUDIO_API_BASE_URL}/transcriptions`, { diff --git a/src/lib/apis/chats/index.ts b/src/lib/apis/chats/index.ts index 0ff56ea23..9d24b3971 100644 --- a/src/lib/apis/chats/index.ts +++ b/src/lib/apis/chats/index.ts @@ -111,10 +111,79 @@ export const getChatList = async (token: string = '', page: number | null = null })); }; -export const getChatListByUserId = async (token: string = '', userId: string) => { +export const getChatListByUserId = async ( + token: string = '', + userId: string, + page: number = 1, + filter?: object +) => { let error = null; - const res = await fetch(`${WEBUI_API_BASE_URL}/chats/list/user/${userId}`, { + const searchParams = new URLSearchParams(); + + searchParams.append('page', `${page}`); + + if (filter) { + Object.entries(filter).forEach(([key, value]) => { + if (value !== undefined && value !== null) { + searchParams.append(key, value.toString()); + } + }); + } + + const res = await fetch( + `${WEBUI_API_BASE_URL}/chats/list/user/${userId}?${searchParams.toString()}`, + { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + } + } + ) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err; + console.error(err); + return null; + }); + + if (error) { + throw error; + } + + return res.map((chat) => ({ + ...chat, + time_range: getTimeRange(chat.updated_at) + })); +}; + +export const getArchivedChatList = async ( + token: string = '', + page: number = 1, + filter?: object +) => { + let error = null; + + const searchParams = new URLSearchParams(); + searchParams.append('page', `${page}`); + + if (filter) { + Object.entries(filter).forEach(([key, value]) => { + if (value !== undefined && value !== null) { + searchParams.append(key, value.toString()); + } + }); + } + + const res = await fetch(`${WEBUI_API_BASE_URL}/chats/archived?${searchParams.toString()}`, { method: 'GET', headers: { Accept: 'application/json', @@ -145,37 +214,6 @@ export const getChatListByUserId = async (token: string = '', userId: string) => })); }; -export const getArchivedChatList = async (token: string = '') => { - let error = null; - - const res = await fetch(`${WEBUI_API_BASE_URL}/chats/archived`, { - method: 'GET', - headers: { - Accept: 'application/json', - 'Content-Type': 'application/json', - ...(token && { authorization: `Bearer ${token}` }) - } - }) - .then(async (res) => { - if (!res.ok) throw await res.json(); - return res.json(); - }) - .then((json) => { - return json; - }) - .catch((err) => { - error = err; - console.error(err); - return null; - }); - - if (error) { - throw error; - } - - return res; -}; - export const getAllChats = async (token: string) => { let error = null; diff --git a/src/lib/apis/files/index.ts b/src/lib/apis/files/index.ts index 261fe56db..a58d7cb93 100644 --- a/src/lib/apis/files/index.ts +++ b/src/lib/apis/files/index.ts @@ -1,8 +1,12 @@ import { WEBUI_API_BASE_URL } from '$lib/constants'; -export const uploadFile = async (token: string, file: File) => { +export const uploadFile = async (token: string, file: File, metadata?: object | null) => { const data = new FormData(); data.append('file', file); + if (metadata) { + data.append('metadata', JSON.stringify(metadata)); + } + let error = null; const res = await fetch(`${WEBUI_API_BASE_URL}/files/`, { diff --git a/src/lib/apis/functions/index.ts b/src/lib/apis/functions/index.ts index f1a9bf5a0..60e88118b 100644 --- a/src/lib/apis/functions/index.ts +++ b/src/lib/apis/functions/index.ts @@ -62,6 +62,40 @@ export const getFunctions = async (token: string = '') => { return res; }; +export const loadFunctionByUrl = async (token: string = '', url: string) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/functions/load/url`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + url + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + console.error(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + export const exportFunctions = async (token: string = '') => { let error = null; diff --git a/src/lib/apis/index.ts b/src/lib/apis/index.ts index 710179c12..268be397b 100644 --- a/src/lib/apis/index.ts +++ b/src/lib/apis/index.ts @@ -346,11 +346,15 @@ export const getToolServersData = async (i18n, servers: object[]) => { .map(async (server) => { const data = await getToolServerData( (server?.auth_type ?? 'bearer') === 'bearer' ? server?.key : localStorage.token, - server?.url + '/' + (server?.path ?? 'openapi.json') + (server?.path ?? '').includes('://') + ? server?.path + : `${server?.url}${(server?.path ?? '').startsWith('/') ? '' : '/'}${server?.path}` ).catch((err) => { toast.error( i18n.t(`Failed to connect to {{URL}} OpenAPI tool server`, { - URL: server?.url + '/' + (server?.path ?? 'openapi.json') + URL: (server?.path ?? '').includes('://') + ? server?.path + : `${server?.url}${(server?.path ?? '').startsWith('/') ? '' : '/'}${server?.path}` }) ); return null; diff --git a/src/lib/apis/ollama/index.ts b/src/lib/apis/ollama/index.ts index f159555da..489055c1b 100644 --- a/src/lib/apis/ollama/index.ts +++ b/src/lib/apis/ollama/index.ts @@ -355,6 +355,31 @@ export const generateChatCompletion = async (token: string = '', body: object) = return [res, controller]; }; +export const unloadModel = async (token: string, tagName: string) => { + let error = null; + + const res = await fetch(`${OLLAMA_API_BASE_URL}/api/unload`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + name: tagName + }) + }).catch((err) => { + error = err; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + export const createModel = async (token: string, payload: object, urlIdx: string | null = null) => { let error = null; diff --git a/src/lib/components/AddServerModal.svelte b/src/lib/components/AddServerModal.svelte index ff0a546fa..1c9ce46e2 100644 --- a/src/lib/components/AddServerModal.svelte +++ b/src/lib/components/AddServerModal.svelte @@ -22,7 +22,6 @@ export let edit = false; export let direct = false; - export let connection = null; let url = ''; @@ -33,6 +32,9 @@ let accessControl = {}; + let name = ''; + let description = ''; + let enable = true; let loading = false; @@ -51,7 +53,7 @@ if (direct) { const res = await getToolServerData( auth_type === 'bearer' ? key : localStorage.token, - `${url}/${path}` + path.includes('://') ? path : `${url}${path.startsWith('/') ? '' : '/'}${path}` ).catch((err) => { toast.error($i18n.t('Connection failed')); }); @@ -69,6 +71,10 @@ config: { enable: enable, access_control: accessControl + }, + info: { + name, + description } }).catch((err) => { toast.error($i18n.t('Connection failed')); @@ -95,6 +101,10 @@ config: { enable: enable, access_control: accessControl + }, + info: { + name: name, + description: description } }; @@ -108,6 +118,9 @@ key = ''; auth_type = 'bearer'; + name = ''; + description = ''; + enable = true; accessControl = null; }; @@ -120,6 +133,9 @@ auth_type = connection?.auth_type ?? 'bearer'; key = connection?.key ?? ''; + name = connection.info?.name ?? ''; + description = connection.info?.description ?? ''; + enable = connection.config?.enable ?? true; accessControl = connection.config?.access_control ?? null; } @@ -221,12 +237,11 @@
-
/
@@ -236,7 +251,7 @@
{$i18n.t(`WebUI will make requests to "{{url}}"`, { - url: `${url}/${path}` + url: path.includes('://') ? path : `${url}${path.startsWith('/') ? '' : '/'}${path}` })}
@@ -276,6 +291,39 @@ {#if !direct}
+
+
+
{$i18n.t('Name')}
+ +
+ +
+
+
+ +
+
{$i18n.t('Description')}
+ +
+ +
+
+ +
+
diff --git a/src/lib/components/admin/Functions.svelte b/src/lib/components/admin/Functions.svelte index 9897f7f11..afd78303f 100644 --- a/src/lib/components/admin/Functions.svelte +++ b/src/lib/components/admin/Functions.svelte @@ -32,6 +32,9 @@ import Search from '../icons/Search.svelte'; import Plus from '../icons/Plus.svelte'; import ChevronRight from '../icons/ChevronRight.svelte'; + import XMark from '../icons/XMark.svelte'; + import AddFunctionMenu from './Functions/AddFunctionMenu.svelte'; + import ImportModal from './Functions/ImportModal.svelte'; const i18n = getContext('i18n'); @@ -40,6 +43,8 @@ let functionsImportInputElement: HTMLInputElement; let importFiles; + let showImportModal = false; + let showConfirm = false; let query = ''; @@ -196,6 +201,16 @@ + { + sessionStorage.function = JSON.stringify({ + ...func + }); + goto('/admin/functions/create'); + }} +/> +
@@ -215,15 +230,36 @@ bind:value={query} placeholder={$i18n.t('Search Functions')} /> + + {#if query} +
+ +
+ {/if}
diff --git a/src/lib/components/admin/Functions/AddFunctionMenu.svelte b/src/lib/components/admin/Functions/AddFunctionMenu.svelte new file mode 100644 index 000000000..6c0f59e1f --- /dev/null +++ b/src/lib/components/admin/Functions/AddFunctionMenu.svelte @@ -0,0 +1,77 @@ + + + { + if (e.detail === false) { + onClose(); + } + }} +> + + + + +
+ + + + + +
+
diff --git a/src/lib/components/admin/Functions/ImportModal.svelte b/src/lib/components/admin/Functions/ImportModal.svelte new file mode 100644 index 000000000..47b8c0e2e --- /dev/null +++ b/src/lib/components/admin/Functions/ImportModal.svelte @@ -0,0 +1,145 @@ + + + +
+
+
{$i18n.t('Import')}
+ +
+ +
+
+
{ + submitHandler(); + }} + > +
+
+
{$i18n.t('URL')}
+ +
+ +
+
+
+ +
+ +
+
+
+
+
+
diff --git a/src/lib/components/admin/Settings/Documents.svelte b/src/lib/components/admin/Settings/Documents.svelte index 0660dc7ae..4144004fb 100644 --- a/src/lib/components/admin/Settings/Documents.svelte +++ b/src/lib/components/admin/Settings/Documents.svelte @@ -771,6 +771,26 @@
{/if} + + {#if RAGConfig.ENABLE_RAG_HYBRID_SEARCH === true} +
+
+ {$i18n.t('Weight of BM25 Retrieval')} +
+
+ +
+
+ {/if} {/if}
diff --git a/src/lib/components/admin/Settings/General.svelte b/src/lib/components/admin/Settings/General.svelte index 85f21824f..df79249b0 100644 --- a/src/lib/components/admin/Settings/General.svelte +++ b/src/lib/components/admin/Settings/General.svelte @@ -84,7 +84,7 @@ if (res) { saveHandler(); } else { - toast.error(i18n.t('Failed to update settings')); + toast.error($i18n.t('Failed to update settings')); } }; diff --git a/src/lib/components/admin/Settings/Interface.svelte b/src/lib/components/admin/Settings/Interface.svelte index 6b062d772..548db5a98 100644 --- a/src/lib/components/admin/Settings/Interface.svelte +++ b/src/lib/components/admin/Settings/Interface.svelte @@ -1,4 +1,7 @@ + +
+ {#each banners as banner, bannerIdx (banner.id)} + + {/each} +
diff --git a/src/lib/components/admin/Settings/Models.svelte b/src/lib/components/admin/Settings/Models.svelte index 4966e9f6d..875d542ff 100644 --- a/src/lib/components/admin/Settings/Models.svelte +++ b/src/lib/components/admin/Settings/Models.svelte @@ -20,6 +20,7 @@ import Tooltip from '$lib/components/common/Tooltip.svelte'; import Switch from '$lib/components/common/Switch.svelte'; import Spinner from '$lib/components/common/Spinner.svelte'; + import XMark from '$lib/components/icons/XMark.svelte'; import ModelEditor from '$lib/components/workspace/Models/ModelEditor.svelte'; import { toast } from 'svelte-sonner'; @@ -33,6 +34,7 @@ import EllipsisHorizontal from '$lib/components/icons/EllipsisHorizontal.svelte'; import EyeSlash from '$lib/components/icons/EyeSlash.svelte'; import Eye from '$lib/components/icons/Eye.svelte'; + import { copyToClipboard } from '$lib/utils'; let shiftKey = false; @@ -181,6 +183,17 @@ upsertModelHandler(model); }; + const copyLinkHandler = async (model) => { + const baseUrl = window.location.origin; + const res = await copyToClipboard(`${baseUrl}/?model=${encodeURIComponent(model.id)}`); + + if (res) { + toast.success($i18n.t('Copied link to clipboard')); + } else { + toast.error($i18n.t('Failed to copy link')); + } + }; + const exportModelHandler = async (model) => { let blob = new Blob([JSON.stringify([model])], { type: 'application/json' @@ -271,6 +284,18 @@ bind:value={searchValue} placeholder={$i18n.t('Search Models')} /> + {#if searchValue} +
+ +
+ {/if}
@@ -381,6 +406,9 @@ hideHandler={() => { hideModelHandler(model); }} + copyLinkHandler={() => { + copyLinkHandler(model); + }} onClose={() => {}} > - - -
-
- {#if chats} - {#if chats.length > 0} -
-
- - - - - - - - - {#each chats.sort((a, b) => { - if (a[sortKey] < b[sortKey]) return sortOrder === 'asc' ? -1 : 1; - if (a[sortKey] > b[sortKey]) return sortOrder === 'asc' ? 1 : -1; - return 0; - }) as chat, idx} - - - - - - - - {/each} - -
setSortKey('title')} - > - {$i18n.t('Title')} - {#if sortKey === 'title'} - {sortOrder === 'asc' ? '▲' : '▼'} - {:else} - - {/if} - -
- -
- {chat.title} -
-
-
-
- - - -
-
-
- -
- {:else} -
- {user.name} - {$i18n.t('has no conversations.')} -
- {/if} - {:else} - - {/if} -
-
- + loadHandler={loadMoreChats} +> diff --git a/src/lib/components/channel/MessageInput.svelte b/src/lib/components/channel/MessageInput.svelte index 2ad4faa93..901e8d58f 100644 --- a/src/lib/components/channel/MessageInput.svelte +++ b/src/lib/components/channel/MessageInput.svelte @@ -17,7 +17,6 @@ import { WEBUI_API_BASE_URL } from '$lib/constants'; import FileItem from '../common/FileItem.svelte'; import Image from '../common/Image.svelte'; - import { transcribeAudio } from '$lib/apis/audio'; import FilesOverlay from '../chat/MessageInput/FilesOverlay.svelte'; export let placeholder = $i18n.t('Send a Message'); @@ -160,7 +159,19 @@ try { // During the file upload, file content is automatically extracted. - const uploadedFile = await uploadFile(localStorage.token, file); + + // If the file is an audio file, provide the language for STT. + let metadata = null; + if ( + (file.type.startsWith('audio/') || file.type.startsWith('video/')) && + $settings?.audio?.stt?.language + ) { + metadata = { + language: $settings?.audio?.stt?.language + }; + } + + const uploadedFile = await uploadFile(localStorage.token, file, metadata); if (uploadedFile) { console.info('File upload completed:', { diff --git a/src/lib/components/channel/Navbar.svelte b/src/lib/components/channel/Navbar.svelte index 7c2bc1839..31f94fb48 100644 --- a/src/lib/components/channel/Navbar.svelte +++ b/src/lib/components/channel/Navbar.svelte @@ -57,8 +57,9 @@
{#if $user !== undefined} { if (e.detail === 'archived-chat') { showArchivedChats.set(true); diff --git a/src/lib/components/chat/Chat.svelte b/src/lib/components/chat/Chat.svelte index 77f2fd6a1..a37ce5be7 100644 --- a/src/lib/components/chat/Chat.svelte +++ b/src/lib/components/chat/Chat.svelte @@ -49,7 +49,8 @@ sleep, removeDetails, getPromptVariables, - processDetails + processDetails, + removeAllDetails } from '$lib/utils'; import { generateChatCompletion } from '$lib/apis/ollama'; @@ -88,6 +89,7 @@ import Placeholder from './Placeholder.svelte'; import NotificationToast from '../NotificationToast.svelte'; import Spinner from '../common/Spinner.svelte'; + import { fade } from 'svelte/transition'; export let chatIdProp = ''; @@ -193,15 +195,27 @@ console.log('saveSessionSelectedModels', selectedModels, sessionStorage.selectedModels); }; - $: if (selectedModels) { - setToolIds(); - setFilterIds(); + let oldSelectedModelIds = ['']; + $: if (JSON.stringify(selectedModelIds) !== JSON.stringify(oldSelectedModelIds)) { + onSelectedModelIdsChange(); } - $: if (atSelectedModel || selectedModels) { + const onSelectedModelIdsChange = () => { + if (oldSelectedModelIds.filter((id) => id).length > 0) { + resetInput(); + } + oldSelectedModelIds = selectedModelIds; + }; + + const resetInput = () => { + console.debug('resetInput'); setToolIds(); - setFilterIds(); - } + + selectedFilterIds = []; + webSearchEnabled = false; + imageGenerationEnabled = false; + codeInterpreterEnabled = false; + }; const setToolIds = async () => { if (!$tools) { @@ -213,20 +227,14 @@ } const model = atSelectedModel ?? $models.find((m) => m.id === selectedModels[0]); - if (model) { + if (model && model?.info?.meta?.toolIds) { selectedToolIds = [ ...new Set( - [...selectedToolIds, ...(model?.info?.meta?.toolIds ?? [])].filter((id) => - $tools.find((t) => t.id === id) - ) + [...(model?.info?.meta?.toolIds ?? [])].filter((id) => $tools.find((t) => t.id === id)) ) ]; - } - }; - - const setFilterIds = async () => { - if (selectedModels.length !== 1 && !atSelectedModel) { - selectedFilterIds = []; + } else { + selectedToolIds = []; } }; @@ -583,9 +591,20 @@ throw new Error('Created file is empty'); } + // If the file is an audio file, provide the language for STT. + let metadata = null; + if ( + (file.type.startsWith('audio/') || file.type.startsWith('video/')) && + $settings?.audio?.stt?.language + ) { + metadata = { + language: $settings?.audio?.stt?.language + }; + } + // Upload file to server console.log('Uploading file to server...'); - const uploadedFile = await uploadFile(localStorage.token, file); + const uploadedFile = await uploadFile(localStorage.token, file, metadata); if (!uploadedFile) { throw new Error('Server returned null response for file upload'); @@ -844,6 +863,8 @@ (chatContent?.models ?? undefined) !== undefined ? chatContent.models : [chatContent.models ?? '']; + oldSelectedModelIds = selectedModels; + history = (chatContent?.history ?? undefined) !== undefined ? chatContent.history @@ -1171,7 +1192,7 @@ // Emit chat event for TTS const messageContentParts = getMessageContentParts( - message.content, + removeAllDetails(message.content), $config?.audio?.tts?.split_on ?? 'punctuation' ); messageContentParts.pop(); @@ -1205,7 +1226,7 @@ // Emit chat event for TTS const messageContentParts = getMessageContentParts( - message.content, + removeAllDetails(message.content), $config?.audio?.tts?.split_on ?? 'punctuation' ); messageContentParts.pop(); @@ -1252,9 +1273,10 @@ // Emit chat event for TTS let lastMessageContentPart = - getMessageContentParts(message.content, $config?.audio?.tts?.split_on ?? 'punctuation')?.at( - -1 - ) ?? ''; + getMessageContentParts( + removeAllDetails(message.content), + $config?.audio?.tts?.split_on ?? 'punctuation' + )?.at(-1) ?? ''; if (lastMessageContentPart) { eventTarget.dispatchEvent( new CustomEvent('chat', { @@ -1430,7 +1452,6 @@ model: model.id, modelName: model.name ?? model.id, modelIdx: modelIdx ? modelIdx : _modelIdx, - userContext: null, timestamp: Math.floor(Date.now() / 1000) // Unix epoch }; @@ -1485,32 +1506,6 @@ let responseMessageId = responseMessageIds[`${modelId}-${modelIdx ? modelIdx : _modelIdx}`]; - let responseMessage = _history.messages[responseMessageId]; - - let userContext = null; - if ($settings?.memory ?? false) { - if (userContext === null) { - const res = await queryMemory(localStorage.token, prompt).catch((error) => { - toast.error(`${error}`); - return null; - }); - if (res) { - if (res.documents[0].length > 0) { - userContext = res.documents[0].reduce((acc, doc, index) => { - const createdAtTimestamp = res.metadatas[0][index].created_at; - const createdAtDate = new Date(createdAtTimestamp * 1000) - .toISOString() - .split('T')[0]; - return `${acc}${index + 1}. [${createdAtDate}]. ${doc}\n`; - }, ''); - } - - console.log(userContext); - } - } - } - responseMessage.userContext = userContext; - const chatEventEmitter = await getChatEventEmitter(model.id, _chatId); scrollToBottom(); @@ -1572,7 +1567,7 @@ true; let messages = [ - params?.system || $settings.system || (responseMessage?.userContext ?? null) + params?.system || $settings.system ? { role: 'system', content: `${promptTemplate( @@ -1584,11 +1579,7 @@ return undefined; }) : undefined - )}${ - (responseMessage?.userContext ?? null) - ? `\n\nUser Context:\n${responseMessage?.userContext ?? ''}` - : '' - }` + )}` } : undefined, ...createMessagesList(_history, responseMessageId).map((message) => ({ @@ -1665,7 +1656,8 @@ $config?.features?.enable_web_search && ($user?.role === 'admin' || $user?.permissions?.features?.web_search) ? webSearchEnabled || ($settings?.webSearch ?? false) === 'always' - : false + : false, + memory: $settings?.memory ?? false }, variables: { ...getPromptVariables( @@ -2011,196 +2003,198 @@ id="chat-container" > {#if !loading} - {#if $settings?.backgroundImageUrl ?? null} -
- -
- {/if} - - - - + {#if $settings?.backgroundImageUrl ?? null} +
-
- {#if $settings?.landingPageMode === 'chat' || createMessagesList(history, history.currentId).length > 0} - - - - { - const model = $models.find((m) => m.id === e); - if (model) { - return [...a, model]; - } - return a; - }, [])} - {submitPrompt} - {stopResponse} - {showMessage} - {eventTarget} - /> - + { + const model = $models.find((m) => m.id === e); + if (model) { + return [...a, model]; + } + return a; + }, [])} + {submitPrompt} + {stopResponse} + {showMessage} + {eventTarget} + /> + +
{:else if loading}
diff --git a/src/lib/components/chat/ContentRenderer/FloatingButtons.svelte b/src/lib/components/chat/ContentRenderer/FloatingButtons.svelte index 9286aaed0..caded99ab 100644 --- a/src/lib/components/chat/ContentRenderer/FloatingButtons.svelte +++ b/src/lib/components/chat/ContentRenderer/FloatingButtons.svelte @@ -10,7 +10,7 @@ import { chatCompletion } from '$lib/apis/openai'; import ChatBubble from '$lib/components/icons/ChatBubble.svelte'; - import LightBlub from '$lib/components/icons/LightBlub.svelte'; + import LightBulb from '$lib/components/icons/LightBulb.svelte'; import Markdown from '../Messages/Markdown.svelte'; import Skeleton from '../Messages/Skeleton.svelte'; @@ -44,7 +44,13 @@ toast.error('Model not selected'); return; } - prompt = `${floatingInputValue}\n\`\`\`\n${selectedText}\n\`\`\``; + prompt = [ + // Blockquote each line of the selected text + ...selectedText.split('\n').map((line) => `> ${line}`), + '', + // Then your question + floatingInputValue + ].join('\n'); floatingInputValue = ''; responseContent = ''; @@ -121,8 +127,11 @@ toast.error('Model not selected'); return; } - const explainText = $i18n.t('Explain this section to me in more detail'); - prompt = `${explainText}\n\n\`\`\`\n${selectedText}\n\`\`\``; + const quotedText = selectedText + .split('\n') + .map((line) => `> ${line}`) + .join('\n'); + prompt = `${quotedText}\n\nExplain`; responseContent = ''; const [res, controller] = await chatCompletion(localStorage.token, { @@ -256,7 +265,7 @@ explainHandler(); }} > - +
{$i18n.t('Explain')}
diff --git a/src/lib/components/chat/MessageInput.svelte b/src/lib/components/chat/MessageInput.svelte index e1a5b1ea6..987355e3b 100644 --- a/src/lib/components/chat/MessageInput.svelte +++ b/src/lib/components/chat/MessageInput.svelte @@ -27,7 +27,6 @@ createMessagesList, extractCurlyBraceWords } from '$lib/utils'; - import { transcribeAudio } from '$lib/apis/audio'; import { uploadFile } from '$lib/apis/files'; import { generateAutoCompletion } from '$lib/apis'; import { deleteFileById } from '$lib/apis/files'; @@ -110,7 +109,9 @@ let commandsElement; let inputFiles; + let dragged = false; + let shiftKey = false; let user = null; export let placeholder = ''; @@ -151,6 +152,30 @@ .map((id) => ($models.find((model) => model.id === id) || {})?.filters ?? []) .reduce((acc, filters) => acc.filter((f1) => filters.some((f2) => f2.id === f1.id))); + let showToolsButton = false; + $: showToolsButton = toolServers.length + selectedToolIds.length > 0; + + let showWebSearchButton = false; + $: showWebSearchButton = + (atSelectedModel?.id ? [atSelectedModel.id] : selectedModels).length === + webSearchCapableModels.length && + $config?.features?.enable_web_search && + ($_user.role === 'admin' || $_user?.permissions?.features?.web_search); + + let showImageGenerationButton = false; + $: showImageGenerationButton = + (atSelectedModel?.id ? [atSelectedModel.id] : selectedModels).length === + imageGenerationCapableModels.length && + $config?.features?.enable_image_generation && + ($_user.role === 'admin' || $_user?.permissions?.features?.image_generation); + + let showCodeInterpreterButton = false; + $: showCodeInterpreterButton = + (atSelectedModel?.id ? [atSelectedModel.id] : selectedModels).length === + codeInterpreterCapableModels.length && + $config?.features?.enable_code_interpreter && + ($_user.role === 'admin' || $_user?.permissions?.features?.code_interpreter); + const scrollToBottom = () => { const element = document.getElementById('messages-container'); element.scrollTo({ @@ -225,8 +250,19 @@ files = [...files, fileItem]; try { + // If the file is an audio file, provide the language for STT. + let metadata = null; + if ( + (file.type.startsWith('audio/') || file.type.startsWith('video/')) && + $settings?.audio?.stt?.language + ) { + metadata = { + language: $settings?.audio?.stt?.language + }; + } + // During the file upload, file content is automatically extracted. - const uploadedFile = await uploadFile(localStorage.token, file); + const uploadedFile = await uploadFile(localStorage.token, file, metadata); if (uploadedFile) { console.log('File upload completed:', { @@ -318,13 +354,6 @@ }); }; - const handleKeyDown = (event: KeyboardEvent) => { - if (event.key === 'Escape') { - console.log('Escape'); - dragged = false; - } - }; - const onDragOver = (e) => { e.preventDefault(); @@ -355,6 +384,29 @@ dragged = false; }; + const onKeyDown = (e) => { + if (e.key === 'Shift') { + shiftKey = true; + } + + if (e.key === 'Escape') { + console.log('Escape'); + dragged = false; + } + }; + + const onKeyUp = (e) => { + if (e.key === 'Shift') { + shiftKey = false; + } + }; + + const onFocus = () => {}; + + const onBlur = () => { + shiftKey = false; + }; + onMount(async () => { loaded = true; @@ -363,7 +415,11 @@ chatInput?.focus(); }, 0); - window.addEventListener('keydown', handleKeyDown); + window.addEventListener('keydown', onKeyDown); + window.addEventListener('keyup', onKeyUp); + + window.addEventListener('focus', onFocus); + window.addEventListener('blur', onBlur); await tick(); @@ -376,7 +432,11 @@ onDestroy(() => { console.log('destroy'); - window.removeEventListener('keydown', handleKeyDown); + window.removeEventListener('keydown', onKeyDown); + window.removeEventListener('keyup', onKeyUp); + + window.removeEventListener('focus', onFocus); + window.removeEventListener('blur', onBlur); const dropzoneElement = document.getElementById('chat-container'); @@ -641,7 +701,7 @@
{#if $settings?.richTextInput ?? true}
0 ))} placeholder={placeholder ? placeholder : $i18n.t('Send a Message')} - largeTextAsFile={$settings?.largeTextAsFile ?? false} + largeTextAsFile={($settings?.largeTextAsFile ?? false) && !shiftKey} autocomplete={$config?.features?.enable_autocomplete_generation && ($settings?.promptAutocomplete ?? false)} generateAutoCompletion={async (text) => { @@ -839,7 +899,7 @@ reader.readAsDataURL(blob); } else if (item.type === 'text/plain') { - if ($settings?.largeTextAsFile ?? false) { + if (($settings?.largeTextAsFile ?? false) && !shiftKey) { const text = clipboardData.getData('text/plain'); if (text.length > PASTED_TEXT_CHARACTER_LIMIT) { @@ -1070,7 +1130,7 @@ reader.readAsDataURL(blob); } else if (item.type === 'text/plain') { - if ($settings?.largeTextAsFile ?? false) { + if (($settings?.largeTextAsFile ?? false) && !shiftKey) { const text = clipboardData.getData('text/plain'); if (text.length > PASTED_TEXT_CHARACTER_LIMIT) { @@ -1091,8 +1151,8 @@ {/if}
-
-
+
+
-
- {#if toolServers.length + selectedToolIds.length > 0} - - - - {/if} + + {toolServers.length + selectedToolIds.length} + + + + {/if} - {#if $_user} {#each toggleFilters as filter, filterIdx (filter.id)} {/each} - {#if (atSelectedModel?.id ? [atSelectedModel.id] : selectedModels).length === webSearchCapableModels.length && $config?.features?.enable_web_search && ($_user.role === 'admin' || $_user?.permissions?.features?.web_search)} + {#if showWebSearchButton} {/if} - {#if (atSelectedModel?.id ? [atSelectedModel.id] : selectedModels).length === imageGenerationCapableModels.length && $config?.features?.enable_image_generation && ($_user.role === 'admin' || $_user?.permissions?.features?.image_generation)} + {#if showImageGenerationButton} {/if} - {#if (atSelectedModel?.id ? [atSelectedModel.id] : selectedModels).length === codeInterpreterCapableModels.length && $config?.features?.enable_code_interpreter && ($_user.role === 'admin' || $_user?.permissions?.features?.code_interpreter)} + {#if showCodeInterpreterButton} {/if} - {/if} -
+
+ {/if}
{#if (!history?.currentId || history.messages[history.currentId]?.done == true) && ($_user?.role === 'admin' || ($_user?.permissions?.chat?.stt ?? true))} - + +
+ {#if document.metadata?.parameters} +
+ {$i18n.t('Parameters')} +
+
{JSON.stringify(
+										document.metadata.parameters,
+										null,
+										2
+									)}
+ {/if} {#if showRelevance}
{$i18n.t('Relevance')} diff --git a/src/lib/components/chat/Messages/Markdown/AlertRenderer.svelte b/src/lib/components/chat/Messages/Markdown/AlertRenderer.svelte index caf1410b1..ae00acb60 100644 --- a/src/lib/components/chat/Messages/Markdown/AlertRenderer.svelte +++ b/src/lib/components/chat/Messages/Markdown/AlertRenderer.svelte @@ -24,7 +24,7 @@ TIP: { border: 'border-emerald-500', text: 'text-emerald-500', - icon: LightBlub + icon: LightBulb }, IMPORTANT: { border: 'border-purple-500', @@ -65,7 +65,7 @@ -
{ + models.set( + await getModels( + localStorage.token, + $config?.features?.enable_direct_connections && ($settings?.directConnections ?? null) + ) + ); + }} + type="button" > {#if selectedModel} {selectedModel.label} @@ -335,7 +364,7 @@ {placeholder} {/if} -
+ {item.label}
- - {#if item.model.owned_by === 'ollama' && (item.model.ollama?.details?.parameter_size ?? '') !== ''} -
- - {item.model.ollama?.details?.parameter_size ?? ''} - -
- {/if}
+ {#if item.model.owned_by === 'ollama'} + {#if (item.model.ollama?.details?.parameter_size ?? '') !== ''} +
+ + {item.model.ollama?.details?.parameter_size ?? ''} + +
+ {/if} + {#if item.model.ollama?.expires_at && new Date(item.model.ollama?.expires_at * 1000) > new Date()} +
+ +
+ + + + +
+
+
+ {/if} + {/if} + {#if item.model?.direct} - +
- {:else if item.model.owned_by === 'openai'} - + {:else if item.model.connection_type === 'external'} +
- {#if value === item.value} -
- -
- {/if} +
+ {#if $user?.role === 'admin' && item.model.owned_by === 'ollama' && item.model.ollama?.expires_at && new Date(item.model.ollama?.expires_at * 1000) > new Date()} + + + + {/if} + + {#if value === item.value} +
+ +
+ {/if} +
{:else}
@@ -746,7 +811,7 @@
{#if showTemporaryChatControl} -
+
-
+
{#if models[selectedModelIdx]?.name} {models[selectedModelIdx]?.name} {:else} @@ -221,7 +221,7 @@
-
+
- {$i18n.t('Logit Bias')} + {'logit_bias'}
-
- - - {#if (params?.mirostat ?? null) !== null} -
-
- -
-
- -
-
- {/if} -
-
- {$i18n.t('Mirostat Eta')} + {'max_tokens'}
+
- {#if (params?.mirostat_eta ?? null) !== null} + {#if (params?.max_tokens ?? null) !== null}
-
-
- {/if} -
- -
- -
-
- {$i18n.t('Mirostat Tau')} -
- - -
-
- - {#if (params?.mirostat_tau ?? null) !== null} -
-
- -
-
-
@@ -518,7 +409,7 @@ >
- {$i18n.t('Top K')} + {'top_k'}
+
+ + + {#if (params?.mirostat ?? null) !== null} +
+
+ +
+
+ +
+
+ {/if} +
+ +
+ +
+
+ {'mirostat_eta'} +
+ +
+
+ + {#if (params?.mirostat_eta ?? null) !== null} +
+
+ +
+
+ +
+
+ {/if} +
+ +
+ +
+
+ {'mirostat_tau'} +
+ + +
+
+ + {#if (params?.mirostat_tau ?? null) !== null} +
+
+ +
+
+ +
+
+ {/if} +
+
- {$i18n.t('Repeat Last N')} + {'repeat_last_n'}
+
+
+ + {#if (params?.repeat_penalty ?? null) !== null} +
+
+ +
+
+ +
+
+ {/if} +
+ + {#if admin} +
+ +
+
+ {'use_mmap'} +
+ +
+
+ + {#if (params?.use_mmap ?? null) !== null} +
+
+ {params.use_mmap ? 'Enabled' : 'Disabled'} +
+
+ +
+
+ {/if} +
+ +
+ +
+
+ {'use_mlock'} +
+ + +
+
+ + {#if (params?.use_mlock ?? null) !== null} +
+
+ {params.use_mlock ? 'Enabled' : 'Disabled'} +
+ +
+ +
+
+ {/if} +
+ {/if} +
- {$i18n.t('Tokens To Keep On Context Refresh (num_keep)')} + {'num_keep'} ({$i18n.t('Ollama')})
-
-
- - {#if (params?.max_tokens ?? null) !== null} -
-
- -
-
- -
-
- {/if} -
- -
- -
-
- {$i18n.t('Repeat Penalty (Ollama)')} -
- - -
-
- - {#if (params?.repeat_penalty ?? null) !== null} -
-
- -
-
- -
-
- {/if} -
-
- {$i18n.t('Context Length')} - {$i18n.t('(Ollama)')} + {'num_ctx'} ({$i18n.t('Ollama')})
-
-
- - {#if (params?.use_mmap ?? null) !== null} -
-
- {params.use_mmap ? 'Enabled' : 'Disabled'} -
-
- -
-
- {/if} -
- -
- -
-
- {$i18n.t('use_mlock (Ollama)')} -
- - -
-
- - {#if (params?.use_mlock ?? null) !== null} -
-
- {params.use_mlock ? 'Enabled' : 'Disabled'} -
- -
- -
-
- {/if} -
-
- {$i18n.t('num_thread (Ollama)')} + {'num_thread'} ({$i18n.t('Ollama')})
+ +
+
{$i18n.t('Language')}
+ +
+ + + +
+
{/if}
@@ -270,15 +293,15 @@
{$i18n.t('Speech Playback Speed')}
-
- - {#each speedOptions as option} - - {/each} - + class=" text-sm text-right bg-transparent dark:text-gray-300 outline-hidden" + /> + x
@@ -293,7 +316,7 @@
@@ -330,7 +353,7 @@
diff --git a/src/lib/components/chat/Settings/Chats.svelte b/src/lib/components/chat/Settings/Chats.svelte index 8a0ada02d..7ef0d7a56 100644 --- a/src/lib/components/chat/Settings/Chats.svelte +++ b/src/lib/components/chat/Settings/Chats.svelte @@ -16,7 +16,7 @@ import { onMount, getContext } from 'svelte'; import { goto } from '$app/navigation'; import { toast } from 'svelte-sonner'; - import ArchivedChatsModal from '$lib/components/layout/Sidebar/ArchivedChatsModal.svelte'; + import ArchivedChatsModal from '$lib/components/layout/ArchivedChatsModal.svelte'; const i18n = getContext('i18n'); @@ -105,7 +105,7 @@ }; - +
diff --git a/src/lib/components/chat/Settings/Personalization/AddMemoryModal.svelte b/src/lib/components/chat/Settings/Personalization/AddMemoryModal.svelte index 4854f48cd..a5bdb4bfc 100644 --- a/src/lib/components/chat/Settings/Personalization/AddMemoryModal.svelte +++ b/src/lib/components/chat/Settings/Personalization/AddMemoryModal.svelte @@ -70,8 +70,9 @@