diff --git a/.github/workflows/integration-test.yml b/.github/workflows/integration-test.disabled similarity index 99% rename from .github/workflows/integration-test.yml rename to .github/workflows/integration-test.disabled index cb404f1fc..b248df4b5 100644 --- a/.github/workflows/integration-test.yml +++ b/.github/workflows/integration-test.disabled @@ -52,6 +52,8 @@ jobs: - name: Cypress run uses: cypress-io/github-action@v6 + env: + LIBGL_ALWAYS_SOFTWARE: 1 with: browser: chrome wait-on: 'http://localhost:3000' diff --git a/CHANGELOG.md b/CHANGELOG.md index 80f5f481e..666fdb53d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,7 +5,58 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). -## [0.5.4] - 2024-01-05 +## [0.5.7] - 2025-01-23 + +### Added + +- **🌍 Enhanced Internationalization (i18n)**: Refined and expanded translations for greater global accessibility and a smoother experience for international users. + +### Fixed + +- **🔗 Connection Model ID Resolution**: Resolved an issue preventing model IDs from registering in connections. +- **💡 Prefix ID for Ollama Connections**: Fixed a bug where prefix IDs in Ollama connections were non-functional. +- **🔧 Ollama Model Enable/Disable Functionality**: Addressed the issue of enable/disable toggles not working for Ollama base models. +- **🔒 RBAC Permissions for Tools and Models**: Corrected incorrect Role-Based Access Control (RBAC) permissions for tools and models, ensuring that users now only access features according to their assigned privileges, enhancing security and role clarity. + +## [0.5.6] - 2025-01-22 + +### Added + +- **🧠 Effortful Reasoning Control for OpenAI Models**: Introduced the reasoning_effort parameter in chat controls for supported OpenAI models, enabling users to fine-tune how much cognitive effort a model dedicates to its responses, offering greater customization for complex queries and reasoning tasks. + +### Fixed + +- **🔄 Chat Controls Loading UI Bug**: Resolved an issue where collapsible chat controls appeared as "loading," ensuring a smoother and more intuitive user experience for managing chat settings. + +### Changed + +- **🔧 Updated Ollama Model Creation**: Revamped the Ollama model creation method to align with their new JSON payload format, ensuring seamless compatibility and more efficient model setup workflows. + +## [0.5.5] - 2025-01-22 + +### Added + +- **🤔 Native 'Think' Tag Support**: Introduced the new 'think' tag support that visually displays how long the model is thinking, omitting the reasoning content itself until the next turn. Ideal for creating a more streamlined and focused interaction experience. +- **🖼️ Toggle Image Generation On/Off**: In the chat input menu, you can now easily toggle image generation before initiating chats, providing greater control and flexibility to suit your needs. +- **🔒 Chat Controls Permissions**: Admins can now disable chat controls access for users, offering tighter management and customization over user interactions. +- **🔍 Web Search & Image Generation Permissions**: Easily disable web search and image generation for specific users, improving workflow governance and security for certain environments. +- **🗂️ S3 and GCS Storage Provider Support**: Scaled deployments now benefit from expanded storage options with Amazon S3 and Google Cloud Storage seamlessly integrated as providers. +- **🎨 Enhanced Model Management**: Reintroduced the ability to download and delete models directly in the admin models settings page to minimize user confusion and aid efficient model management. +- **🔗 Improved Connection Handling**: Enhanced backend to smoothly handle multiple identical base URLs, allowing more flexible multi-instance configurations with fewer hiccups. +- **✨ General UI/UX Refinements**: Numerous tweaks across the WebUI make navigation and usability even more user-friendly and intuitive. +- **🌍 Translation Enhancements**: Various translation updates ensure smoother and more polished interactions for international users. + +### Fixed + +- **⚡ MPS Functionality for Mac Users**: Fixed MPS support, ensuring smooth performance and compatibility for Mac users leveraging MPS. +- **📡 Ollama Connection Management**: Resolved the issue where deleting all Ollama connections prevented adding new ones. + +### Changed + +- **⚙️ General Stability Refac**: Backend refactoring delivers a more stable, robust platform. +- **🖥️ Desktop App Preparations**: Ongoing work to support the upcoming Open WebUI desktop app. Follow our progress and updates here: https://github.com/open-webui/desktop + +## [0.5.4] - 2025-01-05 ### Added diff --git a/backend/open_webui/__init__.py b/backend/open_webui/__init__.py index de34a8bc7..d85be48da 100644 --- a/backend/open_webui/__init__.py +++ b/backend/open_webui/__init__.py @@ -5,12 +5,31 @@ from pathlib import Path import typer import uvicorn +from typing import Optional +from typing_extensions import Annotated app = typer.Typer() KEY_FILE = Path.cwd() / ".webui_secret_key" +def version_callback(value: bool): + if value: + from open_webui.env import VERSION + + typer.echo(f"Open WebUI version: {VERSION}") + raise typer.Exit() + + +@app.command() +def main( + version: Annotated[ + Optional[bool], typer.Option("--version", callback=version_callback) + ] = None, +): + pass + + @app.command() def serve( host: str = "0.0.0.0", diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index f1b1c14a5..d226b9b47 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -9,22 +9,22 @@ from urllib.parse import urlparse import chromadb import requests -import yaml -from open_webui.internal.db import Base, get_db +from pydantic import BaseModel +from sqlalchemy import JSON, Column, DateTime, Integer, func + from open_webui.env import ( - OPEN_WEBUI_DIR, DATA_DIR, + DATABASE_URL, ENV, FRONTEND_BUILD_DIR, + OFFLINE_MODE, + OPEN_WEBUI_DIR, WEBUI_AUTH, WEBUI_FAVICON_URL, WEBUI_NAME, log, - DATABASE_URL, - OFFLINE_MODE, ) -from pydantic import BaseModel -from sqlalchemy import JSON, Column, DateTime, Integer, func +from open_webui.internal.db import Base, get_db class EndpointFilter(logging.Filter): @@ -362,6 +362,30 @@ MICROSOFT_REDIRECT_URI = PersistentConfig( os.environ.get("MICROSOFT_REDIRECT_URI", ""), ) +GITHUB_CLIENT_ID = PersistentConfig( + "GITHUB_CLIENT_ID", + "oauth.github.client_id", + os.environ.get("GITHUB_CLIENT_ID", ""), +) + +GITHUB_CLIENT_SECRET = PersistentConfig( + "GITHUB_CLIENT_SECRET", + "oauth.github.client_secret", + os.environ.get("GITHUB_CLIENT_SECRET", ""), +) + +GITHUB_CLIENT_SCOPE = PersistentConfig( + "GITHUB_CLIENT_SCOPE", + "oauth.github.scope", + os.environ.get("GITHUB_CLIENT_SCOPE", "user:email"), +) + +GITHUB_CLIENT_REDIRECT_URI = PersistentConfig( + "GITHUB_CLIENT_REDIRECT_URI", + "oauth.github.redirect_uri", + os.environ.get("GITHUB_CLIENT_REDIRECT_URI", ""), +) + OAUTH_CLIENT_ID = PersistentConfig( "OAUTH_CLIENT_ID", "oauth.oidc.client_id", @@ -468,12 +492,20 @@ OAUTH_ALLOWED_DOMAINS = PersistentConfig( def load_oauth_providers(): OAUTH_PROVIDERS.clear() if GOOGLE_CLIENT_ID.value and GOOGLE_CLIENT_SECRET.value: + + def google_oauth_register(client): + client.register( + name="google", + client_id=GOOGLE_CLIENT_ID.value, + client_secret=GOOGLE_CLIENT_SECRET.value, + server_metadata_url="https://accounts.google.com/.well-known/openid-configuration", + client_kwargs={"scope": GOOGLE_OAUTH_SCOPE.value}, + redirect_uri=GOOGLE_REDIRECT_URI.value, + ) + OAUTH_PROVIDERS["google"] = { - "client_id": GOOGLE_CLIENT_ID.value, - "client_secret": GOOGLE_CLIENT_SECRET.value, - "server_metadata_url": "https://accounts.google.com/.well-known/openid-configuration", - "scope": GOOGLE_OAUTH_SCOPE.value, "redirect_uri": GOOGLE_REDIRECT_URI.value, + "register": google_oauth_register, } if ( @@ -481,12 +513,44 @@ def load_oauth_providers(): and MICROSOFT_CLIENT_SECRET.value and MICROSOFT_CLIENT_TENANT_ID.value ): + + def microsoft_oauth_register(client): + client.register( + name="microsoft", + client_id=MICROSOFT_CLIENT_ID.value, + client_secret=MICROSOFT_CLIENT_SECRET.value, + server_metadata_url=f"https://login.microsoftonline.com/{MICROSOFT_CLIENT_TENANT_ID.value}/v2.0/.well-known/openid-configuration", + client_kwargs={ + "scope": MICROSOFT_OAUTH_SCOPE.value, + }, + redirect_uri=MICROSOFT_REDIRECT_URI.value, + ) + OAUTH_PROVIDERS["microsoft"] = { - "client_id": MICROSOFT_CLIENT_ID.value, - "client_secret": MICROSOFT_CLIENT_SECRET.value, - "server_metadata_url": f"https://login.microsoftonline.com/{MICROSOFT_CLIENT_TENANT_ID.value}/v2.0/.well-known/openid-configuration", - "scope": MICROSOFT_OAUTH_SCOPE.value, "redirect_uri": MICROSOFT_REDIRECT_URI.value, + "picture_url": "https://graph.microsoft.com/v1.0/me/photo/$value", + "register": microsoft_oauth_register, + } + + if GITHUB_CLIENT_ID.value and GITHUB_CLIENT_SECRET.value: + + def github_oauth_register(client): + client.register( + name="github", + client_id=GITHUB_CLIENT_ID.value, + client_secret=GITHUB_CLIENT_SECRET.value, + access_token_url="https://github.com/login/oauth/access_token", + authorize_url="https://github.com/login/oauth/authorize", + api_base_url="https://api.github.com", + userinfo_endpoint="https://api.github.com/user", + client_kwargs={"scope": GITHUB_CLIENT_SCOPE.value}, + redirect_uri=GITHUB_CLIENT_REDIRECT_URI.value, + ) + + OAUTH_PROVIDERS["github"] = { + "redirect_uri": GITHUB_CLIENT_REDIRECT_URI.value, + "register": github_oauth_register, + "sub_claim": "id", } if ( @@ -494,13 +558,23 @@ def load_oauth_providers(): and OAUTH_CLIENT_SECRET.value and OPENID_PROVIDER_URL.value ): + + def oidc_oauth_register(client): + client.register( + name="oidc", + client_id=OAUTH_CLIENT_ID.value, + client_secret=OAUTH_CLIENT_SECRET.value, + server_metadata_url=OPENID_PROVIDER_URL.value, + client_kwargs={ + "scope": OAUTH_SCOPES.value, + }, + redirect_uri=OPENID_REDIRECT_URI.value, + ) + OAUTH_PROVIDERS["oidc"] = { - "client_id": OAUTH_CLIENT_ID.value, - "client_secret": OAUTH_CLIENT_SECRET.value, - "server_metadata_url": OPENID_PROVIDER_URL.value, - "scope": OAUTH_SCOPES.value, "name": OAUTH_PROVIDER_NAME.value, "redirect_uri": OPENID_REDIRECT_URI.value, + "register": oidc_oauth_register, } @@ -580,7 +654,7 @@ if CUSTOM_NAME: # STORAGE PROVIDER #################################### -STORAGE_PROVIDER = os.environ.get("STORAGE_PROVIDER", "") # defaults to local, s3 +STORAGE_PROVIDER = os.environ.get("STORAGE_PROVIDER", "local") # defaults to local, s3 S3_ACCESS_KEY_ID = os.environ.get("S3_ACCESS_KEY_ID", None) S3_SECRET_ACCESS_KEY = os.environ.get("S3_SECRET_ACCESS_KEY", None) @@ -588,6 +662,11 @@ S3_REGION_NAME = os.environ.get("S3_REGION_NAME", None) S3_BUCKET_NAME = os.environ.get("S3_BUCKET_NAME", None) S3_ENDPOINT_URL = os.environ.get("S3_ENDPOINT_URL", None) +GCS_BUCKET_NAME = os.environ.get("GCS_BUCKET_NAME", None) +GOOGLE_APPLICATION_CREDENTIALS_JSON = os.environ.get( + "GOOGLE_APPLICATION_CREDENTIALS_JSON", None +) + #################################### # File Upload DIR #################################### @@ -819,6 +898,10 @@ USER_PERMISSIONS_WORKSPACE_TOOLS_ACCESS = ( os.environ.get("USER_PERMISSIONS_WORKSPACE_TOOLS_ACCESS", "False").lower() == "true" ) +USER_PERMISSIONS_CHAT_CONTROLS = ( + os.environ.get("USER_PERMISSIONS_CHAT_CONTROLS", "True").lower() == "true" +) + USER_PERMISSIONS_CHAT_FILE_UPLOAD = ( os.environ.get("USER_PERMISSIONS_CHAT_FILE_UPLOAD", "True").lower() == "true" ) @@ -835,23 +918,39 @@ USER_PERMISSIONS_CHAT_TEMPORARY = ( os.environ.get("USER_PERMISSIONS_CHAT_TEMPORARY", "True").lower() == "true" ) +USER_PERMISSIONS_FEATURES_WEB_SEARCH = ( + os.environ.get("USER_PERMISSIONS_FEATURES_WEB_SEARCH", "True").lower() == "true" +) + +USER_PERMISSIONS_FEATURES_IMAGE_GENERATION = ( + os.environ.get("USER_PERMISSIONS_FEATURES_IMAGE_GENERATION", "True").lower() + == "true" +) + +DEFAULT_USER_PERMISSIONS = { + "workspace": { + "models": USER_PERMISSIONS_WORKSPACE_MODELS_ACCESS, + "knowledge": USER_PERMISSIONS_WORKSPACE_KNOWLEDGE_ACCESS, + "prompts": USER_PERMISSIONS_WORKSPACE_PROMPTS_ACCESS, + "tools": USER_PERMISSIONS_WORKSPACE_TOOLS_ACCESS, + }, + "chat": { + "controls": USER_PERMISSIONS_CHAT_CONTROLS, + "file_upload": USER_PERMISSIONS_CHAT_FILE_UPLOAD, + "delete": USER_PERMISSIONS_CHAT_DELETE, + "edit": USER_PERMISSIONS_CHAT_EDIT, + "temporary": USER_PERMISSIONS_CHAT_TEMPORARY, + }, + "features": { + "web_search": USER_PERMISSIONS_FEATURES_WEB_SEARCH, + "image_generation": USER_PERMISSIONS_FEATURES_IMAGE_GENERATION, + }, +} + USER_PERMISSIONS = PersistentConfig( "USER_PERMISSIONS", "user.permissions", - { - "workspace": { - "models": USER_PERMISSIONS_WORKSPACE_MODELS_ACCESS, - "knowledge": USER_PERMISSIONS_WORKSPACE_KNOWLEDGE_ACCESS, - "prompts": USER_PERMISSIONS_WORKSPACE_PROMPTS_ACCESS, - "tools": USER_PERMISSIONS_WORKSPACE_TOOLS_ACCESS, - }, - "chat": { - "file_upload": USER_PERMISSIONS_CHAT_FILE_UPLOAD, - "delete": USER_PERMISSIONS_CHAT_DELETE, - "edit": USER_PERMISSIONS_CHAT_EDIT, - "temporary": USER_PERMISSIONS_CHAT_TEMPORARY, - }, - }, + DEFAULT_USER_PERMISSIONS, ) ENABLE_CHANNELS = PersistentConfig( @@ -1034,6 +1133,32 @@ JSON format: { "tags": ["tag1", "tag2", "tag3"] } {{MESSAGES:END:6}} """ +IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE = PersistentConfig( + "IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE", + "task.image.prompt_template", + os.environ.get("IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE", ""), +) + +DEFAULT_IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE = """### Task: +Generate a detailed prompt for am image generation task based on the given language and context. Describe the image as if you were explaining it to someone who cannot see it. Include relevant details, colors, shapes, and any other important elements. + +### Guidelines: +- Be descriptive and detailed, focusing on the most important aspects of the image. +- Avoid making assumptions or adding information not present in the image. +- Use the chat's primary language; default to English if multilingual. +- If the image is too complex, focus on the most prominent elements. + +### Output: +Strictly return in JSON format: +{ + "prompt": "Your detailed description here." +} + +### Chat History: + +{{MESSAGES:END:6}} +""" + ENABLE_TAGS_GENERATION = PersistentConfig( "ENABLE_TAGS_GENERATION", "task.tags.enable", @@ -1193,6 +1318,7 @@ CHROMA_HTTP_SSL = os.environ.get("CHROMA_HTTP_SSL", "false").lower() == "true" # Milvus MILVUS_URI = os.environ.get("MILVUS_URI", f"{DATA_DIR}/vector_db/milvus.db") +MILVUS_DB = os.environ.get("MILVUS_DB", "default") # Qdrant QDRANT_URI = os.environ.get("QDRANT_URI", None) @@ -1614,6 +1740,13 @@ ENABLE_IMAGE_GENERATION = PersistentConfig( "image_generation.enable", os.environ.get("ENABLE_IMAGE_GENERATION", "").lower() == "true", ) + +ENABLE_IMAGE_PROMPT_GENERATION = PersistentConfig( + "ENABLE_IMAGE_PROMPT_GENERATION", + "image_generation.prompt.enable", + os.environ.get("ENABLE_IMAGE_PROMPT_GENERATION", "true").lower() == "true", +) + AUTOMATIC1111_BASE_URL = PersistentConfig( "AUTOMATIC1111_BASE_URL", "image_generation.automatic1111.base_url", @@ -1943,6 +2076,12 @@ LDAP_SERVER_PORT = PersistentConfig( int(os.environ.get("LDAP_SERVER_PORT", "389")), ) +LDAP_ATTRIBUTE_FOR_MAIL = PersistentConfig( + "LDAP_ATTRIBUTE_FOR_MAIL", + "ldap.server.attribute_for_mail", + os.environ.get("LDAP_ATTRIBUTE_FOR_MAIL", "mail"), +) + LDAP_ATTRIBUTE_FOR_USERNAME = PersistentConfig( "LDAP_ATTRIBUTE_FOR_USERNAME", "ldap.server.attribute_for_username", diff --git a/backend/open_webui/constants.py b/backend/open_webui/constants.py index c5fdfabfb..cb65e0d77 100644 --- a/backend/open_webui/constants.py +++ b/backend/open_webui/constants.py @@ -113,6 +113,7 @@ class TASKS(str, Enum): TAGS_GENERATION = "tags_generation" EMOJI_GENERATION = "emoji_generation" QUERY_GENERATION = "query_generation" + IMAGE_PROMPT_GENERATION = "image_prompt_generation" AUTOCOMPLETE_GENERATION = "autocomplete_generation" FUNCTION_CALLING = "function_calling" MOA_RESPONSE_GENERATION = "moa_response_generation" diff --git a/backend/open_webui/env.py b/backend/open_webui/env.py index f16f2ea6e..77e632ccc 100644 --- a/backend/open_webui/env.py +++ b/backend/open_webui/env.py @@ -54,6 +54,8 @@ else: DEVICE_TYPE = "cpu" try: + import torch + if torch.backends.mps.is_available() and torch.backends.mps.is_built(): DEVICE_TYPE = "mps" except Exception: @@ -272,6 +274,8 @@ DATABASE_URL = os.environ.get("DATABASE_URL", f"sqlite:///{DATA_DIR}/webui.db") if "postgres://" in DATABASE_URL: DATABASE_URL = DATABASE_URL.replace("postgres://", "postgresql://") +DATABASE_SCHEMA = os.environ.get("DATABASE_SCHEMA", None) + DATABASE_POOL_SIZE = os.environ.get("DATABASE_POOL_SIZE", 0) if DATABASE_POOL_SIZE == "": diff --git a/backend/open_webui/internal/db.py b/backend/open_webui/internal/db.py index 5f19e695d..840f571cc 100644 --- a/backend/open_webui/internal/db.py +++ b/backend/open_webui/internal/db.py @@ -7,6 +7,7 @@ from open_webui.internal.wrappers import register_connection from open_webui.env import ( OPEN_WEBUI_DIR, DATABASE_URL, + DATABASE_SCHEMA, SRC_LOG_LEVELS, DATABASE_POOL_MAX_OVERFLOW, DATABASE_POOL_RECYCLE, @@ -14,7 +15,7 @@ from open_webui.env import ( DATABASE_POOL_TIMEOUT, ) from peewee_migrate import Router -from sqlalchemy import Dialect, create_engine, types +from sqlalchemy import Dialect, create_engine, MetaData, types from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import scoped_session, sessionmaker from sqlalchemy.pool import QueuePool, NullPool @@ -99,7 +100,8 @@ else: SessionLocal = sessionmaker( autocommit=False, autoflush=False, bind=engine, expire_on_commit=False ) -Base = declarative_base() +metadata_obj = MetaData(schema=DATABASE_SCHEMA) +Base = declarative_base(metadata=metadata_obj) Session = scoped_session(SessionLocal) diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 52337a640..00270aabc 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -108,6 +108,7 @@ from open_webui.config import ( COMFYUI_WORKFLOW, COMFYUI_WORKFLOW_NODES, ENABLE_IMAGE_GENERATION, + ENABLE_IMAGE_PROMPT_GENERATION, IMAGE_GENERATION_ENGINE, IMAGE_GENERATION_MODEL, IMAGE_SIZE, @@ -225,6 +226,7 @@ from open_webui.config import ( LDAP_SERVER_LABEL, LDAP_SERVER_HOST, LDAP_SERVER_PORT, + LDAP_ATTRIBUTE_FOR_MAIL, LDAP_ATTRIBUTE_FOR_USERNAME, LDAP_SEARCH_FILTERS, LDAP_SEARCH_BASE, @@ -254,6 +256,7 @@ from open_webui.config import ( ENABLE_AUTOCOMPLETE_GENERATION, TITLE_GENERATION_PROMPT_TEMPLATE, TAGS_GENERATION_PROMPT_TEMPLATE, + IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE, TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, QUERY_GENERATION_PROMPT_TEMPLATE, AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE, @@ -437,6 +440,7 @@ app.state.config.ENABLE_LDAP = ENABLE_LDAP app.state.config.LDAP_SERVER_LABEL = LDAP_SERVER_LABEL app.state.config.LDAP_SERVER_HOST = LDAP_SERVER_HOST app.state.config.LDAP_SERVER_PORT = LDAP_SERVER_PORT +app.state.config.LDAP_ATTRIBUTE_FOR_MAIL = LDAP_ATTRIBUTE_FOR_MAIL app.state.config.LDAP_ATTRIBUTE_FOR_USERNAME = LDAP_ATTRIBUTE_FOR_USERNAME app.state.config.LDAP_APP_DN = LDAP_APP_DN app.state.config.LDAP_APP_PASSWORD = LDAP_APP_PASSWORD @@ -572,6 +576,7 @@ app.state.EMBEDDING_FUNCTION = get_embedding_function( app.state.config.IMAGE_GENERATION_ENGINE = IMAGE_GENERATION_ENGINE app.state.config.ENABLE_IMAGE_GENERATION = ENABLE_IMAGE_GENERATION +app.state.config.ENABLE_IMAGE_PROMPT_GENERATION = ENABLE_IMAGE_PROMPT_GENERATION app.state.config.IMAGES_OPENAI_API_BASE_URL = IMAGES_OPENAI_API_BASE_URL app.state.config.IMAGES_OPENAI_API_KEY = IMAGES_OPENAI_API_KEY @@ -642,6 +647,10 @@ app.state.config.ENABLE_TAGS_GENERATION = ENABLE_TAGS_GENERATION app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMPLATE app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = TAGS_GENERATION_PROMPT_TEMPLATE +app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE = ( + IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE +) + app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = ( TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE ) diff --git a/backend/open_webui/models/chats.py b/backend/open_webui/models/chats.py index 8e721da78..73ff6c102 100644 --- a/backend/open_webui/models/chats.py +++ b/backend/open_webui/models/chats.py @@ -393,7 +393,7 @@ class ChatTable: limit: int = 50, ) -> list[ChatModel]: with get_db() as db: - query = db.query(Chat).filter_by(user_id=user_id).filter_by(folder_id=None) + query = db.query(Chat).filter_by(user_id=user_id) if not include_archived: query = query.filter_by(archived=False) diff --git a/backend/open_webui/models/groups.py b/backend/open_webui/models/groups.py index 94f4cfae8..763340fbc 100644 --- a/backend/open_webui/models/groups.py +++ b/backend/open_webui/models/groups.py @@ -80,12 +80,11 @@ class GroupResponse(BaseModel): class GroupForm(BaseModel): name: str description: str + permissions: Optional[dict] = None class GroupUpdateForm(GroupForm): - permissions: Optional[dict] = None user_ids: Optional[list[str]] = None - admin_ids: Optional[list[str]] = None class GroupTable: @@ -95,7 +94,7 @@ class GroupTable: with get_db() as db: group = GroupModel( **{ - **form_data.model_dump(), + **form_data.model_dump(exclude_none=True), "id": str(uuid.uuid4()), "user_id": user_id, "created_at": int(time.time()), @@ -189,5 +188,24 @@ class GroupTable: except Exception: return False + def remove_user_from_all_groups(self, user_id: str) -> bool: + with get_db() as db: + try: + groups = self.get_groups_by_member_id(user_id) + + for group in groups: + group.user_ids.remove(user_id) + db.query(Group).filter_by(id=group.id).update( + { + "user_ids": group.user_ids, + "updated_at": int(time.time()), + } + ) + db.commit() + + return True + except Exception: + return False + Groups = GroupTable() diff --git a/backend/open_webui/models/users.py b/backend/open_webui/models/users.py index 9ba127605..5c196281f 100644 --- a/backend/open_webui/models/users.py +++ b/backend/open_webui/models/users.py @@ -2,7 +2,12 @@ import time from typing import Optional from open_webui.internal.db import Base, JSONField, get_db + + from open_webui.models.chats import Chats +from open_webui.models.groups import Groups + + from pydantic import BaseModel, ConfigDict from sqlalchemy import BigInteger, Column, String, Text @@ -268,9 +273,11 @@ class UsersTable: def delete_user_by_id(self, id: str) -> bool: try: + # Remove User from Groups + Groups.remove_user_from_all_groups(id) + # Delete User Chats result = Chats.delete_chats_by_user_id(id) - if result: with get_db() as db: # Delete User @@ -300,5 +307,10 @@ class UsersTable: except Exception: return None + def get_valid_user_ids(self, user_ids: list[str]) -> list[str]: + with get_db() as db: + users = db.query(User).filter(User.id.in_(user_ids)).all() + return [user.id for user in users] + Users = UsersTable() diff --git a/backend/open_webui/retrieval/utils.py b/backend/open_webui/retrieval/utils.py index 544a65a89..08ab75786 100644 --- a/backend/open_webui/retrieval/utils.py +++ b/backend/open_webui/retrieval/utils.py @@ -11,6 +11,8 @@ from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriev from langchain_community.retrievers import BM25Retriever from langchain_core.documents import Document + +from open_webui.config import VECTOR_DB from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT from open_webui.utils.misc import get_last_user_message @@ -203,7 +205,12 @@ def query_collection( else: pass - return merge_and_sort_query_results(results, k=k) + if VECTOR_DB == "chroma": + # Chroma uses unconventional cosine similarity, so we don't need to reverse the results + # https://docs.trychroma.com/docs/collections/configure#configuring-chroma-collections + return merge_and_sort_query_results(results, k=k, reverse=False) + else: + return merge_and_sort_query_results(results, k=k, reverse=True) def query_collection_with_hybrid_search( @@ -239,7 +246,12 @@ def query_collection_with_hybrid_search( "Hybrid search failed for all collections. Using Non hybrid search as fallback." ) - return merge_and_sort_query_results(results, k=k, reverse=True) + if VECTOR_DB == "chroma": + # Chroma uses unconventional cosine similarity, so we don't need to reverse the results + # https://docs.trychroma.com/docs/collections/configure#configuring-chroma-collections + return merge_and_sort_query_results(results, k=k, reverse=False) + else: + return merge_and_sort_query_results(results, k=k, reverse=True) def get_embedding_function( diff --git a/backend/open_webui/retrieval/vector/dbs/chroma.py b/backend/open_webui/retrieval/vector/dbs/chroma.py index 00d73a889..c40618fcc 100644 --- a/backend/open_webui/retrieval/vector/dbs/chroma.py +++ b/backend/open_webui/retrieval/vector/dbs/chroma.py @@ -51,8 +51,8 @@ class ChromaClient: def has_collection(self, collection_name: str) -> bool: # Check if the collection exists based on the collection name. - collections = self.client.list_collections() - return collection_name in [collection.name for collection in collections] + collection_names = self.client.list_collections() + return collection_name in collection_names def delete_collection(self, collection_name: str): # Delete the collection based on the collection name. diff --git a/backend/open_webui/retrieval/vector/dbs/milvus.py b/backend/open_webui/retrieval/vector/dbs/milvus.py index 31d890664..bdfa16eb6 100644 --- a/backend/open_webui/retrieval/vector/dbs/milvus.py +++ b/backend/open_webui/retrieval/vector/dbs/milvus.py @@ -7,13 +7,14 @@ from typing import Optional from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult from open_webui.config import ( MILVUS_URI, + MILVUS_DB, ) class MilvusClient: def __init__(self): self.collection_prefix = "open_webui" - self.client = Client(uri=MILVUS_URI) + self.client = Client(uri=MILVUS_URI, database=MILVUS_DB) def _result_to_get_result(self, result) -> GetResult: ids = [] diff --git a/backend/open_webui/retrieval/vector/dbs/pgvector.py b/backend/open_webui/retrieval/vector/dbs/pgvector.py index 64b6fd6c7..341b3056f 100644 --- a/backend/open_webui/retrieval/vector/dbs/pgvector.py +++ b/backend/open_webui/retrieval/vector/dbs/pgvector.py @@ -9,6 +9,7 @@ from sqlalchemy import ( select, text, Text, + Table, values, ) from sqlalchemy.sql import true @@ -18,6 +19,7 @@ from sqlalchemy.orm import declarative_base, scoped_session, sessionmaker from sqlalchemy.dialects.postgresql import JSONB, array from pgvector.sqlalchemy import Vector from sqlalchemy.ext.mutable import MutableDict +from sqlalchemy.exc import NoSuchTableError from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult from open_webui.config import PGVECTOR_DB_URL, PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH @@ -92,31 +94,34 @@ class PgvectorClient: Raises an exception if there is a mismatch. """ metadata = MetaData() - metadata.reflect(bind=self.session.bind, only=["document_chunk"]) + try: + # Attempt to reflect the 'document_chunk' table + document_chunk_table = Table( + "document_chunk", metadata, autoload_with=self.session.bind + ) + except NoSuchTableError: + # Table does not exist; no action needed + return - if "document_chunk" in metadata.tables: - document_chunk_table = metadata.tables["document_chunk"] - if "vector" in document_chunk_table.columns: - vector_column = document_chunk_table.columns["vector"] - vector_type = vector_column.type - if isinstance(vector_type, Vector): - db_vector_length = vector_type.dim - if db_vector_length != VECTOR_LENGTH: - raise Exception( - f"VECTOR_LENGTH {VECTOR_LENGTH} does not match existing vector column dimension {db_vector_length}. " - "Cannot change vector size after initialization without migrating the data." - ) - else: + # Proceed to check the vector column + if "vector" in document_chunk_table.columns: + vector_column = document_chunk_table.columns["vector"] + vector_type = vector_column.type + if isinstance(vector_type, Vector): + db_vector_length = vector_type.dim + if db_vector_length != VECTOR_LENGTH: raise Exception( - "The 'vector' column exists but is not of type 'Vector'." + f"VECTOR_LENGTH {VECTOR_LENGTH} does not match existing vector column dimension {db_vector_length}. " + "Cannot change vector size after initialization without migrating the data." ) else: raise Exception( - "The 'vector' column does not exist in the 'document_chunk' table." + "The 'vector' column exists but is not of type 'Vector'." ) else: - # Table does not exist yet; no action needed - pass + raise Exception( + "The 'vector' column does not exist in the 'document_chunk' table." + ) def adjust_vector_length(self, vector: List[float]) -> List[float]: # Adjust vector to have length VECTOR_LENGTH diff --git a/backend/open_webui/retrieval/web/bing.py b/backend/open_webui/retrieval/web/bing.py index 09beb3460..0a3ba4621 100644 --- a/backend/open_webui/retrieval/web/bing.py +++ b/backend/open_webui/retrieval/web/bing.py @@ -23,7 +23,7 @@ def search_bing( filter_list: Optional[list[str]] = None, ) -> list[SearchResult]: mkt = locale - params = {"q": query, "mkt": mkt, "answerCount": count} + params = {"q": query, "mkt": mkt, "count": count} headers = {"Ocp-Apim-Subscription-Key": subscription_key} try: diff --git a/backend/open_webui/routers/auths.py b/backend/open_webui/routers/auths.py index 1a89058fa..47baeb0ac 100644 --- a/backend/open_webui/routers/auths.py +++ b/backend/open_webui/routers/auths.py @@ -51,7 +51,7 @@ from open_webui.utils.access_control import get_permissions from typing import Optional, List from ssl import CERT_REQUIRED, PROTOCOL_TLS -from ldap3 import Server, Connection, ALL, Tls +from ldap3 import Server, Connection, NONE, Tls from ldap3.utils.conv import escape_filter_chars router = APIRouter() @@ -170,6 +170,7 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm): LDAP_SERVER_LABEL = request.app.state.config.LDAP_SERVER_LABEL LDAP_SERVER_HOST = request.app.state.config.LDAP_SERVER_HOST LDAP_SERVER_PORT = request.app.state.config.LDAP_SERVER_PORT + LDAP_ATTRIBUTE_FOR_MAIL = request.app.state.config.LDAP_ATTRIBUTE_FOR_MAIL LDAP_ATTRIBUTE_FOR_USERNAME = request.app.state.config.LDAP_ATTRIBUTE_FOR_USERNAME LDAP_SEARCH_BASE = request.app.state.config.LDAP_SEARCH_BASE LDAP_SEARCH_FILTERS = request.app.state.config.LDAP_SEARCH_FILTERS @@ -201,7 +202,7 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm): server = Server( host=LDAP_SERVER_HOST, port=LDAP_SERVER_PORT, - get_info=ALL, + get_info=NONE, use_ssl=LDAP_USE_TLS, tls=tls, ) @@ -218,7 +219,11 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm): search_success = connection_app.search( search_base=LDAP_SEARCH_BASE, search_filter=f"(&({LDAP_ATTRIBUTE_FOR_USERNAME}={escape_filter_chars(form_data.user.lower())}){LDAP_SEARCH_FILTERS})", - attributes=[f"{LDAP_ATTRIBUTE_FOR_USERNAME}", "mail", "cn"], + attributes=[ + f"{LDAP_ATTRIBUTE_FOR_USERNAME}", + f"{LDAP_ATTRIBUTE_FOR_MAIL}", + "cn", + ], ) if not search_success: @@ -226,7 +231,9 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm): entry = connection_app.entries[0] username = str(entry[f"{LDAP_ATTRIBUTE_FOR_USERNAME}"]).lower() - mail = str(entry["mail"]) + mail = str(entry[f"{LDAP_ATTRIBUTE_FOR_MAIL}"]) + if not mail or mail == "" or mail == "[]": + raise HTTPException(400, f"User {form_data.user} does not have mail.") cn = str(entry["cn"]) user_dn = entry.entry_dn @@ -691,6 +698,7 @@ class LdapServerConfig(BaseModel): label: str host: str port: Optional[int] = None + attribute_for_mail: str = "mail" attribute_for_username: str = "uid" app_dn: str app_dn_password: str @@ -707,6 +715,7 @@ async def get_ldap_server(request: Request, user=Depends(get_admin_user)): "label": request.app.state.config.LDAP_SERVER_LABEL, "host": request.app.state.config.LDAP_SERVER_HOST, "port": request.app.state.config.LDAP_SERVER_PORT, + "attribute_for_mail": request.app.state.config.LDAP_ATTRIBUTE_FOR_MAIL, "attribute_for_username": request.app.state.config.LDAP_ATTRIBUTE_FOR_USERNAME, "app_dn": request.app.state.config.LDAP_APP_DN, "app_dn_password": request.app.state.config.LDAP_APP_PASSWORD, @@ -725,6 +734,7 @@ async def update_ldap_server( required_fields = [ "label", "host", + "attribute_for_mail", "attribute_for_username", "app_dn", "app_dn_password", @@ -743,6 +753,7 @@ async def update_ldap_server( request.app.state.config.LDAP_SERVER_LABEL = form_data.label request.app.state.config.LDAP_SERVER_HOST = form_data.host request.app.state.config.LDAP_SERVER_PORT = form_data.port + request.app.state.config.LDAP_ATTRIBUTE_FOR_MAIL = form_data.attribute_for_mail request.app.state.config.LDAP_ATTRIBUTE_FOR_USERNAME = ( form_data.attribute_for_username ) @@ -758,6 +769,7 @@ async def update_ldap_server( "label": request.app.state.config.LDAP_SERVER_LABEL, "host": request.app.state.config.LDAP_SERVER_HOST, "port": request.app.state.config.LDAP_SERVER_PORT, + "attribute_for_mail": request.app.state.config.LDAP_ATTRIBUTE_FOR_MAIL, "attribute_for_username": request.app.state.config.LDAP_ATTRIBUTE_FOR_USERNAME, "app_dn": request.app.state.config.LDAP_APP_DN, "app_dn_password": request.app.state.config.LDAP_APP_PASSWORD, diff --git a/backend/open_webui/routers/files.py b/backend/open_webui/routers/files.py index 9e36d98b7..b648fccc2 100644 --- a/backend/open_webui/routers/files.py +++ b/backend/open_webui/routers/files.py @@ -345,6 +345,8 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)): async def delete_file_by_id(id: str, user=Depends(get_verified_user)): file = Files.get_file_by_id(id) if file and (file.user_id == user.id or user.role == "admin"): + # We should add Chroma cleanup here + result = Files.delete_file_by_id(id) if result: try: diff --git a/backend/open_webui/routers/groups.py b/backend/open_webui/routers/groups.py index e8f8994a4..5b5130f71 100644 --- a/backend/open_webui/routers/groups.py +++ b/backend/open_webui/routers/groups.py @@ -2,6 +2,8 @@ import os from pathlib import Path from typing import Optional + +from open_webui.models.users import Users from open_webui.models.groups import ( Groups, GroupForm, @@ -80,6 +82,9 @@ async def update_group_by_id( id: str, form_data: GroupUpdateForm, user=Depends(get_admin_user) ): try: + if form_data.user_ids: + form_data.user_ids = Users.get_valid_user_ids(form_data.user_ids) + group = Groups.update_group_by_id(id, form_data) if group: return group diff --git a/backend/open_webui/routers/images.py b/backend/open_webui/routers/images.py index c40d12522..7afd9d106 100644 --- a/backend/open_webui/routers/images.py +++ b/backend/open_webui/routers/images.py @@ -43,6 +43,7 @@ async def get_config(request: Request, user=Depends(get_admin_user)): return { "enabled": request.app.state.config.ENABLE_IMAGE_GENERATION, "engine": request.app.state.config.IMAGE_GENERATION_ENGINE, + "prompt_generation": request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION, "openai": { "OPENAI_API_BASE_URL": request.app.state.config.IMAGES_OPENAI_API_BASE_URL, "OPENAI_API_KEY": request.app.state.config.IMAGES_OPENAI_API_KEY, @@ -86,6 +87,7 @@ class ComfyUIConfigForm(BaseModel): class ConfigForm(BaseModel): enabled: bool engine: str + prompt_generation: bool openai: OpenAIConfigForm automatic1111: Automatic1111ConfigForm comfyui: ComfyUIConfigForm @@ -98,6 +100,10 @@ async def update_config( request.app.state.config.IMAGE_GENERATION_ENGINE = form_data.engine request.app.state.config.ENABLE_IMAGE_GENERATION = form_data.enabled + request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION = ( + form_data.prompt_generation + ) + request.app.state.config.IMAGES_OPENAI_API_BASE_URL = ( form_data.openai.OPENAI_API_BASE_URL ) @@ -137,6 +143,7 @@ async def update_config( return { "enabled": request.app.state.config.ENABLE_IMAGE_GENERATION, "engine": request.app.state.config.IMAGE_GENERATION_ENGINE, + "prompt_generation": request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION, "openai": { "OPENAI_API_BASE_URL": request.app.state.config.IMAGES_OPENAI_API_BASE_URL, "OPENAI_API_KEY": request.app.state.config.IMAGES_OPENAI_API_KEY, @@ -408,10 +415,14 @@ def save_b64_image(b64_str): return None -def save_url_image(url): +def save_url_image(url, headers=None): image_id = str(uuid.uuid4()) try: - r = requests.get(url) + if headers: + r = requests.get(url, headers=headers) + else: + r = requests.get(url) + r.raise_for_status() if r.headers["content-type"].split("/")[0] == "image": mime_type = r.headers["content-type"] @@ -535,7 +546,13 @@ async def image_generations( images = [] for image in res["data"]: - image_filename = save_url_image(image["url"]) + headers = None + if request.app.state.config.COMFYUI_API_KEY: + headers = { + "Authorization": f"Bearer {request.app.state.config.COMFYUI_API_KEY}" + } + + image_filename = save_url_image(image["url"], headers) images.append({"url": f"/cache/image/generations/{image_filename}"}) file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json") diff --git a/backend/open_webui/routers/knowledge.py b/backend/open_webui/routers/knowledge.py index ad67cc31f..cce3d6311 100644 --- a/backend/open_webui/routers/knowledge.py +++ b/backend/open_webui/routers/knowledge.py @@ -17,7 +17,7 @@ from open_webui.routers.retrieval import ( process_files_batch, BatchProcessFilesForm, ) - +from open_webui.storage.provider import Storage from open_webui.constants import ERROR_MESSAGES from open_webui.utils.auth import get_verified_user @@ -25,6 +25,7 @@ from open_webui.utils.access_control import has_access, has_permission from open_webui.env import SRC_LOG_LEVELS +from open_webui.models.models import Models, ModelForm log = logging.getLogger(__name__) @@ -212,8 +213,12 @@ async def update_knowledge_by_id( status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.NOT_FOUND, ) - - if knowledge.user_id != user.id and user.role != "admin": + # Is the user the original creator, in a group with write access, or an admin + if ( + knowledge.user_id != user.id + and not has_access(user.id, "write", knowledge.access_control) + and user.role != "admin" + ): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.ACCESS_PROHIBITED, @@ -419,6 +424,18 @@ def remove_file_from_knowledge_by_id( collection_name=knowledge.id, filter={"file_id": form_data.file_id} ) + # Remove the file's collection from vector database + file_collection = f"file-{form_data.file_id}" + if VECTOR_DB_CLIENT.has_collection(collection_name=file_collection): + VECTOR_DB_CLIENT.delete_collection(collection_name=file_collection) + + # Delete physical file + if file.path: + Storage.delete_file(file.path) + + # Delete file from database + Files.delete_file_by_id(form_data.file_id) + if knowledge: data = knowledge.data or {} file_ids = data.get("file_ids", []) @@ -473,6 +490,36 @@ async def delete_knowledge_by_id(id: str, user=Depends(get_verified_user)): detail=ERROR_MESSAGES.ACCESS_PROHIBITED, ) + log.info(f"Deleting knowledge base: {id} (name: {knowledge.name})") + + # Get all models + models = Models.get_all_models() + log.info(f"Found {len(models)} models to check for knowledge base {id}") + + # Update models that reference this knowledge base + for model in models: + if model.meta and hasattr(model.meta, "knowledge"): + knowledge_list = model.meta.knowledge or [] + # Filter out the deleted knowledge base + updated_knowledge = [k for k in knowledge_list if k.get("id") != id] + + # If the knowledge list changed, update the model + if len(updated_knowledge) != len(knowledge_list): + log.info(f"Updating model {model.id} to remove knowledge base {id}") + model.meta.knowledge = updated_knowledge + # Create a ModelForm for the update + model_form = ModelForm( + id=model.id, + name=model.name, + base_model_id=model.base_model_id, + meta=model.meta, + params=model.params, + access_control=model.access_control, + is_active=model.is_active, + ) + Models.update_model_by_id(model.id, model_form) + + # Clean up vector DB try: VECTOR_DB_CLIENT.delete_collection(collection_name=id) except Exception as e: diff --git a/backend/open_webui/routers/models.py b/backend/open_webui/routers/models.py index db981a913..6c8519b2c 100644 --- a/backend/open_webui/routers/models.py +++ b/backend/open_webui/routers/models.py @@ -155,6 +155,16 @@ async def update_model_by_id( detail=ERROR_MESSAGES.NOT_FOUND, ) + if ( + model.user_id != user.id + and not has_access(user.id, "write", model.access_control) + and user.role != "admin" + ): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) + model = Models.update_model_by_id(id, form_data) return model diff --git a/backend/open_webui/routers/ollama.py b/backend/open_webui/routers/ollama.py index 275146c72..261cd5ba3 100644 --- a/backend/open_webui/routers/ollama.py +++ b/backend/open_webui/routers/ollama.py @@ -152,10 +152,12 @@ async def send_post_request( ) -def get_api_key(url, configs): +def get_api_key(idx, url, configs): parsed_url = urlparse(url) base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" - return configs.get(base_url, {}).get("key", None) + return configs.get(str(idx), configs.get(base_url, {})).get( + "key", None + ) # Legacy support ########################################## @@ -238,11 +240,13 @@ async def update_config( request.app.state.config.OLLAMA_BASE_URLS = form_data.OLLAMA_BASE_URLS request.app.state.config.OLLAMA_API_CONFIGS = form_data.OLLAMA_API_CONFIGS - # Remove any extra configs - config_urls = request.app.state.config.OLLAMA_API_CONFIGS.keys() - for url in list(request.app.state.config.OLLAMA_BASE_URLS): - if url not in config_urls: - request.app.state.config.OLLAMA_API_CONFIGS.pop(url, None) + # Remove the API configs that are not in the API URLS + keys = list(map(str, range(len(request.app.state.config.OLLAMA_BASE_URLS)))) + request.app.state.config.OLLAMA_API_CONFIGS = { + key: value + for key, value in request.app.state.config.OLLAMA_API_CONFIGS.items() + if key in keys + } return { "ENABLE_OLLAMA_API": request.app.state.config.ENABLE_OLLAMA_API, @@ -256,12 +260,19 @@ async def get_all_models(request: Request): log.info("get_all_models()") if request.app.state.config.ENABLE_OLLAMA_API: request_tasks = [] - for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS): - if url not in request.app.state.config.OLLAMA_API_CONFIGS: + if (str(idx) not in request.app.state.config.OLLAMA_API_CONFIGS) and ( + url not in request.app.state.config.OLLAMA_API_CONFIGS # Legacy support + ): request_tasks.append(send_get_request(f"{url}/api/tags")) else: - api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}) + api_config = request.app.state.config.OLLAMA_API_CONFIGS.get( + str(idx), + request.app.state.config.OLLAMA_API_CONFIGS.get( + url, {} + ), # Legacy support + ) + enable = api_config.get("enable", True) key = api_config.get("key", None) @@ -275,7 +286,12 @@ async def get_all_models(request: Request): for idx, response in enumerate(responses): if response: url = request.app.state.config.OLLAMA_BASE_URLS[idx] - api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}) + api_config = request.app.state.config.OLLAMA_API_CONFIGS.get( + str(idx), + request.app.state.config.OLLAMA_API_CONFIGS.get( + url, {} + ), # Legacy support + ) prefix_id = api_config.get("prefix_id", None) model_ids = api_config.get("model_ids", []) @@ -349,7 +365,7 @@ async def get_ollama_tags( models = await get_all_models(request) else: url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] - key = get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS) + key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS) r = None try: @@ -393,11 +409,14 @@ async def get_ollama_versions(request: Request, url_idx: Optional[int] = None): request_tasks = [ send_get_request( f"{url}/api/version", - request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}).get( - "key", None - ), + request.app.state.config.OLLAMA_API_CONFIGS.get( + str(idx), + request.app.state.config.OLLAMA_API_CONFIGS.get( + url, {} + ), # Legacy support + ).get("key", None), ) - for url in request.app.state.config.OLLAMA_BASE_URLS + for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS) ] responses = await asyncio.gather(*request_tasks) responses = list(filter(lambda x: x is not None, responses)) @@ -454,11 +473,14 @@ async def get_ollama_loaded_models(request: Request, user=Depends(get_verified_u request_tasks = [ send_get_request( f"{url}/api/ps", - request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}).get( - "key", None - ), + request.app.state.config.OLLAMA_API_CONFIGS.get( + str(idx), + request.app.state.config.OLLAMA_API_CONFIGS.get( + url, {} + ), # Legacy support + ).get("key", None), ) - for url in request.app.state.config.OLLAMA_BASE_URLS + for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS) ] responses = await asyncio.gather(*request_tasks) @@ -488,7 +510,7 @@ async def pull_model( return await send_post_request( url=f"{url}/api/pull", payload=json.dumps(payload), - key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS), + key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS), ) @@ -524,16 +546,17 @@ async def push_model( return await send_post_request( url=f"{url}/api/push", payload=form_data.model_dump_json(exclude_none=True).encode(), - key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS), + key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS), ) class CreateModelForm(BaseModel): - name: str - modelfile: Optional[str] = None + model: Optional[str] = None stream: Optional[bool] = None path: Optional[str] = None + model_config = ConfigDict(extra="allow") + @router.post("/api/create") @router.post("/api/create/{url_idx}") @@ -549,7 +572,7 @@ async def create_model( return await send_post_request( url=f"{url}/api/create", payload=form_data.model_dump_json(exclude_none=True).encode(), - key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS), + key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS), ) @@ -579,7 +602,7 @@ async def copy_model( ) url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] - key = get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS) + key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS) try: r = requests.request( @@ -634,7 +657,7 @@ async def delete_model( ) url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] - key = get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS) + key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS) try: r = requests.request( @@ -684,7 +707,7 @@ async def show_model_info( url_idx = random.choice(models[form_data.name]["urls"]) url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] - key = get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS) + key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS) try: r = requests.request( @@ -753,7 +776,7 @@ async def embed( ) url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] - key = get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS) + key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS) try: r = requests.request( @@ -822,7 +845,7 @@ async def embeddings( ) url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] - key = get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS) + key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS) try: r = requests.request( @@ -897,7 +920,10 @@ async def generate_completion( ) url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] - api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}) + api_config = request.app.state.config.OLLAMA_API_CONFIGS.get( + str(url_idx), + request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}), # Legacy support + ) prefix_id = api_config.get("prefix_id", None) if prefix_id: @@ -906,7 +932,7 @@ async def generate_completion( return await send_post_request( url=f"{url}/api/generate", payload=form_data.model_dump_json(exclude_none=True).encode(), - key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS), + key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS), ) @@ -936,7 +962,7 @@ async def get_ollama_url(request: Request, model: str, url_idx: Optional[int] = ) url_idx = random.choice(models[model].get("urls", [])) url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] - return url + return url, url_idx @router.post("/api/chat") @@ -1004,8 +1030,11 @@ async def generate_chat_completion( if ":" not in payload["model"]: payload["model"] = f"{payload['model']}:latest" - url = await get_ollama_url(request, payload["model"], url_idx) - api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}) + url, url_idx = await get_ollama_url(request, payload["model"], url_idx) + api_config = request.app.state.config.OLLAMA_API_CONFIGS.get( + str(url_idx), + request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}), # Legacy support + ) prefix_id = api_config.get("prefix_id", None) if prefix_id: @@ -1015,7 +1044,7 @@ async def generate_chat_completion( url=f"{url}/api/chat", payload=json.dumps(payload), stream=form_data.stream, - key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS), + key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS), content_type="application/x-ndjson", ) @@ -1103,8 +1132,11 @@ async def generate_openai_completion( if ":" not in payload["model"]: payload["model"] = f"{payload['model']}:latest" - url = await get_ollama_url(request, payload["model"], url_idx) - api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}) + url, url_idx = await get_ollama_url(request, payload["model"], url_idx) + api_config = request.app.state.config.OLLAMA_API_CONFIGS.get( + str(url_idx), + request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}), # Legacy support + ) prefix_id = api_config.get("prefix_id", None) @@ -1115,7 +1147,7 @@ async def generate_openai_completion( url=f"{url}/v1/completions", payload=json.dumps(payload), stream=payload.get("stream", False), - key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS), + key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS), ) @@ -1177,8 +1209,11 @@ async def generate_openai_chat_completion( if ":" not in payload["model"]: payload["model"] = f"{payload['model']}:latest" - url = await get_ollama_url(request, payload["model"], url_idx) - api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}) + url, url_idx = await get_ollama_url(request, payload["model"], url_idx) + api_config = request.app.state.config.OLLAMA_API_CONFIGS.get( + str(url_idx), + request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}), # Legacy support + ) prefix_id = api_config.get("prefix_id", None) if prefix_id: @@ -1188,7 +1223,7 @@ async def generate_openai_chat_completion( url=f"{url}/v1/chat/completions", payload=json.dumps(payload), stream=payload.get("stream", False), - key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS), + key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS), ) diff --git a/backend/open_webui/routers/openai.py b/backend/open_webui/routers/openai.py index 4ab381ea4..f7d7fd294 100644 --- a/backend/open_webui/routers/openai.py +++ b/backend/open_webui/routers/openai.py @@ -145,11 +145,13 @@ async def update_config( request.app.state.config.OPENAI_API_CONFIGS = form_data.OPENAI_API_CONFIGS - # Remove any extra configs - config_urls = request.app.state.config.OPENAI_API_CONFIGS.keys() - for idx, url in enumerate(request.app.state.config.OPENAI_API_BASE_URLS): - if url not in config_urls: - request.app.state.config.OPENAI_API_CONFIGS.pop(url, None) + # Remove the API configs that are not in the API URLS + keys = list(map(str, range(len(request.app.state.config.OPENAI_API_BASE_URLS)))) + request.app.state.config.OPENAI_API_CONFIGS = { + key: value + for key, value in request.app.state.config.OPENAI_API_CONFIGS.items() + if key in keys + } return { "ENABLE_OPENAI_API": request.app.state.config.ENABLE_OPENAI_API, @@ -264,14 +266,21 @@ async def get_all_models_responses(request: Request) -> list: request_tasks = [] for idx, url in enumerate(request.app.state.config.OPENAI_API_BASE_URLS): - if url not in request.app.state.config.OPENAI_API_CONFIGS: + if (str(idx) not in request.app.state.config.OPENAI_API_CONFIGS) and ( + url not in request.app.state.config.OPENAI_API_CONFIGS # Legacy support + ): request_tasks.append( send_get_request( f"{url}/models", request.app.state.config.OPENAI_API_KEYS[idx] ) ) else: - api_config = request.app.state.config.OPENAI_API_CONFIGS.get(url, {}) + api_config = request.app.state.config.OPENAI_API_CONFIGS.get( + str(idx), + request.app.state.config.OPENAI_API_CONFIGS.get( + url, {} + ), # Legacy support + ) enable = api_config.get("enable", True) model_ids = api_config.get("model_ids", []) @@ -310,7 +319,12 @@ async def get_all_models_responses(request: Request) -> list: for idx, response in enumerate(responses): if response: url = request.app.state.config.OPENAI_API_BASE_URLS[idx] - api_config = request.app.state.config.OPENAI_API_CONFIGS.get(url, {}) + api_config = request.app.state.config.OPENAI_API_CONFIGS.get( + str(idx), + request.app.state.config.OPENAI_API_CONFIGS.get( + url, {} + ), # Legacy support + ) prefix_id = api_config.get("prefix_id", None) @@ -573,6 +587,7 @@ async def generate_chat_completion( detail="Model not found", ) + await get_all_models(request) model = request.app.state.OPENAI_MODELS.get(model_id) if model: idx = model["urlIdx"] @@ -584,7 +599,10 @@ async def generate_chat_completion( # Get the API config for the model api_config = request.app.state.config.OPENAI_API_CONFIGS.get( - request.app.state.config.OPENAI_API_BASE_URLS[idx], {} + str(idx), + request.app.state.config.OPENAI_API_CONFIGS.get( + request.app.state.config.OPENAI_API_BASE_URLS[idx], {} + ), # Legacy support ) prefix_id = api_config.get("prefix_id", None) diff --git a/backend/open_webui/routers/prompts.py b/backend/open_webui/routers/prompts.py index 4f1c48482..014e5652e 100644 --- a/backend/open_webui/routers/prompts.py +++ b/backend/open_webui/routers/prompts.py @@ -112,7 +112,12 @@ async def update_prompt_by_command( detail=ERROR_MESSAGES.NOT_FOUND, ) - if prompt.user_id != user.id and user.role != "admin": + # Is the user the original creator, in a group with write access, or an admin + if ( + prompt.user_id != user.id + and not has_access(user.id, "write", prompt.access_control) + and user.role != "admin" + ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED, diff --git a/backend/open_webui/routers/retrieval.py b/backend/open_webui/routers/retrieval.py index 255cff112..166738876 100644 --- a/backend/open_webui/routers/retrieval.py +++ b/backend/open_webui/routers/retrieval.py @@ -385,7 +385,7 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)): "serply_api_key": request.app.state.config.SERPLY_API_KEY, "tavily_api_key": request.app.state.config.TAVILY_API_KEY, "searchapi_api_key": request.app.state.config.SEARCHAPI_API_KEY, - "seaarchapi_engine": request.app.state.config.SEARCHAPI_ENGINE, + "searchapi_engine": request.app.state.config.SEARCHAPI_ENGINE, "jina_api_key": request.app.state.config.JINA_API_KEY, "bing_search_v7_endpoint": request.app.state.config.BING_SEARCH_V7_ENDPOINT, "bing_search_v7_subscription_key": request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY, diff --git a/backend/open_webui/routers/tasks.py b/backend/open_webui/routers/tasks.py index 7d14a9d18..6d7343c8a 100644 --- a/backend/open_webui/routers/tasks.py +++ b/backend/open_webui/routers/tasks.py @@ -9,6 +9,7 @@ from open_webui.utils.chat import generate_chat_completion from open_webui.utils.task import ( title_generation_template, query_generation_template, + image_prompt_generation_template, autocomplete_generation_template, tags_generation_template, emoji_generation_template, @@ -23,6 +24,7 @@ from open_webui.utils.task import get_task_model_id from open_webui.config import ( DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE, DEFAULT_TAGS_GENERATION_PROMPT_TEMPLATE, + DEFAULT_IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE, DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE, DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE, DEFAULT_EMOJI_GENERATION_PROMPT_TEMPLATE, @@ -50,6 +52,7 @@ async def get_task_config(request: Request, user=Depends(get_verified_user)): "TASK_MODEL": request.app.state.config.TASK_MODEL, "TASK_MODEL_EXTERNAL": request.app.state.config.TASK_MODEL_EXTERNAL, "TITLE_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE, + "IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE": request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE, "ENABLE_AUTOCOMPLETE_GENERATION": request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION, "AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH, "TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE, @@ -65,6 +68,7 @@ class TaskConfigForm(BaseModel): TASK_MODEL: Optional[str] TASK_MODEL_EXTERNAL: Optional[str] TITLE_GENERATION_PROMPT_TEMPLATE: str + IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE: str ENABLE_AUTOCOMPLETE_GENERATION: bool AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH: int TAGS_GENERATION_PROMPT_TEMPLATE: str @@ -114,6 +118,7 @@ async def update_task_config( "TASK_MODEL": request.app.state.config.TASK_MODEL, "TASK_MODEL_EXTERNAL": request.app.state.config.TASK_MODEL_EXTERNAL, "TITLE_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE, + "IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE": request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE, "ENABLE_AUTOCOMPLETE_GENERATION": request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION, "AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH, "TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE, @@ -256,6 +261,66 @@ async def generate_chat_tags( ) +@router.post("/image_prompt/completions") +async def generate_image_prompt( + request: Request, form_data: dict, user=Depends(get_verified_user) +): + models = request.app.state.MODELS + + model_id = form_data["model"] + if model_id not in models: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Model not found", + ) + + # Check if the user has a custom task model + # If the user has a custom task model, use that model + task_model_id = get_task_model_id( + model_id, + request.app.state.config.TASK_MODEL, + request.app.state.config.TASK_MODEL_EXTERNAL, + models, + ) + + log.debug( + f"generating image prompt using model {task_model_id} for user {user.email} " + ) + + if request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE != "": + template = request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE + else: + template = DEFAULT_IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE + + content = image_prompt_generation_template( + template, + form_data["messages"], + user={ + "name": user.name, + }, + ) + + payload = { + "model": task_model_id, + "messages": [{"role": "user", "content": content}], + "stream": False, + "metadata": { + "task": str(TASKS.IMAGE_PROMPT_GENERATION), + "task_body": form_data, + "chat_id": form_data.get("chat_id", None), + }, + } + + try: + return await generate_chat_completion(request, form_data=payload, user=user) + except Exception as e: + log.error("Exception occurred", exc_info=True) + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"detail": "An internal error has occurred."}, + ) + + @router.post("/queries/completions") async def generate_queries( request: Request, form_data: dict, user=Depends(get_verified_user) diff --git a/backend/open_webui/routers/tools.py b/backend/open_webui/routers/tools.py index 9e95ebe5a..7b9144b4c 100644 --- a/backend/open_webui/routers/tools.py +++ b/backend/open_webui/routers/tools.py @@ -70,7 +70,7 @@ async def create_new_tools( user=Depends(get_verified_user), ): if user.role != "admin" and not has_permission( - user.id, "workspace.knowledge", request.app.state.config.USER_PERMISSIONS + user.id, "workspace.tools", request.app.state.config.USER_PERMISSIONS ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -165,7 +165,12 @@ async def update_tools_by_id( detail=ERROR_MESSAGES.NOT_FOUND, ) - if tools.user_id != user.id and user.role != "admin": + # Is the user the original creator, in a group with write access, or an admin + if ( + tools.user_id != user.id + and not has_access(user.id, "write", tools.access_control) + and user.role != "admin" + ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.UNAUTHORIZED, @@ -304,6 +309,17 @@ async def update_tools_valves_by_id( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND, ) + + if ( + tools.user_id != user.id + and not has_access(user.id, "write", tools.access_control) + and user.role != "admin" + ): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) + if id in request.app.state.TOOLS: tools_module = request.app.state.TOOLS[id] else: diff --git a/backend/open_webui/routers/users.py b/backend/open_webui/routers/users.py index 7006091e1..b37ad4b39 100644 --- a/backend/open_webui/routers/users.py +++ b/backend/open_webui/routers/users.py @@ -62,27 +62,44 @@ async def get_user_permissisions(user=Depends(get_verified_user)): # User Default Permissions ############################ class WorkspacePermissions(BaseModel): - models: bool - knowledge: bool - prompts: bool - tools: bool + models: bool = False + knowledge: bool = False + prompts: bool = False + tools: bool = False class ChatPermissions(BaseModel): - file_upload: bool - delete: bool - edit: bool - temporary: bool + controls: bool = True + file_upload: bool = True + delete: bool = True + edit: bool = True + temporary: bool = True + + +class FeaturesPermissions(BaseModel): + web_search: bool = True + image_generation: bool = True class UserPermissions(BaseModel): workspace: WorkspacePermissions chat: ChatPermissions + features: FeaturesPermissions -@router.get("/default/permissions") +@router.get("/default/permissions", response_model=UserPermissions) async def get_user_permissions(request: Request, user=Depends(get_admin_user)): - return request.app.state.config.USER_PERMISSIONS + return { + "workspace": WorkspacePermissions( + **request.app.state.config.USER_PERMISSIONS.get("workspace", {}) + ), + "chat": ChatPermissions( + **request.app.state.config.USER_PERMISSIONS.get("chat", {}) + ), + "features": FeaturesPermissions( + **request.app.state.config.USER_PERMISSIONS.get("features", {}) + ), + } @router.post("/default/permissions") diff --git a/backend/open_webui/socket/utils.py b/backend/open_webui/socket/utils.py index d3de87b05..46fafbb9e 100644 --- a/backend/open_webui/socket/utils.py +++ b/backend/open_webui/socket/utils.py @@ -26,7 +26,7 @@ class RedisLock: def release_lock(self): lock_value = self.redis.get(self.lock_name) - if lock_value and lock_value.decode("utf-8") == self.lock_id: + if lock_value and lock_value == self.lock_id: self.redis.delete(self.lock_name) diff --git a/backend/open_webui/static/favicon.png b/backend/open_webui/static/favicon.png index 2b2074780..63735ad46 100644 Binary files a/backend/open_webui/static/favicon.png and b/backend/open_webui/static/favicon.png differ diff --git a/backend/open_webui/static/logo.png b/backend/open_webui/static/logo.png index 519af1db6..a652a5fb8 100644 Binary files a/backend/open_webui/static/logo.png and b/backend/open_webui/static/logo.png differ diff --git a/backend/open_webui/storage/provider.py b/backend/open_webui/storage/provider.py index ae3347682..0c0a8aacf 100644 --- a/backend/open_webui/storage/provider.py +++ b/backend/open_webui/storage/provider.py @@ -1,121 +1,73 @@ import os +import shutil +import json +from abc import ABC, abstractmethod +from typing import BinaryIO, Tuple + import boto3 from botocore.exceptions import ClientError -import shutil - - -from typing import BinaryIO, Tuple, Optional, Union - -from open_webui.constants import ERROR_MESSAGES from open_webui.config import ( - STORAGE_PROVIDER, S3_ACCESS_KEY_ID, - S3_SECRET_ACCESS_KEY, S3_BUCKET_NAME, - S3_REGION_NAME, S3_ENDPOINT_URL, + S3_REGION_NAME, + S3_SECRET_ACCESS_KEY, + GCS_BUCKET_NAME, + GOOGLE_APPLICATION_CREDENTIALS_JSON, + STORAGE_PROVIDER, UPLOAD_DIR, ) +from google.cloud import storage +from google.cloud.exceptions import GoogleCloudError, NotFound +from open_webui.constants import ERROR_MESSAGES -import boto3 -from botocore.exceptions import ClientError -from typing import BinaryIO, Tuple, Optional +class StorageProvider(ABC): + @abstractmethod + def get_file(self, file_path: str) -> str: + pass + + @abstractmethod + def upload_file(self, file: BinaryIO, filename: str) -> Tuple[bytes, str]: + pass + + @abstractmethod + def delete_all_files(self) -> None: + pass + + @abstractmethod + def delete_file(self, file_path: str) -> None: + pass -class StorageProvider: - def __init__(self, provider: Optional[str] = None): - self.storage_provider: str = provider or STORAGE_PROVIDER - - self.s3_client = None - self.s3_bucket_name: Optional[str] = None - - if self.storage_provider == "s3": - self._initialize_s3() - - def _initialize_s3(self) -> None: - """Initializes the S3 client and bucket name if using S3 storage.""" - self.s3_client = boto3.client( - "s3", - region_name=S3_REGION_NAME, - endpoint_url=S3_ENDPOINT_URL, - aws_access_key_id=S3_ACCESS_KEY_ID, - aws_secret_access_key=S3_SECRET_ACCESS_KEY, - ) - self.bucket_name = S3_BUCKET_NAME - - def _upload_to_s3(self, file_path: str, filename: str) -> Tuple[bytes, str]: - """Handles uploading of the file to S3 storage.""" - if not self.s3_client: - raise RuntimeError("S3 Client is not initialized.") - - try: - self.s3_client.upload_file(file_path, self.bucket_name, filename) - return ( - open(file_path, "rb").read(), - "s3://" + self.bucket_name + "/" + filename, - ) - except ClientError as e: - raise RuntimeError(f"Error uploading file to S3: {e}") - - def _upload_to_local(self, contents: bytes, filename: str) -> Tuple[bytes, str]: - """Handles uploading of the file to local storage.""" +class LocalStorageProvider(StorageProvider): + @staticmethod + def upload_file(file: BinaryIO, filename: str) -> Tuple[bytes, str]: + contents = file.read() + if not contents: + raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT) file_path = f"{UPLOAD_DIR}/{filename}" with open(file_path, "wb") as f: f.write(contents) return contents, file_path - def _get_file_from_s3(self, file_path: str) -> str: - """Handles downloading of the file from S3 storage.""" - if not self.s3_client: - raise RuntimeError("S3 Client is not initialized.") - - try: - bucket_name, key = file_path.split("//")[1].split("/") - local_file_path = f"{UPLOAD_DIR}/{key}" - self.s3_client.download_file(bucket_name, key, local_file_path) - return local_file_path - except ClientError as e: - raise RuntimeError(f"Error downloading file from S3: {e}") - - def _get_file_from_local(self, file_path: str) -> str: + @staticmethod + def get_file(file_path: str) -> str: """Handles downloading of the file from local storage.""" return file_path - def _delete_from_s3(self, filename: str) -> None: - """Handles deletion of the file from S3 storage.""" - if not self.s3_client: - raise RuntimeError("S3 Client is not initialized.") - - try: - self.s3_client.delete_object(Bucket=self.bucket_name, Key=filename) - except ClientError as e: - raise RuntimeError(f"Error deleting file from S3: {e}") - - def _delete_from_local(self, filename: str) -> None: + @staticmethod + def delete_file(file_path: str) -> None: """Handles deletion of the file from local storage.""" + filename = file_path.split("/")[-1] file_path = f"{UPLOAD_DIR}/{filename}" if os.path.isfile(file_path): os.remove(file_path) else: print(f"File {file_path} not found in local storage.") - def _delete_all_from_s3(self) -> None: - """Handles deletion of all files from S3 storage.""" - if not self.s3_client: - raise RuntimeError("S3 Client is not initialized.") - - try: - response = self.s3_client.list_objects_v2(Bucket=self.bucket_name) - if "Contents" in response: - for content in response["Contents"]: - self.s3_client.delete_object( - Bucket=self.bucket_name, Key=content["Key"] - ) - except ClientError as e: - raise RuntimeError(f"Error deleting all files from S3: {e}") - - def _delete_all_from_local(self) -> None: + @staticmethod + def delete_all_files() -> None: """Handles deletion of all files from local storage.""" if os.path.exists(UPLOAD_DIR): for filename in os.listdir(UPLOAD_DIR): @@ -130,40 +82,141 @@ class StorageProvider: else: print(f"Directory {UPLOAD_DIR} not found in local storage.") - def upload_file(self, file: BinaryIO, filename: str) -> Tuple[bytes, str]: - """Uploads a file either to S3 or the local file system.""" - contents = file.read() - if not contents: - raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT) - contents, file_path = self._upload_to_local(contents, filename) - if self.storage_provider == "s3": - return self._upload_to_s3(file_path, filename) - return contents, file_path +class S3StorageProvider(StorageProvider): + def __init__(self): + self.s3_client = boto3.client( + "s3", + region_name=S3_REGION_NAME, + endpoint_url=S3_ENDPOINT_URL, + aws_access_key_id=S3_ACCESS_KEY_ID, + aws_secret_access_key=S3_SECRET_ACCESS_KEY, + ) + self.bucket_name = S3_BUCKET_NAME + + def upload_file(self, file: BinaryIO, filename: str) -> Tuple[bytes, str]: + """Handles uploading of the file to S3 storage.""" + _, file_path = LocalStorageProvider.upload_file(file, filename) + try: + self.s3_client.upload_file(file_path, self.bucket_name, filename) + return ( + open(file_path, "rb").read(), + "s3://" + self.bucket_name + "/" + filename, + ) + except ClientError as e: + raise RuntimeError(f"Error uploading file to S3: {e}") def get_file(self, file_path: str) -> str: - """Downloads a file either from S3 or the local file system and returns the file path.""" - if self.storage_provider == "s3": - return self._get_file_from_s3(file_path) - return self._get_file_from_local(file_path) + """Handles downloading of the file from S3 storage.""" + try: + bucket_name, key = file_path.split("//")[1].split("/") + local_file_path = f"{UPLOAD_DIR}/{key}" + self.s3_client.download_file(bucket_name, key, local_file_path) + return local_file_path + except ClientError as e: + raise RuntimeError(f"Error downloading file from S3: {e}") def delete_file(self, file_path: str) -> None: - """Deletes a file either from S3 or the local file system.""" + """Handles deletion of the file from S3 storage.""" filename = file_path.split("/")[-1] - - if self.storage_provider == "s3": - self._delete_from_s3(filename) + try: + self.s3_client.delete_object(Bucket=self.bucket_name, Key=filename) + except ClientError as e: + raise RuntimeError(f"Error deleting file from S3: {e}") # Always delete from local storage - self._delete_from_local(filename) + LocalStorageProvider.delete_file(file_path) def delete_all_files(self) -> None: - """Deletes all files from the storage.""" - if self.storage_provider == "s3": - self._delete_all_from_s3() + """Handles deletion of all files from S3 storage.""" + try: + response = self.s3_client.list_objects_v2(Bucket=self.bucket_name) + if "Contents" in response: + for content in response["Contents"]: + self.s3_client.delete_object( + Bucket=self.bucket_name, Key=content["Key"] + ) + except ClientError as e: + raise RuntimeError(f"Error deleting all files from S3: {e}") # Always delete from local storage - self._delete_all_from_local() + LocalStorageProvider.delete_all_files() -Storage = StorageProvider(provider=STORAGE_PROVIDER) +class GCSStorageProvider(StorageProvider): + def __init__(self): + self.bucket_name = GCS_BUCKET_NAME + + if GOOGLE_APPLICATION_CREDENTIALS_JSON: + self.gcs_client = storage.Client.from_service_account_info( + info=json.loads(GOOGLE_APPLICATION_CREDENTIALS_JSON) + ) + else: + # if no credentials json is provided, credentials will be picked up from the environment + # if running on local environment, credentials would be user credentials + # if running on a Compute Engine instance, credentials would be from Google Metadata server + self.gcs_client = storage.Client() + self.bucket = self.gcs_client.bucket(GCS_BUCKET_NAME) + + def upload_file(self, file: BinaryIO, filename: str) -> Tuple[bytes, str]: + """Handles uploading of the file to GCS storage.""" + contents, file_path = LocalStorageProvider.upload_file(file, filename) + try: + blob = self.bucket.blob(filename) + blob.upload_from_filename(file_path) + return contents, "gs://" + self.bucket_name + "/" + filename + except GoogleCloudError as e: + raise RuntimeError(f"Error uploading file to GCS: {e}") + + def get_file(self, file_path: str) -> str: + """Handles downloading of the file from GCS storage.""" + try: + filename = file_path.removeprefix("gs://").split("/")[1] + local_file_path = f"{UPLOAD_DIR}/{filename}" + blob = self.bucket.get_blob(filename) + blob.download_to_filename(local_file_path) + + return local_file_path + except NotFound as e: + raise RuntimeError(f"Error downloading file from GCS: {e}") + + def delete_file(self, file_path: str) -> None: + """Handles deletion of the file from GCS storage.""" + try: + filename = file_path.removeprefix("gs://").split("/")[1] + blob = self.bucket.get_blob(filename) + blob.delete() + except NotFound as e: + raise RuntimeError(f"Error deleting file from GCS: {e}") + + # Always delete from local storage + LocalStorageProvider.delete_file(file_path) + + def delete_all_files(self) -> None: + """Handles deletion of all files from GCS storage.""" + try: + blobs = self.bucket.list_blobs() + + for blob in blobs: + blob.delete() + + except NotFound as e: + raise RuntimeError(f"Error deleting all files from GCS: {e}") + + # Always delete from local storage + LocalStorageProvider.delete_all_files() + + +def get_storage_provider(storage_provider: str): + if storage_provider == "local": + Storage = LocalStorageProvider() + elif storage_provider == "s3": + Storage = S3StorageProvider() + elif storage_provider == "gcs": + Storage = GCSStorageProvider() + else: + raise RuntimeError(f"Unsupported storage provider: {storage_provider}") + return Storage + + +Storage = get_storage_provider(STORAGE_PROVIDER) diff --git a/backend/open_webui/test/apps/webui/storage/test_provider.py b/backend/open_webui/test/apps/webui/storage/test_provider.py new file mode 100644 index 000000000..863106e75 --- /dev/null +++ b/backend/open_webui/test/apps/webui/storage/test_provider.py @@ -0,0 +1,274 @@ +import io +import os +import boto3 +import pytest +from botocore.exceptions import ClientError +from moto import mock_aws +from open_webui.storage import provider +from gcp_storage_emulator.server import create_server +from google.cloud import storage + + +def mock_upload_dir(monkeypatch, tmp_path): + """Fixture to monkey-patch the UPLOAD_DIR and create a temporary directory.""" + directory = tmp_path / "uploads" + directory.mkdir() + monkeypatch.setattr(provider, "UPLOAD_DIR", str(directory)) + return directory + + +def test_imports(): + provider.StorageProvider + provider.LocalStorageProvider + provider.S3StorageProvider + provider.GCSStorageProvider + provider.Storage + + +def test_get_storage_provider(): + Storage = provider.get_storage_provider("local") + assert isinstance(Storage, provider.LocalStorageProvider) + Storage = provider.get_storage_provider("s3") + assert isinstance(Storage, provider.S3StorageProvider) + Storage = provider.get_storage_provider("gcs") + assert isinstance(Storage, provider.GCSStorageProvider) + with pytest.raises(RuntimeError): + provider.get_storage_provider("invalid") + + +def test_class_instantiation(): + with pytest.raises(TypeError): + provider.StorageProvider() + with pytest.raises(TypeError): + + class Test(provider.StorageProvider): + pass + + Test() + provider.LocalStorageProvider() + provider.S3StorageProvider() + provider.GCSStorageProvider() + + +class TestLocalStorageProvider: + Storage = provider.LocalStorageProvider() + file_content = b"test content" + file_bytesio = io.BytesIO(file_content) + filename = "test.txt" + filename_extra = "test_exyta.txt" + file_bytesio_empty = io.BytesIO() + + def test_upload_file(self, monkeypatch, tmp_path): + upload_dir = mock_upload_dir(monkeypatch, tmp_path) + contents, file_path = self.Storage.upload_file(self.file_bytesio, self.filename) + assert (upload_dir / self.filename).exists() + assert (upload_dir / self.filename).read_bytes() == self.file_content + assert contents == self.file_content + assert file_path == str(upload_dir / self.filename) + with pytest.raises(ValueError): + self.Storage.upload_file(self.file_bytesio_empty, self.filename) + + def test_get_file(self, monkeypatch, tmp_path): + upload_dir = mock_upload_dir(monkeypatch, tmp_path) + file_path = str(upload_dir / self.filename) + file_path_return = self.Storage.get_file(file_path) + assert file_path == file_path_return + + def test_delete_file(self, monkeypatch, tmp_path): + upload_dir = mock_upload_dir(monkeypatch, tmp_path) + (upload_dir / self.filename).write_bytes(self.file_content) + assert (upload_dir / self.filename).exists() + file_path = str(upload_dir / self.filename) + self.Storage.delete_file(file_path) + assert not (upload_dir / self.filename).exists() + + def test_delete_all_files(self, monkeypatch, tmp_path): + upload_dir = mock_upload_dir(monkeypatch, tmp_path) + (upload_dir / self.filename).write_bytes(self.file_content) + (upload_dir / self.filename_extra).write_bytes(self.file_content) + self.Storage.delete_all_files() + assert not (upload_dir / self.filename).exists() + assert not (upload_dir / self.filename_extra).exists() + + +@mock_aws +class TestS3StorageProvider: + + def __init__(self): + self.Storage = provider.S3StorageProvider() + self.Storage.bucket_name = "my-bucket" + self.s3_client = boto3.resource("s3", region_name="us-east-1") + self.file_content = b"test content" + self.filename = "test.txt" + self.filename_extra = "test_exyta.txt" + self.file_bytesio_empty = io.BytesIO() + super().__init__() + + def test_upload_file(self, monkeypatch, tmp_path): + upload_dir = mock_upload_dir(monkeypatch, tmp_path) + # S3 checks + with pytest.raises(Exception): + self.Storage.upload_file(io.BytesIO(self.file_content), self.filename) + self.s3_client.create_bucket(Bucket=self.Storage.bucket_name) + contents, s3_file_path = self.Storage.upload_file( + io.BytesIO(self.file_content), self.filename + ) + object = self.s3_client.Object(self.Storage.bucket_name, self.filename) + assert self.file_content == object.get()["Body"].read() + # local checks + assert (upload_dir / self.filename).exists() + assert (upload_dir / self.filename).read_bytes() == self.file_content + assert contents == self.file_content + assert s3_file_path == "s3://" + self.Storage.bucket_name + "/" + self.filename + with pytest.raises(ValueError): + self.Storage.upload_file(self.file_bytesio_empty, self.filename) + + def test_get_file(self, monkeypatch, tmp_path): + upload_dir = mock_upload_dir(monkeypatch, tmp_path) + self.s3_client.create_bucket(Bucket=self.Storage.bucket_name) + contents, s3_file_path = self.Storage.upload_file( + io.BytesIO(self.file_content), self.filename + ) + file_path = self.Storage.get_file(s3_file_path) + assert file_path == str(upload_dir / self.filename) + assert (upload_dir / self.filename).exists() + + def test_delete_file(self, monkeypatch, tmp_path): + upload_dir = mock_upload_dir(monkeypatch, tmp_path) + self.s3_client.create_bucket(Bucket=self.Storage.bucket_name) + contents, s3_file_path = self.Storage.upload_file( + io.BytesIO(self.file_content), self.filename + ) + assert (upload_dir / self.filename).exists() + self.Storage.delete_file(s3_file_path) + assert not (upload_dir / self.filename).exists() + with pytest.raises(ClientError) as exc: + self.s3_client.Object(self.Storage.bucket_name, self.filename).load() + error = exc.value.response["Error"] + assert error["Code"] == "404" + assert error["Message"] == "Not Found" + + def test_delete_all_files(self, monkeypatch, tmp_path): + upload_dir = mock_upload_dir(monkeypatch, tmp_path) + # create 2 files + self.s3_client.create_bucket(Bucket=self.Storage.bucket_name) + self.Storage.upload_file(io.BytesIO(self.file_content), self.filename) + object = self.s3_client.Object(self.Storage.bucket_name, self.filename) + assert self.file_content == object.get()["Body"].read() + assert (upload_dir / self.filename).exists() + assert (upload_dir / self.filename).read_bytes() == self.file_content + self.Storage.upload_file(io.BytesIO(self.file_content), self.filename_extra) + object = self.s3_client.Object(self.Storage.bucket_name, self.filename_extra) + assert self.file_content == object.get()["Body"].read() + assert (upload_dir / self.filename).exists() + assert (upload_dir / self.filename).read_bytes() == self.file_content + + self.Storage.delete_all_files() + assert not (upload_dir / self.filename).exists() + with pytest.raises(ClientError) as exc: + self.s3_client.Object(self.Storage.bucket_name, self.filename).load() + error = exc.value.response["Error"] + assert error["Code"] == "404" + assert error["Message"] == "Not Found" + assert not (upload_dir / self.filename_extra).exists() + with pytest.raises(ClientError) as exc: + self.s3_client.Object(self.Storage.bucket_name, self.filename_extra).load() + error = exc.value.response["Error"] + assert error["Code"] == "404" + assert error["Message"] == "Not Found" + + self.Storage.delete_all_files() + assert not (upload_dir / self.filename).exists() + assert not (upload_dir / self.filename_extra).exists() + + +class TestGCSStorageProvider: + Storage = provider.GCSStorageProvider() + Storage.bucket_name = "my-bucket" + file_content = b"test content" + filename = "test.txt" + filename_extra = "test_exyta.txt" + file_bytesio_empty = io.BytesIO() + + @pytest.fixture(scope="class") + def setup(self): + host, port = "localhost", 9023 + + server = create_server(host, port, in_memory=True) + server.start() + os.environ["STORAGE_EMULATOR_HOST"] = f"http://{host}:{port}" + + gcs_client = storage.Client() + bucket = gcs_client.bucket(self.Storage.bucket_name) + bucket.create() + self.Storage.gcs_client, self.Storage.bucket = gcs_client, bucket + yield + bucket.delete(force=True) + server.stop() + + def test_upload_file(self, monkeypatch, tmp_path, setup): + upload_dir = mock_upload_dir(monkeypatch, tmp_path) + # catch error if bucket does not exist + with pytest.raises(Exception): + self.Storage.bucket = monkeypatch(self.Storage, "bucket", None) + self.Storage.upload_file(io.BytesIO(self.file_content), self.filename) + contents, gcs_file_path = self.Storage.upload_file( + io.BytesIO(self.file_content), self.filename + ) + object = self.Storage.bucket.get_blob(self.filename) + assert self.file_content == object.download_as_bytes() + # local checks + assert (upload_dir / self.filename).exists() + assert (upload_dir / self.filename).read_bytes() == self.file_content + assert contents == self.file_content + assert gcs_file_path == "gs://" + self.Storage.bucket_name + "/" + self.filename + # test error if file is empty + with pytest.raises(ValueError): + self.Storage.upload_file(self.file_bytesio_empty, self.filename) + + def test_get_file(self, monkeypatch, tmp_path, setup): + upload_dir = mock_upload_dir(monkeypatch, tmp_path) + contents, gcs_file_path = self.Storage.upload_file( + io.BytesIO(self.file_content), self.filename + ) + file_path = self.Storage.get_file(gcs_file_path) + assert file_path == str(upload_dir / self.filename) + assert (upload_dir / self.filename).exists() + + def test_delete_file(self, monkeypatch, tmp_path, setup): + upload_dir = mock_upload_dir(monkeypatch, tmp_path) + contents, gcs_file_path = self.Storage.upload_file( + io.BytesIO(self.file_content), self.filename + ) + # ensure that local directory has the uploaded file as well + assert (upload_dir / self.filename).exists() + assert self.Storage.bucket.get_blob(self.filename).name == self.filename + self.Storage.delete_file(gcs_file_path) + # check that deleting file from gcs will delete the local file as well + assert not (upload_dir / self.filename).exists() + assert self.Storage.bucket.get_blob(self.filename) == None + + def test_delete_all_files(self, monkeypatch, tmp_path, setup): + upload_dir = mock_upload_dir(monkeypatch, tmp_path) + # create 2 files + self.Storage.upload_file(io.BytesIO(self.file_content), self.filename) + object = self.Storage.bucket.get_blob(self.filename) + assert (upload_dir / self.filename).exists() + assert (upload_dir / self.filename).read_bytes() == self.file_content + assert self.Storage.bucket.get_blob(self.filename).name == self.filename + assert self.file_content == object.download_as_bytes() + self.Storage.upload_file(io.BytesIO(self.file_content), self.filename_extra) + object = self.Storage.bucket.get_blob(self.filename_extra) + assert (upload_dir / self.filename_extra).exists() + assert (upload_dir / self.filename_extra).read_bytes() == self.file_content + assert ( + self.Storage.bucket.get_blob(self.filename_extra).name + == self.filename_extra + ) + assert self.file_content == object.download_as_bytes() + + self.Storage.delete_all_files() + assert not (upload_dir / self.filename).exists() + assert not (upload_dir / self.filename_extra).exists() + assert self.Storage.bucket.get_blob(self.filename) == None + assert self.Storage.bucket.get_blob(self.filename_extra) == None diff --git a/backend/open_webui/utils/access_control.py b/backend/open_webui/utils/access_control.py index da61e7fb3..1699cfaa7 100644 --- a/backend/open_webui/utils/access_control.py +++ b/backend/open_webui/utils/access_control.py @@ -1,9 +1,30 @@ from typing import Optional, Union, List, Dict, Any from open_webui.models.users import Users, UserModel from open_webui.models.groups import Groups + + +from open_webui.config import DEFAULT_USER_PERMISSIONS import json +def fill_missing_permissions( + permissions: Dict[str, Any], default_permissions: Dict[str, Any] +) -> Dict[str, Any]: + """ + Recursively fills in missing properties in the permissions dictionary + using the default permissions as a template. + """ + for key, value in default_permissions.items(): + if key not in permissions: + permissions[key] = value + elif isinstance(value, dict) and isinstance( + permissions[key], dict + ): # Both are nested dictionaries + permissions[key] = fill_missing_permissions(permissions[key], value) + + return permissions + + def get_permissions( user_id: str, default_permissions: Dict[str, Any], @@ -27,39 +48,45 @@ def get_permissions( if key not in permissions: permissions[key] = value else: - permissions[key] = permissions[key] or value + permissions[key] = ( + permissions[key] or value + ) # Use the most permissive value (True > False) return permissions user_groups = Groups.get_groups_by_member_id(user_id) - # deep copy default permissions to avoid modifying the original dict + # Deep copy default permissions to avoid modifying the original dict permissions = json.loads(json.dumps(default_permissions)) + # Combine permissions from all user groups for group in user_groups: group_permissions = group.permissions permissions = combine_permissions(permissions, group_permissions) + # Ensure all fields from default_permissions are present and filled in + permissions = fill_missing_permissions(permissions, default_permissions) + return permissions def has_permission( user_id: str, permission_key: str, - default_permissions: Dict[str, bool] = {}, + default_permissions: Dict[str, Any] = {}, ) -> bool: """ Check if a user has a specific permission by checking the group permissions - and falls back to default permissions if not found in any group. + and fall back to default permissions if not found in any group. Permission keys can be hierarchical and separated by dots ('.'). """ - def get_permission(permissions: Dict[str, bool], keys: List[str]) -> bool: + def get_permission(permissions: Dict[str, Any], keys: List[str]) -> bool: """Traverse permissions dict using a list of keys (from dot-split permission_key).""" for key in keys: if key not in permissions: return False # If any part of the hierarchy is missing, deny access - permissions = permissions[key] # Go one level deeper + permissions = permissions[key] # Traverse one level deeper return bool(permissions) # Return the boolean at the final level @@ -73,7 +100,10 @@ def has_permission( if get_permission(group_permissions, permission_hierarchy): return True - # Check default permissions afterwards if the group permissions don't allow it + # Check default permissions afterward if the group permissions don't allow it + default_permissions = fill_missing_permissions( + default_permissions, DEFAULT_USER_PERMISSIONS + ) return get_permission(default_permissions, permission_hierarchy) diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index 3c53435a4..6b2329be1 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -28,9 +28,13 @@ from open_webui.socket.main import ( from open_webui.routers.tasks import ( generate_queries, generate_title, + generate_image_prompt, generate_chat_tags, ) from open_webui.routers.retrieval import process_web_search, SearchForm +from open_webui.routers.images import image_generations, GenerateImageForm + + from open_webui.utils.webhook import post_webhook @@ -486,6 +490,100 @@ async def chat_web_search_handler( return form_data +async def chat_image_generation_handler( + request: Request, form_data: dict, extra_params: dict, user +): + __event_emitter__ = extra_params["__event_emitter__"] + await __event_emitter__( + { + "type": "status", + "data": {"description": "Generating an image", "done": False}, + } + ) + + messages = form_data["messages"] + user_message = get_last_user_message(messages) + + prompt = user_message + negative_prompt = "" + + if request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION: + try: + res = await generate_image_prompt( + request, + { + "model": form_data["model"], + "messages": messages, + }, + user, + ) + + response = res["choices"][0]["message"]["content"] + + try: + bracket_start = response.find("{") + bracket_end = response.rfind("}") + 1 + + if bracket_start == -1 or bracket_end == -1: + raise Exception("No JSON object found in the response") + + response = response[bracket_start:bracket_end] + response = json.loads(response) + prompt = response.get("prompt", []) + except Exception as e: + prompt = user_message + + except Exception as e: + log.exception(e) + prompt = user_message + + system_message_content = "" + + try: + images = await image_generations( + request=request, + form_data=GenerateImageForm(**{"prompt": prompt}), + user=user, + ) + + await __event_emitter__( + { + "type": "status", + "data": {"description": "Generated an image", "done": True}, + } + ) + + for image in images: + await __event_emitter__( + { + "type": "message", + "data": {"content": f"![Generated Image]({image['url']})\n"}, + } + ) + + system_message_content = "User is shown the generated image, tell the user that the image has been generated" + except Exception as e: + log.exception(e) + await __event_emitter__( + { + "type": "status", + "data": { + "description": f"An error occured while generating an image", + "done": True, + }, + } + ) + + system_message_content = "Unable to generate an image, tell the user that an error occured" + + if system_message_content: + form_data["messages"] = add_or_update_system_message( + system_message_content, form_data["messages"] + ) + + return form_data + + async def chat_completion_files_handler( request: Request, body: dict, user: UserModel ) -> tuple[dict, dict[str, list]]: @@ -523,17 +621,28 @@ async def chat_completion_files_handler( if len(queries) == 0: queries = [get_last_user_message(body["messages"])] - sources = get_sources_from_files( - files=files, - queries=queries, - embedding_function=request.app.state.EMBEDDING_FUNCTION, - k=request.app.state.config.TOP_K, - reranking_function=request.app.state.rf, - r=request.app.state.config.RELEVANCE_THRESHOLD, - hybrid_search=request.app.state.config.ENABLE_RAG_HYBRID_SEARCH, - ) + try: + # Offload get_sources_from_files to a separate thread + loop = asyncio.get_running_loop() + with ThreadPoolExecutor() as executor: + sources = await loop.run_in_executor( + executor, + lambda: get_sources_from_files( + files=files, + queries=queries, + embedding_function=request.app.state.EMBEDDING_FUNCTION, + k=request.app.state.config.TOP_K, + reranking_function=request.app.state.rf, + r=request.app.state.config.RELEVANCE_THRESHOLD, + hybrid_search=request.app.state.config.ENABLE_RAG_HYBRID_SEARCH, + ), + ) + + except Exception as e: + log.exception(e) log.debug(f"rag_contexts:sources: {sources}") + return body, {"sources": sources} @@ -562,6 +671,10 @@ def apply_params_to_form_data(form_data, model): if "frequency_penalty" in params: form_data["frequency_penalty"] = params["frequency_penalty"] + + if "reasoning_effort" in params: + form_data["reasoning_effort"] = params["reasoning_effort"] + return form_data @@ -640,6 +753,11 @@ async def process_chat_payload(request, form_data, metadata, user, model): request, form_data, extra_params, user ) + if "image_generation" in features and features["image_generation"]: + form_data = await chat_image_generation_handler( + request, form_data, extra_params, user + ) + try: form_data, flags = await chat_completion_filter_functions_handler( request, form_data, model, extra_params @@ -770,14 +888,17 @@ async def process_chat_response( ) if res and isinstance(res, dict): - title = ( - res.get("choices", [])[0] - .get("message", {}) - .get( - "content", - message.get("content", "New Chat"), - ) - ).strip() + if len(res.get("choices", [])) == 1: + title = ( + res.get("choices", [])[0] + .get("message", {}) + .get( + "content", + message.get("content", "New Chat"), + ) + ).strip() + else: + title = None if not title: title = messages[0].get("content", "New Chat") @@ -814,11 +935,14 @@ async def process_chat_response( ) if res and isinstance(res, dict): - tags_string = ( - res.get("choices", [])[0] - .get("message", {}) - .get("content", "") - ) + if len(res.get("choices", [])) == 1: + tags_string = ( + res.get("choices", [])[0] + .get("message", {}) + .get("content", "") + ) + else: + tags_string = "" tags_string = tags_string[ tags_string.find("{") : tags_string.rfind("}") + 1 @@ -837,7 +961,7 @@ async def process_chat_response( } ) except Exception as e: - print(f"Error: {e}") + pass event_emitter = None if ( @@ -952,6 +1076,16 @@ async def process_chat_response( }, ) + # We might want to disable this by default + detect_reasoning = True + reasoning_tags = ["think", "reason", "reasoning", "thought", "Thought"] + current_tag = None + + reasoning_start_time = None + + reasoning_content = "" + ongoing_content = "" + async for line in response.body_iterator: line = line.decode("utf-8") if isinstance(line, bytes) else line data = line @@ -960,12 +1094,12 @@ async def process_chat_response( if not data.strip(): continue - # "data: " is the prefix for each event - if not data.startswith("data: "): + # "data:" is the prefix for each event + if not data.startswith("data:"): continue # Remove the prefix - data = data[len("data: ") :] + data = data[len("data:") :].strip() try: data = json.loads(data) @@ -978,7 +1112,6 @@ async def process_chat_response( "selectedModelId": data["selected_model_id"], }, ) - else: value = ( data.get("choices", [])[0] @@ -989,6 +1122,73 @@ async def process_chat_response( if value: content = f"{content}{value}" + if detect_reasoning: + for tag in reasoning_tags: + start_tag = f"<{tag}>\n" + end_tag = f"\n" + + if start_tag in content: + # Remove the start tag + content = content.replace(start_tag, "") + ongoing_content = content + + reasoning_start_time = time.time() + reasoning_content = "" + + current_tag = tag + break + + if reasoning_start_time is not None: + # Remove the last value from the content + content = content[: -len(value)] + + reasoning_content += value + + end_tag = f"\n" + if end_tag in reasoning_content: + reasoning_end_time = time.time() + reasoning_duration = int( + reasoning_end_time + - reasoning_start_time + ) + reasoning_content = ( + reasoning_content.strip( + f"<{current_tag}>\n" + ) + .strip(end_tag) + .strip() + ) + + if reasoning_content: + reasoning_display_content = "\n".join( + ( + f"> {line}" + if not line.startswith(">") + else line + ) + for line in reasoning_content.splitlines() + ) + + # Format reasoning with
tag + content = f'{ongoing_content}
\nThought for {reasoning_duration} seconds\n{reasoning_display_content}\n
\n' + else: + content = "" + + reasoning_start_time = None + else: + + reasoning_display_content = "\n".join( + ( + f"> {line}" + if not line.startswith(">") + else line + ) + for line in reasoning_content.splitlines() + ) + + # Show ongoing thought process + content = f'{ongoing_content}
\nThinking…\n{reasoning_display_content}\n
\n' + if ENABLE_REALTIME_CHAT_SAVE: # Save message in the database Chats.upsert_message_to_chat_by_id_and_message_id( @@ -1009,10 +1209,8 @@ async def process_chat_response( "data": data, } ) - except Exception as e: done = "data: [DONE]" in line - if done: pass else: diff --git a/backend/open_webui/utils/oauth.py b/backend/open_webui/utils/oauth.py index c76a1453b..1ae6d4aa7 100644 --- a/backend/open_webui/utils/oauth.py +++ b/backend/open_webui/utils/oauth.py @@ -34,7 +34,7 @@ from open_webui.config import ( JWT_EXPIRES_IN, AppConfig, ) -from open_webui.constants import ERROR_MESSAGES +from open_webui.constants import ERROR_MESSAGES, WEBHOOK_MESSAGES from open_webui.env import WEBUI_SESSION_COOKIE_SAME_SITE, WEBUI_SESSION_COOKIE_SECURE from open_webui.utils.misc import parse_duration from open_webui.utils.auth import get_password_hash, create_token @@ -63,17 +63,8 @@ auth_manager_config.JWT_EXPIRES_IN = JWT_EXPIRES_IN class OAuthManager: def __init__(self): self.oauth = OAuth() - for provider_name, provider_config in OAUTH_PROVIDERS.items(): - self.oauth.register( - name=provider_name, - client_id=provider_config["client_id"], - client_secret=provider_config["client_secret"], - server_metadata_url=provider_config["server_metadata_url"], - client_kwargs={ - "scope": provider_config["scope"], - }, - redirect_uri=provider_config["redirect_uri"], - ) + for _, provider_config in OAUTH_PROVIDERS.items(): + provider_config["register"](self.oauth) def get_client(self, provider_name): return self.oauth.create_client(provider_name) @@ -200,14 +191,14 @@ class OAuthManager: except Exception as e: log.warning(f"OAuth callback error: {e}") raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) - user_data: UserInfo = token["userinfo"] + user_data: UserInfo = token.get("userinfo") if not user_data: user_data: UserInfo = await client.userinfo(token=token) if not user_data: log.warning(f"OAuth callback failed, user data is missing: {token}") raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) - sub = user_data.get("sub") + sub = user_data.get(OAUTH_PROVIDERS[provider].get("sub_claim", "sub")) if not sub: log.warning(f"OAuth callback failed, sub is missing: {user_data}") raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) @@ -255,12 +246,20 @@ class OAuthManager: raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN) picture_claim = auth_manager_config.OAUTH_PICTURE_CLAIM - picture_url = user_data.get(picture_claim, "") + picture_url = user_data.get( + picture_claim, OAUTH_PROVIDERS[provider].get("picture_url", "") + ) if picture_url: # Download the profile image into a base64 string try: + access_token = token.get("access_token") + get_kwargs = {} + if access_token: + get_kwargs["headers"] = { + "Authorization": f"Bearer {access_token}", + } async with aiohttp.ClientSession() as session: - async with session.get(picture_url) as resp: + async with session.get(picture_url, **get_kwargs) as resp: picture = await resp.read() base64_encoded_picture = base64.b64encode( picture @@ -295,12 +294,10 @@ class OAuthManager: if auth_manager_config.WEBHOOK_URL: post_webhook( auth_manager_config.WEBHOOK_URL, - auth_manager_config.WEBHOOK_MESSAGES.USER_SIGNUP(user.name), + WEBHOOK_MESSAGES.USER_SIGNUP(user.name), { "action": "signup", - "message": auth_manager_config.WEBHOOK_MESSAGES.USER_SIGNUP( - user.name - ), + "message": WEBHOOK_MESSAGES.USER_SIGNUP(user.name), "user": user.model_dump_json(exclude_none=True), }, ) @@ -314,7 +311,7 @@ class OAuthManager: expires_delta=parse_duration(auth_manager_config.JWT_EXPIRES_IN), ) - if auth_manager_config.ENABLE_OAUTH_GROUP_MANAGEMENT: + if auth_manager_config.ENABLE_OAUTH_GROUP_MANAGEMENT and user.role != "admin": self.update_user_groups( user=user, user_data=user_data, diff --git a/backend/open_webui/utils/payload.py b/backend/open_webui/utils/payload.py index fdc62f79f..13f98ee01 100644 --- a/backend/open_webui/utils/payload.py +++ b/backend/open_webui/utils/payload.py @@ -47,6 +47,7 @@ def apply_model_params_to_body_openai(params: dict, form_data: dict) -> dict: "top_p": float, "max_tokens": int, "frequency_penalty": float, + "reasoning_effort": str, "seed": lambda x: x, "stop": lambda x: [bytes(s, "utf-8").decode("unicode_escape") for s in x], } diff --git a/backend/open_webui/utils/pdf_generator.py b/backend/open_webui/utils/pdf_generator.py index bbaf42dbb..1bb9f76b3 100644 --- a/backend/open_webui/utils/pdf_generator.py +++ b/backend/open_webui/utils/pdf_generator.py @@ -53,6 +53,7 @@ class PDFGenerator: # - https://facelessuser.github.io/pymdown-extensions/usage_notes/ # html_content = markdown(content, extensions=["pymdownx.extra"]) + content = content.replace("\n", "
") html_message = f"""
diff --git a/backend/open_webui/utils/task.py b/backend/open_webui/utils/task.py index ebb7483ba..f5ba75ebe 100644 --- a/backend/open_webui/utils/task.py +++ b/backend/open_webui/utils/task.py @@ -217,6 +217,24 @@ def tags_generation_template( return template +def image_prompt_generation_template( + template: str, messages: list[dict], user: Optional[dict] = None +) -> str: + prompt = get_last_user_message(messages) + template = replace_prompt_variable(template, prompt) + template = replace_messages_variable(template, messages) + + template = prompt_template( + template, + **( + {"user_name": user.get("name"), "user_location": user.get("location")} + if user + else {} + ), + ) + return template + + def emoji_generation_template( template: str, prompt: str, user: Optional[dict] = None ) -> str: diff --git a/backend/requirements.txt b/backend/requirements.txt index f951d78db..eecb9c4a5 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -40,14 +40,15 @@ tiktoken langchain==0.3.7 langchain-community==0.3.7 -langchain-chroma==0.1.4 fake-useragent==1.5.1 -chromadb==0.5.15 +chromadb==0.6.2 pymilvus==2.5.0 qdrant-client~=1.12.0 opensearch-py==2.7.1 + +transformers sentence-transformers==3.3.1 colbert-ai==0.2.21 einops==0.8.0 @@ -88,7 +89,7 @@ pytube==15.0.0 extract_msg pydub -duckduckgo-search~=6.3.5 +duckduckgo-search~=7.2.1 ## Google Drive google-api-python-client @@ -101,6 +102,7 @@ pytest~=8.3.2 pytest-docker~=3.1.1 googleapis-common-protos==1.63.2 +google-cloud-storage==2.19.0 ## LDAP ldap3==2.9.1 diff --git a/docs/apache.md b/docs/apache.md index ebbcc17f4..1bd920593 100644 --- a/docs/apache.md +++ b/docs/apache.md @@ -16,6 +16,9 @@ For the UI configuration, you can set up the Apache VirtualHost as follows: ProxyPass / http://server.com:3000/ nocanon ProxyPassReverse / http://server.com:3000/ + # Needed after 0.5 + ProxyPass / ws://server.com:3000/ nocanon + ProxyPassReverse / ws://server.com:3000/ ``` @@ -32,6 +35,9 @@ Enable the site first before you can request SSL: ProxyPass / http://server.com:3000/ nocanon ProxyPassReverse / http://server.com:3000/ + # Needed after 0.5 + ProxyPass / ws://server.com:3000/ nocanon + ProxyPassReverse / ws://server.com:3000/ SSLEngine on SSLCertificateFile /etc/ssl/virtualmin/170514456861234/ssl.cert diff --git a/package-lock.json b/package-lock.json index 0d78397ef..c98e814d9 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "open-webui", - "version": "0.5.4", + "version": "0.5.7", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "open-webui", - "version": "0.5.4", + "version": "0.5.7", "dependencies": { "@codemirror/lang-javascript": "^6.2.2", "@codemirror/lang-python": "^6.1.6", @@ -41,7 +41,7 @@ "i18next-resources-to-backend": "^1.2.0", "idb": "^7.1.1", "js-sha256": "^0.10.1", - "katex": "^0.16.9", + "katex": "^0.16.21", "marked": "^9.1.0", "mermaid": "^10.9.3", "paneforge": "^0.0.6", @@ -89,7 +89,7 @@ "tailwindcss": "^3.3.3", "tslib": "^2.4.1", "typescript": "^5.5.4", - "vite": "^5.3.5", + "vite": "^5.4.14", "vitest": "^1.6.0" }, "engines": { @@ -7110,13 +7110,14 @@ } }, "node_modules/katex": { - "version": "0.16.10", - "resolved": "https://registry.npmjs.org/katex/-/katex-0.16.10.tgz", - "integrity": "sha512-ZiqaC04tp2O5utMsl2TEZTXxa6WSC4yo0fv5ML++D3QZv/vx2Mct0mTlRx3O+uUkjfuAgOkzsCmq5MiUEsDDdA==", + "version": "0.16.21", + "resolved": "https://registry.npmjs.org/katex/-/katex-0.16.21.tgz", + "integrity": "sha512-XvqR7FgOHtWupfMiigNzmh+MgUVmDGU2kXZm899ZkPfcuoPuFxyHmXsgATDpFZDAXCI8tvinaVcDo8PIIJSo4A==", "funding": [ "https://opencollective.com/katex", "https://github.com/sponsors/katex" ], + "license": "MIT", "dependencies": { "commander": "^8.3.0" }, @@ -11677,9 +11678,10 @@ } }, "node_modules/vite": { - "version": "5.4.6", - "resolved": "https://registry.npmjs.org/vite/-/vite-5.4.6.tgz", - "integrity": "sha512-IeL5f8OO5nylsgzd9tq4qD2QqI0k2CQLGrWD0rCN0EQJZpBK5vJAx0I+GDkMOXxQX/OfFHMuLIx6ddAxGX/k+Q==", + "version": "5.4.14", + "resolved": "https://registry.npmjs.org/vite/-/vite-5.4.14.tgz", + "integrity": "sha512-EK5cY7Q1D8JNhSaPKVK4pwBFvaTmZxEnoKXLG/U9gmdDcihQGNzFlgIvaxezFR4glP1LsuiedwMBqCXH3wZccA==", + "license": "MIT", "dependencies": { "esbuild": "^0.21.3", "postcss": "^8.4.43", diff --git a/package.json b/package.json index 6a0f451fc..a2463d9e3 100644 --- a/package.json +++ b/package.json @@ -1,9 +1,10 @@ { "name": "open-webui", - "version": "0.5.4", + "version": "0.5.7", "private": true, "scripts": { "dev": "npm run pyodide:fetch && vite dev --host", + "dev:5050": "npm run pyodide:fetch && vite dev --port 5050", "build": "npm run pyodide:fetch && vite build", "preview": "vite preview", "check": "svelte-kit sync && svelte-check --tsconfig ./tsconfig.json", @@ -44,7 +45,7 @@ "tailwindcss": "^3.3.3", "tslib": "^2.4.1", "typescript": "^5.5.4", - "vite": "^5.3.5", + "vite": "^5.4.14", "vitest": "^1.6.0" }, "type": "module", @@ -82,7 +83,7 @@ "i18next-resources-to-backend": "^1.2.0", "idb": "^7.1.1", "js-sha256": "^0.10.1", - "katex": "^0.16.9", + "katex": "^0.16.21", "marked": "^9.1.0", "mermaid": "^10.9.3", "paneforge": "^0.0.6", diff --git a/pyproject.toml b/pyproject.toml index 63a97e69a..edd01db8f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,14 +47,14 @@ dependencies = [ "langchain==0.3.7", "langchain-community==0.3.7", - "langchain-chroma==0.1.4", "fake-useragent==1.5.1", - "chromadb==0.5.15", + "chromadb==0.6.2", "pymilvus==2.5.0", "qdrant-client~=1.12.0", "opensearch-py==2.7.1", + "transformers", "sentence-transformers==3.3.1", "colbert-ai==0.2.21", "einops==0.8.0", @@ -94,15 +94,22 @@ dependencies = [ "extract_msg", "pydub", - "duckduckgo-search~=6.3.5", + "duckduckgo-search~=7.2.1", + + "google-api-python-client", + "google-auth-httplib2", + "google-auth-oauthlib", "docker~=7.1.0", "pytest~=8.3.2", "pytest-docker~=3.1.1", + "moto[s3]>=5.0.26", "googleapis-common-protos==1.63.2", + "google-cloud-storage==2.19.0", - "ldap3==2.9.1" + "ldap3==2.9.1", + "gcp-storage-emulator>=2024.8.3", ] readme = "README.md" requires-python = ">= 3.11, < 3.13.0a1" diff --git a/src/app.css b/src/app.css index fcc438bea..dadfda78f 100644 --- a/src/app.css +++ b/src/app.css @@ -53,11 +53,11 @@ math { } .markdown-prose { - @apply prose dark:prose-invert prose-headings:font-semibold prose-hr:my-4 prose-hr:border-gray-100 prose-hr:dark:border-gray-800 prose-p:my-0 prose-img:my-1 prose-headings:my-1 prose-pre:my-0 prose-table:my-0 prose-blockquote:my-0 prose-ul:-my-0 prose-ol:-my-0 prose-li:-my-0 whitespace-pre-line; + @apply prose dark:prose-invert prose-blockquote:border-gray-100 prose-blockquote:dark:border-gray-800 prose-blockquote:border-l-2 prose-blockquote:not-italic prose-blockquote:font-normal prose-headings:font-semibold prose-hr:my-4 prose-hr:border-gray-100 prose-hr:dark:border-gray-800 prose-p:my-0 prose-img:my-1 prose-headings:my-1 prose-pre:my-0 prose-table:my-0 prose-blockquote:my-0 prose-ul:-my-0 prose-ol:-my-0 prose-li:-my-0 whitespace-pre-line; } .markdown-prose-xs { - @apply text-xs prose dark:prose-invert prose-headings:font-semibold prose-hr:my-0 prose-hr:border-gray-100 prose-hr:dark:border-gray-800 prose-p:my-0 prose-img:my-1 prose-headings:my-1 prose-pre:my-0 prose-table:my-0 prose-blockquote:my-0 prose-ul:-my-0 prose-ol:-my-0 prose-li:-my-0 whitespace-pre-line; + @apply text-xs prose dark:prose-invert prose-blockquote:border-gray-100 prose-blockquote:dark:border-gray-800 prose-blockquote:border-l-2 prose-blockquote:not-italic prose-blockquote:font-normal prose-headings:font-semibold prose-hr:my-0 prose-hr:border-gray-100 prose-hr:dark:border-gray-800 prose-p:my-0 prose-img:my-1 prose-headings:my-1 prose-pre:my-0 prose-table:my-0 prose-blockquote:my-0 prose-ul:-my-0 prose-ol:-my-0 prose-li:-my-0 whitespace-pre-line; } .markdown a { @@ -68,6 +68,19 @@ math { font-family: 'Archivo', sans-serif; } +.drag-region { + -webkit-app-region: drag; +} + +.drag-region a, +.drag-region button { + -webkit-app-region: no-drag; +} + +.no-drag-region { + -webkit-app-region: no-drag; +} + iframe { @apply rounded-lg; } @@ -102,18 +115,62 @@ li p { select { background-image: url("data:image/svg+xml;charset=utf-8,%3Csvg xmlns='http://www.w3.org/2000/svg' fill='none' viewBox='0 0 20 20'%3E%3Cpath stroke='%236B7280' stroke-linecap='round' stroke-linejoin='round' stroke-width='1.5' d='m6 8 4 4 4-4'/%3E%3C/svg%3E"); - background-position: right 0.5rem center; + background-position: right 0rem center; background-repeat: no-repeat; background-size: 1.5em 1.5em; - padding-right: 2.5rem; -webkit-print-color-adjust: exact; print-color-adjust: exact; + /* padding-right: 2.5rem; */ /* for Firefox */ -moz-appearance: none; /* for Chrome */ -webkit-appearance: none; } +@keyframes shimmer { + 0% { + background-position: 200% 0; + } + 100% { + background-position: -200% 0; + } +} + +.shimmer { + background: linear-gradient(90deg, #9a9b9e 25%, #2a2929 50%, #9a9b9e 75%); + background-size: 200% 100%; + background-clip: text; + -webkit-background-clip: text; + -webkit-text-fill-color: transparent; + animation: shimmer 4s linear infinite; + color: #818286; /* Fallback color */ +} + +:global(.dark) .shimmer { + background: linear-gradient(90deg, #818286 25%, #eae5e5 50%, #818286 75%); + background-size: 200% 100%; + background-clip: text; + -webkit-background-clip: text; + -webkit-text-fill-color: transparent; + animation: shimmer 4s linear infinite; + color: #a1a3a7; /* Darker fallback color for dark mode */ +} + +@keyframes smoothFadeIn { + 0% { + opacity: 0; + transform: translateY(-10px); + } + 100% { + opacity: 1; + transform: translateY(0); + } +} + +.status-description { + animation: smoothFadeIn 0.2s forwards; +} + .katex-mathml { display: none; } diff --git a/src/lib/apis/ollama/index.ts b/src/lib/apis/ollama/index.ts index 16eed9f21..b96567e63 100644 --- a/src/lib/apis/ollama/index.ts +++ b/src/lib/apis/ollama/index.ts @@ -360,12 +360,7 @@ export const generateChatCompletion = async (token: string = '', body: object) = return [res, controller]; }; -export const createModel = async ( - token: string, - tagName: string, - content: string, - urlIdx: string | null = null -) => { +export const createModel = async (token: string, payload: object, urlIdx: string | null = null) => { let error = null; const res = await fetch( @@ -377,10 +372,7 @@ export const createModel = async ( 'Content-Type': 'application/json', Authorization: `Bearer ${token}` }, - body: JSON.stringify({ - name: tagName, - modelfile: content - }) + body: JSON.stringify(payload) } ).catch((err) => { error = err; diff --git a/src/lib/components/admin/Functions.svelte b/src/lib/components/admin/Functions.svelte index 03da04ea7..80c7d11cd 100644 --- a/src/lib/components/admin/Functions.svelte +++ b/src/lib/components/admin/Functions.svelte @@ -61,7 +61,7 @@ const shareHandler = async (func) => { const item = await getFunctionById(localStorage.token, func.id).catch((error) => { - toast.error(error); + toast.error(`${error}`); return null; }); @@ -88,7 +88,7 @@ const cloneHandler = async (func) => { const _function = await getFunctionById(localStorage.token, func.id).catch((error) => { - toast.error(error); + toast.error(`${error}`); return null; }); @@ -104,7 +104,7 @@ const exportHandler = async (func) => { const _function = await getFunctionById(localStorage.token, func.id).catch((error) => { - toast.error(error); + toast.error(`${error}`); return null; }); @@ -118,7 +118,7 @@ const deleteHandler = async (func) => { const res = await deleteFunctionById(localStorage.token, func.id).catch((error) => { - toast.error(error); + toast.error(`${error}`); return null; }); @@ -132,7 +132,7 @@ const toggleGlobalHandler = async (func) => { const res = await toggleGlobalById(localStorage.token, func.id).catch((error) => { - toast.error(error); + toast.error(`${error}`); }); if (res) { @@ -418,7 +418,7 @@ class="flex text-xs items-center space-x-1 px-3 py-1.5 rounded-xl bg-gray-50 hover:bg-gray-100 dark:bg-gray-800 dark:hover:bg-gray-700 dark:text-gray-200 transition" on:click={async () => { const _functions = await exportFunctions(localStorage.token).catch((error) => { - toast.error(error); + toast.error(`${error}`); return null; }); @@ -510,7 +510,7 @@ for (const func of _functions) { const res = await createNewFunction(localStorage.token, func).catch((error) => { - toast.error(error); + toast.error(`${error}`); return null; }); } diff --git a/src/lib/components/admin/Settings/Connections.svelte b/src/lib/components/admin/Settings/Connections.svelte index ddc19bb8f..35e6e0293 100644 --- a/src/lib/components/admin/Settings/Connections.svelte +++ b/src/lib/components/admin/Settings/Connections.svelte @@ -43,9 +43,8 @@ const updateOpenAIHandler = async () => { if (ENABLE_OPENAI_API !== null) { - OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS.filter( - (url, urlIdx) => OPENAI_API_BASE_URLS.indexOf(url) === urlIdx && url !== '' - ).map((url) => url.replace(/\/$/, '')); + // Remove trailing slashes + OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS.map((url) => url.replace(/\/$/, '')); // Check if API KEYS length is same than API URLS length if (OPENAI_API_KEYS.length !== OPENAI_API_BASE_URLS.length) { @@ -69,7 +68,7 @@ OPENAI_API_KEYS: OPENAI_API_KEYS, OPENAI_API_CONFIGS: OPENAI_API_CONFIGS }).catch((error) => { - toast.error(error); + toast.error(`${error}`); }); if (res) { @@ -81,24 +80,15 @@ const updateOllamaHandler = async () => { if (ENABLE_OLLAMA_API !== null) { - // Remove duplicate URLs - OLLAMA_BASE_URLS = OLLAMA_BASE_URLS.filter( - (url, urlIdx) => OLLAMA_BASE_URLS.indexOf(url) === urlIdx && url !== '' - ).map((url) => url.replace(/\/$/, '')); - - console.log(OLLAMA_BASE_URLS); - - if (OLLAMA_BASE_URLS.length === 0) { - ENABLE_OLLAMA_API = false; - toast.info($i18n.t('Ollama API disabled')); - } + // Remove trailing slashes + OLLAMA_BASE_URLS = OLLAMA_BASE_URLS.map((url) => url.replace(/\/$/, '')); const res = await updateOllamaConfig(localStorage.token, { ENABLE_OLLAMA_API: ENABLE_OLLAMA_API, OLLAMA_BASE_URLS: OLLAMA_BASE_URLS, OLLAMA_API_CONFIGS: OLLAMA_API_CONFIGS }).catch((error) => { - toast.error(error); + toast.error(`${error}`); }); if (res) { @@ -111,14 +101,14 @@ const addOpenAIConnectionHandler = async (connection) => { OPENAI_API_BASE_URLS = [...OPENAI_API_BASE_URLS, connection.url]; OPENAI_API_KEYS = [...OPENAI_API_KEYS, connection.key]; - OPENAI_API_CONFIGS[connection.url] = connection.config; + OPENAI_API_CONFIGS[OPENAI_API_BASE_URLS.length] = connection.config; await updateOpenAIHandler(); }; const addOllamaConnectionHandler = async (connection) => { OLLAMA_BASE_URLS = [...OLLAMA_BASE_URLS, connection.url]; - OLLAMA_API_CONFIGS[connection.url] = connection.config; + OLLAMA_API_CONFIGS[OLLAMA_BASE_URLS.length] = connection.config; await updateOllamaHandler(); }; @@ -148,15 +138,17 @@ OLLAMA_API_CONFIGS = ollamaConfig.OLLAMA_API_CONFIGS; if (ENABLE_OPENAI_API) { - for (const url of OPENAI_API_BASE_URLS) { - if (!OPENAI_API_CONFIGS[url]) { - OPENAI_API_CONFIGS[url] = {}; + // get url and idx + for (const [idx, url] of OPENAI_API_BASE_URLS.entries()) { + if (!OPENAI_API_CONFIGS[idx]) { + // Legacy support, url as key + OPENAI_API_CONFIGS[idx] = OPENAI_API_CONFIGS[url] || {}; } } OPENAI_API_BASE_URLS.forEach(async (url, idx) => { - OPENAI_API_CONFIGS[url] = OPENAI_API_CONFIGS[url] || {}; - if (!(OPENAI_API_CONFIGS[url]?.enable ?? true)) { + OPENAI_API_CONFIGS[idx] = OPENAI_API_CONFIGS[idx] || {}; + if (!(OPENAI_API_CONFIGS[idx]?.enable ?? true)) { return; } const res = await getOpenAIModels(localStorage.token, idx); @@ -167,9 +159,9 @@ } if (ENABLE_OLLAMA_API) { - for (const url of OLLAMA_BASE_URLS) { - if (!OLLAMA_API_CONFIGS[url]) { - OLLAMA_API_CONFIGS[url] = {}; + for (const [idx, url] of OLLAMA_BASE_URLS.entries()) { + if (!OLLAMA_API_CONFIGS[idx]) { + OLLAMA_API_CONFIGS[idx] = OLLAMA_API_CONFIGS[url] || {}; } } } @@ -242,7 +234,7 @@ pipeline={pipelineUrls[url] ? true : false} bind:url bind:key={OPENAI_API_KEYS[idx]} - bind:config={OPENAI_API_CONFIGS[url]} + bind:config={OPENAI_API_CONFIGS[idx]} onSubmit={() => { updateOpenAIHandler(); }} @@ -251,6 +243,8 @@ (url, urlIdx) => idx !== urlIdx ); OPENAI_API_KEYS = OPENAI_API_KEYS.filter((key, keyIdx) => idx !== keyIdx); + + delete OPENAI_API_CONFIGS[idx]; }} /> {/each} @@ -301,13 +295,14 @@ {#each OLLAMA_BASE_URLS as url, idx} { updateOllamaHandler(); }} onDelete={() => { OLLAMA_BASE_URLS = OLLAMA_BASE_URLS.filter((url, urlIdx) => idx !== urlIdx); + delete OLLAMA_API_CONFIGS[idx]; }} /> {/each} diff --git a/src/lib/components/admin/Settings/Connections/AddConnectionModal.svelte b/src/lib/components/admin/Settings/Connections/AddConnectionModal.svelte index 3f24dc6d7..a8726a546 100644 --- a/src/lib/components/admin/Settings/Connections/AddConnectionModal.svelte +++ b/src/lib/components/admin/Settings/Connections/AddConnectionModal.svelte @@ -37,7 +37,7 @@ const verifyOllamaHandler = async () => { const res = await verifyOllamaConnection(localStorage.token, url, key).catch((error) => { - toast.error(error); + toast.error(`${error}`); }); if (res) { @@ -47,7 +47,7 @@ const verifyOpenAIHandler = async () => { const res = await verifyOpenAIConnection(localStorage.token, url, key).catch((error) => { - toast.error(error); + toast.error(`${error}`); }); if (res) { diff --git a/src/lib/components/admin/Settings/Connections/ManageOllamaModal.svelte b/src/lib/components/admin/Settings/Connections/ManageOllamaModal.svelte index 220214ed1..543db060e 100644 --- a/src/lib/components/admin/Settings/Connections/ManageOllamaModal.svelte +++ b/src/lib/components/admin/Settings/Connections/ManageOllamaModal.svelte @@ -3,527 +3,13 @@ import { getContext, onMount } from 'svelte'; const i18n = getContext('i18n'); - import { WEBUI_NAME, models, MODEL_DOWNLOAD_POOL, user, config } from '$lib/stores'; - import { splitStream } from '$lib/utils'; - - import { - createModel, - deleteModel, - downloadModel, - getOllamaUrls, - getOllamaVersion, - pullModel, - uploadModel, - getOllamaConfig, - getOllamaModels - } from '$lib/apis/ollama'; - import { getModels } from '$lib/apis'; - import Modal from '$lib/components/common/Modal.svelte'; - import Tooltip from '$lib/components/common/Tooltip.svelte'; - import ModelDeleteConfirmDialog from '$lib/components/common/ConfirmDialog.svelte'; - import Spinner from '$lib/components/common/Spinner.svelte'; + import ManageOllama from '../Models/Manage/ManageOllama.svelte'; export let show = false; - - let modelUploadInputElement: HTMLInputElement; - let showModelDeleteConfirm = false; - - let loading = true; - - // Models export let urlIdx: number | null = null; - - let ollamaModels = []; - - let updateModelId = null; - let updateProgress = null; - let showExperimentalOllama = false; - - const MAX_PARALLEL_DOWNLOADS = 3; - - let modelTransferring = false; - let modelTag = ''; - - let createModelLoading = false; - let createModelTag = ''; - let createModelContent = ''; - let createModelDigest = ''; - let createModelPullProgress = null; - - let digest = ''; - let pullProgress = null; - - let modelUploadMode = 'file'; - let modelInputFile: File[] | null = null; - let modelFileUrl = ''; - let modelFileContent = `TEMPLATE """{{ .System }}\nUSER: {{ .Prompt }}\nASSISTANT: """\nPARAMETER num_ctx 4096\nPARAMETER stop ""\nPARAMETER stop "USER:"\nPARAMETER stop "ASSISTANT:"`; - let modelFileDigest = ''; - - let uploadProgress = null; - let uploadMessage = ''; - - let deleteModelTag = ''; - - const updateModelsHandler = async () => { - for (const model of ollamaModels) { - console.log(model); - - updateModelId = model.id; - const [res, controller] = await pullModel(localStorage.token, model.id, urlIdx).catch( - (error) => { - toast.error(error); - return null; - } - ); - - if (res) { - const reader = res.body - .pipeThrough(new TextDecoderStream()) - .pipeThrough(splitStream('\n')) - .getReader(); - - while (true) { - try { - const { value, done } = await reader.read(); - if (done) break; - - let lines = value.split('\n'); - - for (const line of lines) { - if (line !== '') { - let data = JSON.parse(line); - - console.log(data); - if (data.error) { - throw data.error; - } - if (data.detail) { - throw data.detail; - } - if (data.status) { - if (data.digest) { - updateProgress = 0; - if (data.completed) { - updateProgress = Math.round((data.completed / data.total) * 1000) / 10; - } else { - updateProgress = 100; - } - } else { - toast.success(data.status); - } - } - } - } - } catch (error) { - console.log(error); - } - } - } - } - - updateModelId = null; - updateProgress = null; - }; - - const pullModelHandler = async () => { - const sanitizedModelTag = modelTag.trim().replace(/^ollama\s+(run|pull)\s+/, ''); - console.log($MODEL_DOWNLOAD_POOL); - if ($MODEL_DOWNLOAD_POOL[sanitizedModelTag]) { - toast.error( - $i18n.t(`Model '{{modelTag}}' is already in queue for downloading.`, { - modelTag: sanitizedModelTag - }) - ); - return; - } - if (Object.keys($MODEL_DOWNLOAD_POOL).length === MAX_PARALLEL_DOWNLOADS) { - toast.error( - $i18n.t('Maximum of 3 models can be downloaded simultaneously. Please try again later.') - ); - return; - } - - const [res, controller] = await pullModel(localStorage.token, sanitizedModelTag, urlIdx).catch( - (error) => { - toast.error(error); - return null; - } - ); - - if (res) { - const reader = res.body - .pipeThrough(new TextDecoderStream()) - .pipeThrough(splitStream('\n')) - .getReader(); - - MODEL_DOWNLOAD_POOL.set({ - ...$MODEL_DOWNLOAD_POOL, - [sanitizedModelTag]: { - ...$MODEL_DOWNLOAD_POOL[sanitizedModelTag], - abortController: controller, - reader, - done: false - } - }); - - while (true) { - try { - const { value, done } = await reader.read(); - if (done) break; - - let lines = value.split('\n'); - - for (const line of lines) { - if (line !== '') { - let data = JSON.parse(line); - console.log(data); - if (data.error) { - throw data.error; - } - if (data.detail) { - throw data.detail; - } - - if (data.status) { - if (data.digest) { - let downloadProgress = 0; - if (data.completed) { - downloadProgress = Math.round((data.completed / data.total) * 1000) / 10; - } else { - downloadProgress = 100; - } - - MODEL_DOWNLOAD_POOL.set({ - ...$MODEL_DOWNLOAD_POOL, - [sanitizedModelTag]: { - ...$MODEL_DOWNLOAD_POOL[sanitizedModelTag], - pullProgress: downloadProgress, - digest: data.digest - } - }); - } else { - toast.success(data.status); - - MODEL_DOWNLOAD_POOL.set({ - ...$MODEL_DOWNLOAD_POOL, - [sanitizedModelTag]: { - ...$MODEL_DOWNLOAD_POOL[sanitizedModelTag], - done: data.status === 'success' - } - }); - } - } - } - } - } catch (error) { - console.log(error); - if (typeof error !== 'string') { - error = error.message; - } - - toast.error(error); - // opts.callback({ success: false, error, modelName: opts.modelName }); - } - } - - console.log($MODEL_DOWNLOAD_POOL[sanitizedModelTag]); - - if ($MODEL_DOWNLOAD_POOL[sanitizedModelTag].done) { - toast.success( - $i18n.t(`Model '{{modelName}}' has been successfully downloaded.`, { - modelName: sanitizedModelTag - }) - ); - - models.set(await getModels(localStorage.token)); - } else { - toast.error($i18n.t('Download canceled')); - } - - delete $MODEL_DOWNLOAD_POOL[sanitizedModelTag]; - - MODEL_DOWNLOAD_POOL.set({ - ...$MODEL_DOWNLOAD_POOL - }); - } - - modelTag = ''; - modelTransferring = false; - }; - - const uploadModelHandler = async () => { - modelTransferring = true; - - let uploaded = false; - let fileResponse = null; - let name = ''; - - if (modelUploadMode === 'file') { - const file = modelInputFile ? modelInputFile[0] : null; - - if (file) { - uploadMessage = 'Uploading...'; - - fileResponse = await uploadModel(localStorage.token, file, urlIdx).catch((error) => { - toast.error(error); - return null; - }); - } - } else { - uploadProgress = 0; - fileResponse = await downloadModel(localStorage.token, modelFileUrl, urlIdx).catch( - (error) => { - toast.error(error); - return null; - } - ); - } - - if (fileResponse && fileResponse.ok) { - const reader = fileResponse.body - .pipeThrough(new TextDecoderStream()) - .pipeThrough(splitStream('\n')) - .getReader(); - - while (true) { - const { value, done } = await reader.read(); - if (done) break; - - try { - let lines = value.split('\n'); - - for (const line of lines) { - if (line !== '') { - let data = JSON.parse(line.replace(/^data: /, '')); - - if (data.progress) { - if (uploadMessage) { - uploadMessage = ''; - } - uploadProgress = data.progress; - } - - if (data.error) { - throw data.error; - } - - if (data.done) { - modelFileDigest = data.blob; - name = data.name; - uploaded = true; - } - } - } - } catch (error) { - console.log(error); - } - } - } else { - const error = await fileResponse?.json(); - toast.error(error?.detail ?? error); - } - - if (uploaded) { - const res = await createModel( - localStorage.token, - `${name}:latest`, - `FROM @${modelFileDigest}\n${modelFileContent}` - ); - - if (res && res.ok) { - const reader = res.body - .pipeThrough(new TextDecoderStream()) - .pipeThrough(splitStream('\n')) - .getReader(); - - while (true) { - const { value, done } = await reader.read(); - if (done) break; - - try { - let lines = value.split('\n'); - - for (const line of lines) { - if (line !== '') { - console.log(line); - let data = JSON.parse(line); - console.log(data); - - if (data.error) { - throw data.error; - } - if (data.detail) { - throw data.detail; - } - - if (data.status) { - if ( - !data.digest && - !data.status.includes('writing') && - !data.status.includes('sha256') - ) { - toast.success(data.status); - } else { - if (data.digest) { - digest = data.digest; - - if (data.completed) { - pullProgress = Math.round((data.completed / data.total) * 1000) / 10; - } else { - pullProgress = 100; - } - } - } - } - } - } - } catch (error) { - console.log(error); - toast.error(error); - } - } - } - } - - modelFileUrl = ''; - - if (modelUploadInputElement) { - modelUploadInputElement.value = ''; - } - modelInputFile = null; - modelTransferring = false; - uploadProgress = null; - - models.set(await getModels(localStorage.token)); - }; - - const deleteModelHandler = async () => { - const res = await deleteModel(localStorage.token, deleteModelTag, urlIdx).catch((error) => { - toast.error(error); - }); - - if (res) { - toast.success($i18n.t(`Deleted {{deleteModelTag}}`, { deleteModelTag })); - } - - deleteModelTag = ''; - models.set(await getModels(localStorage.token)); - }; - - const cancelModelPullHandler = async (model: string) => { - const { reader, abortController } = $MODEL_DOWNLOAD_POOL[model]; - if (abortController) { - abortController.abort(); - } - if (reader) { - await reader.cancel(); - delete $MODEL_DOWNLOAD_POOL[model]; - MODEL_DOWNLOAD_POOL.set({ - ...$MODEL_DOWNLOAD_POOL - }); - await deleteModel(localStorage.token, model); - toast.success(`${model} download has been canceled`); - } - }; - - const createModelHandler = async () => { - createModelLoading = true; - const res = await createModel( - localStorage.token, - createModelTag, - createModelContent, - urlIdx - ).catch((error) => { - toast.error(error); - return null; - }); - - if (res && res.ok) { - const reader = res.body - .pipeThrough(new TextDecoderStream()) - .pipeThrough(splitStream('\n')) - .getReader(); - - while (true) { - const { value, done } = await reader.read(); - if (done) break; - - try { - let lines = value.split('\n'); - - for (const line of lines) { - if (line !== '') { - console.log(line); - let data = JSON.parse(line); - console.log(data); - - if (data.error) { - throw data.error; - } - if (data.detail) { - throw data.detail; - } - - if (data.status) { - if ( - !data.digest && - !data.status.includes('writing') && - !data.status.includes('sha256') - ) { - toast.success(data.status); - } else { - if (data.digest) { - createModelDigest = data.digest; - - if (data.completed) { - createModelPullProgress = - Math.round((data.completed / data.total) * 1000) / 10; - } else { - createModelPullProgress = 100; - } - } - } - } - } - } - } catch (error) { - console.log(error); - toast.error(error); - } - } - } - - models.set(await getModels(localStorage.token)); - - createModelLoading = false; - - createModelTag = ''; - createModelContent = ''; - createModelDigest = ''; - createModelPullProgress = null; - }; - - const init = async () => { - loading = true; - ollamaModels = await getOllamaModels(localStorage.token, urlIdx); - - console.log(ollamaModels); - loading = false; - }; - - $: if (show) { - init(); - } - { - deleteModelHandler(); - }} -/> -
@@ -533,31 +19,6 @@
{$i18n.t('Manage Ollama')}
- -
- - - -
-
- -
- {$i18n.t('To access the available model names for downloading,')} - {$i18n.t('click here.')} -
- - {#if Object.keys($MODEL_DOWNLOAD_POOL).length > 0} - {#each Object.keys($MODEL_DOWNLOAD_POOL) as model} - {#if 'pullProgress' in $MODEL_DOWNLOAD_POOL[model]} -
-
{model}
-
-
-
-
- {$MODEL_DOWNLOAD_POOL[model].pullProgress ?? 0}% -
-
- - - - -
- {#if 'digest' in $MODEL_DOWNLOAD_POOL[model]} -
- {$MODEL_DOWNLOAD_POOL[model].digest} -
- {/if} -
-
- {/if} - {/each} - {/if} -
- -
-
{$i18n.t('Delete a model')}
-
-
- -
- -
-
- -
-
{$i18n.t('Create a model')}
-
-
- - -