diff --git a/.dockerignore b/.dockerignore index d7e716758..2b4f7b5fc 100644 --- a/.dockerignore +++ b/.dockerignore @@ -16,4 +16,5 @@ _old uploads .ipynb_checkpoints **/*.db -_test \ No newline at end of file +_test +backend/data/* diff --git a/.github/workflows/deploy-to-hf-spaces.yml b/.github/workflows/deploy-to-hf-spaces.yml index aa8bbcfce..7fc66acf5 100644 --- a/.github/workflows/deploy-to-hf-spaces.yml +++ b/.github/workflows/deploy-to-hf-spaces.yml @@ -28,6 +28,8 @@ jobs: steps: - name: Checkout repository uses: actions/checkout@v4 + with: + lfs: true - name: Remove git history run: rm -rf .git @@ -52,7 +54,9 @@ jobs: - name: Set up Git and push to Space run: | git init --initial-branch=main + git lfs install git lfs track "*.ttf" + git lfs track "*.jpg" rm demo.gif git add . git commit -m "GitHub deploy: ${{ github.sha }}" diff --git a/CHANGELOG.md b/CHANGELOG.md index 5ca694e65..f7416361d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,10 +5,177 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +### Added +- **🌐 Enhanced Translations**: Added Slovak language, improved Czech language. + +## [0.4.8] - 2024-12-07 + +### Added + +- **🔓 Bypass Model Access Control**: Introduced the 'BYPASS_MODEL_ACCESS_CONTROL' environment variable. Easily bypass model access controls for user roles when access control isn't required, simplifying workflows for trusted environments. +- **📝 Markdown in Banners**: Now supports markdown for banners, enabling richer, more visually engaging announcements. +- **🌐 Internationalization Updates**: Enhanced translations across multiple languages, further improving accessibility and global user experience. +- **🎨 Styling Enhancements**: General UI style refinements for a cleaner and more polished interface. +- **📋 Rich Text Reliability**: Improved the reliability and stability of rich text input across chats for smoother interactions. + +### Fixed + +- **💡 Tailwind Build Issue**: Resolved a breaking bug caused by Tailwind, ensuring smoother builds and overall system reliability. +- **📚 Knowledge Collection Query Fix**: Addressed API endpoint issues with querying knowledge collections, ensuring accurate and reliable information retrieval. + +## [0.4.7] - 2024-12-01 + +### Added + +- **✨ Prompt Input Auto-Completion**: Type a prompt and let AI intelligently suggest and complete your inputs. Simply press 'Tab' or swipe right on mobile to confirm. Available only with Rich Text Input (default setting). Disable via Admin Settings for full control. +- **🌍 Improved Translations**: Enhanced localization for multiple languages, ensuring a more polished and accessible experience for international users. + +### Fixed + +- **🛠️ Tools Export Issue**: Resolved a critical issue where exporting tools wasn’t functioning, restoring seamless export capabilities. +- **🔗 Model ID Registration**: Fixed an issue where model IDs weren’t registering correctly in the model editor, ensuring reliable model setup and tracking. +- **🖋️ Textarea Auto-Expansion**: Corrected a bug where textareas didn’t expand automatically on certain browsers, improving usability for multi-line inputs. +- **🔧 Ollama Embed Endpoint**: Addressed the /ollama/embed endpoint malfunction, ensuring consistent performance and functionality. + +### Changed + +- **🎨 Knowledge Base Styling**: Refined knowledge base visuals for a cleaner, more modern look, laying the groundwork for further enhancements in upcoming releases. + +## [0.4.6] - 2024-11-26 + +### Added + +- **🌍 Enhanced Translations**: Various language translations improved to make the WebUI more accessible and user-friendly worldwide. + +### Fixed + +- **✏️ Textarea Shifting Bug**: Resolved the issue where the textarea shifted unexpectedly, ensuring a smoother typing experience. +- **⚙️ Model Configuration Modal**: Fixed the issue where the models configuration modal introduced in 0.4.5 wasn’t working for some users. +- **🔍 Legacy Query Support**: Restored functionality for custom query generation in RAG when using legacy prompts, ensuring both default and custom templates now work seamlessly. +- **⚡ Improved General Reliability**: Various minor fixes improve platform stability and ensure a smoother overall experience across workflows. + +## [0.4.5] - 2024-11-26 + +### Added + +- **🎨 Model Order/Defaults Reintroduced**: Brought back the ability to set model order and default models, now configurable via Admin Settings > Models > Configure (Gear Icon). + +### Fixed + +- **🔍 Query Generation Issue**: Resolved an error in web search query generation, enhancing search accuracy and ensuring smoother search workflows. +- **📏 Textarea Auto Height Bug**: Fixed a layout issue where textarea input height was shifting unpredictably, particularly when editing system prompts. +- **🔑 Ollama Authentication**: Corrected an issue with Ollama’s authorization headers, guaranteeing reliable authentication across all endpoints. +- **⚙️ Missing Min_P Save**: Resolved an issue where the 'min_p' parameter was not being saved in configurations. +- **🛠️ Tools Description**: Fixed a key issue that omitted tool descriptions in tools payload. + +## [0.4.4] - 2024-11-22 + +### Added + +- **🌐 Translation Updates**: Refreshed Catalan, Brazilian Portuguese, German, and Ukrainian translations, further enhancing the platform's accessibility and improving the experience for international users. + +### Fixed + +- **📱 Mobile Controls Visibility**: Resolved an issue where the controls button was not displaying on the new chats page for mobile users, ensuring smoother navigation and functionality on smaller screens. +- **📷 LDAP Profile Image Issue**: Fixed an LDAP integration bug related to profile images, ensuring seamless authentication and a reliable login experience for users. +- **⏳ RAG Query Generation Issue**: Addressed a significant problem where RAG query generation occurred unnecessarily without attached files, drastically improving speed and reducing delays during chat completions. + +### Changed + +- **⚙️ Legacy Event Emitter Support**: Reintroduced compatibility with legacy "citation" types for event emitters in tools and functions, providing smoother workflows and broader tool support for users. + +## [0.4.3] - 2024-11-21 + +### Added + +- **📚 Inline Citations for RAG Results**: Get seamless inline citations for Retrieval-Augmented Generation (RAG) responses using the default RAG prompt. Note: This feature only supports newly uploaded files, improving traceability and providing source clarity. +- **🎨 Better Rich Text Input Support**: Enjoy smoother and more reliable rich text formatting for chats, enhancing communication quality. +- **⚡ Faster Model Retrieval**: Implemented caching optimizations for faster model loading, providing a noticeable speed boost across workflows. Further improvements are on the way! + +### Fixed + +- **🔗 Pipelines Feature Restored**: Resolved a critical issue that previously prevented Pipelines from functioning, ensuring seamless workflows. +- **✏️ Missing Suffix Field in Ollama Form**: Added the missing "suffix" field to the Ollama generate form, enhancing customization options. + +### Changed + +- **🗂️ Renamed "Citations" to "Sources"**: Improved clarity and consistency by renaming the "citations" field to "sources" in messages. + +## [0.4.2] - 2024-11-20 + +### Fixed + +- **📁 Knowledge Files Visibility Issue**: Resolved the bug preventing individual files in knowledge collections from displaying when referenced with '#'. +- **🔗 OpenAI Endpoint Prefix**: Fixed the issue where certain OpenAI connections that deviate from the official API spec weren’t working correctly with prefixes. +- **⚔️ Arena Model Access Control**: Corrected an issue where arena model access control settings were not being saved. +- **🔧 Usage Capability Selector**: Fixed the broken usage capabilities selector in the model editor. + +## [0.4.1] - 2024-11-19 + +### Added + +- **📊 Enhanced Feedback System**: Introduced a detailed 1-10 rating scale for feedback alongside thumbs up/down, preparing for more precise model fine-tuning and improving feedback quality. +- **ℹ️ Tool Descriptions on Hover**: Easily access tool descriptions by hovering over the message input, providing a smoother workflow with more context when utilizing tools. + +### Fixed + +- **🗑️ Graceful Handling of Deleted Users**: Resolved an issue where deleted users caused workspace items (models, knowledge, prompts, tools) to fail, ensuring reliable workspace loading. +- **🔑 API Key Creation**: Fixed an issue preventing users from creating new API keys, restoring secure and seamless API management. +- **🔗 HTTPS Proxy Fix**: Corrected HTTPS proxy issues affecting the '/api/v1/models/' endpoint, ensuring smoother, uninterrupted model management. + +## [0.4.0] - 2024-11-19 + +### Added + +- **👥 User Groups**: You can now create and manage user groups, making user organization seamless. +- **🔐 Group-Based Access Control**: Set granular access to models, knowledge, prompts, and tools based on user groups, allowing for more controlled and secure environments. +- **🛠️ Group-Based User Permissions**: Easily manage workspace permissions. Grant users the ability to upload files, delete, edit, or create temporary chats, as well as define their ability to create models, knowledge, prompts, and tools. +- **🔑 LDAP Support**: Newly introduced LDAP authentication adds robust security and scalability to user management. +- **🌐 Enhanced OpenAI-Compatible Connections**: Added prefix ID support to avoid model ID clashes, with explicit model ID support for APIs lacking '/models' endpoint support, ensuring smooth operation with custom setups. +- **🔐 Ollama API Key Support**: Now manage credentials for Ollama when set behind proxies, including the option to utilize prefix ID for proper distinction across multiple Ollama instances. +- **🔄 Connection Enable/Disable Toggle**: Easily enable or disable individual OpenAI and Ollama connections as needed. +- **🎨 Redesigned Model Workspace**: Freshly redesigned to improve usability for managing models across users and groups. +- **🎨 Redesigned Prompt Workspace**: A fresh UI to conveniently organize and manage prompts. +- **🧩 Sorted Functions Workspace**: Functions are now automatically categorized by type (Action, Filter, Pipe), streamlining management. +- **💻 Redesigned Collaborative Workspace**: Enhanced support for multiple users contributing to models, knowledge, prompts, or tools, improving collaboration. +- **🔧 Auto-Selected Tools in Model Editor**: Tools enabled through the model editor are now automatically selected, whereas previously it only gave users the option to enable the tool, reducing manual steps and enhancing efficiency. +- **🔔 Web Search & Tools Indicator**: A clear indication now shows when web search or tools are active, reducing confusion. +- **🔑 Toggle API Key Auth**: Tighten security by easily enabling or disabling API key authentication option for Open WebUI. +- **🗂️ Agentic Retrieval**: Improve RAG accuracy via smart pre-processing of chat history to determine the best queries before retrieval. +- **📁 Large Text as File Option**: Optionally convert large pasted text into a file upload, keeping the chat interface cleaner. +- **🗂️ Toggle Citations for Models**: Ability to disable citations has been introduced in the model editor. +- **🔍 User Settings Search**: Quickly search for settings fields, improving ease of use and navigation. +- **🗣️ Experimental SpeechT5 TTS**: Local SpeechT5 support added for improved text-to-speech capabilities. +- **🔄 Unified Reset for Models**: A one-click option has been introduced to reset and remove all models from the Admin Settings. +- **🛠️ Initial Setup Wizard**: The setup process now explicitly informs users that they are creating an admin account during the first-time setup, ensuring clarity. Previously, users encountered the login page right away without this distinction. +- **🌐 Enhanced Translations**: Several language translations, including Ukrainian, Norwegian, and Brazilian Portuguese, were refined for better localization. + +### Fixed + +- **🎥 YouTube Video Attachments**: Fixed issues preventing proper loading and attachment of YouTube videos as files. +- **🔄 Shared Chat Update**: Corrected issues where shared chats were not updating, improving collaboration consistency. +- **🔍 DuckDuckGo Rate Limit Fix**: Addressed issues with DuckDuckGo search integration, enhancing search stability and performance when operating within rate limits. +- **🧾 Citations Relevance Fix**: Adjusted the relevance percentage calculation for citations, so that Open WebUI properly reflect the accuracy of a retrieved document in RAG, ensuring users get clearer insights into sources. +- **🔑 Jina Search API Key Requirement**: Added the option to input an API key for Jina Search, ensuring smooth functionality as keys are now mandatory. + +### Changed + +- **🛠️ Functions Moved to Admin Panel**: As Functions operate as advanced plugins, they are now accessible from the Admin Panel instead of the workspace. +- **🛠️ Manage Ollama Connections**: The "Models" section in Admin Settings has been relocated to Admin Settings > "Connections" > Ollama Connections. You can now manage Ollama instances via a dedicated "Manage Ollama" modal from "Connections", streamlining the setup and configuration of Ollama models. +- **📊 Base Models in Admin Settings**: Admins can now find all base models, both connections or functions, in the "Models" Admin setting. Global model accessibility can be enabled or disabled here. Models are private by default, requiring explicit permission assignment for user access. +- **📌 Sticky Model Selection for New Chats**: The model chosen from a previous chat now persists when creating a new chat. If you click "New Chat" again from the new chat page, it will revert to your default model. +- **🎨 Design Refactoring**: Overall design refinements across the platform have been made, providing a more cohesive and polished user experience. + +### Removed + +- **📂 Model List Reordering**: Temporarily removed and will be reintroduced in upcoming user group settings improvements. +- **⚙️ Default Model Setting**: Removed the ability to set a default model for users, will be reintroduced with user group settings in the future. + ## [0.3.35] - 2024-10-26 ### Added +- **🌐 Translation Update**: Added translation labels in the SearchInput and CreateCollection components and updated Brazilian Portuguese translation (pt-BR) - **📁 Robust File Handling**: Enhanced file input handling for chat. If the content extraction fails or is empty, users will now receive a clear warning, preventing silent failures and ensuring you always know what's happening with your uploads. - **🌍 New Language Support**: Introduced Hungarian translations and updated French translations, expanding the platform's language accessibility for a more global user base. diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md index 37ac5263c..b1c7b56a3 100644 --- a/CODE_OF_CONDUCT.md +++ b/CODE_OF_CONDUCT.md @@ -2,76 +2,98 @@ ## Our Pledge -We as members, contributors, and leaders pledge to make participation in our -community a harassment-free experience for everyone, regardless of age, body -size, visible or invisible disability, ethnicity, sex characteristics, gender -identity and expression, level of experience, education, socio-economic status, -nationality, personal appearance, race, religion, or sexual identity -and orientation. +As members, contributors, and leaders of this community, we pledge to make participation in our open-source project a harassment-free experience for everyone, regardless of age, body size, visible or invisible disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, religion, or sexual identity and orientation. -We pledge to act and interact in ways that contribute to an open, welcoming, diverse, inclusive, and healthy community. +We are committed to creating and maintaining an open, respectful, and professional environment where positive contributions and meaningful discussions can flourish. By participating in this project, you agree to uphold these values and align your behavior to the standards outlined in this Code of Conduct. + +## Why These Standards Are Important + +Open-source projects rely on a community of volunteers dedicating their time, expertise, and effort toward a shared goal. These projects are inherently collaborative but also fragile, as the success of the project depends on the goodwill, energy, and productivity of those involved. + +Maintaining a positive and respectful environment is essential to safeguarding the integrity of this project and protecting contributors' efforts. Behavior that disrupts this atmosphere—whether through hostility, entitlement, or unprofessional conduct—can severely harm the morale and productivity of the community. **Strict enforcement of these standards ensures a safe and supportive space for meaningful collaboration.** + +This is a community where **respect and professionalism are mandatory.** Violations of these standards will result in **zero tolerance** and immediate enforcement to prevent disruption and ensure the well-being of all participants. ## Our Standards -Examples of behavior that contribute to a positive environment for our community include: +Examples of behavior that contribute to a positive and professional community include: -- Demonstrating empathy and kindness toward other people -- Being respectful of differing opinions, viewpoints, and experiences -- Giving and gracefully accepting constructive feedback -- Accepting responsibility and apologizing to those affected by our mistakes, and learning from the experience -- Focusing on what is best not just for us as individuals, but for the overall community +- **Respecting others.** Be considerate, listen actively, and engage with empathy toward others' viewpoints and experiences. +- **Constructive feedback.** Provide actionable, thoughtful, and respectful feedback that helps improve the project and encourages collaboration. Avoid unproductive negativity or hypercriticism. +- **Recognizing volunteer contributions.** Appreciate that contributors dedicate their free time and resources selflessly. Approach them with gratitude and patience. +- **Focusing on shared goals.** Collaborate in ways that prioritize the health, success, and sustainability of the community over individual agendas. Examples of unacceptable behavior include: -- The use of sexualized language or imagery, and sexual attention or advances of any kind -- Trolling, insulting or derogatory comments, and personal or political attacks -- Public or private harassment -- Publishing others' private information, such as a physical or email address, without their explicit permission -- **Spamming of any kind** -- Aggressive sales tactics targeting our community members are strictly prohibited. You can mention your product if it's relevant to the discussion, but under no circumstances should you push it forcefully -- Other conduct which could reasonably be considered inappropriate in a professional setting +- The use of discriminatory, demeaning, or sexualized language or behavior. +- Personal attacks, derogatory comments, trolling, or inflammatory political or ideological arguments. +- Harassment, intimidation, or any behavior intended to create a hostile, uncomfortable, or unsafe environment. +- Publishing others' private information (e.g., physical or email addresses) without explicit permission. +- **Entitlement, demand, or aggression toward contributors.** Volunteers are under no obligation to provide immediate or personalized support. Rude or dismissive behavior will not be tolerated. +- **Unproductive or destructive behavior.** This includes venting frustration as hostility ("tantrums"), hypercriticism, attention-seeking negativity, or anything that distracts from the project's goals. +- **Spamming and promotional exploitation.** Sharing irrelevant product promotions or self-promotion in the community is not allowed unless it directly contributes value to the discussion. + +### Feedback and Community Engagement + +- **Constructive feedback is encouraged, but hostile or entitled behavior will result in immediate action.** If you disagree with elements of the project, we encourage you to offer meaningful improvements or fork the project if necessary. Healthy discussions and technical disagreements are welcome only when handled with professionalism. +- **Respect contributors' time and efforts.** No one is entitled to personalized or on-demand assistance. This is a community built on collaboration and shared effort; demanding or demeaning behavior undermines that trust and will not be allowed. + +### Zero Tolerance: No Warnings, Immediate Action + +This community operates under a **zero-tolerance policy.** Any behavior deemed unacceptable under this Code of Conduct will result in **immediate enforcement, without prior warning.** + +We employ this approach to ensure that unproductive or disruptive behavior does not escalate further or cause unnecessary harm to other contributors. The standards are clear, and violations of any kind—whether mild or severe—will be addressed decisively to protect the community. ## Enforcement Responsibilities -Community leaders are responsible for clarifying and enforcing our standards of acceptable behavior and will take appropriate and fair corrective action in response to any behavior that they deem inappropriate, threatening, offensive, or harmful. +Community leaders are responsible for upholding and enforcing these standards. They are empowered to take **immediate and appropriate action** to address any behaviors they deem unacceptable under this Code of Conduct. These actions are taken with the goal of protecting the community and preserving its safe, positive, and productive environment. ## Scope -This Code of Conduct applies within all community spaces and also applies when an individual is officially representing the community in public spaces. Examples of representing our community include using an official e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. +This Code of Conduct applies to all community spaces, including forums, repositories, social media accounts, and in-person events. It also applies when an individual represents the community in public settings, such as conferences or official communications. -## Enforcement +Additionally, any behavior outside of these defined spaces that negatively impacts the community or its members may fall within the scope of this Code of Conduct. -Instances of abusive, harassing, spamming, or otherwise unacceptable behavior may be reported to the community leaders responsible for enforcement at hello@openwebui.com. All complaints will be reviewed and investigated promptly and fairly. +## Reporting Violations -All community leaders are obligated to respect the privacy and security of the reporter of any incident. +Instances of unacceptable behavior can be reported to the leadership team at **hello@openwebui.com**. Reports will be handled promptly, confidentially, and with consideration for the safety and well-being of the reporter. + +All community leaders are required to uphold confidentiality and impartiality when addressing reports of violations. ## Enforcement Guidelines -Community leaders will follow these Community Impact Guidelines in determining the consequences for any action they deem in violation of this Code of Conduct: +### Ban -### 1. Temporary Ban +**Community Impact**: Community leaders will issue a ban to any participant whose behavior is deemed unacceptable according to this Code of Conduct. Bans are enforced immediately and without prior notice. -**Community Impact**: Any violation of community standards, including but not limited to inappropriate language, unprofessional behavior, harassment, or spamming. +A ban may be temporary or permanent, depending on the severity of the violation. This includes—but is not limited to—behavior such as: -**Consequence**: A temporary ban from any sort of interaction or public communication with the community for a specified period of time. No public or private interaction with the people involved, including unsolicited interaction with those enforcing the Code of Conduct, is allowed during this period. Violating these terms may lead to a permanent ban. +- Harassment or abusive behavior toward contributors. +- Persistent negativity or hostility that disrupts the collaborative environment. +- Disrespectful, demanding, or aggressive interactions with others. +- Attempts to cause harm or sabotage the community. -### 2. Permanent Ban +**Consequence**: A banned individual is immediately removed from access to all community spaces, communication channels, and events. Community leaders reserve the right to enforce either a time-limited suspension or a permanent ban based on the specific circumstances of the violation. -**Community Impact**: Repeated or severe violations of community standards, including sustained inappropriate behavior, harassment of an individual, or aggression toward or disparagement of classes of individuals. +This approach ensures that disruptive behaviors are addressed swiftly and decisively in order to maintain the integrity and productivity of the community. -**Consequence**: A permanent ban from any sort of public interaction within the community. +## Why Zero Tolerance Is Necessary + +Open-source projects thrive on collaboration, goodwill, and mutual respect. Toxic behaviors—such as entitlement, hostility, or persistent negativity—threaten not just individual contributors but the health of the project as a whole. Allowing such behaviors to persist robs contributors of their time, energy, and enthusiasm for the work they do. + +By enforcing a zero-tolerance policy, we ensure that the community remains a safe, welcoming space for all participants. These measures are not about harshness—they are about protecting contributors and fostering a productive environment where innovation can thrive. + +Our expectations are clear, and our enforcement reflects our commitment to this project's long-term success. ## Attribution -This Code of Conduct is adapted from the [Contributor Covenant][homepage], -version 2.0, available at +This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 2.0, available at https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. -Community Impact Guidelines were inspired by [Mozilla's code of conduct -enforcement ladder](https://github.com/mozilla/diversity). +Community Impact Guidelines were inspired by [Mozilla's code of conduct enforcement ladder](https://github.com/mozilla/diversity). [homepage]: https://www.contributor-covenant.org -For answers to common questions about this code of conduct, see the FAQ at -https://www.contributor-covenant.org/faq. Translations are available at +For answers to common questions about this code of conduct, see the FAQ at +https://www.contributor-covenant.org/faq. Translations are available at https://www.contributor-covenant.org/translations. diff --git a/README.md b/README.md index 7766d681e..4ac495a47 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,7 @@ Open WebUI is an [extensible](https://github.com/open-webui/pipelines), feature- - 🤝 **Ollama/OpenAI API Integration**: Effortlessly integrate OpenAI-compatible APIs for versatile conversations alongside Ollama models. Customize the OpenAI API URL to link with **LMStudio, GroqCloud, Mistral, OpenRouter, and more**. -- 🧩 **Pipelines, Open WebUI Plugin Support**: Seamlessly integrate custom logic and Python libraries into Open WebUI using [Pipelines Plugin Framework](https://github.com/open-webui/pipelines). Launch your Pipelines instance, set the OpenAI URL to the Pipelines URL, and explore endless possibilities. [Examples](https://github.com/open-webui/pipelines/tree/main/examples) include **Function Calling**, User **Rate Limiting** to control access, **Usage Monitoring** with tools like Langfuse, **Live Translation with LibreTranslate** for multilingual support, **Toxic Message Filtering** and much more. +- 🛡️ **Granular Permissions and User Groups**: By allowing administrators to create detailed user roles and permissions, we ensure a secure user environment. This granularity not only enhances security but also allows for customized user experiences, fostering a sense of ownership and responsibility amongst users. - 📱 **Responsive Design**: Enjoy a seamless experience across Desktop PC, Laptop, and Mobile devices. @@ -37,7 +37,7 @@ Open WebUI is an [extensible](https://github.com/open-webui/pipelines), feature- - 📚 **Local RAG Integration**: Dive into the future of chat interactions with groundbreaking Retrieval Augmented Generation (RAG) support. This feature seamlessly integrates document interactions into your chat experience. You can load documents directly into the chat or add files to your document library, effortlessly accessing them using the `#` command before a query. -- 🔍 **Web Search for RAG**: Perform web searches using providers like `SearXNG`, `Google PSE`, `Brave Search`, `serpstack`, `serper`, `Serply`, `DuckDuckGo`, `TavilySearch` and `SearchApi` and inject the results directly into your chat experience. +- 🔍 **Web Search for RAG**: Perform web searches using providers like `SearXNG`, `Google PSE`, `Brave Search`, `serpstack`, `serper`, `Serply`, `DuckDuckGo`, `TavilySearch`, `SearchApi` and `Bing` and inject the results directly into your chat experience. - 🌐 **Web Browsing Capability**: Seamlessly integrate websites into your chat experience using the `#` command followed by a URL. This feature allows you to incorporate web content directly into your conversations, enhancing the richness and depth of your interactions. @@ -49,6 +49,8 @@ Open WebUI is an [extensible](https://github.com/open-webui/pipelines), feature- - 🌐🌍 **Multilingual Support**: Experience Open WebUI in your preferred language with our internationalization (i18n) support. Join us in expanding our supported languages! We're actively seeking contributors! +- 🧩 **Pipelines, Open WebUI Plugin Support**: Seamlessly integrate custom logic and Python libraries into Open WebUI using [Pipelines Plugin Framework](https://github.com/open-webui/pipelines). Launch your Pipelines instance, set the OpenAI URL to the Pipelines URL, and explore endless possibilities. [Examples](https://github.com/open-webui/pipelines/tree/main/examples) include **Function Calling**, User **Rate Limiting** to control access, **Usage Monitoring** with tools like Langfuse, **Live Translation with LibreTranslate** for multilingual support, **Toxic Message Filtering** and much more. + - 🌟 **Continuous Updates**: We are committed to improving Open WebUI with regular updates, fixes, and new features. Want to learn more about Open WebUI's features? Check out our [Open WebUI documentation](https://docs.openwebui.com/features) for a comprehensive overview! diff --git a/backend/open_webui/apps/audio/main.py b/backend/open_webui/apps/audio/main.py deleted file mode 100644 index 50ffad12b..000000000 --- a/backend/open_webui/apps/audio/main.py +++ /dev/null @@ -1,641 +0,0 @@ -import hashlib -import json -import logging -import os -import uuid -from functools import lru_cache -from pathlib import Path -from pydub import AudioSegment -from pydub.silence import split_on_silence - -import requests -from open_webui.config import ( - AUDIO_STT_ENGINE, - AUDIO_STT_MODEL, - AUDIO_STT_OPENAI_API_BASE_URL, - AUDIO_STT_OPENAI_API_KEY, - AUDIO_TTS_API_KEY, - AUDIO_TTS_ENGINE, - AUDIO_TTS_MODEL, - AUDIO_TTS_OPENAI_API_BASE_URL, - AUDIO_TTS_OPENAI_API_KEY, - AUDIO_TTS_SPLIT_ON, - AUDIO_TTS_VOICE, - AUDIO_TTS_AZURE_SPEECH_REGION, - AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT, - CACHE_DIR, - CORS_ALLOW_ORIGIN, - WHISPER_MODEL, - WHISPER_MODEL_AUTO_UPDATE, - WHISPER_MODEL_DIR, - AppConfig, -) - -from open_webui.constants import ERROR_MESSAGES -from open_webui.env import ENV, SRC_LOG_LEVELS, DEVICE_TYPE -from fastapi import Depends, FastAPI, File, HTTPException, Request, UploadFile, status -from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import FileResponse -from pydantic import BaseModel -from open_webui.utils.utils import get_admin_user, get_verified_user - -# Constants -MAX_FILE_SIZE_MB = 25 -MAX_FILE_SIZE = MAX_FILE_SIZE_MB * 1024 * 1024 # Convert MB to bytes - - -log = logging.getLogger(__name__) -log.setLevel(SRC_LOG_LEVELS["AUDIO"]) - -app = FastAPI(docs_url="/docs" if ENV == "dev" else None, openapi_url="/openapi.json" if ENV == "dev" else None, redoc_url=None) - -app.add_middleware( - CORSMiddleware, - allow_origins=CORS_ALLOW_ORIGIN, - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) - -app.state.config = AppConfig() - -app.state.config.STT_OPENAI_API_BASE_URL = AUDIO_STT_OPENAI_API_BASE_URL -app.state.config.STT_OPENAI_API_KEY = AUDIO_STT_OPENAI_API_KEY -app.state.config.STT_ENGINE = AUDIO_STT_ENGINE -app.state.config.STT_MODEL = AUDIO_STT_MODEL - -app.state.config.WHISPER_MODEL = WHISPER_MODEL -app.state.faster_whisper_model = None - -app.state.config.TTS_OPENAI_API_BASE_URL = AUDIO_TTS_OPENAI_API_BASE_URL -app.state.config.TTS_OPENAI_API_KEY = AUDIO_TTS_OPENAI_API_KEY -app.state.config.TTS_ENGINE = AUDIO_TTS_ENGINE -app.state.config.TTS_MODEL = AUDIO_TTS_MODEL -app.state.config.TTS_VOICE = AUDIO_TTS_VOICE -app.state.config.TTS_API_KEY = AUDIO_TTS_API_KEY -app.state.config.TTS_SPLIT_ON = AUDIO_TTS_SPLIT_ON - -app.state.config.TTS_AZURE_SPEECH_REGION = AUDIO_TTS_AZURE_SPEECH_REGION -app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT = AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT - -# setting device type for whisper model -whisper_device_type = DEVICE_TYPE if DEVICE_TYPE and DEVICE_TYPE == "cuda" else "cpu" -log.info(f"whisper_device_type: {whisper_device_type}") - -SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/") -SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True) - - -def set_faster_whisper_model(model: str, auto_update: bool = False): - if model and app.state.config.STT_ENGINE == "": - from faster_whisper import WhisperModel - - faster_whisper_kwargs = { - "model_size_or_path": model, - "device": whisper_device_type, - "compute_type": "int8", - "download_root": WHISPER_MODEL_DIR, - "local_files_only": not auto_update, - } - - try: - app.state.faster_whisper_model = WhisperModel(**faster_whisper_kwargs) - except Exception: - log.warning( - "WhisperModel initialization failed, attempting download with local_files_only=False" - ) - faster_whisper_kwargs["local_files_only"] = False - app.state.faster_whisper_model = WhisperModel(**faster_whisper_kwargs) - - else: - app.state.faster_whisper_model = None - - -class TTSConfigForm(BaseModel): - OPENAI_API_BASE_URL: str - OPENAI_API_KEY: str - API_KEY: str - ENGINE: str - MODEL: str - VOICE: str - SPLIT_ON: str - AZURE_SPEECH_REGION: str - AZURE_SPEECH_OUTPUT_FORMAT: str - - -class STTConfigForm(BaseModel): - OPENAI_API_BASE_URL: str - OPENAI_API_KEY: str - ENGINE: str - MODEL: str - WHISPER_MODEL: str - - -class AudioConfigUpdateForm(BaseModel): - tts: TTSConfigForm - stt: STTConfigForm - - -from pydub import AudioSegment -from pydub.utils import mediainfo - - -def is_mp4_audio(file_path): - """Check if the given file is an MP4 audio file.""" - if not os.path.isfile(file_path): - print(f"File not found: {file_path}") - return False - - info = mediainfo(file_path) - if ( - info.get("codec_name") == "aac" - and info.get("codec_type") == "audio" - and info.get("codec_tag_string") == "mp4a" - ): - return True - return False - - -def convert_mp4_to_wav(file_path, output_path): - """Convert MP4 audio file to WAV format.""" - audio = AudioSegment.from_file(file_path, format="mp4") - audio.export(output_path, format="wav") - print(f"Converted {file_path} to {output_path}") - - -@app.get("/config") -async def get_audio_config(user=Depends(get_admin_user)): - return { - "tts": { - "OPENAI_API_BASE_URL": app.state.config.TTS_OPENAI_API_BASE_URL, - "OPENAI_API_KEY": app.state.config.TTS_OPENAI_API_KEY, - "API_KEY": app.state.config.TTS_API_KEY, - "ENGINE": app.state.config.TTS_ENGINE, - "MODEL": app.state.config.TTS_MODEL, - "VOICE": app.state.config.TTS_VOICE, - "SPLIT_ON": app.state.config.TTS_SPLIT_ON, - "AZURE_SPEECH_REGION": app.state.config.TTS_AZURE_SPEECH_REGION, - "AZURE_SPEECH_OUTPUT_FORMAT": app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT, - }, - "stt": { - "OPENAI_API_BASE_URL": app.state.config.STT_OPENAI_API_BASE_URL, - "OPENAI_API_KEY": app.state.config.STT_OPENAI_API_KEY, - "ENGINE": app.state.config.STT_ENGINE, - "MODEL": app.state.config.STT_MODEL, - "WHISPER_MODEL": app.state.config.WHISPER_MODEL, - }, - } - - -@app.post("/config/update") -async def update_audio_config( - form_data: AudioConfigUpdateForm, user=Depends(get_admin_user) -): - app.state.config.TTS_OPENAI_API_BASE_URL = form_data.tts.OPENAI_API_BASE_URL - app.state.config.TTS_OPENAI_API_KEY = form_data.tts.OPENAI_API_KEY - app.state.config.TTS_API_KEY = form_data.tts.API_KEY - app.state.config.TTS_ENGINE = form_data.tts.ENGINE - app.state.config.TTS_MODEL = form_data.tts.MODEL - app.state.config.TTS_VOICE = form_data.tts.VOICE - app.state.config.TTS_SPLIT_ON = form_data.tts.SPLIT_ON - app.state.config.TTS_AZURE_SPEECH_REGION = form_data.tts.AZURE_SPEECH_REGION - app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT = ( - form_data.tts.AZURE_SPEECH_OUTPUT_FORMAT - ) - - app.state.config.STT_OPENAI_API_BASE_URL = form_data.stt.OPENAI_API_BASE_URL - app.state.config.STT_OPENAI_API_KEY = form_data.stt.OPENAI_API_KEY - app.state.config.STT_ENGINE = form_data.stt.ENGINE - app.state.config.STT_MODEL = form_data.stt.MODEL - app.state.config.WHISPER_MODEL = form_data.stt.WHISPER_MODEL - set_faster_whisper_model(form_data.stt.WHISPER_MODEL, WHISPER_MODEL_AUTO_UPDATE) - - return { - "tts": { - "OPENAI_API_BASE_URL": app.state.config.TTS_OPENAI_API_BASE_URL, - "OPENAI_API_KEY": app.state.config.TTS_OPENAI_API_KEY, - "API_KEY": app.state.config.TTS_API_KEY, - "ENGINE": app.state.config.TTS_ENGINE, - "MODEL": app.state.config.TTS_MODEL, - "VOICE": app.state.config.TTS_VOICE, - "SPLIT_ON": app.state.config.TTS_SPLIT_ON, - "AZURE_SPEECH_REGION": app.state.config.TTS_AZURE_SPEECH_REGION, - "AZURE_SPEECH_OUTPUT_FORMAT": app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT, - }, - "stt": { - "OPENAI_API_BASE_URL": app.state.config.STT_OPENAI_API_BASE_URL, - "OPENAI_API_KEY": app.state.config.STT_OPENAI_API_KEY, - "ENGINE": app.state.config.STT_ENGINE, - "MODEL": app.state.config.STT_MODEL, - "WHISPER_MODEL": app.state.config.WHISPER_MODEL, - }, - } - - -@app.post("/speech") -async def speech(request: Request, user=Depends(get_verified_user)): - body = await request.body() - name = hashlib.sha256(body).hexdigest() - - file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3") - file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json") - - # Check if the file already exists in the cache - if file_path.is_file(): - return FileResponse(file_path) - - if app.state.config.TTS_ENGINE == "openai": - headers = {} - headers["Authorization"] = f"Bearer {app.state.config.TTS_OPENAI_API_KEY}" - headers["Content-Type"] = "application/json" - - try: - body = body.decode("utf-8") - body = json.loads(body) - body["model"] = app.state.config.TTS_MODEL - body = json.dumps(body).encode("utf-8") - except Exception: - pass - - r = None - try: - r = requests.post( - url=f"{app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech", - data=body, - headers=headers, - stream=True, - ) - - r.raise_for_status() - - # Save the streaming content to a file - with open(file_path, "wb") as f: - for chunk in r.iter_content(chunk_size=8192): - f.write(chunk) - - with open(file_body_path, "w") as f: - json.dump(json.loads(body.decode("utf-8")), f) - - # Return the saved file - return FileResponse(file_path) - - except Exception as e: - log.exception(e) - error_detail = "Open WebUI: Server Connection Error" - if r is not None: - try: - res = r.json() - if "error" in res: - error_detail = f"External: {res['error']['message']}" - except Exception: - error_detail = f"External: {e}" - - raise HTTPException( - status_code=r.status_code if r != None else 500, - detail=error_detail, - ) - - elif app.state.config.TTS_ENGINE == "elevenlabs": - payload = None - try: - payload = json.loads(body.decode("utf-8")) - except Exception as e: - log.exception(e) - raise HTTPException(status_code=400, detail="Invalid JSON payload") - - voice_id = payload.get("voice", "") - - if voice_id not in get_available_voices(): - raise HTTPException( - status_code=400, - detail="Invalid voice id", - ) - - url = f"https://api.elevenlabs.io/v1/text-to-speech/{voice_id}" - - headers = { - "Accept": "audio/mpeg", - "Content-Type": "application/json", - "xi-api-key": app.state.config.TTS_API_KEY, - } - - data = { - "text": payload["input"], - "model_id": app.state.config.TTS_MODEL, - "voice_settings": {"stability": 0.5, "similarity_boost": 0.5}, - } - - try: - r = requests.post(url, json=data, headers=headers) - - r.raise_for_status() - - # Save the streaming content to a file - with open(file_path, "wb") as f: - for chunk in r.iter_content(chunk_size=8192): - f.write(chunk) - - with open(file_body_path, "w") as f: - json.dump(json.loads(body.decode("utf-8")), f) - - # Return the saved file - return FileResponse(file_path) - - except Exception as e: - log.exception(e) - error_detail = "Open WebUI: Server Connection Error" - if r is not None: - try: - res = r.json() - if "error" in res: - error_detail = f"External: {res['error']['message']}" - except Exception: - error_detail = f"External: {e}" - - raise HTTPException( - status_code=r.status_code if r != None else 500, - detail=error_detail, - ) - - elif app.state.config.TTS_ENGINE == "azure": - payload = None - try: - payload = json.loads(body.decode("utf-8")) - except Exception as e: - log.exception(e) - raise HTTPException(status_code=400, detail="Invalid JSON payload") - - region = app.state.config.TTS_AZURE_SPEECH_REGION - language = app.state.config.TTS_VOICE - locale = "-".join(app.state.config.TTS_VOICE.split("-")[:1]) - output_format = app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT - url = f"https://{region}.tts.speech.microsoft.com/cognitiveservices/v1" - - headers = { - "Ocp-Apim-Subscription-Key": app.state.config.TTS_API_KEY, - "Content-Type": "application/ssml+xml", - "X-Microsoft-OutputFormat": output_format, - } - - data = f""" - {payload["input"]} - """ - - response = requests.post(url, headers=headers, data=data) - - if response.status_code == 200: - with open(file_path, "wb") as f: - f.write(response.content) - return FileResponse(file_path) - else: - log.error(f"Error synthesizing speech - {response.reason}") - raise HTTPException( - status_code=500, detail=f"Error synthesizing speech - {response.reason}" - ) - - -def transcribe(file_path): - print("transcribe", file_path) - filename = os.path.basename(file_path) - file_dir = os.path.dirname(file_path) - id = filename.split(".")[0] - - if app.state.config.STT_ENGINE == "": - if app.state.faster_whisper_model is None: - set_faster_whisper_model(app.state.config.WHISPER_MODEL) - - model = app.state.faster_whisper_model - segments, info = model.transcribe(file_path, beam_size=5) - log.info( - "Detected language '%s' with probability %f" - % (info.language, info.language_probability) - ) - - transcript = "".join([segment.text for segment in list(segments)]) - data = {"text": transcript.strip()} - - # save the transcript to a json file - transcript_file = f"{file_dir}/{id}.json" - with open(transcript_file, "w") as f: - json.dump(data, f) - - log.debug(data) - return data - elif app.state.config.STT_ENGINE == "openai": - if is_mp4_audio(file_path): - print("is_mp4_audio") - os.rename(file_path, file_path.replace(".wav", ".mp4")) - # Convert MP4 audio file to WAV format - convert_mp4_to_wav(file_path.replace(".wav", ".mp4"), file_path) - - headers = {"Authorization": f"Bearer {app.state.config.STT_OPENAI_API_KEY}"} - - files = {"file": (filename, open(file_path, "rb"))} - data = {"model": app.state.config.STT_MODEL} - - log.debug(files, data) - - r = None - try: - r = requests.post( - url=f"{app.state.config.STT_OPENAI_API_BASE_URL}/audio/transcriptions", - headers=headers, - files=files, - data=data, - ) - - r.raise_for_status() - - data = r.json() - - # save the transcript to a json file - transcript_file = f"{file_dir}/{id}.json" - with open(transcript_file, "w") as f: - json.dump(data, f) - - print(data) - return data - except Exception as e: - log.exception(e) - error_detail = "Open WebUI: Server Connection Error" - if r is not None: - try: - res = r.json() - if "error" in res: - error_detail = f"External: {res['error']['message']}" - except Exception: - error_detail = f"External: {e}" - - raise Exception(error_detail) - - -@app.post("/transcriptions") -def transcription( - file: UploadFile = File(...), - user=Depends(get_verified_user), -): - log.info(f"file.content_type: {file.content_type}") - - if file.content_type not in ["audio/mpeg", "audio/wav", "audio/ogg", "audio/x-m4a"]: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.FILE_NOT_SUPPORTED, - ) - - try: - ext = file.filename.split(".")[-1] - id = uuid.uuid4() - - filename = f"{id}.{ext}" - contents = file.file.read() - - file_dir = f"{CACHE_DIR}/audio/transcriptions" - os.makedirs(file_dir, exist_ok=True) - file_path = f"{file_dir}/{filename}" - - with open(file_path, "wb") as f: - f.write(contents) - - try: - if os.path.getsize(file_path) > MAX_FILE_SIZE: # file is bigger than 25MB - log.debug(f"File size is larger than {MAX_FILE_SIZE_MB}MB") - audio = AudioSegment.from_file(file_path) - audio = audio.set_frame_rate(16000).set_channels(1) # Compress audio - compressed_path = f"{file_dir}/{id}_compressed.opus" - audio.export(compressed_path, format="opus", bitrate="32k") - log.debug(f"Compressed audio to {compressed_path}") - file_path = compressed_path - - if ( - os.path.getsize(file_path) > MAX_FILE_SIZE - ): # Still larger than 25MB after compression - log.debug( - f"Compressed file size is still larger than {MAX_FILE_SIZE_MB}MB: {os.path.getsize(file_path)}" - ) - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.FILE_TOO_LARGE( - size=f"{MAX_FILE_SIZE_MB}MB" - ), - ) - - data = transcribe(file_path) - else: - data = transcribe(file_path) - - file_path = file_path.split("/")[-1] - return {**data, "filename": file_path} - except Exception as e: - log.exception(e) - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.DEFAULT(e), - ) - - except Exception as e: - log.exception(e) - - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.DEFAULT(e), - ) - - -def get_available_models() -> list[dict]: - if app.state.config.TTS_ENGINE == "openai": - return [{"id": "tts-1"}, {"id": "tts-1-hd"}] - elif app.state.config.TTS_ENGINE == "elevenlabs": - headers = { - "xi-api-key": app.state.config.TTS_API_KEY, - "Content-Type": "application/json", - } - - try: - response = requests.get( - "https://api.elevenlabs.io/v1/models", headers=headers, timeout=5 - ) - response.raise_for_status() - models = response.json() - return [ - {"name": model["name"], "id": model["model_id"]} for model in models - ] - except requests.RequestException as e: - log.error(f"Error fetching voices: {str(e)}") - return [] - - -@app.get("/models") -async def get_models(user=Depends(get_verified_user)): - return {"models": get_available_models()} - - -def get_available_voices() -> dict: - """Returns {voice_id: voice_name} dict""" - ret = {} - if app.state.config.TTS_ENGINE == "openai": - ret = { - "alloy": "alloy", - "echo": "echo", - "fable": "fable", - "onyx": "onyx", - "nova": "nova", - "shimmer": "shimmer", - } - elif app.state.config.TTS_ENGINE == "elevenlabs": - try: - ret = get_elevenlabs_voices() - except Exception: - # Avoided @lru_cache with exception - pass - elif app.state.config.TTS_ENGINE == "azure": - try: - region = app.state.config.TTS_AZURE_SPEECH_REGION - url = f"https://{region}.tts.speech.microsoft.com/cognitiveservices/voices/list" - headers = {"Ocp-Apim-Subscription-Key": app.state.config.TTS_API_KEY} - - response = requests.get(url, headers=headers) - response.raise_for_status() - voices = response.json() - for voice in voices: - ret[voice["ShortName"]] = ( - f"{voice['DisplayName']} ({voice['ShortName']})" - ) - except requests.RequestException as e: - log.error(f"Error fetching voices: {str(e)}") - - return ret - - -@lru_cache -def get_elevenlabs_voices() -> dict: - """ - Note, set the following in your .env file to use Elevenlabs: - AUDIO_TTS_ENGINE=elevenlabs - AUDIO_TTS_API_KEY=sk_... # Your Elevenlabs API key - AUDIO_TTS_VOICE=EXAVITQu4vr4xnSDxMaL # From https://api.elevenlabs.io/v1/voices - AUDIO_TTS_MODEL=eleven_multilingual_v2 - """ - headers = { - "xi-api-key": app.state.config.TTS_API_KEY, - "Content-Type": "application/json", - } - try: - # TODO: Add retries - response = requests.get("https://api.elevenlabs.io/v1/voices", headers=headers) - response.raise_for_status() - voices_data = response.json() - - voices = {} - for voice in voices_data.get("voices", []): - voices[voice["voice_id"]] = voice["name"] - except requests.RequestException as e: - # Avoid @lru_cache with exception - log.error(f"Error fetching voices: {str(e)}") - raise RuntimeError(f"Error fetching voices: {str(e)}") - - return voices - - -@app.get("/voices") -async def get_voices(user=Depends(get_verified_user)): - return {"voices": [{"id": k, "name": v} for k, v in get_available_voices().items()]} diff --git a/backend/open_webui/apps/ollama/main.py b/backend/open_webui/apps/ollama/main.py deleted file mode 100644 index 8976d55b4..000000000 --- a/backend/open_webui/apps/ollama/main.py +++ /dev/null @@ -1,1123 +0,0 @@ -import asyncio -import json -import logging -import os -import random -import re -import time -from typing import Optional, Union -from urllib.parse import urlparse - -import aiohttp -import requests -from open_webui.apps.webui.models.models import Models -from open_webui.config import ( - CORS_ALLOW_ORIGIN, - ENABLE_MODEL_FILTER, - ENABLE_OLLAMA_API, - MODEL_FILTER_LIST, - OLLAMA_BASE_URLS, - UPLOAD_DIR, - AppConfig, -) -from open_webui.env import AIOHTTP_CLIENT_TIMEOUT - - -from open_webui.constants import ERROR_MESSAGES -from open_webui.env import ENV, SRC_LOG_LEVELS -from fastapi import Depends, FastAPI, File, HTTPException, Request, UploadFile -from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import StreamingResponse -from pydantic import BaseModel, ConfigDict -from starlette.background import BackgroundTask - - -from open_webui.utils.misc import ( - calculate_sha256, -) -from open_webui.utils.payload import ( - apply_model_params_to_body_ollama, - apply_model_params_to_body_openai, - apply_model_system_prompt_to_body, -) -from open_webui.utils.utils import get_admin_user, get_verified_user - -log = logging.getLogger(__name__) -log.setLevel(SRC_LOG_LEVELS["OLLAMA"]) - - -app = FastAPI(docs_url="/docs" if ENV == "dev" else None, openapi_url="/openapi.json" if ENV == "dev" else None, redoc_url=None) - -app.add_middleware( - CORSMiddleware, - allow_origins=CORS_ALLOW_ORIGIN, - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) - -app.state.config = AppConfig() - -app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER -app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST - -app.state.config.ENABLE_OLLAMA_API = ENABLE_OLLAMA_API -app.state.config.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS -app.state.MODELS = {} - - -# TODO: Implement a more intelligent load balancing mechanism for distributing requests among multiple backend instances. -# Current implementation uses a simple round-robin approach (random.choice). Consider incorporating algorithms like weighted round-robin, -# least connections, or least response time for better resource utilization and performance optimization. - - -@app.middleware("http") -async def check_url(request: Request, call_next): - if len(app.state.MODELS) == 0: - await get_all_models() - else: - pass - - response = await call_next(request) - return response - - -@app.head("/") -@app.get("/") -async def get_status(): - return {"status": True} - - -@app.get("/config") -async def get_config(user=Depends(get_admin_user)): - return {"ENABLE_OLLAMA_API": app.state.config.ENABLE_OLLAMA_API} - - -class OllamaConfigForm(BaseModel): - enable_ollama_api: Optional[bool] = None - - -@app.post("/config/update") -async def update_config(form_data: OllamaConfigForm, user=Depends(get_admin_user)): - app.state.config.ENABLE_OLLAMA_API = form_data.enable_ollama_api - return {"ENABLE_OLLAMA_API": app.state.config.ENABLE_OLLAMA_API} - - -@app.get("/urls") -async def get_ollama_api_urls(user=Depends(get_admin_user)): - return {"OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS} - - -class UrlUpdateForm(BaseModel): - urls: list[str] - - -@app.post("/urls/update") -async def update_ollama_api_url(form_data: UrlUpdateForm, user=Depends(get_admin_user)): - app.state.config.OLLAMA_BASE_URLS = form_data.urls - - log.info(f"app.state.config.OLLAMA_BASE_URLS: {app.state.config.OLLAMA_BASE_URLS}") - return {"OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS} - - -async def fetch_url(url): - timeout = aiohttp.ClientTimeout(total=3) - try: - async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: - async with session.get(url) as response: - return await response.json() - except Exception as e: - # Handle connection error here - log.error(f"Connection error: {e}") - return None - - -async def cleanup_response( - response: Optional[aiohttp.ClientResponse], - session: Optional[aiohttp.ClientSession], -): - if response: - response.close() - if session: - await session.close() - - -async def post_streaming_url( - url: str, payload: Union[str, bytes], stream: bool = True, content_type=None -): - r = None - try: - session = aiohttp.ClientSession( - trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT) - ) - r = await session.post( - url, - data=payload, - headers={"Content-Type": "application/json"}, - ) - r.raise_for_status() - - if stream: - headers = dict(r.headers) - if content_type: - headers["Content-Type"] = content_type - return StreamingResponse( - r.content, - status_code=r.status, - headers=headers, - background=BackgroundTask( - cleanup_response, response=r, session=session - ), - ) - else: - res = await r.json() - await cleanup_response(r, session) - return res - - except Exception as e: - error_detail = "Open WebUI: Server Connection Error" - if r is not None: - try: - res = await r.json() - if "error" in res: - error_detail = f"Ollama: {res['error']}" - except Exception: - error_detail = f"Ollama: {e}" - - raise HTTPException( - status_code=r.status if r else 500, - detail=error_detail, - ) - - -def merge_models_lists(model_lists): - merged_models = {} - - for idx, model_list in enumerate(model_lists): - if model_list is not None: - for model in model_list: - digest = model["digest"] - if digest not in merged_models: - model["urls"] = [idx] - merged_models[digest] = model - else: - merged_models[digest]["urls"].append(idx) - - return list(merged_models.values()) - - -async def get_all_models(): - log.info("get_all_models()") - - if app.state.config.ENABLE_OLLAMA_API: - tasks = [ - fetch_url(f"{url}/api/tags") for url in app.state.config.OLLAMA_BASE_URLS - ] - responses = await asyncio.gather(*tasks) - - models = { - "models": merge_models_lists( - map( - lambda response: response["models"] if response else None, responses - ) - ) - } - - else: - models = {"models": []} - - app.state.MODELS = {model["model"]: model for model in models["models"]} - - return models - - -@app.get("/api/tags") -@app.get("/api/tags/{url_idx}") -async def get_ollama_tags( - url_idx: Optional[int] = None, user=Depends(get_verified_user) -): - if url_idx is None: - models = await get_all_models() - - if app.state.config.ENABLE_MODEL_FILTER: - if user.role == "user": - models["models"] = list( - filter( - lambda model: model["name"] - in app.state.config.MODEL_FILTER_LIST, - models["models"], - ) - ) - return models - return models - else: - url = app.state.config.OLLAMA_BASE_URLS[url_idx] - - r = None - try: - r = requests.request(method="GET", url=f"{url}/api/tags") - r.raise_for_status() - - return r.json() - except Exception as e: - log.exception(e) - error_detail = "Open WebUI: Server Connection Error" - if r is not None: - try: - res = r.json() - if "error" in res: - error_detail = f"Ollama: {res['error']}" - except Exception: - error_detail = f"Ollama: {e}" - - raise HTTPException( - status_code=r.status_code if r else 500, - detail=error_detail, - ) - - -@app.get("/api/version") -@app.get("/api/version/{url_idx}") -async def get_ollama_versions(url_idx: Optional[int] = None): - if app.state.config.ENABLE_OLLAMA_API: - if url_idx is None: - # returns lowest version - tasks = [ - fetch_url(f"{url}/api/version") - for url in app.state.config.OLLAMA_BASE_URLS - ] - responses = await asyncio.gather(*tasks) - responses = list(filter(lambda x: x is not None, responses)) - - if len(responses) > 0: - lowest_version = min( - responses, - key=lambda x: tuple( - map(int, re.sub(r"^v|-.*", "", x["version"]).split(".")) - ), - ) - - return {"version": lowest_version["version"]} - else: - raise HTTPException( - status_code=500, - detail=ERROR_MESSAGES.OLLAMA_NOT_FOUND, - ) - else: - url = app.state.config.OLLAMA_BASE_URLS[url_idx] - - r = None - try: - r = requests.request(method="GET", url=f"{url}/api/version") - r.raise_for_status() - - return r.json() - except Exception as e: - log.exception(e) - error_detail = "Open WebUI: Server Connection Error" - if r is not None: - try: - res = r.json() - if "error" in res: - error_detail = f"Ollama: {res['error']}" - except Exception: - error_detail = f"Ollama: {e}" - - raise HTTPException( - status_code=r.status_code if r else 500, - detail=error_detail, - ) - else: - return {"version": False} - - -class ModelNameForm(BaseModel): - name: str - - -@app.post("/api/pull") -@app.post("/api/pull/{url_idx}") -async def pull_model( - form_data: ModelNameForm, url_idx: int = 0, user=Depends(get_admin_user) -): - url = app.state.config.OLLAMA_BASE_URLS[url_idx] - log.info(f"url: {url}") - - # Admin should be able to pull models from any source - payload = {**form_data.model_dump(exclude_none=True), "insecure": True} - - return await post_streaming_url(f"{url}/api/pull", json.dumps(payload)) - - -class PushModelForm(BaseModel): - name: str - insecure: Optional[bool] = None - stream: Optional[bool] = None - - -@app.delete("/api/push") -@app.delete("/api/push/{url_idx}") -async def push_model( - form_data: PushModelForm, - url_idx: Optional[int] = None, - user=Depends(get_admin_user), -): - if url_idx is None: - if form_data.name in app.state.MODELS: - url_idx = app.state.MODELS[form_data.name]["urls"][0] - else: - raise HTTPException( - status_code=400, - detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name), - ) - - url = app.state.config.OLLAMA_BASE_URLS[url_idx] - log.debug(f"url: {url}") - - return await post_streaming_url( - f"{url}/api/push", form_data.model_dump_json(exclude_none=True).encode() - ) - - -class CreateModelForm(BaseModel): - name: str - modelfile: Optional[str] = None - stream: Optional[bool] = None - path: Optional[str] = None - - -@app.post("/api/create") -@app.post("/api/create/{url_idx}") -async def create_model( - form_data: CreateModelForm, url_idx: int = 0, user=Depends(get_admin_user) -): - log.debug(f"form_data: {form_data}") - url = app.state.config.OLLAMA_BASE_URLS[url_idx] - log.info(f"url: {url}") - - return await post_streaming_url( - f"{url}/api/create", form_data.model_dump_json(exclude_none=True).encode() - ) - - -class CopyModelForm(BaseModel): - source: str - destination: str - - -@app.post("/api/copy") -@app.post("/api/copy/{url_idx}") -async def copy_model( - form_data: CopyModelForm, - url_idx: Optional[int] = None, - user=Depends(get_admin_user), -): - if url_idx is None: - if form_data.source in app.state.MODELS: - url_idx = app.state.MODELS[form_data.source]["urls"][0] - else: - raise HTTPException( - status_code=400, - detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.source), - ) - - url = app.state.config.OLLAMA_BASE_URLS[url_idx] - log.info(f"url: {url}") - r = requests.request( - method="POST", - url=f"{url}/api/copy", - headers={"Content-Type": "application/json"}, - data=form_data.model_dump_json(exclude_none=True).encode(), - ) - - try: - r.raise_for_status() - - log.debug(f"r.text: {r.text}") - - return True - except Exception as e: - log.exception(e) - error_detail = "Open WebUI: Server Connection Error" - if r is not None: - try: - res = r.json() - if "error" in res: - error_detail = f"Ollama: {res['error']}" - except Exception: - error_detail = f"Ollama: {e}" - - raise HTTPException( - status_code=r.status_code if r else 500, - detail=error_detail, - ) - - -@app.delete("/api/delete") -@app.delete("/api/delete/{url_idx}") -async def delete_model( - form_data: ModelNameForm, - url_idx: Optional[int] = None, - user=Depends(get_admin_user), -): - if url_idx is None: - if form_data.name in app.state.MODELS: - url_idx = app.state.MODELS[form_data.name]["urls"][0] - else: - raise HTTPException( - status_code=400, - detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name), - ) - - url = app.state.config.OLLAMA_BASE_URLS[url_idx] - log.info(f"url: {url}") - - r = requests.request( - method="DELETE", - url=f"{url}/api/delete", - headers={"Content-Type": "application/json"}, - data=form_data.model_dump_json(exclude_none=True).encode(), - ) - try: - r.raise_for_status() - - log.debug(f"r.text: {r.text}") - - return True - except Exception as e: - log.exception(e) - error_detail = "Open WebUI: Server Connection Error" - if r is not None: - try: - res = r.json() - if "error" in res: - error_detail = f"Ollama: {res['error']}" - except Exception: - error_detail = f"Ollama: {e}" - - raise HTTPException( - status_code=r.status_code if r else 500, - detail=error_detail, - ) - - -@app.post("/api/show") -async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_user)): - if form_data.name not in app.state.MODELS: - raise HTTPException( - status_code=400, - detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name), - ) - - url_idx = random.choice(app.state.MODELS[form_data.name]["urls"]) - url = app.state.config.OLLAMA_BASE_URLS[url_idx] - log.info(f"url: {url}") - - r = requests.request( - method="POST", - url=f"{url}/api/show", - headers={"Content-Type": "application/json"}, - data=form_data.model_dump_json(exclude_none=True).encode(), - ) - try: - r.raise_for_status() - - return r.json() - except Exception as e: - log.exception(e) - error_detail = "Open WebUI: Server Connection Error" - if r is not None: - try: - res = r.json() - if "error" in res: - error_detail = f"Ollama: {res['error']}" - except Exception: - error_detail = f"Ollama: {e}" - - raise HTTPException( - status_code=r.status_code if r else 500, - detail=error_detail, - ) - - -class GenerateEmbeddingsForm(BaseModel): - model: str - prompt: str - options: Optional[dict] = None - keep_alive: Optional[Union[int, str]] = None - - -class GenerateEmbedForm(BaseModel): - model: str - input: list[str] | str - truncate: Optional[bool] = None - options: Optional[dict] = None - keep_alive: Optional[Union[int, str]] = None - - -@app.post("/api/embed") -@app.post("/api/embed/{url_idx}") -async def generate_embeddings( - form_data: GenerateEmbedForm, - url_idx: Optional[int] = None, - user=Depends(get_verified_user), -): - return generate_ollama_batch_embeddings(form_data, url_idx) - - -@app.post("/api/embeddings") -@app.post("/api/embeddings/{url_idx}") -async def generate_embeddings( - form_data: GenerateEmbeddingsForm, - url_idx: Optional[int] = None, - user=Depends(get_verified_user), -): - return generate_ollama_embeddings(form_data=form_data, url_idx=url_idx) - - -def generate_ollama_embeddings( - form_data: GenerateEmbeddingsForm, - url_idx: Optional[int] = None, -): - log.info(f"generate_ollama_embeddings {form_data}") - - if url_idx is None: - model = form_data.model - - if ":" not in model: - model = f"{model}:latest" - - if model in app.state.MODELS: - url_idx = random.choice(app.state.MODELS[model]["urls"]) - else: - raise HTTPException( - status_code=400, - detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), - ) - - url = app.state.config.OLLAMA_BASE_URLS[url_idx] - log.info(f"url: {url}") - - r = requests.request( - method="POST", - url=f"{url}/api/embeddings", - headers={"Content-Type": "application/json"}, - data=form_data.model_dump_json(exclude_none=True).encode(), - ) - try: - r.raise_for_status() - - data = r.json() - - log.info(f"generate_ollama_embeddings {data}") - - if "embedding" in data: - return data - else: - raise Exception("Something went wrong :/") - except Exception as e: - log.exception(e) - error_detail = "Open WebUI: Server Connection Error" - if r is not None: - try: - res = r.json() - if "error" in res: - error_detail = f"Ollama: {res['error']}" - except Exception: - error_detail = f"Ollama: {e}" - - raise HTTPException( - status_code=r.status_code if r else 500, - detail=error_detail, - ) - - -def generate_ollama_batch_embeddings( - form_data: GenerateEmbedForm, - url_idx: Optional[int] = None, -): - log.info(f"generate_ollama_batch_embeddings {form_data}") - - if url_idx is None: - model = form_data.model - - if ":" not in model: - model = f"{model}:latest" - - if model in app.state.MODELS: - url_idx = random.choice(app.state.MODELS[model]["urls"]) - else: - raise HTTPException( - status_code=400, - detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), - ) - - url = app.state.config.OLLAMA_BASE_URLS[url_idx] - log.info(f"url: {url}") - - r = requests.request( - method="POST", - url=f"{url}/api/embed", - headers={"Content-Type": "application/json"}, - data=form_data.model_dump_json(exclude_none=True).encode(), - ) - try: - r.raise_for_status() - - data = r.json() - - log.info(f"generate_ollama_batch_embeddings {data}") - - if "embeddings" in data: - return data - else: - raise Exception("Something went wrong :/") - except Exception as e: - log.exception(e) - error_detail = "Open WebUI: Server Connection Error" - if r is not None: - try: - res = r.json() - if "error" in res: - error_detail = f"Ollama: {res['error']}" - except Exception: - error_detail = f"Ollama: {e}" - - raise Exception(error_detail) - - -class GenerateCompletionForm(BaseModel): - model: str - prompt: str - images: Optional[list[str]] = None - format: Optional[str] = None - options: Optional[dict] = None - system: Optional[str] = None - template: Optional[str] = None - context: Optional[str] = None - stream: Optional[bool] = True - raw: Optional[bool] = None - keep_alive: Optional[Union[int, str]] = None - - -@app.post("/api/generate") -@app.post("/api/generate/{url_idx}") -async def generate_completion( - form_data: GenerateCompletionForm, - url_idx: Optional[int] = None, - user=Depends(get_verified_user), -): - if url_idx is None: - model = form_data.model - - if ":" not in model: - model = f"{model}:latest" - - if model in app.state.MODELS: - url_idx = random.choice(app.state.MODELS[model]["urls"]) - else: - raise HTTPException( - status_code=400, - detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), - ) - - url = app.state.config.OLLAMA_BASE_URLS[url_idx] - log.info(f"url: {url}") - - return await post_streaming_url( - f"{url}/api/generate", form_data.model_dump_json(exclude_none=True).encode() - ) - - -class ChatMessage(BaseModel): - role: str - content: str - images: Optional[list[str]] = None - - -class GenerateChatCompletionForm(BaseModel): - model: str - messages: list[ChatMessage] - format: Optional[str] = None - options: Optional[dict] = None - template: Optional[str] = None - stream: Optional[bool] = True - keep_alive: Optional[Union[int, str]] = None - - -def get_ollama_url(url_idx: Optional[int], model: str): - if url_idx is None: - if model not in app.state.MODELS: - raise HTTPException( - status_code=400, - detail=ERROR_MESSAGES.MODEL_NOT_FOUND(model), - ) - url_idx = random.choice(app.state.MODELS[model]["urls"]) - url = app.state.config.OLLAMA_BASE_URLS[url_idx] - return url - - -@app.post("/api/chat") -@app.post("/api/chat/{url_idx}") -async def generate_chat_completion( - form_data: GenerateChatCompletionForm, - url_idx: Optional[int] = None, - user=Depends(get_verified_user), - bypass_filter: Optional[bool] = False, -): - payload = {**form_data.model_dump(exclude_none=True)} - log.debug(f"generate_chat_completion() - 1.payload = {payload}") - if "metadata" in payload: - del payload["metadata"] - - model_id = form_data.model - - if not bypass_filter and app.state.config.ENABLE_MODEL_FILTER: - if user.role == "user" and model_id not in app.state.config.MODEL_FILTER_LIST: - raise HTTPException( - status_code=403, - detail="Model not found", - ) - - model_info = Models.get_model_by_id(model_id) - - if model_info: - if model_info.base_model_id: - payload["model"] = model_info.base_model_id - - params = model_info.params.model_dump() - - if params: - if payload.get("options") is None: - payload["options"] = {} - - payload["options"] = apply_model_params_to_body_ollama( - params, payload["options"] - ) - payload = apply_model_system_prompt_to_body(params, payload, user) - - if ":" not in payload["model"]: - payload["model"] = f"{payload['model']}:latest" - - url = get_ollama_url(url_idx, payload["model"]) - log.info(f"url: {url}") - log.debug(f"generate_chat_completion() - 2.payload = {payload}") - - return await post_streaming_url( - f"{url}/api/chat", - json.dumps(payload), - stream=form_data.stream, - content_type="application/x-ndjson", - ) - - -# TODO: we should update this part once Ollama supports other types -class OpenAIChatMessageContent(BaseModel): - type: str - model_config = ConfigDict(extra="allow") - - -class OpenAIChatMessage(BaseModel): - role: str - content: Union[str, OpenAIChatMessageContent] - - model_config = ConfigDict(extra="allow") - - -class OpenAIChatCompletionForm(BaseModel): - model: str - messages: list[OpenAIChatMessage] - - model_config = ConfigDict(extra="allow") - - -@app.post("/v1/chat/completions") -@app.post("/v1/chat/completions/{url_idx}") -async def generate_openai_chat_completion( - form_data: dict, - url_idx: Optional[int] = None, - user=Depends(get_verified_user), -): - completion_form = OpenAIChatCompletionForm(**form_data) - payload = {**completion_form.model_dump(exclude_none=True, exclude=["metadata"])} - if "metadata" in payload: - del payload["metadata"] - - model_id = completion_form.model - - if app.state.config.ENABLE_MODEL_FILTER: - if user.role == "user" and model_id not in app.state.config.MODEL_FILTER_LIST: - raise HTTPException( - status_code=403, - detail="Model not found", - ) - - model_info = Models.get_model_by_id(model_id) - - if model_info: - if model_info.base_model_id: - payload["model"] = model_info.base_model_id - - params = model_info.params.model_dump() - - if params: - payload = apply_model_params_to_body_openai(params, payload) - payload = apply_model_system_prompt_to_body(params, payload, user) - - if ":" not in payload["model"]: - payload["model"] = f"{payload['model']}:latest" - - url = get_ollama_url(url_idx, payload["model"]) - log.info(f"url: {url}") - - return await post_streaming_url( - f"{url}/v1/chat/completions", - json.dumps(payload), - stream=payload.get("stream", False), - ) - - -@app.get("/v1/models") -@app.get("/v1/models/{url_idx}") -async def get_openai_models( - url_idx: Optional[int] = None, - user=Depends(get_verified_user), -): - if url_idx is None: - models = await get_all_models() - - if app.state.config.ENABLE_MODEL_FILTER: - if user.role == "user": - models["models"] = list( - filter( - lambda model: model["name"] - in app.state.config.MODEL_FILTER_LIST, - models["models"], - ) - ) - - return { - "data": [ - { - "id": model["model"], - "object": "model", - "created": int(time.time()), - "owned_by": "openai", - } - for model in models["models"] - ], - "object": "list", - } - - else: - url = app.state.config.OLLAMA_BASE_URLS[url_idx] - try: - r = requests.request(method="GET", url=f"{url}/api/tags") - r.raise_for_status() - - models = r.json() - - return { - "data": [ - { - "id": model["model"], - "object": "model", - "created": int(time.time()), - "owned_by": "openai", - } - for model in models["models"] - ], - "object": "list", - } - - except Exception as e: - log.exception(e) - error_detail = "Open WebUI: Server Connection Error" - if r is not None: - try: - res = r.json() - if "error" in res: - error_detail = f"Ollama: {res['error']}" - except Exception: - error_detail = f"Ollama: {e}" - - raise HTTPException( - status_code=r.status_code if r else 500, - detail=error_detail, - ) - - -class UrlForm(BaseModel): - url: str - - -class UploadBlobForm(BaseModel): - filename: str - - -def parse_huggingface_url(hf_url): - try: - # Parse the URL - parsed_url = urlparse(hf_url) - - # Get the path and split it into components - path_components = parsed_url.path.split("/") - - # Extract the desired output - model_file = path_components[-1] - - return model_file - except ValueError: - return None - - -async def download_file_stream( - ollama_url, file_url, file_path, file_name, chunk_size=1024 * 1024 -): - done = False - - if os.path.exists(file_path): - current_size = os.path.getsize(file_path) - else: - current_size = 0 - - headers = {"Range": f"bytes={current_size}-"} if current_size > 0 else {} - - timeout = aiohttp.ClientTimeout(total=600) # Set the timeout - - async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: - async with session.get(file_url, headers=headers) as response: - total_size = int(response.headers.get("content-length", 0)) + current_size - - with open(file_path, "ab+") as file: - async for data in response.content.iter_chunked(chunk_size): - current_size += len(data) - file.write(data) - - done = current_size == total_size - progress = round((current_size / total_size) * 100, 2) - - yield f'data: {{"progress": {progress}, "completed": {current_size}, "total": {total_size}}}\n\n' - - if done: - file.seek(0) - hashed = calculate_sha256(file) - file.seek(0) - - url = f"{ollama_url}/api/blobs/sha256:{hashed}" - response = requests.post(url, data=file) - - if response.ok: - res = { - "done": done, - "blob": f"sha256:{hashed}", - "name": file_name, - } - os.remove(file_path) - - yield f"data: {json.dumps(res)}\n\n" - else: - raise "Ollama: Could not create blob, Please try again." - - -# url = "https://huggingface.co/TheBloke/stablelm-zephyr-3b-GGUF/resolve/main/stablelm-zephyr-3b.Q2_K.gguf" -@app.post("/models/download") -@app.post("/models/download/{url_idx}") -async def download_model( - form_data: UrlForm, - url_idx: Optional[int] = None, - user=Depends(get_admin_user), -): - allowed_hosts = ["https://huggingface.co/", "https://github.com/"] - - if not any(form_data.url.startswith(host) for host in allowed_hosts): - raise HTTPException( - status_code=400, - detail="Invalid file_url. Only URLs from allowed hosts are permitted.", - ) - - if url_idx is None: - url_idx = 0 - url = app.state.config.OLLAMA_BASE_URLS[url_idx] - - file_name = parse_huggingface_url(form_data.url) - - if file_name: - file_path = f"{UPLOAD_DIR}/{file_name}" - - return StreamingResponse( - download_file_stream(url, form_data.url, file_path, file_name), - ) - else: - return None - - -@app.post("/models/upload") -@app.post("/models/upload/{url_idx}") -def upload_model( - file: UploadFile = File(...), - url_idx: Optional[int] = None, - user=Depends(get_admin_user), -): - if url_idx is None: - url_idx = 0 - ollama_url = app.state.config.OLLAMA_BASE_URLS[url_idx] - - file_path = f"{UPLOAD_DIR}/{file.filename}" - - # Save file in chunks - with open(file_path, "wb+") as f: - for chunk in file.file: - f.write(chunk) - - def file_process_stream(): - nonlocal ollama_url - total_size = os.path.getsize(file_path) - chunk_size = 1024 * 1024 - try: - with open(file_path, "rb") as f: - total = 0 - done = False - - while not done: - chunk = f.read(chunk_size) - if not chunk: - done = True - continue - - total += len(chunk) - progress = round((total / total_size) * 100, 2) - - res = { - "progress": progress, - "total": total_size, - "completed": total, - } - yield f"data: {json.dumps(res)}\n\n" - - if done: - f.seek(0) - hashed = calculate_sha256(f) - f.seek(0) - - url = f"{ollama_url}/api/blobs/sha256:{hashed}" - response = requests.post(url, data=f) - - if response.ok: - res = { - "done": done, - "blob": f"sha256:{hashed}", - "name": file.filename, - } - os.remove(file_path) - yield f"data: {json.dumps(res)}\n\n" - else: - raise Exception( - "Ollama: Could not create blob, Please try again." - ) - - except Exception as e: - res = {"error": str(e)} - yield f"data: {json.dumps(res)}\n\n" - - return StreamingResponse(file_process_stream(), media_type="text/event-stream") diff --git a/backend/open_webui/apps/openai/main.py b/backend/open_webui/apps/openai/main.py deleted file mode 100644 index f7e1c7bf8..000000000 --- a/backend/open_webui/apps/openai/main.py +++ /dev/null @@ -1,557 +0,0 @@ -import asyncio -import hashlib -import json -import logging -from pathlib import Path -from typing import Literal, Optional, overload - -import aiohttp -import requests -from open_webui.apps.webui.models.models import Models -from open_webui.config import ( - CACHE_DIR, - CORS_ALLOW_ORIGIN, - ENABLE_MODEL_FILTER, - ENABLE_OPENAI_API, - MODEL_FILTER_LIST, - OPENAI_API_BASE_URLS, - OPENAI_API_KEYS, - AppConfig, -) -from open_webui.env import ( - AIOHTTP_CLIENT_TIMEOUT, - AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST, -) - -from open_webui.constants import ERROR_MESSAGES -from open_webui.env import ENV, SRC_LOG_LEVELS -from fastapi import Depends, FastAPI, HTTPException, Request -from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import FileResponse, StreamingResponse -from pydantic import BaseModel -from starlette.background import BackgroundTask - -from open_webui.utils.payload import ( - apply_model_params_to_body_openai, - apply_model_system_prompt_to_body, -) - -from open_webui.utils.utils import get_admin_user, get_verified_user - -log = logging.getLogger(__name__) -log.setLevel(SRC_LOG_LEVELS["OPENAI"]) - - -app = FastAPI(docs_url="/docs" if ENV == "dev" else None, openapi_url="/openapi.json" if ENV == "dev" else None, redoc_url=None) - - -app.add_middleware( - CORSMiddleware, - allow_origins=CORS_ALLOW_ORIGIN, - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) - -app.state.config = AppConfig() - -app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER -app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST - -app.state.config.ENABLE_OPENAI_API = ENABLE_OPENAI_API -app.state.config.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS -app.state.config.OPENAI_API_KEYS = OPENAI_API_KEYS - -app.state.MODELS = {} - - -@app.middleware("http") -async def check_url(request: Request, call_next): - if len(app.state.MODELS) == 0: - await get_all_models() - - response = await call_next(request) - return response - - -@app.get("/config") -async def get_config(user=Depends(get_admin_user)): - return {"ENABLE_OPENAI_API": app.state.config.ENABLE_OPENAI_API} - - -class OpenAIConfigForm(BaseModel): - enable_openai_api: Optional[bool] = None - - -@app.post("/config/update") -async def update_config(form_data: OpenAIConfigForm, user=Depends(get_admin_user)): - app.state.config.ENABLE_OPENAI_API = form_data.enable_openai_api - return {"ENABLE_OPENAI_API": app.state.config.ENABLE_OPENAI_API} - - -class UrlsUpdateForm(BaseModel): - urls: list[str] - - -class KeysUpdateForm(BaseModel): - keys: list[str] - - -@app.get("/urls") -async def get_openai_urls(user=Depends(get_admin_user)): - return {"OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS} - - -@app.post("/urls/update") -async def update_openai_urls(form_data: UrlsUpdateForm, user=Depends(get_admin_user)): - await get_all_models() - app.state.config.OPENAI_API_BASE_URLS = form_data.urls - return {"OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS} - - -@app.get("/keys") -async def get_openai_keys(user=Depends(get_admin_user)): - return {"OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS} - - -@app.post("/keys/update") -async def update_openai_key(form_data: KeysUpdateForm, user=Depends(get_admin_user)): - app.state.config.OPENAI_API_KEYS = form_data.keys - return {"OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS} - - -@app.post("/audio/speech") -async def speech(request: Request, user=Depends(get_verified_user)): - idx = None - try: - idx = app.state.config.OPENAI_API_BASE_URLS.index("https://api.openai.com/v1") - body = await request.body() - name = hashlib.sha256(body).hexdigest() - - SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/") - SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True) - file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3") - file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json") - - # Check if the file already exists in the cache - if file_path.is_file(): - return FileResponse(file_path) - - headers = {} - headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEYS[idx]}" - headers["Content-Type"] = "application/json" - if "openrouter.ai" in app.state.config.OPENAI_API_BASE_URLS[idx]: - headers["HTTP-Referer"] = "https://openwebui.com/" - headers["X-Title"] = "Open WebUI" - r = None - try: - r = requests.post( - url=f"{app.state.config.OPENAI_API_BASE_URLS[idx]}/audio/speech", - data=body, - headers=headers, - stream=True, - ) - - r.raise_for_status() - - # Save the streaming content to a file - with open(file_path, "wb") as f: - for chunk in r.iter_content(chunk_size=8192): - f.write(chunk) - - with open(file_body_path, "w") as f: - json.dump(json.loads(body.decode("utf-8")), f) - - # Return the saved file - return FileResponse(file_path) - - except Exception as e: - log.exception(e) - error_detail = "Open WebUI: Server Connection Error" - if r is not None: - try: - res = r.json() - if "error" in res: - error_detail = f"External: {res['error']}" - except Exception: - error_detail = f"External: {e}" - - raise HTTPException( - status_code=r.status_code if r else 500, detail=error_detail - ) - - except ValueError: - raise HTTPException(status_code=401, detail=ERROR_MESSAGES.OPENAI_NOT_FOUND) - - -async def fetch_url(url, key): - timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST) - try: - headers = {"Authorization": f"Bearer {key}"} - async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: - async with session.get(url, headers=headers) as response: - return await response.json() - except Exception as e: - # Handle connection error here - log.error(f"Connection error: {e}") - return None - - -async def cleanup_response( - response: Optional[aiohttp.ClientResponse], - session: Optional[aiohttp.ClientSession], -): - if response: - response.close() - if session: - await session.close() - - -def merge_models_lists(model_lists): - log.debug(f"merge_models_lists {model_lists}") - merged_list = [] - - for idx, models in enumerate(model_lists): - if models is not None and "error" not in models: - merged_list.extend( - [ - { - **model, - "name": model.get("name", model["id"]), - "owned_by": "openai", - "openai": model, - "urlIdx": idx, - } - for model in models - if "api.openai.com" - not in app.state.config.OPENAI_API_BASE_URLS[idx] - or not any( - name in model["id"] - for name in [ - "babbage", - "dall-e", - "davinci", - "embedding", - "tts", - "whisper", - ] - ) - ] - ) - - return merged_list - - -def is_openai_api_disabled(): - return not app.state.config.ENABLE_OPENAI_API - - -async def get_all_models_raw() -> list: - if is_openai_api_disabled(): - return [] - - # Check if API KEYS length is same than API URLS length - num_urls = len(app.state.config.OPENAI_API_BASE_URLS) - num_keys = len(app.state.config.OPENAI_API_KEYS) - - if num_keys != num_urls: - # if there are more keys than urls, remove the extra keys - if num_keys > num_urls: - new_keys = app.state.config.OPENAI_API_KEYS[:num_urls] - app.state.config.OPENAI_API_KEYS = new_keys - # if there are more urls than keys, add empty keys - else: - app.state.config.OPENAI_API_KEYS += [""] * (num_urls - num_keys) - - tasks = [ - fetch_url(f"{url}/models", app.state.config.OPENAI_API_KEYS[idx]) - for idx, url in enumerate(app.state.config.OPENAI_API_BASE_URLS) - ] - - responses = await asyncio.gather(*tasks) - log.debug(f"get_all_models:responses() {responses}") - - return responses - - -@overload -async def get_all_models(raw: Literal[True]) -> list: ... - - -@overload -async def get_all_models(raw: Literal[False] = False) -> dict[str, list]: ... - - -async def get_all_models(raw=False) -> dict[str, list] | list: - log.info("get_all_models()") - if is_openai_api_disabled(): - return [] if raw else {"data": []} - - responses = await get_all_models_raw() - if raw: - return responses - - def extract_data(response): - if response and "data" in response: - return response["data"] - if isinstance(response, list): - return response - return None - - models = {"data": merge_models_lists(map(extract_data, responses))} - - log.debug(f"models: {models}") - app.state.MODELS = {model["id"]: model for model in models["data"]} - - return models - - -@app.get("/models") -@app.get("/models/{url_idx}") -async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_user)): - if url_idx is None: - models = await get_all_models() - if app.state.config.ENABLE_MODEL_FILTER: - if user.role == "user": - models["data"] = list( - filter( - lambda model: model["id"] in app.state.config.MODEL_FILTER_LIST, - models["data"], - ) - ) - return models - return models - else: - url = app.state.config.OPENAI_API_BASE_URLS[url_idx] - key = app.state.config.OPENAI_API_KEYS[url_idx] - - headers = {} - headers["Authorization"] = f"Bearer {key}" - headers["Content-Type"] = "application/json" - - r = None - - try: - r = requests.request(method="GET", url=f"{url}/models", headers=headers) - r.raise_for_status() - - response_data = r.json() - - if "api.openai.com" in url: - # Filter the response data - response_data["data"] = [ - model - for model in response_data["data"] - if not any( - name in model["id"] - for name in [ - "babbage", - "dall-e", - "davinci", - "embedding", - "tts", - "whisper", - ] - ) - ] - - return response_data - except Exception as e: - log.exception(e) - error_detail = "Open WebUI: Server Connection Error" - if r is not None: - try: - res = r.json() - if "error" in res: - error_detail = f"External: {res['error']}" - except Exception: - error_detail = f"External: {e}" - - raise HTTPException( - status_code=r.status_code if r else 500, - detail=error_detail, - ) - - -@app.post("/chat/completions") -@app.post("/chat/completions/{url_idx}") -async def generate_chat_completion( - form_data: dict, - url_idx: Optional[int] = None, - user=Depends(get_verified_user), -): - idx = 0 - payload = {**form_data} - - if "metadata" in payload: - del payload["metadata"] - - model_id = form_data.get("model") - model_info = Models.get_model_by_id(model_id) - - if model_info: - if model_info.base_model_id: - payload["model"] = model_info.base_model_id - - params = model_info.params.model_dump() - payload = apply_model_params_to_body_openai(params, payload) - payload = apply_model_system_prompt_to_body(params, payload, user) - - model = app.state.MODELS[payload.get("model")] - idx = model["urlIdx"] - - if "pipeline" in model and model.get("pipeline"): - payload["user"] = { - "name": user.name, - "id": user.id, - "email": user.email, - "role": user.role, - } - - url = app.state.config.OPENAI_API_BASE_URLS[idx] - key = app.state.config.OPENAI_API_KEYS[idx] - is_o1 = payload["model"].lower().startswith("o1-") - - # Change max_completion_tokens to max_tokens (Backward compatible) - if "api.openai.com" not in url and not is_o1: - if "max_completion_tokens" in payload: - # Remove "max_completion_tokens" from the payload - payload["max_tokens"] = payload["max_completion_tokens"] - del payload["max_completion_tokens"] - else: - if is_o1 and "max_tokens" in payload: - payload["max_completion_tokens"] = payload["max_tokens"] - del payload["max_tokens"] - if "max_tokens" in payload and "max_completion_tokens" in payload: - del payload["max_tokens"] - - # Fix: O1 does not support the "system" parameter, Modify "system" to "user" - if is_o1 and payload["messages"][0]["role"] == "system": - payload["messages"][0]["role"] = "user" - - # Convert the modified body back to JSON - payload = json.dumps(payload) - - log.debug(payload) - - headers = {} - headers["Authorization"] = f"Bearer {key}" - headers["Content-Type"] = "application/json" - if "openrouter.ai" in app.state.config.OPENAI_API_BASE_URLS[idx]: - headers["HTTP-Referer"] = "https://openwebui.com/" - headers["X-Title"] = "Open WebUI" - - r = None - session = None - streaming = False - response = None - - try: - session = aiohttp.ClientSession( - trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT) - ) - r = await session.request( - method="POST", - url=f"{url}/chat/completions", - data=payload, - headers=headers, - ) - - # Check if response is SSE - if "text/event-stream" in r.headers.get("Content-Type", ""): - streaming = True - return StreamingResponse( - r.content, - status_code=r.status, - headers=dict(r.headers), - background=BackgroundTask( - cleanup_response, response=r, session=session - ), - ) - else: - try: - response = await r.json() - except Exception as e: - log.error(e) - response = await r.text() - - r.raise_for_status() - return response - except Exception as e: - log.exception(e) - error_detail = "Open WebUI: Server Connection Error" - if isinstance(response, dict): - if "error" in response: - error_detail = f"{response['error']['message'] if 'message' in response['error'] else response['error']}" - elif isinstance(response, str): - error_detail = response - - raise HTTPException(status_code=r.status if r else 500, detail=error_detail) - finally: - if not streaming and session: - if r: - r.close() - await session.close() - - -@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) -async def proxy(path: str, request: Request, user=Depends(get_verified_user)): - idx = 0 - - body = await request.body() - - url = app.state.config.OPENAI_API_BASE_URLS[idx] - key = app.state.config.OPENAI_API_KEYS[idx] - - target_url = f"{url}/{path}" - - headers = {} - headers["Authorization"] = f"Bearer {key}" - headers["Content-Type"] = "application/json" - - r = None - session = None - streaming = False - - try: - session = aiohttp.ClientSession(trust_env=True) - r = await session.request( - method=request.method, - url=target_url, - data=body, - headers=headers, - ) - - r.raise_for_status() - - # Check if response is SSE - if "text/event-stream" in r.headers.get("Content-Type", ""): - streaming = True - return StreamingResponse( - r.content, - status_code=r.status, - headers=dict(r.headers), - background=BackgroundTask( - cleanup_response, response=r, session=session - ), - ) - else: - response_data = await r.json() - return response_data - except Exception as e: - log.exception(e) - error_detail = "Open WebUI: Server Connection Error" - if r is not None: - try: - res = await r.json() - print(res) - if "error" in res: - error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}" - except Exception: - error_detail = f"External: {e}" - raise HTTPException(status_code=r.status if r else 500, detail=error_detail) - finally: - if not streaming and session: - if r: - r.close() - await session.close() diff --git a/backend/open_webui/apps/retrieval/main.py b/backend/open_webui/apps/retrieval/main.py deleted file mode 100644 index 85472485d..000000000 --- a/backend/open_webui/apps/retrieval/main.py +++ /dev/null @@ -1,1332 +0,0 @@ -# TODO: Merge this with the webui_app and make it a single app - -import json -import logging -import mimetypes -import os -import shutil - -import uuid -from datetime import datetime -from pathlib import Path -from typing import Iterator, Optional, Sequence, Union - -from fastapi import Depends, FastAPI, File, Form, HTTPException, UploadFile, status -from fastapi.middleware.cors import CORSMiddleware -from pydantic import BaseModel -import tiktoken - - -from open_webui.storage.provider import Storage -from open_webui.apps.webui.models.knowledge import Knowledges -from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT - -# Document loaders -from open_webui.apps.retrieval.loaders.main import Loader - -# Web search engines -from open_webui.apps.retrieval.web.main import SearchResult -from open_webui.apps.retrieval.web.utils import get_web_loader -from open_webui.apps.retrieval.web.brave import search_brave -from open_webui.apps.retrieval.web.duckduckgo import search_duckduckgo -from open_webui.apps.retrieval.web.google_pse import search_google_pse -from open_webui.apps.retrieval.web.jina_search import search_jina -from open_webui.apps.retrieval.web.searchapi import search_searchapi -from open_webui.apps.retrieval.web.searxng import search_searxng -from open_webui.apps.retrieval.web.serper import search_serper -from open_webui.apps.retrieval.web.serply import search_serply -from open_webui.apps.retrieval.web.serpstack import search_serpstack -from open_webui.apps.retrieval.web.tavily import search_tavily - - -from open_webui.apps.retrieval.utils import ( - get_embedding_function, - get_model_path, - query_collection, - query_collection_with_hybrid_search, - query_doc, - query_doc_with_hybrid_search, -) - -from open_webui.apps.webui.models.files import Files -from open_webui.config import ( - BRAVE_SEARCH_API_KEY, - TIKTOKEN_ENCODING_NAME, - RAG_TEXT_SPLITTER, - CHUNK_OVERLAP, - CHUNK_SIZE, - CONTENT_EXTRACTION_ENGINE, - CORS_ALLOW_ORIGIN, - ENABLE_RAG_HYBRID_SEARCH, - ENABLE_RAG_LOCAL_WEB_FETCH, - ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, - ENABLE_RAG_WEB_SEARCH, - ENV, - GOOGLE_PSE_API_KEY, - GOOGLE_PSE_ENGINE_ID, - PDF_EXTRACT_IMAGES, - RAG_EMBEDDING_ENGINE, - RAG_EMBEDDING_MODEL, - RAG_EMBEDDING_MODEL_AUTO_UPDATE, - RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE, - RAG_EMBEDDING_BATCH_SIZE, - RAG_FILE_MAX_COUNT, - RAG_FILE_MAX_SIZE, - RAG_OPENAI_API_BASE_URL, - RAG_OPENAI_API_KEY, - RAG_RELEVANCE_THRESHOLD, - RAG_RERANKING_MODEL, - RAG_RERANKING_MODEL_AUTO_UPDATE, - RAG_RERANKING_MODEL_TRUST_REMOTE_CODE, - DEFAULT_RAG_TEMPLATE, - RAG_TEMPLATE, - RAG_TOP_K, - RAG_WEB_SEARCH_CONCURRENT_REQUESTS, - RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, - RAG_WEB_SEARCH_ENGINE, - RAG_WEB_SEARCH_RESULT_COUNT, - SEARCHAPI_API_KEY, - SEARCHAPI_ENGINE, - SEARXNG_QUERY_URL, - SERPER_API_KEY, - SERPLY_API_KEY, - SERPSTACK_API_KEY, - SERPSTACK_HTTPS, - TAVILY_API_KEY, - TIKA_SERVER_URL, - UPLOAD_DIR, - YOUTUBE_LOADER_LANGUAGE, - AppConfig, -) -from open_webui.constants import ERROR_MESSAGES -from open_webui.env import SRC_LOG_LEVELS, DEVICE_TYPE, DOCKER -from open_webui.utils.misc import ( - calculate_sha256, - calculate_sha256_string, - extract_folders_after_data_docs, - sanitize_filename, -) -from open_webui.utils.utils import get_admin_user, get_verified_user - -from langchain.text_splitter import RecursiveCharacterTextSplitter, TokenTextSplitter -from langchain_community.document_loaders import ( - YoutubeLoader, -) -from langchain_core.documents import Document - - -log = logging.getLogger(__name__) -log.setLevel(SRC_LOG_LEVELS["RAG"]) - -app = FastAPI(docs_url="/docs" if ENV == "dev" else None, openapi_url="/openapi.json" if ENV == "dev" else None, redoc_url=None) - -app.state.config = AppConfig() - -app.state.config.TOP_K = RAG_TOP_K -app.state.config.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD -app.state.config.FILE_MAX_SIZE = RAG_FILE_MAX_SIZE -app.state.config.FILE_MAX_COUNT = RAG_FILE_MAX_COUNT - -app.state.config.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH -app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = ( - ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION -) - -app.state.config.CONTENT_EXTRACTION_ENGINE = CONTENT_EXTRACTION_ENGINE -app.state.config.TIKA_SERVER_URL = TIKA_SERVER_URL - -app.state.config.TEXT_SPLITTER = RAG_TEXT_SPLITTER -app.state.config.TIKTOKEN_ENCODING_NAME = TIKTOKEN_ENCODING_NAME - -app.state.config.CHUNK_SIZE = CHUNK_SIZE -app.state.config.CHUNK_OVERLAP = CHUNK_OVERLAP - -app.state.config.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE -app.state.config.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL -app.state.config.RAG_EMBEDDING_BATCH_SIZE = RAG_EMBEDDING_BATCH_SIZE -app.state.config.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL -app.state.config.RAG_TEMPLATE = RAG_TEMPLATE - -app.state.config.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL -app.state.config.OPENAI_API_KEY = RAG_OPENAI_API_KEY - -app.state.config.PDF_EXTRACT_IMAGES = PDF_EXTRACT_IMAGES - -app.state.config.YOUTUBE_LOADER_LANGUAGE = YOUTUBE_LOADER_LANGUAGE -app.state.YOUTUBE_LOADER_TRANSLATION = None - - -app.state.config.ENABLE_RAG_WEB_SEARCH = ENABLE_RAG_WEB_SEARCH -app.state.config.RAG_WEB_SEARCH_ENGINE = RAG_WEB_SEARCH_ENGINE -app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = RAG_WEB_SEARCH_DOMAIN_FILTER_LIST - -app.state.config.SEARXNG_QUERY_URL = SEARXNG_QUERY_URL -app.state.config.GOOGLE_PSE_API_KEY = GOOGLE_PSE_API_KEY -app.state.config.GOOGLE_PSE_ENGINE_ID = GOOGLE_PSE_ENGINE_ID -app.state.config.BRAVE_SEARCH_API_KEY = BRAVE_SEARCH_API_KEY -app.state.config.SERPSTACK_API_KEY = SERPSTACK_API_KEY -app.state.config.SERPSTACK_HTTPS = SERPSTACK_HTTPS -app.state.config.SERPER_API_KEY = SERPER_API_KEY -app.state.config.SERPLY_API_KEY = SERPLY_API_KEY -app.state.config.TAVILY_API_KEY = TAVILY_API_KEY -app.state.config.SEARCHAPI_API_KEY = SEARCHAPI_API_KEY -app.state.config.SEARCHAPI_ENGINE = SEARCHAPI_ENGINE -app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = RAG_WEB_SEARCH_RESULT_COUNT -app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS - - -def update_embedding_model( - embedding_model: str, - auto_update: bool = False, -): - if embedding_model and app.state.config.RAG_EMBEDDING_ENGINE == "": - from sentence_transformers import SentenceTransformer - - app.state.sentence_transformer_ef = SentenceTransformer( - get_model_path(embedding_model, auto_update), - device=DEVICE_TYPE, - trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE, - ) - else: - app.state.sentence_transformer_ef = None - - -def update_reranking_model( - reranking_model: str, - auto_update: bool = False, -): - if reranking_model: - if any(model in reranking_model for model in ["jinaai/jina-colbert-v2"]): - try: - from open_webui.apps.retrieval.models.colbert import ColBERT - - app.state.sentence_transformer_rf = ColBERT( - get_model_path(reranking_model, auto_update), - env="docker" if DOCKER else None, - ) - except Exception as e: - log.error(f"ColBERT: {e}") - app.state.sentence_transformer_rf = None - app.state.config.ENABLE_RAG_HYBRID_SEARCH = False - else: - import sentence_transformers - - try: - app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder( - get_model_path(reranking_model, auto_update), - device=DEVICE_TYPE, - trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE, - ) - except: - log.error("CrossEncoder error") - app.state.sentence_transformer_rf = None - app.state.config.ENABLE_RAG_HYBRID_SEARCH = False - else: - app.state.sentence_transformer_rf = None - - -update_embedding_model( - app.state.config.RAG_EMBEDDING_MODEL, - RAG_EMBEDDING_MODEL_AUTO_UPDATE, -) - -update_reranking_model( - app.state.config.RAG_RERANKING_MODEL, - RAG_RERANKING_MODEL_AUTO_UPDATE, -) - - -app.state.EMBEDDING_FUNCTION = get_embedding_function( - app.state.config.RAG_EMBEDDING_ENGINE, - app.state.config.RAG_EMBEDDING_MODEL, - app.state.sentence_transformer_ef, - app.state.config.OPENAI_API_KEY, - app.state.config.OPENAI_API_BASE_URL, - app.state.config.RAG_EMBEDDING_BATCH_SIZE, -) - -app.add_middleware( - CORSMiddleware, - allow_origins=CORS_ALLOW_ORIGIN, - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) - - -class CollectionNameForm(BaseModel): - collection_name: Optional[str] = None - - -class ProcessUrlForm(CollectionNameForm): - url: str - - -class SearchForm(CollectionNameForm): - query: str - - -@app.get("/") -async def get_status(): - return { - "status": True, - "chunk_size": app.state.config.CHUNK_SIZE, - "chunk_overlap": app.state.config.CHUNK_OVERLAP, - "template": app.state.config.RAG_TEMPLATE, - "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE, - "embedding_model": app.state.config.RAG_EMBEDDING_MODEL, - "reranking_model": app.state.config.RAG_RERANKING_MODEL, - "embedding_batch_size": app.state.config.RAG_EMBEDDING_BATCH_SIZE, - } - - -@app.get("/embedding") -async def get_embedding_config(user=Depends(get_admin_user)): - return { - "status": True, - "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE, - "embedding_model": app.state.config.RAG_EMBEDDING_MODEL, - "embedding_batch_size": app.state.config.RAG_EMBEDDING_BATCH_SIZE, - "openai_config": { - "url": app.state.config.OPENAI_API_BASE_URL, - "key": app.state.config.OPENAI_API_KEY, - }, - } - - -@app.get("/reranking") -async def get_reraanking_config(user=Depends(get_admin_user)): - return { - "status": True, - "reranking_model": app.state.config.RAG_RERANKING_MODEL, - } - - -class OpenAIConfigForm(BaseModel): - url: str - key: str - - -class EmbeddingModelUpdateForm(BaseModel): - openai_config: Optional[OpenAIConfigForm] = None - embedding_engine: str - embedding_model: str - embedding_batch_size: Optional[int] = 1 - - -@app.post("/embedding/update") -async def update_embedding_config( - form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user) -): - log.info( - f"Updating embedding model: {app.state.config.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}" - ) - try: - app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine - app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model - - if app.state.config.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]: - if form_data.openai_config is not None: - app.state.config.OPENAI_API_BASE_URL = form_data.openai_config.url - app.state.config.OPENAI_API_KEY = form_data.openai_config.key - app.state.config.RAG_EMBEDDING_BATCH_SIZE = form_data.embedding_batch_size - - update_embedding_model(app.state.config.RAG_EMBEDDING_MODEL) - - app.state.EMBEDDING_FUNCTION = get_embedding_function( - app.state.config.RAG_EMBEDDING_ENGINE, - app.state.config.RAG_EMBEDDING_MODEL, - app.state.sentence_transformer_ef, - app.state.config.OPENAI_API_KEY, - app.state.config.OPENAI_API_BASE_URL, - app.state.config.RAG_EMBEDDING_BATCH_SIZE, - ) - - return { - "status": True, - "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE, - "embedding_model": app.state.config.RAG_EMBEDDING_MODEL, - "embedding_batch_size": app.state.config.RAG_EMBEDDING_BATCH_SIZE, - "openai_config": { - "url": app.state.config.OPENAI_API_BASE_URL, - "key": app.state.config.OPENAI_API_KEY, - }, - } - except Exception as e: - log.exception(f"Problem updating embedding model: {e}") - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=ERROR_MESSAGES.DEFAULT(e), - ) - - -class RerankingModelUpdateForm(BaseModel): - reranking_model: str - - -@app.post("/reranking/update") -async def update_reranking_config( - form_data: RerankingModelUpdateForm, user=Depends(get_admin_user) -): - log.info( - f"Updating reranking model: {app.state.config.RAG_RERANKING_MODEL} to {form_data.reranking_model}" - ) - try: - app.state.config.RAG_RERANKING_MODEL = form_data.reranking_model - - update_reranking_model(app.state.config.RAG_RERANKING_MODEL, True) - - return { - "status": True, - "reranking_model": app.state.config.RAG_RERANKING_MODEL, - } - except Exception as e: - log.exception(f"Problem updating reranking model: {e}") - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=ERROR_MESSAGES.DEFAULT(e), - ) - - -@app.get("/config") -async def get_rag_config(user=Depends(get_admin_user)): - return { - "status": True, - "pdf_extract_images": app.state.config.PDF_EXTRACT_IMAGES, - "content_extraction": { - "engine": app.state.config.CONTENT_EXTRACTION_ENGINE, - "tika_server_url": app.state.config.TIKA_SERVER_URL, - }, - "chunk": { - "text_splitter": app.state.config.TEXT_SPLITTER, - "chunk_size": app.state.config.CHUNK_SIZE, - "chunk_overlap": app.state.config.CHUNK_OVERLAP, - }, - "file": { - "max_size": app.state.config.FILE_MAX_SIZE, - "max_count": app.state.config.FILE_MAX_COUNT, - }, - "youtube": { - "language": app.state.config.YOUTUBE_LOADER_LANGUAGE, - "translation": app.state.YOUTUBE_LOADER_TRANSLATION, - }, - "web": { - "ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, - "search": { - "enabled": app.state.config.ENABLE_RAG_WEB_SEARCH, - "engine": app.state.config.RAG_WEB_SEARCH_ENGINE, - "searxng_query_url": app.state.config.SEARXNG_QUERY_URL, - "google_pse_api_key": app.state.config.GOOGLE_PSE_API_KEY, - "google_pse_engine_id": app.state.config.GOOGLE_PSE_ENGINE_ID, - "brave_search_api_key": app.state.config.BRAVE_SEARCH_API_KEY, - "serpstack_api_key": app.state.config.SERPSTACK_API_KEY, - "serpstack_https": app.state.config.SERPSTACK_HTTPS, - "serper_api_key": app.state.config.SERPER_API_KEY, - "serply_api_key": app.state.config.SERPLY_API_KEY, - "tavily_api_key": app.state.config.TAVILY_API_KEY, - "searchapi_api_key": app.state.config.SEARCHAPI_API_KEY, - "seaarchapi_engine": app.state.config.SEARCHAPI_ENGINE, - "result_count": app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, - "concurrent_requests": app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS, - }, - }, - } - - -class FileConfig(BaseModel): - max_size: Optional[int] = None - max_count: Optional[int] = None - - -class ContentExtractionConfig(BaseModel): - engine: str = "" - tika_server_url: Optional[str] = None - - -class ChunkParamUpdateForm(BaseModel): - text_splitter: Optional[str] = None - chunk_size: int - chunk_overlap: int - - -class YoutubeLoaderConfig(BaseModel): - language: list[str] - translation: Optional[str] = None - - -class WebSearchConfig(BaseModel): - enabled: bool - engine: Optional[str] = None - searxng_query_url: Optional[str] = None - google_pse_api_key: Optional[str] = None - google_pse_engine_id: Optional[str] = None - brave_search_api_key: Optional[str] = None - serpstack_api_key: Optional[str] = None - serpstack_https: Optional[bool] = None - serper_api_key: Optional[str] = None - serply_api_key: Optional[str] = None - tavily_api_key: Optional[str] = None - searchapi_api_key: Optional[str] = None - searchapi_engine: Optional[str] = None - result_count: Optional[int] = None - concurrent_requests: Optional[int] = None - - -class WebConfig(BaseModel): - search: WebSearchConfig - web_loader_ssl_verification: Optional[bool] = None - - -class ConfigUpdateForm(BaseModel): - pdf_extract_images: Optional[bool] = None - file: Optional[FileConfig] = None - content_extraction: Optional[ContentExtractionConfig] = None - chunk: Optional[ChunkParamUpdateForm] = None - youtube: Optional[YoutubeLoaderConfig] = None - web: Optional[WebConfig] = None - - -@app.post("/config/update") -async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)): - app.state.config.PDF_EXTRACT_IMAGES = ( - form_data.pdf_extract_images - if form_data.pdf_extract_images is not None - else app.state.config.PDF_EXTRACT_IMAGES - ) - - if form_data.file is not None: - app.state.config.FILE_MAX_SIZE = form_data.file.max_size - app.state.config.FILE_MAX_COUNT = form_data.file.max_count - - if form_data.content_extraction is not None: - log.info(f"Updating text settings: {form_data.content_extraction}") - app.state.config.CONTENT_EXTRACTION_ENGINE = form_data.content_extraction.engine - app.state.config.TIKA_SERVER_URL = form_data.content_extraction.tika_server_url - - if form_data.chunk is not None: - app.state.config.TEXT_SPLITTER = form_data.chunk.text_splitter - app.state.config.CHUNK_SIZE = form_data.chunk.chunk_size - app.state.config.CHUNK_OVERLAP = form_data.chunk.chunk_overlap - - if form_data.youtube is not None: - app.state.config.YOUTUBE_LOADER_LANGUAGE = form_data.youtube.language - app.state.YOUTUBE_LOADER_TRANSLATION = form_data.youtube.translation - - if form_data.web is not None: - app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = ( - form_data.web.web_loader_ssl_verification - ) - - app.state.config.ENABLE_RAG_WEB_SEARCH = form_data.web.search.enabled - app.state.config.RAG_WEB_SEARCH_ENGINE = form_data.web.search.engine - app.state.config.SEARXNG_QUERY_URL = form_data.web.search.searxng_query_url - app.state.config.GOOGLE_PSE_API_KEY = form_data.web.search.google_pse_api_key - app.state.config.GOOGLE_PSE_ENGINE_ID = ( - form_data.web.search.google_pse_engine_id - ) - app.state.config.BRAVE_SEARCH_API_KEY = ( - form_data.web.search.brave_search_api_key - ) - app.state.config.SERPSTACK_API_KEY = form_data.web.search.serpstack_api_key - app.state.config.SERPSTACK_HTTPS = form_data.web.search.serpstack_https - app.state.config.SERPER_API_KEY = form_data.web.search.serper_api_key - app.state.config.SERPLY_API_KEY = form_data.web.search.serply_api_key - app.state.config.TAVILY_API_KEY = form_data.web.search.tavily_api_key - app.state.config.SEARCHAPI_API_KEY = form_data.web.search.searchapi_api_key - app.state.config.SEARCHAPI_ENGINE = form_data.web.search.searchapi_engine - app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = form_data.web.search.result_count - app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = ( - form_data.web.search.concurrent_requests - ) - - return { - "status": True, - "pdf_extract_images": app.state.config.PDF_EXTRACT_IMAGES, - "file": { - "max_size": app.state.config.FILE_MAX_SIZE, - "max_count": app.state.config.FILE_MAX_COUNT, - }, - "content_extraction": { - "engine": app.state.config.CONTENT_EXTRACTION_ENGINE, - "tika_server_url": app.state.config.TIKA_SERVER_URL, - }, - "chunk": { - "text_splitter": app.state.config.TEXT_SPLITTER, - "chunk_size": app.state.config.CHUNK_SIZE, - "chunk_overlap": app.state.config.CHUNK_OVERLAP, - }, - "youtube": { - "language": app.state.config.YOUTUBE_LOADER_LANGUAGE, - "translation": app.state.YOUTUBE_LOADER_TRANSLATION, - }, - "web": { - "ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, - "search": { - "enabled": app.state.config.ENABLE_RAG_WEB_SEARCH, - "engine": app.state.config.RAG_WEB_SEARCH_ENGINE, - "searxng_query_url": app.state.config.SEARXNG_QUERY_URL, - "google_pse_api_key": app.state.config.GOOGLE_PSE_API_KEY, - "google_pse_engine_id": app.state.config.GOOGLE_PSE_ENGINE_ID, - "brave_search_api_key": app.state.config.BRAVE_SEARCH_API_KEY, - "serpstack_api_key": app.state.config.SERPSTACK_API_KEY, - "serpstack_https": app.state.config.SERPSTACK_HTTPS, - "serper_api_key": app.state.config.SERPER_API_KEY, - "serply_api_key": app.state.config.SERPLY_API_KEY, - "serachapi_api_key": app.state.config.SEARCHAPI_API_KEY, - "searchapi_engine": app.state.config.SEARCHAPI_ENGINE, - "tavily_api_key": app.state.config.TAVILY_API_KEY, - "result_count": app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, - "concurrent_requests": app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS, - }, - }, - } - - -@app.get("/template") -async def get_rag_template(user=Depends(get_verified_user)): - return { - "status": True, - "template": app.state.config.RAG_TEMPLATE, - } - - -@app.get("/query/settings") -async def get_query_settings(user=Depends(get_admin_user)): - return { - "status": True, - "template": app.state.config.RAG_TEMPLATE, - "k": app.state.config.TOP_K, - "r": app.state.config.RELEVANCE_THRESHOLD, - "hybrid": app.state.config.ENABLE_RAG_HYBRID_SEARCH, - } - - -class QuerySettingsForm(BaseModel): - k: Optional[int] = None - r: Optional[float] = None - template: Optional[str] = None - hybrid: Optional[bool] = None - - -@app.post("/query/settings/update") -async def update_query_settings( - form_data: QuerySettingsForm, user=Depends(get_admin_user) -): - app.state.config.RAG_TEMPLATE = form_data.template - app.state.config.TOP_K = form_data.k if form_data.k else 4 - app.state.config.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0 - - app.state.config.ENABLE_RAG_HYBRID_SEARCH = ( - form_data.hybrid if form_data.hybrid else False - ) - - return { - "status": True, - "template": app.state.config.RAG_TEMPLATE, - "k": app.state.config.TOP_K, - "r": app.state.config.RELEVANCE_THRESHOLD, - "hybrid": app.state.config.ENABLE_RAG_HYBRID_SEARCH, - } - - -#################################### -# -# Document process and retrieval -# -#################################### - - -def save_docs_to_vector_db( - docs, - collection_name, - metadata: Optional[dict] = None, - overwrite: bool = False, - split: bool = True, - add: bool = False, -) -> bool: - log.info(f"save_docs_to_vector_db {docs} {collection_name}") - - # Check if entries with the same hash (metadata.hash) already exist - if metadata and "hash" in metadata: - result = VECTOR_DB_CLIENT.query( - collection_name=collection_name, - filter={"hash": metadata["hash"]}, - ) - - if result is not None: - existing_doc_ids = result.ids[0] - if existing_doc_ids: - log.info(f"Document with hash {metadata['hash']} already exists") - raise ValueError(ERROR_MESSAGES.DUPLICATE_CONTENT) - - if split: - if app.state.config.TEXT_SPLITTER in ["", "character"]: - text_splitter = RecursiveCharacterTextSplitter( - chunk_size=app.state.config.CHUNK_SIZE, - chunk_overlap=app.state.config.CHUNK_OVERLAP, - add_start_index=True, - ) - elif app.state.config.TEXT_SPLITTER == "token": - log.info( - f"Using token text splitter: {app.state.config.TIKTOKEN_ENCODING_NAME}" - ) - - tiktoken.get_encoding(str(app.state.config.TIKTOKEN_ENCODING_NAME)) - text_splitter = TokenTextSplitter( - encoding_name=str(app.state.config.TIKTOKEN_ENCODING_NAME), - chunk_size=app.state.config.CHUNK_SIZE, - chunk_overlap=app.state.config.CHUNK_OVERLAP, - add_start_index=True, - ) - else: - raise ValueError(ERROR_MESSAGES.DEFAULT("Invalid text splitter")) - - docs = text_splitter.split_documents(docs) - - if len(docs) == 0: - raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT) - - texts = [doc.page_content for doc in docs] - metadatas = [ - { - **doc.metadata, - **(metadata if metadata else {}), - "embedding_config": json.dumps( - { - "engine": app.state.config.RAG_EMBEDDING_ENGINE, - "model": app.state.config.RAG_EMBEDDING_MODEL, - } - ), - } - for doc in docs - ] - - # ChromaDB does not like datetime formats - # for meta-data so convert them to string. - for metadata in metadatas: - for key, value in metadata.items(): - if isinstance(value, datetime): - metadata[key] = str(value) - - try: - if VECTOR_DB_CLIENT.has_collection(collection_name=collection_name): - log.info(f"collection {collection_name} already exists") - - if overwrite: - VECTOR_DB_CLIENT.delete_collection(collection_name=collection_name) - log.info(f"deleting existing collection {collection_name}") - elif add is False: - log.info( - f"collection {collection_name} already exists, overwrite is False and add is False" - ) - return True - - log.info(f"adding to collection {collection_name}") - embedding_function = get_embedding_function( - app.state.config.RAG_EMBEDDING_ENGINE, - app.state.config.RAG_EMBEDDING_MODEL, - app.state.sentence_transformer_ef, - app.state.config.OPENAI_API_KEY, - app.state.config.OPENAI_API_BASE_URL, - app.state.config.RAG_EMBEDDING_BATCH_SIZE, - ) - - embeddings = embedding_function( - list(map(lambda x: x.replace("\n", " "), texts)) - ) - - items = [ - { - "id": str(uuid.uuid4()), - "text": text, - "vector": embeddings[idx], - "metadata": metadatas[idx], - } - for idx, text in enumerate(texts) - ] - - VECTOR_DB_CLIENT.insert( - collection_name=collection_name, - items=items, - ) - - return True - except Exception as e: - log.exception(e) - return False - - -class ProcessFileForm(BaseModel): - file_id: str - content: Optional[str] = None - collection_name: Optional[str] = None - - -@app.post("/process/file") -def process_file( - form_data: ProcessFileForm, - user=Depends(get_verified_user), -): - try: - file = Files.get_file_by_id(form_data.file_id) - - collection_name = form_data.collection_name - - if collection_name is None: - collection_name = f"file-{file.id}" - - if form_data.content: - # Update the content in the file - # Usage: /files/{file_id}/data/content/update - - VECTOR_DB_CLIENT.delete( - collection_name=f"file-{file.id}", - filter={"file_id": file.id}, - ) - - docs = [ - Document( - page_content=form_data.content, - metadata={ - "name": file.meta.get("name", file.filename), - "created_by": file.user_id, - "file_id": file.id, - **file.meta, - }, - ) - ] - - text_content = form_data.content - elif form_data.collection_name: - # Check if the file has already been processed and save the content - # Usage: /knowledge/{id}/file/add, /knowledge/{id}/file/update - - result = VECTOR_DB_CLIENT.query( - collection_name=f"file-{file.id}", filter={"file_id": file.id} - ) - - if result is not None and len(result.ids[0]) > 0: - docs = [ - Document( - page_content=result.documents[0][idx], - metadata=result.metadatas[0][idx], - ) - for idx, id in enumerate(result.ids[0]) - ] - else: - docs = [ - Document( - page_content=file.data.get("content", ""), - metadata={ - "name": file.meta.get("name", file.filename), - "created_by": file.user_id, - "file_id": file.id, - **file.meta, - }, - ) - ] - - text_content = file.data.get("content", "") - else: - # Process the file and save the content - # Usage: /files/ - file_path = file.path - if file_path: - file_path = Storage.get_file(file_path) - loader = Loader( - engine=app.state.config.CONTENT_EXTRACTION_ENGINE, - TIKA_SERVER_URL=app.state.config.TIKA_SERVER_URL, - PDF_EXTRACT_IMAGES=app.state.config.PDF_EXTRACT_IMAGES, - ) - docs = loader.load( - file.filename, file.meta.get("content_type"), file_path - ) - else: - docs = [ - Document( - page_content=file.data.get("content", ""), - metadata={ - "name": file.filename, - "created_by": file.user_id, - "file_id": file.id, - **file.meta, - }, - ) - ] - text_content = " ".join([doc.page_content for doc in docs]) - - log.debug(f"text_content: {text_content}") - Files.update_file_data_by_id( - file.id, - {"content": text_content}, - ) - - hash = calculate_sha256_string(text_content) - Files.update_file_hash_by_id(file.id, hash) - - try: - result = save_docs_to_vector_db( - docs=docs, - collection_name=collection_name, - metadata={ - "file_id": file.id, - "name": file.meta.get("name", file.filename), - "hash": hash, - }, - add=(True if form_data.collection_name else False), - ) - - if result: - Files.update_file_metadata_by_id( - file.id, - { - "collection_name": collection_name, - }, - ) - - return { - "status": True, - "collection_name": collection_name, - "filename": file.meta.get("name", file.filename), - "content": text_content, - } - except Exception as e: - raise e - except Exception as e: - log.exception(e) - if "No pandoc was found" in str(e): - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.PANDOC_NOT_INSTALLED, - ) - else: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=str(e), - ) - - -class ProcessTextForm(BaseModel): - name: str - content: str - collection_name: Optional[str] = None - - -@app.post("/process/text") -def process_text( - form_data: ProcessTextForm, - user=Depends(get_verified_user), -): - collection_name = form_data.collection_name - if collection_name is None: - collection_name = calculate_sha256_string(form_data.content) - - docs = [ - Document( - page_content=form_data.content, - metadata={"name": form_data.name, "created_by": user.id}, - ) - ] - text_content = form_data.content - log.debug(f"text_content: {text_content}") - - result = save_docs_to_vector_db(docs, collection_name) - - if result: - return { - "status": True, - "collection_name": collection_name, - "content": text_content, - } - else: - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=ERROR_MESSAGES.DEFAULT(), - ) - - -@app.post("/process/youtube") -def process_youtube_video(form_data: ProcessUrlForm, user=Depends(get_verified_user)): - try: - collection_name = form_data.collection_name - if not collection_name: - collection_name = calculate_sha256_string(form_data.url)[:63] - - loader = YoutubeLoader.from_youtube_url( - form_data.url, - add_video_info=True, - language=app.state.config.YOUTUBE_LOADER_LANGUAGE, - translation=app.state.YOUTUBE_LOADER_TRANSLATION, - ) - docs = loader.load() - content = " ".join([doc.page_content for doc in docs]) - log.debug(f"text_content: {content}") - save_docs_to_vector_db(docs, collection_name, overwrite=True) - - return { - "status": True, - "collection_name": collection_name, - "filename": form_data.url, - "file": { - "data": { - "content": content, - }, - "meta": { - "name": form_data.url, - }, - }, - } - except Exception as e: - log.exception(e) - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.DEFAULT(e), - ) - - -@app.post("/process/web") -def process_web(form_data: ProcessUrlForm, user=Depends(get_verified_user)): - try: - collection_name = form_data.collection_name - if not collection_name: - collection_name = calculate_sha256_string(form_data.url)[:63] - - loader = get_web_loader( - form_data.url, - verify_ssl=app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, - requests_per_second=app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS, - ) - docs = loader.load() - content = " ".join([doc.page_content for doc in docs]) - log.debug(f"text_content: {content}") - save_docs_to_vector_db(docs, collection_name, overwrite=True) - - return { - "status": True, - "collection_name": collection_name, - "filename": form_data.url, - "file": { - "data": { - "content": content, - }, - "meta": { - "name": form_data.url, - }, - }, - } - except Exception as e: - log.exception(e) - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.DEFAULT(e), - ) - - -def search_web(engine: str, query: str) -> list[SearchResult]: - """Search the web using a search engine and return the results as a list of SearchResult objects. - Will look for a search engine API key in environment variables in the following order: - - SEARXNG_QUERY_URL - - GOOGLE_PSE_API_KEY + GOOGLE_PSE_ENGINE_ID - - BRAVE_SEARCH_API_KEY - - SERPSTACK_API_KEY - - SERPER_API_KEY - - SERPLY_API_KEY - - TAVILY_API_KEY - - SEARCHAPI_API_KEY + SEARCHAPI_ENGINE (by default `google`) - Args: - query (str): The query to search for - """ - - # TODO: add playwright to search the web - if engine == "searxng": - if app.state.config.SEARXNG_QUERY_URL: - return search_searxng( - app.state.config.SEARXNG_QUERY_URL, - query, - app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, - app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, - ) - else: - raise Exception("No SEARXNG_QUERY_URL found in environment variables") - elif engine == "google_pse": - if ( - app.state.config.GOOGLE_PSE_API_KEY - and app.state.config.GOOGLE_PSE_ENGINE_ID - ): - return search_google_pse( - app.state.config.GOOGLE_PSE_API_KEY, - app.state.config.GOOGLE_PSE_ENGINE_ID, - query, - app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, - app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, - ) - else: - raise Exception( - "No GOOGLE_PSE_API_KEY or GOOGLE_PSE_ENGINE_ID found in environment variables" - ) - elif engine == "brave": - if app.state.config.BRAVE_SEARCH_API_KEY: - return search_brave( - app.state.config.BRAVE_SEARCH_API_KEY, - query, - app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, - app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, - ) - else: - raise Exception("No BRAVE_SEARCH_API_KEY found in environment variables") - elif engine == "serpstack": - if app.state.config.SERPSTACK_API_KEY: - return search_serpstack( - app.state.config.SERPSTACK_API_KEY, - query, - app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, - app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, - https_enabled=app.state.config.SERPSTACK_HTTPS, - ) - else: - raise Exception("No SERPSTACK_API_KEY found in environment variables") - elif engine == "serper": - if app.state.config.SERPER_API_KEY: - return search_serper( - app.state.config.SERPER_API_KEY, - query, - app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, - app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, - ) - else: - raise Exception("No SERPER_API_KEY found in environment variables") - elif engine == "serply": - if app.state.config.SERPLY_API_KEY: - return search_serply( - app.state.config.SERPLY_API_KEY, - query, - app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, - app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, - ) - else: - raise Exception("No SERPLY_API_KEY found in environment variables") - elif engine == "duckduckgo": - return search_duckduckgo( - query, - app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, - app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, - ) - elif engine == "tavily": - if app.state.config.TAVILY_API_KEY: - return search_tavily( - app.state.config.TAVILY_API_KEY, - query, - app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, - ) - else: - raise Exception("No TAVILY_API_KEY found in environment variables") - elif engine == "searchapi": - if app.state.config.SEARCHAPI_API_KEY: - return search_searchapi( - app.state.config.SEARCHAPI_API_KEY, - app.state.config.SEARCHAPI_ENGINE, - query, - app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, - app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, - ) - else: - raise Exception("No SEARCHAPI_API_KEY found in environment variables") - elif engine == "jina": - return search_jina(query, app.state.config.RAG_WEB_SEARCH_RESULT_COUNT) - else: - raise Exception("No search engine API key found in environment variables") - - -@app.post("/process/web/search") -def process_web_search(form_data: SearchForm, user=Depends(get_verified_user)): - try: - logging.info( - f"trying to web search with {app.state.config.RAG_WEB_SEARCH_ENGINE, form_data.query}" - ) - web_results = search_web( - app.state.config.RAG_WEB_SEARCH_ENGINE, form_data.query - ) - except Exception as e: - log.exception(e) - - print(e) - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.WEB_SEARCH_ERROR(e), - ) - - try: - collection_name = form_data.collection_name - if collection_name == "": - collection_name = calculate_sha256_string(form_data.query)[:63] - - urls = [result.link for result in web_results] - - loader = get_web_loader(urls) - docs = loader.load() - - save_docs_to_vector_db(docs, collection_name, overwrite=True) - - return { - "status": True, - "collection_name": collection_name, - "filenames": urls, - } - except Exception as e: - log.exception(e) - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.DEFAULT(e), - ) - - -class QueryDocForm(BaseModel): - collection_name: str - query: str - k: Optional[int] = None - r: Optional[float] = None - hybrid: Optional[bool] = None - - -@app.post("/query/doc") -def query_doc_handler( - form_data: QueryDocForm, - user=Depends(get_verified_user), -): - try: - if app.state.config.ENABLE_RAG_HYBRID_SEARCH: - return query_doc_with_hybrid_search( - collection_name=form_data.collection_name, - query=form_data.query, - embedding_function=app.state.EMBEDDING_FUNCTION, - k=form_data.k if form_data.k else app.state.config.TOP_K, - reranking_function=app.state.sentence_transformer_rf, - r=( - form_data.r if form_data.r else app.state.config.RELEVANCE_THRESHOLD - ), - ) - else: - return query_doc( - collection_name=form_data.collection_name, - query=form_data.query, - embedding_function=app.state.EMBEDDING_FUNCTION, - k=form_data.k if form_data.k else app.state.config.TOP_K, - ) - except Exception as e: - log.exception(e) - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.DEFAULT(e), - ) - - -class QueryCollectionsForm(BaseModel): - collection_names: list[str] - query: str - k: Optional[int] = None - r: Optional[float] = None - hybrid: Optional[bool] = None - - -@app.post("/query/collection") -def query_collection_handler( - form_data: QueryCollectionsForm, - user=Depends(get_verified_user), -): - try: - if app.state.config.ENABLE_RAG_HYBRID_SEARCH: - return query_collection_with_hybrid_search( - collection_names=form_data.collection_names, - query=form_data.query, - embedding_function=app.state.EMBEDDING_FUNCTION, - k=form_data.k if form_data.k else app.state.config.TOP_K, - reranking_function=app.state.sentence_transformer_rf, - r=( - form_data.r if form_data.r else app.state.config.RELEVANCE_THRESHOLD - ), - ) - else: - return query_collection( - collection_names=form_data.collection_names, - query=form_data.query, - embedding_function=app.state.EMBEDDING_FUNCTION, - k=form_data.k if form_data.k else app.state.config.TOP_K, - ) - - except Exception as e: - log.exception(e) - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.DEFAULT(e), - ) - - -#################################### -# -# Vector DB operations -# -#################################### - - -class DeleteForm(BaseModel): - collection_name: str - file_id: str - - -@app.post("/delete") -def delete_entries_from_collection(form_data: DeleteForm, user=Depends(get_admin_user)): - try: - if VECTOR_DB_CLIENT.has_collection(collection_name=form_data.collection_name): - file = Files.get_file_by_id(form_data.file_id) - hash = file.hash - - VECTOR_DB_CLIENT.delete( - collection_name=form_data.collection_name, - metadata={"hash": hash}, - ) - return {"status": True} - else: - return {"status": False} - except Exception as e: - log.exception(e) - return {"status": False} - - -@app.post("/reset/db") -def reset_vector_db(user=Depends(get_admin_user)): - VECTOR_DB_CLIENT.reset() - Knowledges.delete_all_knowledge() - - -@app.post("/reset/uploads") -def reset_upload_dir(user=Depends(get_admin_user)) -> bool: - folder = f"{UPLOAD_DIR}" - try: - # Check if the directory exists - if os.path.exists(folder): - # Iterate over all the files and directories in the specified directory - for filename in os.listdir(folder): - file_path = os.path.join(folder, filename) - try: - if os.path.isfile(file_path) or os.path.islink(file_path): - os.unlink(file_path) # Remove the file or link - elif os.path.isdir(file_path): - shutil.rmtree(file_path) # Remove the directory - except Exception as e: - print(f"Failed to delete {file_path}. Reason: {e}") - else: - print(f"The directory {folder} does not exist") - except Exception as e: - print(f"Failed to process the directory {folder}. Reason: {e}") - return True - - -if ENV == "dev": - - @app.get("/ef") - async def get_embeddings(): - return {"result": app.state.EMBEDDING_FUNCTION("hello world")} - - @app.get("/ef/{text}") - async def get_embeddings_text(text: str): - return {"result": app.state.EMBEDDING_FUNCTION(text)} diff --git a/backend/open_webui/apps/retrieval/vector/connector.py b/backend/open_webui/apps/retrieval/vector/connector.py deleted file mode 100644 index c7f00f5fd..000000000 --- a/backend/open_webui/apps/retrieval/vector/connector.py +++ /dev/null @@ -1,14 +0,0 @@ -from open_webui.config import VECTOR_DB - -if VECTOR_DB == "milvus": - from open_webui.apps.retrieval.vector.dbs.milvus import MilvusClient - - VECTOR_DB_CLIENT = MilvusClient() -elif VECTOR_DB == "qdrant": - from open_webui.apps.retrieval.vector.dbs.qdrant import QdrantClient - - VECTOR_DB_CLIENT = QdrantClient() -else: - from open_webui.apps.retrieval.vector.dbs.chroma import ChromaClient - - VECTOR_DB_CLIENT = ChromaClient() diff --git a/backend/open_webui/apps/webui/main.py b/backend/open_webui/apps/webui/main.py deleted file mode 100644 index b32c83f88..000000000 --- a/backend/open_webui/apps/webui/main.py +++ /dev/null @@ -1,458 +0,0 @@ -import inspect -import json -import logging -import time -from typing import AsyncGenerator, Generator, Iterator - -from open_webui.apps.socket.main import get_event_call, get_event_emitter -from open_webui.apps.webui.models.functions import Functions -from open_webui.apps.webui.models.models import Models -from open_webui.apps.webui.routers import ( - auths, - chats, - folders, - configs, - files, - functions, - memories, - models, - knowledge, - prompts, - evaluations, - tools, - users, - utils, -) -from open_webui.apps.webui.utils import load_function_module_by_id -from open_webui.config import ( - ADMIN_EMAIL, - CORS_ALLOW_ORIGIN, - DEFAULT_MODELS, - DEFAULT_PROMPT_SUGGESTIONS, - DEFAULT_USER_ROLE, - ENABLE_COMMUNITY_SHARING, - ENABLE_LOGIN_FORM, - ENABLE_MESSAGE_RATING, - ENABLE_SIGNUP, - ENABLE_EVALUATION_ARENA_MODELS, - EVALUATION_ARENA_MODELS, - DEFAULT_ARENA_MODEL, - JWT_EXPIRES_IN, - ENABLE_OAUTH_ROLE_MANAGEMENT, - OAUTH_ROLES_CLAIM, - OAUTH_EMAIL_CLAIM, - OAUTH_PICTURE_CLAIM, - OAUTH_USERNAME_CLAIM, - OAUTH_ALLOWED_ROLES, - OAUTH_ADMIN_ROLES, - SHOW_ADMIN_DETAILS, - USER_PERMISSIONS, - WEBHOOK_URL, - WEBUI_AUTH, - WEBUI_BANNERS, - AppConfig, -) -from open_webui.env import ( - ENV, - WEBUI_AUTH_TRUSTED_EMAIL_HEADER, - WEBUI_AUTH_TRUSTED_NAME_HEADER, -) -from fastapi import FastAPI -from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import StreamingResponse -from pydantic import BaseModel -from open_webui.utils.misc import ( - openai_chat_chunk_message_template, - openai_chat_completion_message_template, -) -from open_webui.utils.payload import ( - apply_model_params_to_body_openai, - apply_model_system_prompt_to_body, -) - - -from open_webui.utils.tools import get_tools - -app = FastAPI(docs_url="/docs" if ENV == "dev" else None, openapi_url="/openapi.json" if ENV == "dev" else None, redoc_url=None) - -log = logging.getLogger(__name__) - -app.state.config = AppConfig() - -app.state.config.ENABLE_SIGNUP = ENABLE_SIGNUP -app.state.config.ENABLE_LOGIN_FORM = ENABLE_LOGIN_FORM -app.state.config.JWT_EXPIRES_IN = JWT_EXPIRES_IN -app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER -app.state.AUTH_TRUSTED_NAME_HEADER = WEBUI_AUTH_TRUSTED_NAME_HEADER - - -app.state.config.SHOW_ADMIN_DETAILS = SHOW_ADMIN_DETAILS -app.state.config.ADMIN_EMAIL = ADMIN_EMAIL - - -app.state.config.DEFAULT_MODELS = DEFAULT_MODELS -app.state.config.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS -app.state.config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE -app.state.config.USER_PERMISSIONS = USER_PERMISSIONS -app.state.config.WEBHOOK_URL = WEBHOOK_URL -app.state.config.BANNERS = WEBUI_BANNERS - -app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING -app.state.config.ENABLE_MESSAGE_RATING = ENABLE_MESSAGE_RATING - -app.state.config.ENABLE_EVALUATION_ARENA_MODELS = ENABLE_EVALUATION_ARENA_MODELS -app.state.config.EVALUATION_ARENA_MODELS = EVALUATION_ARENA_MODELS - -app.state.config.OAUTH_USERNAME_CLAIM = OAUTH_USERNAME_CLAIM -app.state.config.OAUTH_PICTURE_CLAIM = OAUTH_PICTURE_CLAIM -app.state.config.OAUTH_EMAIL_CLAIM = OAUTH_EMAIL_CLAIM - -app.state.config.ENABLE_OAUTH_ROLE_MANAGEMENT = ENABLE_OAUTH_ROLE_MANAGEMENT -app.state.config.OAUTH_ROLES_CLAIM = OAUTH_ROLES_CLAIM -app.state.config.OAUTH_ALLOWED_ROLES = OAUTH_ALLOWED_ROLES -app.state.config.OAUTH_ADMIN_ROLES = OAUTH_ADMIN_ROLES - -app.state.MODELS = {} -app.state.TOOLS = {} -app.state.FUNCTIONS = {} - -app.add_middleware( - CORSMiddleware, - allow_origins=CORS_ALLOW_ORIGIN, - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) - - -app.include_router(configs.router, prefix="/configs", tags=["configs"]) - -app.include_router(auths.router, prefix="/auths", tags=["auths"]) -app.include_router(users.router, prefix="/users", tags=["users"]) - -app.include_router(chats.router, prefix="/chats", tags=["chats"]) - -app.include_router(models.router, prefix="/models", tags=["models"]) -app.include_router(knowledge.router, prefix="/knowledge", tags=["knowledge"]) -app.include_router(prompts.router, prefix="/prompts", tags=["prompts"]) -app.include_router(tools.router, prefix="/tools", tags=["tools"]) -app.include_router(functions.router, prefix="/functions", tags=["functions"]) - -app.include_router(memories.router, prefix="/memories", tags=["memories"]) -app.include_router(evaluations.router, prefix="/evaluations", tags=["evaluations"]) - -app.include_router(folders.router, prefix="/folders", tags=["folders"]) -app.include_router(files.router, prefix="/files", tags=["files"]) - -app.include_router(utils.router, prefix="/utils", tags=["utils"]) - - -@app.get("/") -async def get_status(): - return { - "status": True, - "auth": WEBUI_AUTH, - "default_models": app.state.config.DEFAULT_MODELS, - "default_prompt_suggestions": app.state.config.DEFAULT_PROMPT_SUGGESTIONS, - } - - -async def get_all_models(): - models = [] - pipe_models = await get_pipe_models() - models = models + pipe_models - - if app.state.config.ENABLE_EVALUATION_ARENA_MODELS: - arena_models = [] - if len(app.state.config.EVALUATION_ARENA_MODELS) > 0: - arena_models = [ - { - "id": model["id"], - "name": model["name"], - "info": { - "meta": model["meta"], - }, - "object": "model", - "created": int(time.time()), - "owned_by": "arena", - "arena": True, - } - for model in app.state.config.EVALUATION_ARENA_MODELS - ] - else: - # Add default arena model - arena_models = [ - { - "id": DEFAULT_ARENA_MODEL["id"], - "name": DEFAULT_ARENA_MODEL["name"], - "info": { - "meta": DEFAULT_ARENA_MODEL["meta"], - }, - "object": "model", - "created": int(time.time()), - "owned_by": "arena", - "arena": True, - } - ] - models = models + arena_models - return models - - -def get_function_module(pipe_id: str): - # Check if function is already loaded - if pipe_id not in app.state.FUNCTIONS: - function_module, _, _ = load_function_module_by_id(pipe_id) - app.state.FUNCTIONS[pipe_id] = function_module - else: - function_module = app.state.FUNCTIONS[pipe_id] - - if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): - valves = Functions.get_function_valves_by_id(pipe_id) - function_module.valves = function_module.Valves(**(valves if valves else {})) - return function_module - - -async def get_pipe_models(): - pipes = Functions.get_functions_by_type("pipe", active_only=True) - pipe_models = [] - - for pipe in pipes: - function_module = get_function_module(pipe.id) - - # Check if function is a manifold - if hasattr(function_module, "pipes"): - sub_pipes = [] - - # Check if pipes is a function or a list - - try: - if callable(function_module.pipes): - sub_pipes = function_module.pipes() - else: - sub_pipes = function_module.pipes - except Exception as e: - log.exception(e) - sub_pipes = [] - - print(sub_pipes) - - for p in sub_pipes: - sub_pipe_id = f'{pipe.id}.{p["id"]}' - sub_pipe_name = p["name"] - - if hasattr(function_module, "name"): - sub_pipe_name = f"{function_module.name}{sub_pipe_name}" - - pipe_flag = {"type": pipe.type} - pipe_models.append( - { - "id": sub_pipe_id, - "name": sub_pipe_name, - "object": "model", - "created": pipe.created_at, - "owned_by": "openai", - "pipe": pipe_flag, - } - ) - else: - pipe_flag = {"type": "pipe"} - - pipe_models.append( - { - "id": pipe.id, - "name": pipe.name, - "object": "model", - "created": pipe.created_at, - "owned_by": "openai", - "pipe": pipe_flag, - } - ) - - return pipe_models - - -async def execute_pipe(pipe, params): - if inspect.iscoroutinefunction(pipe): - return await pipe(**params) - else: - return pipe(**params) - - -async def get_message_content(res: str | Generator | AsyncGenerator) -> str: - if isinstance(res, str): - return res - if isinstance(res, Generator): - return "".join(map(str, res)) - if isinstance(res, AsyncGenerator): - return "".join([str(stream) async for stream in res]) - - -def process_line(form_data: dict, line): - if isinstance(line, BaseModel): - line = line.model_dump_json() - line = f"data: {line}" - if isinstance(line, dict): - line = f"data: {json.dumps(line)}" - - try: - line = line.decode("utf-8") - except Exception: - pass - - if line.startswith("data:"): - return f"{line}\n\n" - else: - line = openai_chat_chunk_message_template(form_data["model"], line) - return f"data: {json.dumps(line)}\n\n" - - -def get_pipe_id(form_data: dict) -> str: - pipe_id = form_data["model"] - if "." in pipe_id: - pipe_id, _ = pipe_id.split(".", 1) - print(pipe_id) - return pipe_id - - -def get_function_params(function_module, form_data, user, extra_params=None): - if extra_params is None: - extra_params = {} - - pipe_id = get_pipe_id(form_data) - - # Get the signature of the function - sig = inspect.signature(function_module.pipe) - params = {"body": form_data} | { - k: v for k, v in extra_params.items() if k in sig.parameters - } - - if "__user__" in params and hasattr(function_module, "UserValves"): - user_valves = Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id) - try: - params["__user__"]["valves"] = function_module.UserValves(**user_valves) - except Exception as e: - log.exception(e) - params["__user__"]["valves"] = function_module.UserValves() - - return params - - -async def generate_function_chat_completion(form_data, user): - model_id = form_data.get("model") - model_info = Models.get_model_by_id(model_id) - - metadata = form_data.pop("metadata", {}) - - files = metadata.get("files", []) - tool_ids = metadata.get("tool_ids", []) - # Check if tool_ids is None - if tool_ids is None: - tool_ids = [] - - __event_emitter__ = None - __event_call__ = None - __task__ = None - __task_body__ = None - - if metadata: - if all(k in metadata for k in ("session_id", "chat_id", "message_id")): - __event_emitter__ = get_event_emitter(metadata) - __event_call__ = get_event_call(metadata) - __task__ = metadata.get("task", None) - __task_body__ = metadata.get("task_body", None) - - extra_params = { - "__event_emitter__": __event_emitter__, - "__event_call__": __event_call__, - "__task__": __task__, - "__task_body__": __task_body__, - "__files__": files, - "__user__": { - "id": user.id, - "email": user.email, - "name": user.name, - "role": user.role, - }, - } - extra_params["__tools__"] = get_tools( - app, - tool_ids, - user, - { - **extra_params, - "__model__": app.state.MODELS[form_data["model"]], - "__messages__": form_data["messages"], - "__files__": files, - }, - ) - - if model_info: - if model_info.base_model_id: - form_data["model"] = model_info.base_model_id - - params = model_info.params.model_dump() - form_data = apply_model_params_to_body_openai(params, form_data) - form_data = apply_model_system_prompt_to_body(params, form_data, user) - - pipe_id = get_pipe_id(form_data) - function_module = get_function_module(pipe_id) - - pipe = function_module.pipe - params = get_function_params(function_module, form_data, user, extra_params) - - if form_data.get("stream", False): - - async def stream_content(): - try: - res = await execute_pipe(pipe, params) - - # Directly return if the response is a StreamingResponse - if isinstance(res, StreamingResponse): - async for data in res.body_iterator: - yield data - return - if isinstance(res, dict): - yield f"data: {json.dumps(res)}\n\n" - return - - except Exception as e: - print(f"Error: {e}") - yield f"data: {json.dumps({'error': {'detail':str(e)}})}\n\n" - return - - if isinstance(res, str): - message = openai_chat_chunk_message_template(form_data["model"], res) - yield f"data: {json.dumps(message)}\n\n" - - if isinstance(res, Iterator): - for line in res: - yield process_line(form_data, line) - - if isinstance(res, AsyncGenerator): - async for line in res: - yield process_line(form_data, line) - - if isinstance(res, str) or isinstance(res, Generator): - finish_message = openai_chat_chunk_message_template( - form_data["model"], "" - ) - finish_message["choices"][0]["finish_reason"] = "stop" - yield f"data: {json.dumps(finish_message)}\n\n" - yield "data: [DONE]" - - return StreamingResponse(stream_content(), media_type="text/event-stream") - else: - try: - res = await execute_pipe(pipe, params) - - except Exception as e: - print(f"Error: {e}") - return {"error": {"detail": str(e)}} - - if isinstance(res, StreamingResponse) or isinstance(res, dict): - return res - if isinstance(res, BaseModel): - return res.model_dump() - - message = await get_message_content(res) - return openai_chat_completion_message_template(form_data["model"], message) diff --git a/backend/open_webui/apps/webui/models/documents.py b/backend/open_webui/apps/webui/models/documents.py deleted file mode 100644 index 0b96c2574..000000000 --- a/backend/open_webui/apps/webui/models/documents.py +++ /dev/null @@ -1,157 +0,0 @@ -import json -import logging -import time -from typing import Optional - -from open_webui.apps.webui.internal.db import Base, get_db -from open_webui.env import SRC_LOG_LEVELS -from pydantic import BaseModel, ConfigDict -from sqlalchemy import BigInteger, Column, String, Text - -log = logging.getLogger(__name__) -log.setLevel(SRC_LOG_LEVELS["MODELS"]) - -#################### -# Documents DB Schema -#################### - - -class Document(Base): - __tablename__ = "document" - - collection_name = Column(String, primary_key=True) - name = Column(String, unique=True) - title = Column(Text) - filename = Column(Text) - content = Column(Text, nullable=True) - user_id = Column(String) - timestamp = Column(BigInteger) - - -class DocumentModel(BaseModel): - model_config = ConfigDict(from_attributes=True) - - collection_name: str - name: str - title: str - filename: str - content: Optional[str] = None - user_id: str - timestamp: int # timestamp in epoch - - -#################### -# Forms -#################### - - -class DocumentResponse(BaseModel): - collection_name: str - name: str - title: str - filename: str - content: Optional[dict] = None - user_id: str - timestamp: int # timestamp in epoch - - -class DocumentUpdateForm(BaseModel): - name: str - title: str - - -class DocumentForm(DocumentUpdateForm): - collection_name: str - filename: str - content: Optional[str] = None - - -class DocumentsTable: - def insert_new_doc( - self, user_id: str, form_data: DocumentForm - ) -> Optional[DocumentModel]: - with get_db() as db: - document = DocumentModel( - **{ - **form_data.model_dump(), - "user_id": user_id, - "timestamp": int(time.time()), - } - ) - - try: - result = Document(**document.model_dump()) - db.add(result) - db.commit() - db.refresh(result) - if result: - return DocumentModel.model_validate(result) - else: - return None - except Exception: - return None - - def get_doc_by_name(self, name: str) -> Optional[DocumentModel]: - try: - with get_db() as db: - document = db.query(Document).filter_by(name=name).first() - return DocumentModel.model_validate(document) if document else None - except Exception: - return None - - def get_docs(self) -> list[DocumentModel]: - with get_db() as db: - return [ - DocumentModel.model_validate(doc) for doc in db.query(Document).all() - ] - - def update_doc_by_name( - self, name: str, form_data: DocumentUpdateForm - ) -> Optional[DocumentModel]: - try: - with get_db() as db: - db.query(Document).filter_by(name=name).update( - { - "title": form_data.title, - "name": form_data.name, - "timestamp": int(time.time()), - } - ) - db.commit() - return self.get_doc_by_name(form_data.name) - except Exception as e: - log.exception(e) - return None - - def update_doc_content_by_name( - self, name: str, updated: dict - ) -> Optional[DocumentModel]: - try: - doc = self.get_doc_by_name(name) - doc_content = json.loads(doc.content if doc.content else "{}") - doc_content = {**doc_content, **updated} - - with get_db() as db: - db.query(Document).filter_by(name=name).update( - { - "content": json.dumps(doc_content), - "timestamp": int(time.time()), - } - ) - db.commit() - return self.get_doc_by_name(name) - except Exception as e: - log.exception(e) - return None - - def delete_doc_by_name(self, name: str) -> bool: - try: - with get_db() as db: - db.query(Document).filter_by(name=name).delete() - db.commit() - return True - except Exception: - return False - - -Documents = DocumentsTable() diff --git a/backend/open_webui/apps/webui/routers/documents.py b/backend/open_webui/apps/webui/routers/documents.py deleted file mode 100644 index c8f27852f..000000000 --- a/backend/open_webui/apps/webui/routers/documents.py +++ /dev/null @@ -1,155 +0,0 @@ -import json -from typing import Optional - -from open_webui.apps.webui.models.documents import ( - DocumentForm, - DocumentResponse, - Documents, - DocumentUpdateForm, -) -from open_webui.constants import ERROR_MESSAGES -from fastapi import APIRouter, Depends, HTTPException, status -from pydantic import BaseModel -from open_webui.utils.utils import get_admin_user, get_verified_user - -router = APIRouter() - -############################ -# GetDocuments -############################ - - -@router.get("/", response_model=list[DocumentResponse]) -async def get_documents(user=Depends(get_verified_user)): - docs = [ - DocumentResponse( - **{ - **doc.model_dump(), - "content": json.loads(doc.content if doc.content else "{}"), - } - ) - for doc in Documents.get_docs() - ] - return docs - - -############################ -# CreateNewDoc -############################ - - -@router.post("/create", response_model=Optional[DocumentResponse]) -async def create_new_doc(form_data: DocumentForm, user=Depends(get_admin_user)): - doc = Documents.get_doc_by_name(form_data.name) - if doc is None: - doc = Documents.insert_new_doc(user.id, form_data) - - if doc: - return DocumentResponse( - **{ - **doc.model_dump(), - "content": json.loads(doc.content if doc.content else "{}"), - } - ) - else: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.FILE_EXISTS, - ) - else: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.NAME_TAG_TAKEN, - ) - - -############################ -# GetDocByName -############################ - - -@router.get("/doc", response_model=Optional[DocumentResponse]) -async def get_doc_by_name(name: str, user=Depends(get_verified_user)): - doc = Documents.get_doc_by_name(name) - - if doc: - return DocumentResponse( - **{ - **doc.model_dump(), - "content": json.loads(doc.content if doc.content else "{}"), - } - ) - else: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail=ERROR_MESSAGES.NOT_FOUND, - ) - - -############################ -# TagDocByName -############################ - - -class TagItem(BaseModel): - name: str - - -class TagDocumentForm(BaseModel): - name: str - tags: list[dict] - - -@router.post("/doc/tags", response_model=Optional[DocumentResponse]) -async def tag_doc_by_name(form_data: TagDocumentForm, user=Depends(get_verified_user)): - doc = Documents.update_doc_content_by_name(form_data.name, {"tags": form_data.tags}) - - if doc: - return DocumentResponse( - **{ - **doc.model_dump(), - "content": json.loads(doc.content if doc.content else "{}"), - } - ) - else: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail=ERROR_MESSAGES.NOT_FOUND, - ) - - -############################ -# UpdateDocByName -############################ - - -@router.post("/doc/update", response_model=Optional[DocumentResponse]) -async def update_doc_by_name( - name: str, - form_data: DocumentUpdateForm, - user=Depends(get_admin_user), -): - doc = Documents.update_doc_by_name(name, form_data) - if doc: - return DocumentResponse( - **{ - **doc.model_dump(), - "content": json.loads(doc.content if doc.content else "{}"), - } - ) - else: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.NAME_TAG_TAKEN, - ) - - -############################ -# DeleteDocByName -############################ - - -@router.delete("/doc/delete", response_model=bool) -async def delete_doc_by_name(name: str, user=Depends(get_admin_user)): - result = Documents.delete_doc_by_name(name) - return result diff --git a/backend/open_webui/apps/webui/routers/models.py b/backend/open_webui/apps/webui/routers/models.py deleted file mode 100644 index a5cb2395e..000000000 --- a/backend/open_webui/apps/webui/routers/models.py +++ /dev/null @@ -1,104 +0,0 @@ -from typing import Optional - -from open_webui.apps.webui.models.models import ( - ModelForm, - ModelModel, - ModelResponse, - Models, -) -from open_webui.constants import ERROR_MESSAGES -from fastapi import APIRouter, Depends, HTTPException, Request, status -from open_webui.utils.utils import get_admin_user, get_verified_user - -router = APIRouter() - -########################### -# getModels -########################### - - -@router.get("/", response_model=list[ModelResponse]) -async def get_models(id: Optional[str] = None, user=Depends(get_verified_user)): - if id: - model = Models.get_model_by_id(id) - if model: - return [model] - else: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail=ERROR_MESSAGES.NOT_FOUND, - ) - else: - return Models.get_all_models() - - -############################ -# AddNewModel -############################ - - -@router.post("/add", response_model=Optional[ModelModel]) -async def add_new_model( - request: Request, - form_data: ModelForm, - user=Depends(get_admin_user), -): - if form_data.id in request.app.state.MODELS: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail=ERROR_MESSAGES.MODEL_ID_TAKEN, - ) - else: - model = Models.insert_new_model(form_data, user.id) - - if model: - return model - else: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail=ERROR_MESSAGES.DEFAULT(), - ) - - -############################ -# UpdateModelById -############################ - - -@router.post("/update", response_model=Optional[ModelModel]) -async def update_model_by_id( - request: Request, - id: str, - form_data: ModelForm, - user=Depends(get_admin_user), -): - model = Models.get_model_by_id(id) - if model: - model = Models.update_model_by_id(id, form_data) - return model - else: - if form_data.id in request.app.state.MODELS: - model = Models.insert_new_model(form_data, user.id) - if model: - return model - else: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail=ERROR_MESSAGES.DEFAULT(), - ) - else: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail=ERROR_MESSAGES.DEFAULT(), - ) - - -############################ -# DeleteModelById -############################ - - -@router.delete("/delete", response_model=bool) -async def delete_model_by_id(id: str, user=Depends(get_admin_user)): - result = Models.delete_model_by_id(id) - return result diff --git a/backend/open_webui/apps/webui/routers/prompts.py b/backend/open_webui/apps/webui/routers/prompts.py deleted file mode 100644 index 593c643b9..000000000 --- a/backend/open_webui/apps/webui/routers/prompts.py +++ /dev/null @@ -1,90 +0,0 @@ -from typing import Optional - -from open_webui.apps.webui.models.prompts import PromptForm, PromptModel, Prompts -from open_webui.constants import ERROR_MESSAGES -from fastapi import APIRouter, Depends, HTTPException, status -from open_webui.utils.utils import get_admin_user, get_verified_user - -router = APIRouter() - -############################ -# GetPrompts -############################ - - -@router.get("/", response_model=list[PromptModel]) -async def get_prompts(user=Depends(get_verified_user)): - return Prompts.get_prompts() - - -############################ -# CreateNewPrompt -############################ - - -@router.post("/create", response_model=Optional[PromptModel]) -async def create_new_prompt(form_data: PromptForm, user=Depends(get_admin_user)): - prompt = Prompts.get_prompt_by_command(form_data.command) - if prompt is None: - prompt = Prompts.insert_new_prompt(user.id, form_data) - - if prompt: - return prompt - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.DEFAULT(), - ) - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.COMMAND_TAKEN, - ) - - -############################ -# GetPromptByCommand -############################ - - -@router.get("/command/{command}", response_model=Optional[PromptModel]) -async def get_prompt_by_command(command: str, user=Depends(get_verified_user)): - prompt = Prompts.get_prompt_by_command(f"/{command}") - - if prompt: - return prompt - else: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail=ERROR_MESSAGES.NOT_FOUND, - ) - - -############################ -# UpdatePromptByCommand -############################ - - -@router.post("/command/{command}/update", response_model=Optional[PromptModel]) -async def update_prompt_by_command( - command: str, - form_data: PromptForm, - user=Depends(get_admin_user), -): - prompt = Prompts.update_prompt_by_command(f"/{command}", form_data) - if prompt: - return prompt - else: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail=ERROR_MESSAGES.ACCESS_PROHIBITED, - ) - - -############################ -# DeletePromptByCommand -############################ - - -@router.delete("/command/{command}/delete", response_model=bool) -async def delete_prompt_by_command(command: str, user=Depends(get_admin_user)): - result = Prompts.delete_prompt_by_command(f"/{command}") - return result diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index c3ab184fb..fbbc80628 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -10,7 +10,7 @@ from urllib.parse import urlparse import chromadb import requests import yaml -from open_webui.apps.webui.internal.db import Base, get_db +from open_webui.internal.db import Base, get_db from open_webui.env import ( OPEN_WEBUI_DIR, DATA_DIR, @@ -20,6 +20,8 @@ from open_webui.env import ( WEBUI_FAVICON_URL, WEBUI_NAME, log, + DATABASE_URL, + OFFLINE_MODE ) from pydantic import BaseModel from sqlalchemy import JSON, Column, DateTime, Integer, func @@ -264,6 +266,13 @@ class AppConfig: # WEBUI_AUTH (Required for security) #################################### +ENABLE_API_KEY = PersistentConfig( + "ENABLE_API_KEY", + "auth.api_key.enable", + os.environ.get("ENABLE_API_KEY", "True").lower() == "true", +) + + JWT_EXPIRES_IN = PersistentConfig( "JWT_EXPIRES_IN", "auth.jwt_expiry", os.environ.get("JWT_EXPIRES_IN", "-1") ) @@ -433,6 +442,15 @@ OAUTH_ADMIN_ROLES = PersistentConfig( [role.strip() for role in os.environ.get("OAUTH_ADMIN_ROLES", "admin").split(",")], ) +OAUTH_ALLOWED_DOMAINS = PersistentConfig( + "OAUTH_ALLOWED_DOMAINS", + "oauth.allowed_domains", + [ + domain.strip() + for domain in os.environ.get("OAUTH_ALLOWED_DOMAINS", "*").split(",") + ], +) + def load_oauth_providers(): OAUTH_PROVIDERS.clear() @@ -587,6 +605,12 @@ OLLAMA_API_BASE_URL = os.environ.get( ) OLLAMA_BASE_URL = os.environ.get("OLLAMA_BASE_URL", "") +if OLLAMA_BASE_URL: + # Remove trailing slash + OLLAMA_BASE_URL = ( + OLLAMA_BASE_URL[:-1] if OLLAMA_BASE_URL.endswith("/") else OLLAMA_BASE_URL + ) + K8S_FLAG = os.environ.get("K8S_FLAG", "") USE_OLLAMA_DOCKER = os.environ.get("USE_OLLAMA_DOCKER", "false") @@ -618,6 +642,12 @@ OLLAMA_BASE_URLS = PersistentConfig( "OLLAMA_BASE_URLS", "ollama.base_urls", OLLAMA_BASE_URLS ) +OLLAMA_API_CONFIGS = PersistentConfig( + "OLLAMA_API_CONFIGS", + "ollama.api_configs", + {}, +) + #################################### # OPENAI_API #################################### @@ -658,15 +688,20 @@ OPENAI_API_BASE_URLS = PersistentConfig( "OPENAI_API_BASE_URLS", "openai.api_base_urls", OPENAI_API_BASE_URLS ) -OPENAI_API_KEY = "" +OPENAI_API_CONFIGS = PersistentConfig( + "OPENAI_API_CONFIGS", + "openai.api_configs", + {}, +) +# Get the actual OpenAI API key based on the base URL +OPENAI_API_KEY = "" try: OPENAI_API_KEY = OPENAI_API_KEYS.value[ OPENAI_API_BASE_URLS.value.index("https://api.openai.com/v1") ] except Exception: pass - OPENAI_API_BASE_URL = "https://api.openai.com/v1" #################################### @@ -689,6 +724,7 @@ ENABLE_LOGIN_FORM = PersistentConfig( os.environ.get("ENABLE_LOGIN_FORM", "True").lower() == "true", ) + DEFAULT_LOCALE = PersistentConfig( "DEFAULT_LOCALE", "ui.default_locale", @@ -733,18 +769,47 @@ DEFAULT_PROMPT_SUGGESTIONS = PersistentConfig( ], ) +MODEL_ORDER_LIST = PersistentConfig( + "MODEL_ORDER_LIST", + "ui.model_order_list", + [], +) + DEFAULT_USER_ROLE = PersistentConfig( "DEFAULT_USER_ROLE", "ui.default_user_role", os.getenv("DEFAULT_USER_ROLE", "pending"), ) -USER_PERMISSIONS_CHAT_DELETION = ( - os.environ.get("USER_PERMISSIONS_CHAT_DELETION", "True").lower() == "true" +USER_PERMISSIONS_WORKSPACE_MODELS_ACCESS = ( + os.environ.get("USER_PERMISSIONS_WORKSPACE_MODELS_ACCESS", "False").lower() + == "true" ) -USER_PERMISSIONS_CHAT_EDITING = ( - os.environ.get("USER_PERMISSIONS_CHAT_EDITING", "True").lower() == "true" +USER_PERMISSIONS_WORKSPACE_KNOWLEDGE_ACCESS = ( + os.environ.get("USER_PERMISSIONS_WORKSPACE_KNOWLEDGE_ACCESS", "False").lower() + == "true" +) + +USER_PERMISSIONS_WORKSPACE_PROMPTS_ACCESS = ( + os.environ.get("USER_PERMISSIONS_WORKSPACE_PROMPTS_ACCESS", "False").lower() + == "true" +) + +USER_PERMISSIONS_WORKSPACE_TOOLS_ACCESS = ( + os.environ.get("USER_PERMISSIONS_WORKSPACE_TOOLS_ACCESS", "False").lower() == "true" +) + +USER_PERMISSIONS_CHAT_FILE_UPLOAD = ( + os.environ.get("USER_PERMISSIONS_CHAT_FILE_UPLOAD", "True").lower() == "true" +) + +USER_PERMISSIONS_CHAT_DELETE = ( + os.environ.get("USER_PERMISSIONS_CHAT_DELETE", "True").lower() == "true" +) + +USER_PERMISSIONS_CHAT_EDIT = ( + os.environ.get("USER_PERMISSIONS_CHAT_EDIT", "True").lower() == "true" ) USER_PERMISSIONS_CHAT_TEMPORARY = ( @@ -753,13 +818,20 @@ USER_PERMISSIONS_CHAT_TEMPORARY = ( USER_PERMISSIONS = PersistentConfig( "USER_PERMISSIONS", - "ui.user_permissions", + "user.permissions", { + "workspace": { + "models": USER_PERMISSIONS_WORKSPACE_MODELS_ACCESS, + "knowledge": USER_PERMISSIONS_WORKSPACE_KNOWLEDGE_ACCESS, + "prompts": USER_PERMISSIONS_WORKSPACE_PROMPTS_ACCESS, + "tools": USER_PERMISSIONS_WORKSPACE_TOOLS_ACCESS, + }, "chat": { - "deletion": USER_PERMISSIONS_CHAT_DELETION, - "editing": USER_PERMISSIONS_CHAT_EDITING, + "file_upload": USER_PERMISSIONS_CHAT_FILE_UPLOAD, + "delete": USER_PERMISSIONS_CHAT_DELETE, + "edit": USER_PERMISSIONS_CHAT_EDIT, "temporary": USER_PERMISSIONS_CHAT_TEMPORARY, - } + }, }, ) @@ -785,18 +857,6 @@ DEFAULT_ARENA_MODEL = { }, } -ENABLE_MODEL_FILTER = PersistentConfig( - "ENABLE_MODEL_FILTER", - "model_filter.enable", - os.environ.get("ENABLE_MODEL_FILTER", "False").lower() == "true", -) -MODEL_FILTER_LIST = os.environ.get("MODEL_FILTER_LIST", "") -MODEL_FILTER_LIST = PersistentConfig( - "MODEL_FILTER_LIST", - "model_filter.list", - [model.strip() for model in MODEL_FILTER_LIST.split(";")], -) - WEBHOOK_URL = PersistentConfig( "WEBHOOK_URL", "webhook_url", os.environ.get("WEBHOOK_URL", "") ) @@ -910,25 +970,155 @@ TITLE_GENERATION_PROMPT_TEMPLATE = PersistentConfig( os.environ.get("TITLE_GENERATION_PROMPT_TEMPLATE", ""), ) +DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE = """Create a concise, 3-5 word title with an emoji as a title for the chat history, in the given language. Suitable Emojis for the summary can be used to enhance understanding but avoid quotation marks or special formatting. RESPOND ONLY WITH THE TITLE TEXT. + +Examples of titles: +📉 Stock Market Trends +🍪 Perfect Chocolate Chip Recipe +Evolution of Music Streaming +Remote Work Productivity Tips +Artificial Intelligence in Healthcare +🎮 Video Game Development Insights + + +{{MESSAGES:END:2}} +""" + + TAGS_GENERATION_PROMPT_TEMPLATE = PersistentConfig( "TAGS_GENERATION_PROMPT_TEMPLATE", "task.tags.prompt_template", os.environ.get("TAGS_GENERATION_PROMPT_TEMPLATE", ""), ) -ENABLE_SEARCH_QUERY = PersistentConfig( - "ENABLE_SEARCH_QUERY", - "task.search.enable", - os.environ.get("ENABLE_SEARCH_QUERY", "True").lower() == "true", +DEFAULT_TAGS_GENERATION_PROMPT_TEMPLATE = """### Task: +Generate 1-3 broad tags categorizing the main themes of the chat history, along with 1-3 more specific subtopic tags. + +### Guidelines: +- Start with high-level domains (e.g. Science, Technology, Philosophy, Arts, Politics, Business, Health, Sports, Entertainment, Education) +- Consider including relevant subfields/subdomains if they are strongly represented throughout the conversation +- If content is too short (less than 3 messages) or too diverse, use only ["General"] +- Use the chat's primary language; default to English if multilingual +- Prioritize accuracy over specificity + +### Output: +JSON format: { "tags": ["tag1", "tag2", "tag3"] } + +### Chat History: + +{{MESSAGES:END:6}} +""" + +ENABLE_TAGS_GENERATION = PersistentConfig( + "ENABLE_TAGS_GENERATION", + "task.tags.enable", + os.environ.get("ENABLE_TAGS_GENERATION", "True").lower() == "true", ) -SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = PersistentConfig( - "SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE", - "task.search.prompt_template", - os.environ.get("SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE", ""), +ENABLE_SEARCH_QUERY_GENERATION = PersistentConfig( + "ENABLE_SEARCH_QUERY_GENERATION", + "task.query.search.enable", + os.environ.get("ENABLE_SEARCH_QUERY_GENERATION", "True").lower() == "true", ) +ENABLE_RETRIEVAL_QUERY_GENERATION = PersistentConfig( + "ENABLE_RETRIEVAL_QUERY_GENERATION", + "task.query.retrieval.enable", + os.environ.get("ENABLE_RETRIEVAL_QUERY_GENERATION", "True").lower() == "true", +) + + +QUERY_GENERATION_PROMPT_TEMPLATE = PersistentConfig( + "QUERY_GENERATION_PROMPT_TEMPLATE", + "task.query.prompt_template", + os.environ.get("QUERY_GENERATION_PROMPT_TEMPLATE", ""), +) + +DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE = """### Task: +Analyze the chat history to determine the necessity of generating search queries, in the given language. By default, **prioritize generating 1-3 broad and relevant search queries** unless it is absolutely certain that no additional information is required. The aim is to retrieve comprehensive, updated, and valuable information even with minimal uncertainty. If no search is unequivocally needed, return an empty list. + +### Guidelines: +- Respond **EXCLUSIVELY** with a JSON object. Any form of extra commentary, explanation, or additional text is strictly prohibited. +- When generating search queries, respond in the format: { "queries": ["query1", "query2"] }, ensuring each query is distinct, concise, and relevant to the topic. +- If and only if it is entirely certain that no useful results can be retrieved by a search, return: { "queries": [] }. +- Err on the side of suggesting search queries if there is **any chance** they might provide useful or updated information. +- Be concise and focused on composing high-quality search queries, avoiding unnecessary elaboration, commentary, or assumptions. +- Today's date is: {{CURRENT_DATE}}. +- Always prioritize providing actionable and broad queries that maximize informational coverage. + +### Output: +Strictly return in JSON format: +{ + "queries": ["query1", "query2"] +} + +### Chat History: + +{{MESSAGES:END:6}} + +""" + +ENABLE_AUTOCOMPLETE_GENERATION = PersistentConfig( + "ENABLE_AUTOCOMPLETE_GENERATION", + "task.autocomplete.enable", + os.environ.get("ENABLE_AUTOCOMPLETE_GENERATION", "True").lower() == "true", +) + +AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH = PersistentConfig( + "AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH", + "task.autocomplete.input_max_length", + int(os.environ.get("AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH", "-1")), +) + +AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE = PersistentConfig( + "AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE", + "task.autocomplete.prompt_template", + os.environ.get("AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE", ""), +) + + +DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE = """### Task: +You are an autocompletion system. Continue the text in `` based on the **completion type** in `` and the given language. + +### **Instructions**: +1. Analyze `` for context and meaning. +2. Use `` to guide your output: + - **General**: Provide a natural, concise continuation. + - **Search Query**: Complete as if generating a realistic search query. +3. Start as if you are directly continuing ``. Do **not** repeat, paraphrase, or respond as a model. Simply complete the text. +4. Ensure the continuation: + - Flows naturally from ``. + - Avoids repetition, overexplaining, or unrelated ideas. +5. If unsure, return: `{ "text": "" }`. + +### **Output Rules**: +- Respond only in JSON format: `{ "text": "" }`. + +### **Examples**: +#### Example 1: +Input: +General +The sun was setting over the horizon, painting the sky +Output: +{ "text": "with vibrant shades of orange and pink." } + +#### Example 2: +Input: +Search Query +Top-rated restaurants in +Output: +{ "text": "New York City for Italian cuisine." } + +--- +### Context: + +{{MESSAGES:END:6}} + +{{TYPE}} +{{PROMPT}} +#### Output: +""" TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = PersistentConfig( "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE", @@ -937,6 +1127,19 @@ TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = PersistentConfig( ) +DEFAULT_TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = """Available Tools: {{TOOLS}}\nReturn an empty string if no tools match the query. If a function tool matches, construct and return a JSON object in the format {\"name\": \"functionName\", \"parameters\": {\"requiredFunctionParamKey\": \"requiredFunctionParamValue\"}} using the appropriate tool and its parameters. Only return the object and limit the response to the JSON object without additional text.""" + + +DEFAULT_EMOJI_GENERATION_PROMPT_TEMPLATE = """Your task is to reflect the speaker's likely facial expression through a fitting emoji. Interpret emotions from the message and reflect their facial expression using fitting, diverse emojis (e.g., 😊, 😢, 😡, 😱). + +Message: ```{{prompt}}```""" + +DEFAULT_MOA_GENERATION_PROMPT_TEMPLATE = """You have been provided with a set of responses from various models to the latest user query: "{{prompt}}" + +Your task is to synthesize these responses into a single, high-quality response. It is crucial to critically evaluate the information provided in these responses, recognizing that some of it may be biased or incorrect. Your response should not simply replicate the given answers but should offer a refined, accurate, and comprehensive reply to the instruction. Ensure your response is well-structured, coherent, and adheres to the highest standards of accuracy and reliability. + +Responses from models: {{responses}}""" + #################################### # Vector Database #################################### @@ -949,6 +1152,8 @@ CHROMA_TENANT = os.environ.get("CHROMA_TENANT", chromadb.DEFAULT_TENANT) CHROMA_DATABASE = os.environ.get("CHROMA_DATABASE", chromadb.DEFAULT_DATABASE) CHROMA_HTTP_HOST = os.environ.get("CHROMA_HTTP_HOST", "") CHROMA_HTTP_PORT = int(os.environ.get("CHROMA_HTTP_PORT", "8000")) +CHROMA_CLIENT_AUTH_PROVIDER = os.environ.get("CHROMA_CLIENT_AUTH_PROVIDER", "") +CHROMA_CLIENT_AUTH_CREDENTIALS = os.environ.get("CHROMA_CLIENT_AUTH_CREDENTIALS", "") # Comma-separated list of header=value pairs CHROMA_HTTP_HEADERS = os.environ.get("CHROMA_HTTP_HEADERS", "") if CHROMA_HTTP_HEADERS: @@ -966,6 +1171,21 @@ MILVUS_URI = os.environ.get("MILVUS_URI", f"{DATA_DIR}/vector_db/milvus.db") # Qdrant QDRANT_URI = os.environ.get("QDRANT_URI", None) +QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY", None) + +# OpenSearch +OPENSEARCH_URI = os.environ.get("OPENSEARCH_URI", "https://localhost:9200") +OPENSEARCH_SSL = os.environ.get("OPENSEARCH_SSL", True) +OPENSEARCH_CERT_VERIFY = os.environ.get("OPENSEARCH_CERT_VERIFY", False) +OPENSEARCH_USERNAME = os.environ.get("OPENSEARCH_USERNAME", None) +OPENSEARCH_PASSWORD = os.environ.get("OPENSEARCH_PASSWORD", None) + +# Pgvector +PGVECTOR_DB_URL = os.environ.get("PGVECTOR_DB_URL", DATABASE_URL) +if VECTOR_DB == "pgvector" and not PGVECTOR_DB_URL.startswith("postgres"): + raise ValueError( + "Pgvector requires setting PGVECTOR_DB_URL or using Postgres with vector extension as the primary database." + ) #################################### # Information Retrieval (RAG) @@ -1045,11 +1265,11 @@ RAG_EMBEDDING_MODEL = PersistentConfig( log.info(f"Embedding model set: {RAG_EMBEDDING_MODEL.value}") RAG_EMBEDDING_MODEL_AUTO_UPDATE = ( - os.environ.get("RAG_EMBEDDING_MODEL_AUTO_UPDATE", "").lower() == "true" + not OFFLINE_MODE and os.environ.get("RAG_EMBEDDING_MODEL_AUTO_UPDATE", "True").lower() == "true" ) RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE = ( - os.environ.get("RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true" + os.environ.get("RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE", "True").lower() == "true" ) RAG_EMBEDDING_BATCH_SIZE = PersistentConfig( @@ -1070,11 +1290,11 @@ if RAG_RERANKING_MODEL.value != "": log.info(f"Reranking model set: {RAG_RERANKING_MODEL.value}") RAG_RERANKING_MODEL_AUTO_UPDATE = ( - os.environ.get("RAG_RERANKING_MODEL_AUTO_UPDATE", "").lower() == "true" + not OFFLINE_MODE and os.environ.get("RAG_RERANKING_MODEL_AUTO_UPDATE", "True").lower() == "true" ) RAG_RERANKING_MODEL_TRUST_REMOTE_CODE = ( - os.environ.get("RAG_RERANKING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true" + os.environ.get("RAG_RERANKING_MODEL_TRUST_REMOTE_CODE", "True").lower() == "true" ) @@ -1102,21 +1322,32 @@ CHUNK_OVERLAP = PersistentConfig( int(os.environ.get("CHUNK_OVERLAP", "100")), ) -DEFAULT_RAG_TEMPLATE = """You are given a user query, some textual context and rules, all inside xml tags. You have to answer the query based on the context while respecting the rules. +DEFAULT_RAG_TEMPLATE = """### Task: +Respond to the user query using the provided context, incorporating inline citations in the format [source_id] **only when the tag is explicitly provided** in the context. + +### Guidelines: +- If you don't know the answer, clearly state that. +- If uncertain, ask the user for clarification. +- Respond in the same language as the user's query. +- If the context is unreadable or of poor quality, inform the user and provide the best possible answer. +- If the answer isn't present in the context but you possess the knowledge, explain this to the user and provide the answer using your own understanding. +- **Only include inline citations using [source_id] when a tag is explicitly provided in the context.** +- Do not cite if the tag is not provided in the context. +- Do not use XML tags in your response. +- Ensure citations are concise and directly related to the information provided. + +### Example of Citation: +If the user asks about a specific topic and the information is found in "whitepaper.pdf" with a provided , the response should include the citation like so: +* "According to the study, the proposed method increases efficiency by 20% [whitepaper.pdf]." +If no is present, the response should omit the citation. + +### Output: +Provide a clear and direct response to the user's query, including inline citations in the format [source_id] only when the tag is present in the context. {{CONTEXT}} - -- If you don't know, just say so. -- If you are not sure, ask for clarification. -- Answer in the same language as the user query. -- If the context appears unreadable or of poor quality, tell the user then answer as best as you can. -- If the answer is not in the context but you think you know the answer, explain that to the user then answer with your own knowledge. -- Answer directly and without using xml tags. - - {{QUERY}} @@ -1139,6 +1370,19 @@ RAG_OPENAI_API_KEY = PersistentConfig( os.getenv("RAG_OPENAI_API_KEY", OPENAI_API_KEY), ) +RAG_OLLAMA_BASE_URL = PersistentConfig( + "RAG_OLLAMA_BASE_URL", + "rag.ollama.url", + os.getenv("RAG_OLLAMA_BASE_URL", OLLAMA_BASE_URL), +) + +RAG_OLLAMA_API_KEY = PersistentConfig( + "RAG_OLLAMA_API_KEY", + "rag.ollama.key", + os.getenv("RAG_OLLAMA_API_KEY", ""), +) + + ENABLE_RAG_LOCAL_WEB_FETCH = ( os.getenv("ENABLE_RAG_LOCAL_WEB_FETCH", "False").lower() == "true" ) @@ -1149,6 +1393,12 @@ YOUTUBE_LOADER_LANGUAGE = PersistentConfig( os.getenv("YOUTUBE_LOADER_LANGUAGE", "en").split(","), ) +YOUTUBE_LOADER_PROXY_URL = PersistentConfig( + "YOUTUBE_LOADER_PROXY_URL", + "rag.youtube_loader_proxy_url", + os.getenv("YOUTUBE_LOADER_PROXY_URL", ""), +) + ENABLE_RAG_WEB_SEARCH = PersistentConfig( "ENABLE_RAG_WEB_SEARCH", @@ -1198,6 +1448,18 @@ BRAVE_SEARCH_API_KEY = PersistentConfig( os.getenv("BRAVE_SEARCH_API_KEY", ""), ) +KAGI_SEARCH_API_KEY = PersistentConfig( + "KAGI_SEARCH_API_KEY", + "rag.web.search.kagi_search_api_key", + os.getenv("KAGI_SEARCH_API_KEY", ""), +) + +MOJEEK_SEARCH_API_KEY = PersistentConfig( + "MOJEEK_SEARCH_API_KEY", + "rag.web.search.mojeek_search_api_key", + os.getenv("MOJEEK_SEARCH_API_KEY", ""), +) + SERPSTACK_API_KEY = PersistentConfig( "SERPSTACK_API_KEY", "rag.web.search.serpstack_api_key", @@ -1228,6 +1490,12 @@ TAVILY_API_KEY = PersistentConfig( os.getenv("TAVILY_API_KEY", ""), ) +JINA_API_KEY = PersistentConfig( + "JINA_API_KEY", + "rag.web.search.jina_api_key", + os.getenv("JINA_API_KEY", ""), +) + SEARCHAPI_API_KEY = PersistentConfig( "SEARCHAPI_API_KEY", "rag.web.search.searchapi_api_key", @@ -1240,6 +1508,21 @@ SEARCHAPI_ENGINE = PersistentConfig( os.getenv("SEARCHAPI_ENGINE", ""), ) +BING_SEARCH_V7_ENDPOINT = PersistentConfig( + "BING_SEARCH_V7_ENDPOINT", + "rag.web.search.bing_search_v7_endpoint", + os.environ.get( + "BING_SEARCH_V7_ENDPOINT", "https://api.bing.microsoft.com/v7.0/search" + ), +) + +BING_SEARCH_V7_SUBSCRIPTION_KEY = PersistentConfig( + "BING_SEARCH_V7_SUBSCRIPTION_KEY", + "rag.web.search.bing_search_v7_subscription_key", + os.environ.get("BING_SEARCH_V7_SUBSCRIPTION_KEY", ""), +) + + RAG_WEB_SEARCH_RESULT_COUNT = PersistentConfig( "RAG_WEB_SEARCH_RESULT_COUNT", "rag.web.search.result_count", @@ -1291,7 +1574,7 @@ AUTOMATIC1111_CFG_SCALE = PersistentConfig( AUTOMATIC1111_SAMPLER = PersistentConfig( - "AUTOMATIC1111_SAMPLERE", + "AUTOMATIC1111_SAMPLER", "image_generation.automatic1111.sampler", ( os.environ.get("AUTOMATIC1111_SAMPLER") @@ -1477,7 +1760,7 @@ WHISPER_MODEL = PersistentConfig( WHISPER_MODEL_DIR = os.getenv("WHISPER_MODEL_DIR", f"{CACHE_DIR}/whisper/models") WHISPER_MODEL_AUTO_UPDATE = ( - os.environ.get("WHISPER_MODEL_AUTO_UPDATE", "").lower() == "true" + not OFFLINE_MODE and os.environ.get("WHISPER_MODEL_AUTO_UPDATE", "").lower() == "true" ) @@ -1560,3 +1843,74 @@ AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT = PersistentConfig( "AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT", "audio-24khz-160kbitrate-mono-mp3" ), ) + + +#################################### +# LDAP +#################################### + +ENABLE_LDAP = PersistentConfig( + "ENABLE_LDAP", + "ldap.enable", + os.environ.get("ENABLE_LDAP", "false").lower() == "true", +) + +LDAP_SERVER_LABEL = PersistentConfig( + "LDAP_SERVER_LABEL", + "ldap.server.label", + os.environ.get("LDAP_SERVER_LABEL", "LDAP Server"), +) + +LDAP_SERVER_HOST = PersistentConfig( + "LDAP_SERVER_HOST", + "ldap.server.host", + os.environ.get("LDAP_SERVER_HOST", "localhost"), +) + +LDAP_SERVER_PORT = PersistentConfig( + "LDAP_SERVER_PORT", + "ldap.server.port", + int(os.environ.get("LDAP_SERVER_PORT", "389")), +) + +LDAP_ATTRIBUTE_FOR_USERNAME = PersistentConfig( + "LDAP_ATTRIBUTE_FOR_USERNAME", + "ldap.server.attribute_for_username", + os.environ.get("LDAP_ATTRIBUTE_FOR_USERNAME", "uid"), +) + +LDAP_APP_DN = PersistentConfig( + "LDAP_APP_DN", "ldap.server.app_dn", os.environ.get("LDAP_APP_DN", "") +) + +LDAP_APP_PASSWORD = PersistentConfig( + "LDAP_APP_PASSWORD", + "ldap.server.app_password", + os.environ.get("LDAP_APP_PASSWORD", ""), +) + +LDAP_SEARCH_BASE = PersistentConfig( + "LDAP_SEARCH_BASE", "ldap.server.users_dn", os.environ.get("LDAP_SEARCH_BASE", "") +) + +LDAP_SEARCH_FILTERS = PersistentConfig( + "LDAP_SEARCH_FILTER", + "ldap.server.search_filter", + os.environ.get("LDAP_SEARCH_FILTER", ""), +) + +LDAP_USE_TLS = PersistentConfig( + "LDAP_USE_TLS", + "ldap.server.use_tls", + os.environ.get("LDAP_USE_TLS", "True").lower() == "true", +) + +LDAP_CA_CERT_FILE = PersistentConfig( + "LDAP_CA_CERT_FILE", + "ldap.server.ca_cert_file", + os.environ.get("LDAP_CA_CERT_FILE", ""), +) + +LDAP_CIPHERS = PersistentConfig( + "LDAP_CIPHERS", "ldap.server.ciphers", os.environ.get("LDAP_CIPHERS", "ALL") +) diff --git a/backend/open_webui/constants.py b/backend/open_webui/constants.py index d6f33af4a..c5fdfabfb 100644 --- a/backend/open_webui/constants.py +++ b/backend/open_webui/constants.py @@ -62,6 +62,7 @@ class ERROR_MESSAGES(str, Enum): NOT_FOUND = "We could not find what you're looking for :/" USER_NOT_FOUND = "We could not find what you're looking for :/" API_KEY_NOT_FOUND = "Oops! It looks like there's a hiccup. The API key is missing. Please make sure to provide a valid API key to access this feature." + API_KEY_NOT_ALLOWED = "Use of API key is not enabled in the environment." MALICIOUS = "Unusual activities detected, please try again in a few minutes." @@ -75,6 +76,7 @@ class ERROR_MESSAGES(str, Enum): OPENAI_NOT_FOUND = lambda name="": "OpenAI API was not found" OLLAMA_NOT_FOUND = "WebUI could not connect to Ollama" CREATE_API_KEY_ERROR = "Oops! Something went wrong while creating your API key. Please try again later. If the issue persists, contact support for assistance." + API_KEY_CREATION_NOT_ALLOWED = "API key creation is not allowed in the environment." EMPTY_CONTENT = "The content provided is empty. Please ensure that there is text or data present before proceeding." @@ -111,5 +113,6 @@ class TASKS(str, Enum): TAGS_GENERATION = "tags_generation" EMOJI_GENERATION = "emoji_generation" QUERY_GENERATION = "query_generation" + AUTOCOMPLETE_GENERATION = "autocomplete_generation" FUNCTION_CALLING = "function_calling" MOA_RESPONSE_GENERATION = "moa_response_generation" diff --git a/backend/open_webui/env.py b/backend/open_webui/env.py index 4b61e1a89..0fd6080de 100644 --- a/backend/open_webui/env.py +++ b/backend/open_webui/env.py @@ -195,6 +195,15 @@ CHANGELOG = changelog_json SAFE_MODE = os.environ.get("SAFE_MODE", "false").lower() == "true" +#################################### +# ENABLE_FORWARD_USER_INFO_HEADERS +#################################### + +ENABLE_FORWARD_USER_INFO_HEADERS = ( + os.environ.get("ENABLE_FORWARD_USER_INFO_HEADERS", "False").lower() == "true" +) + + #################################### # WEBUI_BUILD_HASH #################################### @@ -320,6 +329,9 @@ WEBUI_AUTH_TRUSTED_EMAIL_HEADER = os.environ.get( ) WEBUI_AUTH_TRUSTED_NAME_HEADER = os.environ.get("WEBUI_AUTH_TRUSTED_NAME_HEADER", None) +BYPASS_MODEL_ACCESS_CONTROL = ( + os.environ.get("BYPASS_MODEL_ACCESS_CONTROL", "False").lower() == "true" +) #################################### # WEBUI_SECRET_KEY @@ -364,7 +376,7 @@ else: AIOHTTP_CLIENT_TIMEOUT = 300 AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST = os.environ.get( - "AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST", "3" + "AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST", "" ) if AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST == "": @@ -375,7 +387,7 @@ else: AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST ) except Exception: - AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST = 3 + AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST = 5 #################################### # OFFLINE_MODE diff --git a/backend/open_webui/functions.py b/backend/open_webui/functions.py new file mode 100644 index 000000000..16536a612 --- /dev/null +++ b/backend/open_webui/functions.py @@ -0,0 +1,316 @@ +import logging +import sys +import inspect +import json + +from pydantic import BaseModel +from typing import AsyncGenerator, Generator, Iterator +from fastapi import ( + Depends, + FastAPI, + File, + Form, + HTTPException, + Request, + UploadFile, + status, +) +from starlette.responses import Response, StreamingResponse + + +from open_webui.socket.main import ( + get_event_call, + get_event_emitter, +) + + +from open_webui.models.functions import Functions +from open_webui.models.models import Models + +from open_webui.utils.plugin import load_function_module_by_id +from open_webui.utils.tools import get_tools +from open_webui.utils.access_control import has_access + +from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL + +from open_webui.utils.misc import ( + add_or_update_system_message, + get_last_user_message, + prepend_to_first_user_message_content, + openai_chat_chunk_message_template, + openai_chat_completion_message_template, +) +from open_webui.utils.payload import ( + apply_model_params_to_body_openai, + apply_model_system_prompt_to_body, +) + + +logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL) +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["MAIN"]) + + +def get_function_module_by_id(request: Request, pipe_id: str): + # Check if function is already loaded + if pipe_id not in request.app.state.FUNCTIONS: + function_module, _, _ = load_function_module_by_id(pipe_id) + request.app.state.FUNCTIONS[pipe_id] = function_module + else: + function_module = request.app.state.FUNCTIONS[pipe_id] + + if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): + valves = Functions.get_function_valves_by_id(pipe_id) + function_module.valves = function_module.Valves(**(valves if valves else {})) + return function_module + + +async def get_function_models(request): + pipes = Functions.get_functions_by_type("pipe", active_only=True) + pipe_models = [] + + for pipe in pipes: + function_module = get_function_module_by_id(request, pipe.id) + + # Check if function is a manifold + if hasattr(function_module, "pipes"): + sub_pipes = [] + + # Check if pipes is a function or a list + + try: + if callable(function_module.pipes): + sub_pipes = function_module.pipes() + else: + sub_pipes = function_module.pipes + except Exception as e: + log.exception(e) + sub_pipes = [] + + log.debug( + f"get_function_models: function '{pipe.id}' is a manifold of {sub_pipes}" + ) + + for p in sub_pipes: + sub_pipe_id = f'{pipe.id}.{p["id"]}' + sub_pipe_name = p["name"] + + if hasattr(function_module, "name"): + sub_pipe_name = f"{function_module.name}{sub_pipe_name}" + + pipe_flag = {"type": pipe.type} + + pipe_models.append( + { + "id": sub_pipe_id, + "name": sub_pipe_name, + "object": "model", + "created": pipe.created_at, + "owned_by": "openai", + "pipe": pipe_flag, + } + ) + else: + pipe_flag = {"type": "pipe"} + + log.debug( + f"get_function_models: function '{pipe.id}' is a single pipe {{ 'id': {pipe.id}, 'name': {pipe.name} }}" + ) + + pipe_models.append( + { + "id": pipe.id, + "name": pipe.name, + "object": "model", + "created": pipe.created_at, + "owned_by": "openai", + "pipe": pipe_flag, + } + ) + + return pipe_models + + +async def generate_function_chat_completion( + request, form_data, user, models: dict = {} +): + async def execute_pipe(pipe, params): + if inspect.iscoroutinefunction(pipe): + return await pipe(**params) + else: + return pipe(**params) + + async def get_message_content(res: str | Generator | AsyncGenerator) -> str: + if isinstance(res, str): + return res + if isinstance(res, Generator): + return "".join(map(str, res)) + if isinstance(res, AsyncGenerator): + return "".join([str(stream) async for stream in res]) + + def process_line(form_data: dict, line): + if isinstance(line, BaseModel): + line = line.model_dump_json() + line = f"data: {line}" + if isinstance(line, dict): + line = f"data: {json.dumps(line)}" + + try: + line = line.decode("utf-8") + except Exception: + pass + + if line.startswith("data:"): + return f"{line}\n\n" + else: + line = openai_chat_chunk_message_template(form_data["model"], line) + return f"data: {json.dumps(line)}\n\n" + + def get_pipe_id(form_data: dict) -> str: + pipe_id = form_data["model"] + if "." in pipe_id: + pipe_id, _ = pipe_id.split(".", 1) + return pipe_id + + def get_function_params(function_module, form_data, user, extra_params=None): + if extra_params is None: + extra_params = {} + + pipe_id = get_pipe_id(form_data) + + # Get the signature of the function + sig = inspect.signature(function_module.pipe) + params = {"body": form_data} | { + k: v for k, v in extra_params.items() if k in sig.parameters + } + + if "__user__" in params and hasattr(function_module, "UserValves"): + user_valves = Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id) + try: + params["__user__"]["valves"] = function_module.UserValves(**user_valves) + except Exception as e: + log.exception(e) + params["__user__"]["valves"] = function_module.UserValves() + + return params + + model_id = form_data.get("model") + model_info = Models.get_model_by_id(model_id) + + metadata = form_data.pop("metadata", {}) + + files = metadata.get("files", []) + tool_ids = metadata.get("tool_ids", []) + # Check if tool_ids is None + if tool_ids is None: + tool_ids = [] + + __event_emitter__ = None + __event_call__ = None + __task__ = None + __task_body__ = None + + if metadata: + if all(k in metadata for k in ("session_id", "chat_id", "message_id")): + __event_emitter__ = get_event_emitter(metadata) + __event_call__ = get_event_call(metadata) + __task__ = metadata.get("task", None) + __task_body__ = metadata.get("task_body", None) + + extra_params = { + "__event_emitter__": __event_emitter__, + "__event_call__": __event_call__, + "__task__": __task__, + "__task_body__": __task_body__, + "__files__": files, + "__user__": { + "id": user.id, + "email": user.email, + "name": user.name, + "role": user.role, + }, + "__metadata__": metadata, + "__request__": request, + } + extra_params["__tools__"] = get_tools( + request, + tool_ids, + user, + { + **extra_params, + "__model__": models.get(form_data["model"], None), + "__messages__": form_data["messages"], + "__files__": files, + }, + ) + + if model_info: + if model_info.base_model_id: + form_data["model"] = model_info.base_model_id + + params = model_info.params.model_dump() + form_data = apply_model_params_to_body_openai(params, form_data) + form_data = apply_model_system_prompt_to_body(params, form_data, user) + + pipe_id = get_pipe_id(form_data) + function_module = get_function_module_by_id(request, pipe_id) + + pipe = function_module.pipe + params = get_function_params(function_module, form_data, user, extra_params) + + if form_data.get("stream", False): + + async def stream_content(): + try: + res = await execute_pipe(pipe, params) + + # Directly return if the response is a StreamingResponse + if isinstance(res, StreamingResponse): + async for data in res.body_iterator: + yield data + return + if isinstance(res, dict): + yield f"data: {json.dumps(res)}\n\n" + return + + except Exception as e: + log.error(f"Error: {e}") + yield f"data: {json.dumps({'error': {'detail':str(e)}})}\n\n" + return + + if isinstance(res, str): + message = openai_chat_chunk_message_template(form_data["model"], res) + yield f"data: {json.dumps(message)}\n\n" + + if isinstance(res, Iterator): + for line in res: + yield process_line(form_data, line) + + if isinstance(res, AsyncGenerator): + async for line in res: + yield process_line(form_data, line) + + if isinstance(res, str) or isinstance(res, Generator): + finish_message = openai_chat_chunk_message_template( + form_data["model"], "" + ) + finish_message["choices"][0]["finish_reason"] = "stop" + yield f"data: {json.dumps(finish_message)}\n\n" + yield "data: [DONE]" + + return StreamingResponse(stream_content(), media_type="text/event-stream") + else: + try: + res = await execute_pipe(pipe, params) + + except Exception as e: + log.error(f"Error: {e}") + return {"error": {"detail": str(e)}} + + if isinstance(res, StreamingResponse) or isinstance(res, dict): + return res + if isinstance(res, BaseModel): + return res.model_dump() + + message = await get_message_content(res) + return openai_chat_completion_message_template(form_data["model"], message) diff --git a/backend/open_webui/apps/webui/internal/db.py b/backend/open_webui/internal/db.py similarity index 97% rename from backend/open_webui/apps/webui/internal/db.py rename to backend/open_webui/internal/db.py index bcf913e6f..ba078822e 100644 --- a/backend/open_webui/apps/webui/internal/db.py +++ b/backend/open_webui/internal/db.py @@ -3,7 +3,7 @@ import logging from contextlib import contextmanager from typing import Any, Optional -from open_webui.apps.webui.internal.wrappers import register_connection +from open_webui.internal.wrappers import register_connection from open_webui.env import ( OPEN_WEBUI_DIR, DATABASE_URL, diff --git a/backend/open_webui/apps/webui/internal/migrations/001_initial_schema.py b/backend/open_webui/internal/migrations/001_initial_schema.py similarity index 100% rename from backend/open_webui/apps/webui/internal/migrations/001_initial_schema.py rename to backend/open_webui/internal/migrations/001_initial_schema.py diff --git a/backend/open_webui/apps/webui/internal/migrations/002_add_local_sharing.py b/backend/open_webui/internal/migrations/002_add_local_sharing.py similarity index 100% rename from backend/open_webui/apps/webui/internal/migrations/002_add_local_sharing.py rename to backend/open_webui/internal/migrations/002_add_local_sharing.py diff --git a/backend/open_webui/apps/webui/internal/migrations/003_add_auth_api_key.py b/backend/open_webui/internal/migrations/003_add_auth_api_key.py similarity index 100% rename from backend/open_webui/apps/webui/internal/migrations/003_add_auth_api_key.py rename to backend/open_webui/internal/migrations/003_add_auth_api_key.py diff --git a/backend/open_webui/apps/webui/internal/migrations/004_add_archived.py b/backend/open_webui/internal/migrations/004_add_archived.py similarity index 100% rename from backend/open_webui/apps/webui/internal/migrations/004_add_archived.py rename to backend/open_webui/internal/migrations/004_add_archived.py diff --git a/backend/open_webui/apps/webui/internal/migrations/005_add_updated_at.py b/backend/open_webui/internal/migrations/005_add_updated_at.py similarity index 100% rename from backend/open_webui/apps/webui/internal/migrations/005_add_updated_at.py rename to backend/open_webui/internal/migrations/005_add_updated_at.py diff --git a/backend/open_webui/apps/webui/internal/migrations/006_migrate_timestamps_and_charfields.py b/backend/open_webui/internal/migrations/006_migrate_timestamps_and_charfields.py similarity index 100% rename from backend/open_webui/apps/webui/internal/migrations/006_migrate_timestamps_and_charfields.py rename to backend/open_webui/internal/migrations/006_migrate_timestamps_and_charfields.py diff --git a/backend/open_webui/apps/webui/internal/migrations/007_add_user_last_active_at.py b/backend/open_webui/internal/migrations/007_add_user_last_active_at.py similarity index 100% rename from backend/open_webui/apps/webui/internal/migrations/007_add_user_last_active_at.py rename to backend/open_webui/internal/migrations/007_add_user_last_active_at.py diff --git a/backend/open_webui/apps/webui/internal/migrations/008_add_memory.py b/backend/open_webui/internal/migrations/008_add_memory.py similarity index 100% rename from backend/open_webui/apps/webui/internal/migrations/008_add_memory.py rename to backend/open_webui/internal/migrations/008_add_memory.py diff --git a/backend/open_webui/apps/webui/internal/migrations/009_add_models.py b/backend/open_webui/internal/migrations/009_add_models.py similarity index 100% rename from backend/open_webui/apps/webui/internal/migrations/009_add_models.py rename to backend/open_webui/internal/migrations/009_add_models.py diff --git a/backend/open_webui/apps/webui/internal/migrations/010_migrate_modelfiles_to_models.py b/backend/open_webui/internal/migrations/010_migrate_modelfiles_to_models.py similarity index 100% rename from backend/open_webui/apps/webui/internal/migrations/010_migrate_modelfiles_to_models.py rename to backend/open_webui/internal/migrations/010_migrate_modelfiles_to_models.py diff --git a/backend/open_webui/apps/webui/internal/migrations/011_add_user_settings.py b/backend/open_webui/internal/migrations/011_add_user_settings.py similarity index 100% rename from backend/open_webui/apps/webui/internal/migrations/011_add_user_settings.py rename to backend/open_webui/internal/migrations/011_add_user_settings.py diff --git a/backend/open_webui/apps/webui/internal/migrations/012_add_tools.py b/backend/open_webui/internal/migrations/012_add_tools.py similarity index 100% rename from backend/open_webui/apps/webui/internal/migrations/012_add_tools.py rename to backend/open_webui/internal/migrations/012_add_tools.py diff --git a/backend/open_webui/apps/webui/internal/migrations/013_add_user_info.py b/backend/open_webui/internal/migrations/013_add_user_info.py similarity index 100% rename from backend/open_webui/apps/webui/internal/migrations/013_add_user_info.py rename to backend/open_webui/internal/migrations/013_add_user_info.py diff --git a/backend/open_webui/apps/webui/internal/migrations/014_add_files.py b/backend/open_webui/internal/migrations/014_add_files.py similarity index 100% rename from backend/open_webui/apps/webui/internal/migrations/014_add_files.py rename to backend/open_webui/internal/migrations/014_add_files.py diff --git a/backend/open_webui/apps/webui/internal/migrations/015_add_functions.py b/backend/open_webui/internal/migrations/015_add_functions.py similarity index 100% rename from backend/open_webui/apps/webui/internal/migrations/015_add_functions.py rename to backend/open_webui/internal/migrations/015_add_functions.py diff --git a/backend/open_webui/apps/webui/internal/migrations/016_add_valves_and_is_active.py b/backend/open_webui/internal/migrations/016_add_valves_and_is_active.py similarity index 100% rename from backend/open_webui/apps/webui/internal/migrations/016_add_valves_and_is_active.py rename to backend/open_webui/internal/migrations/016_add_valves_and_is_active.py diff --git a/backend/open_webui/apps/webui/internal/migrations/017_add_user_oauth_sub.py b/backend/open_webui/internal/migrations/017_add_user_oauth_sub.py similarity index 100% rename from backend/open_webui/apps/webui/internal/migrations/017_add_user_oauth_sub.py rename to backend/open_webui/internal/migrations/017_add_user_oauth_sub.py diff --git a/backend/open_webui/apps/webui/internal/migrations/018_add_function_is_global.py b/backend/open_webui/internal/migrations/018_add_function_is_global.py similarity index 100% rename from backend/open_webui/apps/webui/internal/migrations/018_add_function_is_global.py rename to backend/open_webui/internal/migrations/018_add_function_is_global.py diff --git a/backend/open_webui/apps/webui/internal/wrappers.py b/backend/open_webui/internal/wrappers.py similarity index 100% rename from backend/open_webui/apps/webui/internal/wrappers.py rename to backend/open_webui/internal/wrappers.py diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 0161f76ed..80a4a5504 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -8,9 +8,14 @@ import shutil import sys import time import random -from contextlib import asynccontextmanager -from typing import Optional +from contextlib import asynccontextmanager +from urllib.parse import urlencode, parse_qs, urlparse +from pydantic import BaseModel +from sqlalchemy import text + +from typing import Optional +from aiocache import cached import aiohttp import requests from fastapi import ( @@ -26,116 +31,262 @@ from fastapi import ( from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, RedirectResponse from fastapi.staticfiles import StaticFiles -from pydantic import BaseModel -from sqlalchemy import text + from starlette.exceptions import HTTPException as StarletteHTTPException from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.sessions import SessionMiddleware from starlette.responses import Response, StreamingResponse -from open_webui.apps.audio.main import app as audio_app -from open_webui.apps.images.main import app as images_app -from open_webui.apps.ollama.main import ( - app as ollama_app, - get_all_models as get_ollama_models, - generate_chat_completion as generate_ollama_chat_completion, - GenerateChatCompletionForm, -) -from open_webui.apps.openai.main import ( - app as openai_app, - generate_chat_completion as generate_openai_chat_completion, - get_all_models as get_openai_models, -) -from open_webui.apps.retrieval.main import app as retrieval_app -from open_webui.apps.retrieval.utils import get_rag_context, rag_template -from open_webui.apps.socket.main import ( + +from open_webui.socket.main import ( app as socket_app, periodic_usage_pool_cleanup, - get_event_call, - get_event_emitter, ) -from open_webui.apps.webui.internal.db import Session -from open_webui.apps.webui.main import ( - app as webui_app, - generate_function_chat_completion, - get_all_models as get_open_webui_models, +from open_webui.routers import ( + audio, + images, + ollama, + openai, + retrieval, + pipelines, + tasks, + auths, + chats, + folders, + configs, + groups, + files, + functions, + memories, + models, + knowledge, + prompts, + evaluations, + tools, + users, + utils, ) -from open_webui.apps.webui.models.functions import Functions -from open_webui.apps.webui.models.models import Models -from open_webui.apps.webui.models.users import UserModel, Users -from open_webui.apps.webui.utils import load_function_module_by_id + +from open_webui.routers.retrieval import ( + get_embedding_function, + get_ef, + get_rf, +) + +from open_webui.internal.db import Session + +from open_webui.models.functions import Functions +from open_webui.models.models import Models +from open_webui.models.users import UserModel, Users + from open_webui.config import ( + # Ollama + ENABLE_OLLAMA_API, + OLLAMA_BASE_URLS, + OLLAMA_API_CONFIGS, + # OpenAI + ENABLE_OPENAI_API, + OPENAI_API_BASE_URLS, + OPENAI_API_KEYS, + OPENAI_API_CONFIGS, + # Image + AUTOMATIC1111_API_AUTH, + AUTOMATIC1111_BASE_URL, + AUTOMATIC1111_CFG_SCALE, + AUTOMATIC1111_SAMPLER, + AUTOMATIC1111_SCHEDULER, + COMFYUI_BASE_URL, + COMFYUI_WORKFLOW, + COMFYUI_WORKFLOW_NODES, + ENABLE_IMAGE_GENERATION, + IMAGE_GENERATION_ENGINE, + IMAGE_GENERATION_MODEL, + IMAGE_SIZE, + IMAGE_STEPS, + IMAGES_OPENAI_API_BASE_URL, + IMAGES_OPENAI_API_KEY, + # Audio + AUDIO_STT_ENGINE, + AUDIO_STT_MODEL, + AUDIO_STT_OPENAI_API_BASE_URL, + AUDIO_STT_OPENAI_API_KEY, + AUDIO_TTS_API_KEY, + AUDIO_TTS_ENGINE, + AUDIO_TTS_MODEL, + AUDIO_TTS_OPENAI_API_BASE_URL, + AUDIO_TTS_OPENAI_API_KEY, + AUDIO_TTS_SPLIT_ON, + AUDIO_TTS_VOICE, + AUDIO_TTS_AZURE_SPEECH_REGION, + AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT, + WHISPER_MODEL, + WHISPER_MODEL_AUTO_UPDATE, + WHISPER_MODEL_DIR, + # Retrieval + RAG_TEMPLATE, + DEFAULT_RAG_TEMPLATE, + RAG_EMBEDDING_MODEL, + RAG_EMBEDDING_MODEL_AUTO_UPDATE, + RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE, + RAG_RERANKING_MODEL, + RAG_RERANKING_MODEL_AUTO_UPDATE, + RAG_RERANKING_MODEL_TRUST_REMOTE_CODE, + RAG_EMBEDDING_ENGINE, + RAG_EMBEDDING_BATCH_SIZE, + RAG_RELEVANCE_THRESHOLD, + RAG_FILE_MAX_COUNT, + RAG_FILE_MAX_SIZE, + RAG_OPENAI_API_BASE_URL, + RAG_OPENAI_API_KEY, + RAG_OLLAMA_BASE_URL, + RAG_OLLAMA_API_KEY, + CHUNK_OVERLAP, + CHUNK_SIZE, + CONTENT_EXTRACTION_ENGINE, + TIKA_SERVER_URL, + RAG_TOP_K, + RAG_TEXT_SPLITTER, + TIKTOKEN_ENCODING_NAME, + PDF_EXTRACT_IMAGES, + YOUTUBE_LOADER_LANGUAGE, + YOUTUBE_LOADER_PROXY_URL, + # Retrieval (Web Search) + RAG_WEB_SEARCH_ENGINE, + RAG_WEB_SEARCH_RESULT_COUNT, + RAG_WEB_SEARCH_CONCURRENT_REQUESTS, + RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, + JINA_API_KEY, + SEARCHAPI_API_KEY, + SEARCHAPI_ENGINE, + SEARXNG_QUERY_URL, + SERPER_API_KEY, + SERPLY_API_KEY, + SERPSTACK_API_KEY, + SERPSTACK_HTTPS, + TAVILY_API_KEY, + BING_SEARCH_V7_ENDPOINT, + BING_SEARCH_V7_SUBSCRIPTION_KEY, + BRAVE_SEARCH_API_KEY, + KAGI_SEARCH_API_KEY, + MOJEEK_SEARCH_API_KEY, + GOOGLE_PSE_API_KEY, + GOOGLE_PSE_ENGINE_ID, + ENABLE_RAG_HYBRID_SEARCH, + ENABLE_RAG_LOCAL_WEB_FETCH, + ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, + ENABLE_RAG_WEB_SEARCH, + UPLOAD_DIR, + # WebUI + WEBUI_AUTH, + WEBUI_NAME, + WEBUI_BANNERS, + WEBHOOK_URL, + ADMIN_EMAIL, + SHOW_ADMIN_DETAILS, + JWT_EXPIRES_IN, + ENABLE_SIGNUP, + ENABLE_LOGIN_FORM, + ENABLE_API_KEY, + ENABLE_COMMUNITY_SHARING, + ENABLE_MESSAGE_RATING, + ENABLE_EVALUATION_ARENA_MODELS, + USER_PERMISSIONS, + DEFAULT_USER_ROLE, + DEFAULT_PROMPT_SUGGESTIONS, + DEFAULT_MODELS, + DEFAULT_ARENA_MODEL, + MODEL_ORDER_LIST, + EVALUATION_ARENA_MODELS, + # WebUI (OAuth) + ENABLE_OAUTH_ROLE_MANAGEMENT, + OAUTH_ROLES_CLAIM, + OAUTH_EMAIL_CLAIM, + OAUTH_PICTURE_CLAIM, + OAUTH_USERNAME_CLAIM, + OAUTH_ALLOWED_ROLES, + OAUTH_ADMIN_ROLES, + # WebUI (LDAP) + ENABLE_LDAP, + LDAP_SERVER_LABEL, + LDAP_SERVER_HOST, + LDAP_SERVER_PORT, + LDAP_ATTRIBUTE_FOR_USERNAME, + LDAP_SEARCH_FILTERS, + LDAP_SEARCH_BASE, + LDAP_APP_DN, + LDAP_APP_PASSWORD, + LDAP_USE_TLS, + LDAP_CA_CERT_FILE, + LDAP_CIPHERS, + # Misc + ENV, CACHE_DIR, + STATIC_DIR, + FRONTEND_BUILD_DIR, CORS_ALLOW_ORIGIN, DEFAULT_LOCALE, + OAUTH_PROVIDERS, + GOOGLE_DRIVE_CLIENT_ID, + GOOGLE_DRIVE_API_KEY, + # Admin ENABLE_ADMIN_CHAT_ACCESS, ENABLE_ADMIN_EXPORT, - ENABLE_MODEL_FILTER, - ENABLE_OLLAMA_API, - ENABLE_OPENAI_API, - ENV, - FRONTEND_BUILD_DIR, - MODEL_FILTER_LIST, - OAUTH_PROVIDERS, - ENABLE_SEARCH_QUERY, - SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE, - STATIC_DIR, + # Tasks TASK_MODEL, TASK_MODEL_EXTERNAL, + ENABLE_TAGS_GENERATION, + ENABLE_SEARCH_QUERY_GENERATION, + ENABLE_RETRIEVAL_QUERY_GENERATION, + ENABLE_AUTOCOMPLETE_GENERATION, TITLE_GENERATION_PROMPT_TEMPLATE, TAGS_GENERATION_PROMPT_TEMPLATE, TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, - WEBHOOK_URL, - WEBUI_AUTH, - WEBUI_NAME, - GOOGLE_DRIVE_CLIENT_ID, - GOOGLE_DRIVE_API_KEY, + QUERY_GENERATION_PROMPT_TEMPLATE, + AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE, + AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH, AppConfig, reset_config, ) -from open_webui.constants import TASKS from open_webui.env import ( CHANGELOG, GLOBAL_LOG_LEVEL, SAFE_MODE, SRC_LOG_LEVELS, VERSION, + WEBUI_URL, WEBUI_BUILD_HASH, WEBUI_SECRET_KEY, WEBUI_SESSION_COOKIE_SAME_SITE, WEBUI_SESSION_COOKIE_SECURE, - WEBUI_URL, + WEBUI_AUTH_TRUSTED_EMAIL_HEADER, + WEBUI_AUTH_TRUSTED_NAME_HEADER, + BYPASS_MODEL_ACCESS_CONTROL, RESET_CONFIG_ON_START, OFFLINE_MODE, ) -from open_webui.utils.misc import ( - add_or_update_system_message, - get_last_user_message, - prepend_to_first_user_message_content, + + +from open_webui.utils.models import ( + get_all_models, + get_all_base_models, + check_model_access, ) -from open_webui.utils.oauth import oauth_manager -from open_webui.utils.payload import convert_payload_openai_to_ollama -from open_webui.utils.response import ( - convert_response_ollama_to_openai, - convert_streaming_response_ollama_to_openai, +from open_webui.utils.chat import ( + generate_chat_completion as chat_completion_handler, + chat_completed as chat_completed_handler, + chat_action as chat_action_handler, ) -from open_webui.utils.security_headers import SecurityHeadersMiddleware -from open_webui.utils.task import ( - moa_response_generation_template, - tags_generation_template, - search_query_generation_template, - emoji_generation_template, - title_generation_template, - tools_function_calling_generation_template, -) -from open_webui.utils.tools import get_tools -from open_webui.utils.utils import ( +from open_webui.utils.middleware import process_chat_payload, process_chat_response +from open_webui.utils.access_control import has_access + +from open_webui.utils.auth import ( decode_token, get_admin_user, - get_current_user, - get_http_authorization_cred, get_verified_user, ) +from open_webui.utils.oauth import oauth_manager +from open_webui.utils.security_headers import SecurityHeadersMiddleware + if SAFE_MODE: print("SAFE MODE ENABLED") @@ -184,642 +335,308 @@ async def lifespan(app: FastAPI): app = FastAPI( - docs_url="/docs" if ENV == "dev" else None, openapi_url="/openapi.json" if ENV == "dev" else None, redoc_url=None, lifespan=lifespan + docs_url="/docs" if ENV == "dev" else None, + openapi_url="/openapi.json" if ENV == "dev" else None, + redoc_url=None, + lifespan=lifespan, ) app.state.config = AppConfig() -app.state.config.ENABLE_OPENAI_API = ENABLE_OPENAI_API + +######################################## +# +# OLLAMA +# +######################################## + + app.state.config.ENABLE_OLLAMA_API = ENABLE_OLLAMA_API +app.state.config.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS +app.state.config.OLLAMA_API_CONFIGS = OLLAMA_API_CONFIGS -app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER -app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST +app.state.OLLAMA_MODELS = {} +######################################## +# +# OPENAI +# +######################################## + +app.state.config.ENABLE_OPENAI_API = ENABLE_OPENAI_API +app.state.config.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS +app.state.config.OPENAI_API_KEYS = OPENAI_API_KEYS +app.state.config.OPENAI_API_CONFIGS = OPENAI_API_CONFIGS + +app.state.OPENAI_MODELS = {} + +######################################## +# +# WEBUI +# +######################################## + +app.state.config.ENABLE_SIGNUP = ENABLE_SIGNUP +app.state.config.ENABLE_LOGIN_FORM = ENABLE_LOGIN_FORM +app.state.config.ENABLE_API_KEY = ENABLE_API_KEY + +app.state.config.JWT_EXPIRES_IN = JWT_EXPIRES_IN + +app.state.config.SHOW_ADMIN_DETAILS = SHOW_ADMIN_DETAILS +app.state.config.ADMIN_EMAIL = ADMIN_EMAIL + + +app.state.config.DEFAULT_MODELS = DEFAULT_MODELS +app.state.config.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS +app.state.config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE + +app.state.config.USER_PERMISSIONS = USER_PERMISSIONS app.state.config.WEBHOOK_URL = WEBHOOK_URL +app.state.config.BANNERS = WEBUI_BANNERS +app.state.config.MODEL_ORDER_LIST = MODEL_ORDER_LIST + +app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING +app.state.config.ENABLE_MESSAGE_RATING = ENABLE_MESSAGE_RATING + +app.state.config.ENABLE_EVALUATION_ARENA_MODELS = ENABLE_EVALUATION_ARENA_MODELS +app.state.config.EVALUATION_ARENA_MODELS = EVALUATION_ARENA_MODELS + +app.state.config.OAUTH_USERNAME_CLAIM = OAUTH_USERNAME_CLAIM +app.state.config.OAUTH_PICTURE_CLAIM = OAUTH_PICTURE_CLAIM +app.state.config.OAUTH_EMAIL_CLAIM = OAUTH_EMAIL_CLAIM + +app.state.config.ENABLE_OAUTH_ROLE_MANAGEMENT = ENABLE_OAUTH_ROLE_MANAGEMENT +app.state.config.OAUTH_ROLES_CLAIM = OAUTH_ROLES_CLAIM +app.state.config.OAUTH_ALLOWED_ROLES = OAUTH_ALLOWED_ROLES +app.state.config.OAUTH_ADMIN_ROLES = OAUTH_ADMIN_ROLES + +app.state.config.ENABLE_LDAP = ENABLE_LDAP +app.state.config.LDAP_SERVER_LABEL = LDAP_SERVER_LABEL +app.state.config.LDAP_SERVER_HOST = LDAP_SERVER_HOST +app.state.config.LDAP_SERVER_PORT = LDAP_SERVER_PORT +app.state.config.LDAP_ATTRIBUTE_FOR_USERNAME = LDAP_ATTRIBUTE_FOR_USERNAME +app.state.config.LDAP_APP_DN = LDAP_APP_DN +app.state.config.LDAP_APP_PASSWORD = LDAP_APP_PASSWORD +app.state.config.LDAP_SEARCH_BASE = LDAP_SEARCH_BASE +app.state.config.LDAP_SEARCH_FILTERS = LDAP_SEARCH_FILTERS +app.state.config.LDAP_USE_TLS = LDAP_USE_TLS +app.state.config.LDAP_CA_CERT_FILE = LDAP_CA_CERT_FILE +app.state.config.LDAP_CIPHERS = LDAP_CIPHERS + + +app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER +app.state.AUTH_TRUSTED_NAME_HEADER = WEBUI_AUTH_TRUSTED_NAME_HEADER + +app.state.TOOLS = {} +app.state.FUNCTIONS = {} + + +######################################## +# +# RETRIEVAL +# +######################################## + + +app.state.config.TOP_K = RAG_TOP_K +app.state.config.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD +app.state.config.FILE_MAX_SIZE = RAG_FILE_MAX_SIZE +app.state.config.FILE_MAX_COUNT = RAG_FILE_MAX_COUNT + +app.state.config.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH +app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = ( + ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION +) + +app.state.config.CONTENT_EXTRACTION_ENGINE = CONTENT_EXTRACTION_ENGINE +app.state.config.TIKA_SERVER_URL = TIKA_SERVER_URL + +app.state.config.TEXT_SPLITTER = RAG_TEXT_SPLITTER +app.state.config.TIKTOKEN_ENCODING_NAME = TIKTOKEN_ENCODING_NAME + +app.state.config.CHUNK_SIZE = CHUNK_SIZE +app.state.config.CHUNK_OVERLAP = CHUNK_OVERLAP + +app.state.config.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE +app.state.config.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL +app.state.config.RAG_EMBEDDING_BATCH_SIZE = RAG_EMBEDDING_BATCH_SIZE +app.state.config.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL +app.state.config.RAG_TEMPLATE = RAG_TEMPLATE + +app.state.config.RAG_OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL +app.state.config.RAG_OPENAI_API_KEY = RAG_OPENAI_API_KEY + +app.state.config.RAG_OLLAMA_BASE_URL = RAG_OLLAMA_BASE_URL +app.state.config.RAG_OLLAMA_API_KEY = RAG_OLLAMA_API_KEY + +app.state.config.PDF_EXTRACT_IMAGES = PDF_EXTRACT_IMAGES + +app.state.config.YOUTUBE_LOADER_LANGUAGE = YOUTUBE_LOADER_LANGUAGE +app.state.config.YOUTUBE_LOADER_PROXY_URL = YOUTUBE_LOADER_PROXY_URL + + +app.state.config.ENABLE_RAG_WEB_SEARCH = ENABLE_RAG_WEB_SEARCH +app.state.config.RAG_WEB_SEARCH_ENGINE = RAG_WEB_SEARCH_ENGINE +app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = RAG_WEB_SEARCH_DOMAIN_FILTER_LIST + +app.state.config.SEARXNG_QUERY_URL = SEARXNG_QUERY_URL +app.state.config.GOOGLE_PSE_API_KEY = GOOGLE_PSE_API_KEY +app.state.config.GOOGLE_PSE_ENGINE_ID = GOOGLE_PSE_ENGINE_ID +app.state.config.BRAVE_SEARCH_API_KEY = BRAVE_SEARCH_API_KEY +app.state.config.KAGI_SEARCH_API_KEY = KAGI_SEARCH_API_KEY +app.state.config.MOJEEK_SEARCH_API_KEY = MOJEEK_SEARCH_API_KEY +app.state.config.SERPSTACK_API_KEY = SERPSTACK_API_KEY +app.state.config.SERPSTACK_HTTPS = SERPSTACK_HTTPS +app.state.config.SERPER_API_KEY = SERPER_API_KEY +app.state.config.SERPLY_API_KEY = SERPLY_API_KEY +app.state.config.TAVILY_API_KEY = TAVILY_API_KEY +app.state.config.SEARCHAPI_API_KEY = SEARCHAPI_API_KEY +app.state.config.SEARCHAPI_ENGINE = SEARCHAPI_ENGINE +app.state.config.JINA_API_KEY = JINA_API_KEY +app.state.config.BING_SEARCH_V7_ENDPOINT = BING_SEARCH_V7_ENDPOINT +app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY = BING_SEARCH_V7_SUBSCRIPTION_KEY + +app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = RAG_WEB_SEARCH_RESULT_COUNT +app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS + +app.state.EMBEDDING_FUNCTION = None +app.state.ef = None +app.state.rf = None + +app.state.YOUTUBE_LOADER_TRANSLATION = None + + +app.state.EMBEDDING_FUNCTION = get_embedding_function( + app.state.config.RAG_EMBEDDING_ENGINE, + app.state.config.RAG_EMBEDDING_MODEL, + app.state.ef, + ( + app.state.config.RAG_OPENAI_API_BASE_URL + if app.state.config.RAG_EMBEDDING_ENGINE == "openai" + else app.state.config.RAG_OLLAMA_BASE_URL + ), + ( + app.state.config.RAG_OPENAI_API_KEY + if app.state.config.RAG_EMBEDDING_ENGINE == "openai" + else app.state.config.RAG_OLLAMA_API_KEY + ), + app.state.config.RAG_EMBEDDING_BATCH_SIZE, +) + +try: + app.state.ef = get_ef( + app.state.config.RAG_EMBEDDING_ENGINE, + app.state.config.RAG_EMBEDDING_MODEL, + RAG_EMBEDDING_MODEL_AUTO_UPDATE, + ) + + app.state.rf = get_rf( + app.state.config.RAG_RERANKING_MODEL, + RAG_RERANKING_MODEL_AUTO_UPDATE, + ) +except Exception as e: + log.error(f"Error updating models: {e}") + pass + + +######################################## +# +# IMAGES +# +######################################## + +app.state.config.IMAGE_GENERATION_ENGINE = IMAGE_GENERATION_ENGINE +app.state.config.ENABLE_IMAGE_GENERATION = ENABLE_IMAGE_GENERATION + +app.state.config.IMAGES_OPENAI_API_BASE_URL = IMAGES_OPENAI_API_BASE_URL +app.state.config.IMAGES_OPENAI_API_KEY = IMAGES_OPENAI_API_KEY + +app.state.config.IMAGE_GENERATION_MODEL = IMAGE_GENERATION_MODEL + +app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL +app.state.config.AUTOMATIC1111_API_AUTH = AUTOMATIC1111_API_AUTH +app.state.config.AUTOMATIC1111_CFG_SCALE = AUTOMATIC1111_CFG_SCALE +app.state.config.AUTOMATIC1111_SAMPLER = AUTOMATIC1111_SAMPLER +app.state.config.AUTOMATIC1111_SCHEDULER = AUTOMATIC1111_SCHEDULER +app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL +app.state.config.COMFYUI_WORKFLOW = COMFYUI_WORKFLOW +app.state.config.COMFYUI_WORKFLOW_NODES = COMFYUI_WORKFLOW_NODES + +app.state.config.IMAGE_SIZE = IMAGE_SIZE +app.state.config.IMAGE_STEPS = IMAGE_STEPS + + +######################################## +# +# AUDIO +# +######################################## + +app.state.config.STT_OPENAI_API_BASE_URL = AUDIO_STT_OPENAI_API_BASE_URL +app.state.config.STT_OPENAI_API_KEY = AUDIO_STT_OPENAI_API_KEY +app.state.config.STT_ENGINE = AUDIO_STT_ENGINE +app.state.config.STT_MODEL = AUDIO_STT_MODEL + +app.state.config.WHISPER_MODEL = WHISPER_MODEL + +app.state.config.TTS_OPENAI_API_BASE_URL = AUDIO_TTS_OPENAI_API_BASE_URL +app.state.config.TTS_OPENAI_API_KEY = AUDIO_TTS_OPENAI_API_KEY +app.state.config.TTS_ENGINE = AUDIO_TTS_ENGINE +app.state.config.TTS_MODEL = AUDIO_TTS_MODEL +app.state.config.TTS_VOICE = AUDIO_TTS_VOICE +app.state.config.TTS_API_KEY = AUDIO_TTS_API_KEY +app.state.config.TTS_SPLIT_ON = AUDIO_TTS_SPLIT_ON + + +app.state.config.TTS_AZURE_SPEECH_REGION = AUDIO_TTS_AZURE_SPEECH_REGION +app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT = AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT + + +app.state.faster_whisper_model = None +app.state.speech_synthesiser = None +app.state.speech_speaker_embeddings_dataset = None + + +######################################## +# +# TASKS +# +######################################## + app.state.config.TASK_MODEL = TASK_MODEL app.state.config.TASK_MODEL_EXTERNAL = TASK_MODEL_EXTERNAL + + +app.state.config.ENABLE_SEARCH_QUERY_GENERATION = ENABLE_SEARCH_QUERY_GENERATION +app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION = ENABLE_RETRIEVAL_QUERY_GENERATION +app.state.config.ENABLE_AUTOCOMPLETE_GENERATION = ENABLE_AUTOCOMPLETE_GENERATION +app.state.config.ENABLE_TAGS_GENERATION = ENABLE_TAGS_GENERATION + + app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMPLATE app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = TAGS_GENERATION_PROMPT_TEMPLATE -app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = ( - SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE -) -app.state.config.ENABLE_SEARCH_QUERY = ENABLE_SEARCH_QUERY app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = ( TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE ) +app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE = QUERY_GENERATION_PROMPT_TEMPLATE +app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE = ( + AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE +) +app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH = ( + AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH +) + + +######################################## +# +# WEBUI +# +######################################## app.state.MODELS = {} -################################## -# -# ChatCompletion Middleware -# -################################## - - -def get_task_model_id(default_model_id): - # Set the task model - task_model_id = default_model_id - # Check if the user has a custom task model and use that model - if app.state.MODELS[task_model_id]["owned_by"] == "ollama": - if ( - app.state.config.TASK_MODEL - and app.state.config.TASK_MODEL in app.state.MODELS - ): - task_model_id = app.state.config.TASK_MODEL - else: - if ( - app.state.config.TASK_MODEL_EXTERNAL - and app.state.config.TASK_MODEL_EXTERNAL in app.state.MODELS - ): - task_model_id = app.state.config.TASK_MODEL_EXTERNAL - - return task_model_id - - -def get_filter_function_ids(model): - def get_priority(function_id): - function = Functions.get_function_by_id(function_id) - if function is not None and hasattr(function, "valves"): - # TODO: Fix FunctionModel - return (function.valves if function.valves else {}).get("priority", 0) - return 0 - - filter_ids = [function.id for function in Functions.get_global_filter_functions()] - if "info" in model and "meta" in model["info"]: - filter_ids.extend(model["info"]["meta"].get("filterIds", [])) - filter_ids = list(set(filter_ids)) - - enabled_filter_ids = [ - function.id - for function in Functions.get_functions_by_type("filter", active_only=True) - ] - - filter_ids = [ - filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids - ] - - filter_ids.sort(key=get_priority) - return filter_ids - - -async def chat_completion_filter_functions_handler(body, model, extra_params): - skip_files = None - - filter_ids = get_filter_function_ids(model) - for filter_id in filter_ids: - filter = Functions.get_function_by_id(filter_id) - if not filter: - continue - - if filter_id in webui_app.state.FUNCTIONS: - function_module = webui_app.state.FUNCTIONS[filter_id] - else: - function_module, _, _ = load_function_module_by_id(filter_id) - webui_app.state.FUNCTIONS[filter_id] = function_module - - # Check if the function has a file_handler variable - if hasattr(function_module, "file_handler"): - skip_files = function_module.file_handler - - if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): - valves = Functions.get_function_valves_by_id(filter_id) - function_module.valves = function_module.Valves( - **(valves if valves else {}) - ) - - if not hasattr(function_module, "inlet"): - continue - - try: - inlet = function_module.inlet - - # Get the signature of the function - sig = inspect.signature(inlet) - params = {"body": body} | { - k: v - for k, v in { - **extra_params, - "__model__": model, - "__id__": filter_id, - }.items() - if k in sig.parameters - } - - if "__user__" in params and hasattr(function_module, "UserValves"): - try: - params["__user__"]["valves"] = function_module.UserValves( - **Functions.get_user_valves_by_id_and_user_id( - filter_id, params["__user__"]["id"] - ) - ) - except Exception as e: - print(e) - - if inspect.iscoroutinefunction(inlet): - body = await inlet(**params) - else: - body = inlet(**params) - - except Exception as e: - print(f"Error: {e}") - raise e - - if skip_files and "files" in body.get("metadata", {}): - del body["metadata"]["files"] - - return body, {} - - -def get_tools_function_calling_payload(messages, task_model_id, content): - user_message = get_last_user_message(messages) - history = "\n".join( - f"{message['role'].upper()}: \"\"\"{message['content']}\"\"\"" - for message in messages[::-1][:4] - ) - - prompt = f"History:\n{history}\nQuery: {user_message}" - - return { - "model": task_model_id, - "messages": [ - {"role": "system", "content": content}, - {"role": "user", "content": f"Query: {prompt}"}, - ], - "stream": False, - "metadata": {"task": str(TASKS.FUNCTION_CALLING)}, - } - - -async def get_content_from_response(response) -> Optional[str]: - content = None - if hasattr(response, "body_iterator"): - async for chunk in response.body_iterator: - data = json.loads(chunk.decode("utf-8")) - content = data["choices"][0]["message"]["content"] - - # Cleanup any remaining background tasks if necessary - if response.background is not None: - await response.background() - else: - content = response["choices"][0]["message"]["content"] - return content - - -async def chat_completion_tools_handler( - body: dict, user: UserModel, extra_params: dict -) -> tuple[dict, dict]: - # If tool_ids field is present, call the functions - metadata = body.get("metadata", {}) - - tool_ids = metadata.get("tool_ids", None) - log.debug(f"{tool_ids=}") - if not tool_ids: - return body, {} - - skip_files = False - contexts = [] - citations = [] - - task_model_id = get_task_model_id(body["model"]) - tools = get_tools( - webui_app, - tool_ids, - user, - { - **extra_params, - "__model__": app.state.MODELS[task_model_id], - "__messages__": body["messages"], - "__files__": metadata.get("files", []), - }, - ) - log.info(f"{tools=}") - - specs = [tool["spec"] for tool in tools.values()] - tools_specs = json.dumps(specs) - - if app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE != "": - template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE - else: - template = """Available Tools: {{TOOLS}}\nReturn an empty string if no tools match the query. If a function tool matches, construct and return a JSON object in the format {\"name\": \"functionName\", \"parameters\": {\"requiredFunctionParamKey\": \"requiredFunctionParamValue\"}} using the appropriate tool and its parameters. Only return the object and limit the response to the JSON object without additional text.""" - - tools_function_calling_prompt = tools_function_calling_generation_template( - template, tools_specs - ) - log.info(f"{tools_function_calling_prompt=}") - payload = get_tools_function_calling_payload( - body["messages"], task_model_id, tools_function_calling_prompt - ) - - try: - payload = filter_pipeline(payload, user) - except Exception as e: - raise e - - try: - response = await generate_chat_completions(form_data=payload, user=user) - log.debug(f"{response=}") - content = await get_content_from_response(response) - log.debug(f"{content=}") - - if not content: - return body, {} - - try: - content = content[content.find("{") : content.rfind("}") + 1] - if not content: - raise Exception("No JSON object found in the response") - - result = json.loads(content) - - tool_function_name = result.get("name", None) - if tool_function_name not in tools: - return body, {} - - tool_function_params = result.get("parameters", {}) - - try: - required_params = ( - tools[tool_function_name] - .get("spec", {}) - .get("parameters", {}) - .get("required", []) - ) - tool_function = tools[tool_function_name]["callable"] - tool_function_params = { - k: v - for k, v in tool_function_params.items() - if k in required_params - } - tool_output = await tool_function(**tool_function_params) - - except Exception as e: - tool_output = str(e) - - if tools[tool_function_name]["citation"]: - citations.append( - { - "source": { - "name": f"TOOL:{tools[tool_function_name]['toolkit_id']}/{tool_function_name}" - }, - "document": [tool_output], - "metadata": [{"source": tool_function_name}], - } - ) - if tools[tool_function_name]["file_handler"]: - skip_files = True - - if isinstance(tool_output, str): - contexts.append(tool_output) - except Exception as e: - log.exception(f"Error: {e}") - content = None - except Exception as e: - log.exception(f"Error: {e}") - content = None - - log.debug(f"tool_contexts: {contexts}") - - if skip_files and "files" in body.get("metadata", {}): - del body["metadata"]["files"] - - return body, {"contexts": contexts, "citations": citations} - - -async def chat_completion_files_handler(body) -> tuple[dict, dict[str, list]]: - contexts = [] - citations = [] - - if files := body.get("metadata", {}).get("files", None): - contexts, citations = get_rag_context( - files=files, - messages=body["messages"], - embedding_function=retrieval_app.state.EMBEDDING_FUNCTION, - k=retrieval_app.state.config.TOP_K, - reranking_function=retrieval_app.state.sentence_transformer_rf, - r=retrieval_app.state.config.RELEVANCE_THRESHOLD, - hybrid_search=retrieval_app.state.config.ENABLE_RAG_HYBRID_SEARCH, - ) - - log.debug(f"rag_contexts: {contexts}, citations: {citations}") - - return body, {"contexts": contexts, "citations": citations} - - -def is_chat_completion_request(request): - return request.method == "POST" and any( - endpoint in request.url.path - for endpoint in ["/ollama/api/chat", "/chat/completions"] - ) - - -async def get_body_and_model_and_user(request): - # Read the original request body - body = await request.body() - body_str = body.decode("utf-8") - body = json.loads(body_str) if body_str else {} - - model_id = body["model"] - if model_id not in app.state.MODELS: - raise Exception("Model not found") - model = app.state.MODELS[model_id] - - user = get_current_user( - request, - get_http_authorization_cred(request.headers.get("Authorization")), - ) - - return body, model, user - - -class ChatCompletionMiddleware(BaseHTTPMiddleware): - async def dispatch(self, request: Request, call_next): - if not is_chat_completion_request(request): - return await call_next(request) - log.debug(f"request.url.path: {request.url.path}") - - try: - body, model, user = await get_body_and_model_and_user(request) - except Exception as e: - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content={"detail": str(e)}, - ) - - metadata = { - "chat_id": body.pop("chat_id", None), - "message_id": body.pop("id", None), - "session_id": body.pop("session_id", None), - "tool_ids": body.get("tool_ids", None), - "files": body.get("files", None), - } - body["metadata"] = metadata - - extra_params = { - "__event_emitter__": get_event_emitter(metadata), - "__event_call__": get_event_call(metadata), - "__user__": { - "id": user.id, - "email": user.email, - "name": user.name, - "role": user.role, - }, - } - - # Initialize data_items to store additional data to be sent to the client - # Initialize contexts and citation - data_items = [] - contexts = [] - citations = [] - - try: - body, flags = await chat_completion_filter_functions_handler( - body, model, extra_params - ) - except Exception as e: - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content={"detail": str(e)}, - ) - - metadata = { - **metadata, - "tool_ids": body.pop("tool_ids", None), - "files": body.pop("files", None), - } - body["metadata"] = metadata - - try: - body, flags = await chat_completion_tools_handler(body, user, extra_params) - contexts.extend(flags.get("contexts", [])) - citations.extend(flags.get("citations", [])) - except Exception as e: - log.exception(e) - - try: - body, flags = await chat_completion_files_handler(body) - contexts.extend(flags.get("contexts", [])) - citations.extend(flags.get("citations", [])) - except Exception as e: - log.exception(e) - - # If context is not empty, insert it into the messages - if len(contexts) > 0: - context_string = "/n".join(contexts).strip() - prompt = get_last_user_message(body["messages"]) - - if prompt is None: - raise Exception("No user message found") - if ( - retrieval_app.state.config.RELEVANCE_THRESHOLD == 0 - and context_string.strip() == "" - ): - log.debug( - f"With a 0 relevancy threshold for RAG, the context cannot be empty" - ) - - # Workaround for Ollama 2.0+ system prompt issue - # TODO: replace with add_or_update_system_message - if model["owned_by"] == "ollama": - body["messages"] = prepend_to_first_user_message_content( - rag_template( - retrieval_app.state.config.RAG_TEMPLATE, context_string, prompt - ), - body["messages"], - ) - else: - body["messages"] = add_or_update_system_message( - rag_template( - retrieval_app.state.config.RAG_TEMPLATE, context_string, prompt - ), - body["messages"], - ) - - # If there are citations, add them to the data_items - if len(citations) > 0: - data_items.append({"citations": citations}) - - modified_body_bytes = json.dumps(body).encode("utf-8") - # Replace the request body with the modified one - request._body = modified_body_bytes - # Set custom header to ensure content-length matches new body length - request.headers.__dict__["_list"] = [ - (b"content-length", str(len(modified_body_bytes)).encode("utf-8")), - *[(k, v) for k, v in request.headers.raw if k.lower() != b"content-length"], - ] - - response = await call_next(request) - if not isinstance(response, StreamingResponse): - return response - - content_type = response.headers["Content-Type"] - is_openai = "text/event-stream" in content_type - is_ollama = "application/x-ndjson" in content_type - if not is_openai and not is_ollama: - return response - - def wrap_item(item): - return f"data: {item}\n\n" if is_openai else f"{item}\n" - - async def stream_wrapper(original_generator, data_items): - for item in data_items: - yield wrap_item(json.dumps(item)) - - async for data in original_generator: - yield data - - return StreamingResponse( - stream_wrapper(response.body_iterator, data_items), - headers=dict(response.headers), - ) - - async def _receive(self, body: bytes): - return {"type": "http.request", "body": body, "more_body": False} - - -app.add_middleware(ChatCompletionMiddleware) - - -################################## -# -# Pipeline Middleware -# -################################## - - -def get_sorted_filters(model_id): - filters = [ - model - for model in app.state.MODELS.values() - if "pipeline" in model - and "type" in model["pipeline"] - and model["pipeline"]["type"] == "filter" - and ( - model["pipeline"]["pipelines"] == ["*"] - or any( - model_id == target_model_id - for target_model_id in model["pipeline"]["pipelines"] - ) - ) - ] - sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"]) - return sorted_filters - - -def filter_pipeline(payload, user): - user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role} - model_id = payload["model"] - sorted_filters = get_sorted_filters(model_id) - - model = app.state.MODELS[model_id] - - if "pipeline" in model: - sorted_filters.append(model) - - for filter in sorted_filters: - r = None - try: - urlIdx = filter["urlIdx"] - - url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] - key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] - - if key == "": - continue - - headers = {"Authorization": f"Bearer {key}"} - r = requests.post( - f"{url}/{filter['id']}/filter/inlet", - headers=headers, - json={ - "user": user, - "body": payload, - }, - ) - - r.raise_for_status() - payload = r.json() - except Exception as e: - # Handle connection error here - print(f"Connection error: {e}") - - if r is not None: - res = r.json() - if "detail" in res: - raise Exception(r.status_code, res["detail"]) - - return payload - - -class PipelineMiddleware(BaseHTTPMiddleware): - async def dispatch(self, request: Request, call_next): - if not is_chat_completion_request(request): - return await call_next(request) - - log.debug(f"request.url.path: {request.url.path}") - - # Read the original request body - body = await request.body() - # Decode body to string - body_str = body.decode("utf-8") - # Parse string to JSON - data = json.loads(body_str) if body_str else {} - - try: - user = get_current_user( - request, - get_http_authorization_cred(request.headers["Authorization"]), - ) - except KeyError as e: - if len(e.args) > 1: - return JSONResponse( - status_code=e.args[0], - content={"detail": e.args[1]}, - ) - else: - return JSONResponse( - status_code=status.HTTP_401_UNAUTHORIZED, - content={"detail": "Not authenticated"}, - ) - - try: - data = filter_pipeline(data, user) - except Exception as e: - if len(e.args) > 1: - return JSONResponse( - status_code=e.args[0], - content={"detail": e.args[1]}, - ) - else: - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content={"detail": str(e)}, - ) - - modified_body_bytes = json.dumps(data).encode("utf-8") - # Replace the request body with the modified one - request._body = modified_body_bytes - # Set custom header to ensure content-length matches new body length - request.headers.__dict__["_list"] = [ - (b"content-length", str(len(modified_body_bytes)).encode("utf-8")), - *[(k, v) for k, v in request.headers.raw if k.lower() != b"content-length"], - ] - - response = await call_next(request) - return response - - async def _receive(self, body: bytes): - return {"type": "http.request", "body": body, "more_body": False} - - -app.add_middleware(PipelineMiddleware) - - -from urllib.parse import urlencode, parse_qs, urlparse - - class RedirectMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): # Check if the request is a GET request @@ -841,47 +658,24 @@ class RedirectMiddleware(BaseHTTPMiddleware): # Add the middleware to the app app.add_middleware(RedirectMiddleware) - - -app.add_middleware( - CORSMiddleware, - allow_origins=CORS_ALLOW_ORIGIN, - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) - app.add_middleware(SecurityHeadersMiddleware) @app.middleware("http") async def commit_session_after_request(request: Request, call_next): response = await call_next(request) - log.debug("Commit session after request") + # log.debug("Commit session after request") Session.commit() return response @app.middleware("http") async def check_url(request: Request, call_next): - if len(app.state.MODELS) == 0: - await get_all_models() - else: - pass - start_time = int(time.time()) + request.state.enable_api_key = app.state.config.ENABLE_API_KEY response = await call_next(request) process_time = int(time.time()) - start_time response.headers["X-Process-Time"] = str(process_time) - - return response - - -@app.middleware("http") -async def update_embedding_function(request: Request, call_next): - response = await call_next(request) - if "/embedding/update" in request.url.path: - webui_app.state.EMBEDDING_FUNCTION = retrieval_app.state.EMBEDDING_FUNCTION return response @@ -903,176 +697,84 @@ async def inspect_websocket(request: Request, call_next): return await call_next(request) +app.add_middleware( + CORSMiddleware, + allow_origins=CORS_ALLOW_ORIGIN, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + + app.mount("/ws", socket_app) -app.mount("/ollama", ollama_app) -app.mount("/openai", openai_app) - -app.mount("/images/api/v1", images_app) -app.mount("/audio/api/v1", audio_app) -app.mount("/retrieval/api/v1", retrieval_app) - -app.mount("/api/v1", webui_app) -webui_app.state.EMBEDDING_FUNCTION = retrieval_app.state.EMBEDDING_FUNCTION +app.include_router(ollama.router, prefix="/ollama", tags=["ollama"]) +app.include_router(openai.router, prefix="/openai", tags=["openai"]) -async def get_all_models(): - # TODO: Optimize this function - open_webui_models = [] - openai_models = [] - ollama_models = [] +app.include_router(pipelines.router, prefix="/api/v1/pipelines", tags=["pipelines"]) +app.include_router(tasks.router, prefix="/api/v1/tasks", tags=["tasks"]) +app.include_router(images.router, prefix="/api/v1/images", tags=["images"]) +app.include_router(audio.router, prefix="/api/v1/audio", tags=["audio"]) +app.include_router(retrieval.router, prefix="/api/v1/retrieval", tags=["retrieval"]) - if app.state.config.ENABLE_OPENAI_API: - openai_models = await get_openai_models() - openai_models = openai_models["data"] +app.include_router(configs.router, prefix="/api/v1/configs", tags=["configs"]) - if app.state.config.ENABLE_OLLAMA_API: - ollama_models = await get_ollama_models() - ollama_models = [ - { - "id": model["model"], - "name": model["name"], - "object": "model", - "created": int(time.time()), - "owned_by": "ollama", - "ollama": model, - } - for model in ollama_models["models"] - ] +app.include_router(auths.router, prefix="/api/v1/auths", tags=["auths"]) +app.include_router(users.router, prefix="/api/v1/users", tags=["users"]) - open_webui_models = await get_open_webui_models() +app.include_router(chats.router, prefix="/api/v1/chats", tags=["chats"]) - models = open_webui_models + openai_models + ollama_models +app.include_router(models.router, prefix="/api/v1/models", tags=["models"]) +app.include_router(knowledge.router, prefix="/api/v1/knowledge", tags=["knowledge"]) +app.include_router(prompts.router, prefix="/api/v1/prompts", tags=["prompts"]) +app.include_router(tools.router, prefix="/api/v1/tools", tags=["tools"]) - # If there are no models, return an empty list - if len([model for model in models if model["owned_by"] != "arena"]) == 0: - return [] +app.include_router(memories.router, prefix="/api/v1/memories", tags=["memories"]) +app.include_router(folders.router, prefix="/api/v1/folders", tags=["folders"]) +app.include_router(groups.router, prefix="/api/v1/groups", tags=["groups"]) +app.include_router(files.router, prefix="/api/v1/files", tags=["files"]) +app.include_router(functions.router, prefix="/api/v1/functions", tags=["functions"]) +app.include_router( + evaluations.router, prefix="/api/v1/evaluations", tags=["evaluations"] +) +app.include_router(utils.router, prefix="/api/v1/utils", tags=["utils"]) - global_action_ids = [ - function.id for function in Functions.get_global_action_functions() - ] - enabled_action_ids = [ - function.id - for function in Functions.get_functions_by_type("action", active_only=True) - ] - custom_models = Models.get_all_models() - for custom_model in custom_models: - if custom_model.base_model_id is None: - for model in models: - if ( - custom_model.id == model["id"] - or custom_model.id == model["id"].split(":")[0] - ): - model["name"] = custom_model.name - model["info"] = custom_model.model_dump() - - action_ids = [] - if "info" in model and "meta" in model["info"]: - action_ids.extend(model["info"]["meta"].get("actionIds", [])) - - model["action_ids"] = action_ids - else: - owned_by = "openai" - pipe = None - action_ids = [] - - for model in models: - if ( - custom_model.base_model_id == model["id"] - or custom_model.base_model_id == model["id"].split(":")[0] - ): - owned_by = model["owned_by"] - if "pipe" in model: - pipe = model["pipe"] - break - - if custom_model.meta: - meta = custom_model.meta.model_dump() - if "actionIds" in meta: - action_ids.extend(meta["actionIds"]) - - models.append( - { - "id": custom_model.id, - "name": custom_model.name, - "object": "model", - "created": custom_model.created_at, - "owned_by": owned_by, - "info": custom_model.model_dump(), - "preset": True, - **({"pipe": pipe} if pipe is not None else {}), - "action_ids": action_ids, - } - ) - - for model in models: - action_ids = [] - if "action_ids" in model: - action_ids = model["action_ids"] - del model["action_ids"] - - action_ids = action_ids + global_action_ids - action_ids = list(set(action_ids)) - action_ids = [ - action_id for action_id in action_ids if action_id in enabled_action_ids - ] - - model["actions"] = [] - for action_id in action_ids: - action = Functions.get_function_by_id(action_id) - if action is None: - raise Exception(f"Action not found: {action_id}") - - if action_id in webui_app.state.FUNCTIONS: - function_module = webui_app.state.FUNCTIONS[action_id] - else: - function_module, _, _ = load_function_module_by_id(action_id) - webui_app.state.FUNCTIONS[action_id] = function_module - - __webui__ = False - if hasattr(function_module, "__webui__"): - __webui__ = function_module.__webui__ - - if hasattr(function_module, "actions"): - actions = function_module.actions - model["actions"].extend( - [ - { - "id": f"{action_id}.{_action['id']}", - "name": _action.get( - "name", f"{action.name} ({_action['id']})" - ), - "description": action.meta.description, - "icon_url": _action.get( - "icon_url", action.meta.manifest.get("icon_url", None) - ), - **({"__webui__": __webui__} if __webui__ else {}), - } - for _action in actions - ] - ) - else: - model["actions"].append( - { - "id": action_id, - "name": action.name, - "description": action.meta.description, - "icon_url": action.meta.manifest.get("icon_url", None), - **({"__webui__": __webui__} if __webui__ else {}), - } - ) - - app.state.MODELS = {model["id"]: model for model in models} - webui_app.state.MODELS = app.state.MODELS - - return models +################################## +# +# Chat Endpoints +# +################################## @app.get("/api/models") -async def get_models(user=Depends(get_verified_user)): - models = await get_all_models() +async def get_models(request: Request, user=Depends(get_verified_user)): + def get_filtered_models(models, user): + filtered_models = [] + for model in models: + if model.get("arena"): + if has_access( + user.id, + type="read", + access_control=model.get("info", {}) + .get("meta", {}) + .get("access_control", {}), + ): + filtered_models.append(model) + continue + + model_info = Models.get_model_by_id(model["id"]) + if model_info: + if user.id == model_info.user_id or has_access( + user.id, type="read", access_control=model_info.access_control + ): + filtered_models.append(model) + + return filtered_models + + models = await get_all_models(request) # Filter out filter pipelines models = [ @@ -1081,1125 +783,100 @@ async def get_models(user=Depends(get_verified_user)): if "pipeline" not in model or model["pipeline"].get("type", None) != "filter" ] - if app.state.config.ENABLE_MODEL_FILTER: - if user.role == "user": - models = list( - filter( - lambda model: model["id"] in app.state.config.MODEL_FILTER_LIST, - models, - ) - ) - return {"data": models} + model_order_list = request.app.state.config.MODEL_ORDER_LIST + if model_order_list: + model_order_dict = {model_id: i for i, model_id in enumerate(model_order_list)} + # Sort models by order list priority, with fallback for those not in the list + models.sort( + key=lambda x: (model_order_dict.get(x["id"], float("inf")), x["name"]) + ) + # Filter out models that the user does not have access to + if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL: + models = get_filtered_models(models, user) + + log.debug( + f"/api/models returned filtered models accessible to the user: {json.dumps([model['id'] for model in models])}" + ) + return {"data": models} + + +@app.get("/api/models/base") +async def get_base_models(request: Request, user=Depends(get_admin_user)): + models = await get_all_base_models(request) return {"data": models} @app.post("/api/chat/completions") -async def generate_chat_completions( - form_data: dict, user=Depends(get_verified_user), bypass_filter: bool = False +async def chat_completion( + request: Request, + form_data: dict, + user=Depends(get_verified_user), + bypass_filter: bool = False, ): - model_id = form_data["model"] + if not request.app.state.MODELS: + await get_all_models(request) - if model_id not in app.state.MODELS: + try: + model_id = form_data.get("model", None) + if model_id not in request.app.state.MODELS: + raise Exception("Model not found") + model = request.app.state.MODELS[model_id] + + # Check if user has access to the model + if not bypass_filter and user.role == "user": + try: + check_model_access(user, model) + except Exception as e: + raise e + + form_data, events = await process_chat_payload(request, form_data, user, model) + except Exception as e: raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Model not found", + status_code=status.HTTP_400_BAD_REQUEST, + detail=str(e), ) - if not bypass_filter and app.state.config.ENABLE_MODEL_FILTER: - if user.role == "user" and model_id not in app.state.config.MODEL_FILTER_LIST: - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="Model not found", - ) - - model = app.state.MODELS[model_id] - - if model["owned_by"] == "arena": - model_ids = model.get("info", {}).get("meta", {}).get("model_ids") - filter_mode = model.get("info", {}).get("meta", {}).get("filter_mode") - if model_ids and filter_mode == "exclude": - model_ids = [ - model["id"] - for model in await get_all_models() - if model.get("owned_by") != "arena" - and not model.get("info", {}).get("meta", {}).get("hidden", False) - and model["id"] not in model_ids - ] - - selected_model_id = None - if isinstance(model_ids, list) and model_ids: - selected_model_id = random.choice(model_ids) - else: - model_ids = [ - model["id"] - for model in await get_all_models() - if model.get("owned_by") != "arena" - and not model.get("info", {}).get("meta", {}).get("hidden", False) - ] - selected_model_id = random.choice(model_ids) - - form_data["model"] = selected_model_id - - if form_data.get("stream") == True: - - async def stream_wrapper(stream): - yield f"data: {json.dumps({'selected_model_id': selected_model_id})}\n\n" - async for chunk in stream: - yield chunk - - response = await generate_chat_completions( - form_data, user, bypass_filter=True - ) - return StreamingResponse( - stream_wrapper(response.body_iterator), media_type="text/event-stream" - ) - else: - return { - **( - await generate_chat_completions(form_data, user, bypass_filter=True) - ), - "selected_model_id": selected_model_id, - } - if model.get("pipe"): - return await generate_function_chat_completion(form_data, user=user) - if model["owned_by"] == "ollama": - # Using /ollama/api/chat endpoint - form_data = convert_payload_openai_to_ollama(form_data) - form_data = GenerateChatCompletionForm(**form_data) - response = await generate_ollama_chat_completion( - form_data=form_data, user=user, bypass_filter=True + try: + response = await chat_completion_handler( + request, form_data, user, bypass_filter ) - if form_data.stream: - response.headers["content-type"] = "text/event-stream" - return StreamingResponse( - convert_streaming_response_ollama_to_openai(response), - headers=dict(response.headers), - ) - else: - return convert_response_ollama_to_openai(response) - else: - return await generate_openai_chat_completion(form_data, user=user) + return await process_chat_response(response, events) + except Exception as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=str(e), + ) + + +# Alias for chat_completion (Legacy) +generate_chat_completions = chat_completion +generate_chat_completion = chat_completion @app.post("/api/chat/completed") -async def chat_completed(form_data: dict, user=Depends(get_verified_user)): - data = form_data - model_id = data["model"] - if model_id not in app.state.MODELS: +async def chat_completed( + request: Request, form_data: dict, user=Depends(get_verified_user) +): + try: + return await chat_completed_handler(request, form_data, user) + except Exception as e: raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Model not found", + status_code=status.HTTP_400_BAD_REQUEST, + detail=str(e), ) - model = app.state.MODELS[model_id] - - sorted_filters = get_sorted_filters(model_id) - if "pipeline" in model: - sorted_filters = [model] + sorted_filters - - for filter in sorted_filters: - r = None - try: - urlIdx = filter["urlIdx"] - - url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] - key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] - - if key != "": - headers = {"Authorization": f"Bearer {key}"} - r = requests.post( - f"{url}/{filter['id']}/filter/outlet", - headers=headers, - json={ - "user": { - "id": user.id, - "name": user.name, - "email": user.email, - "role": user.role, - }, - "body": data, - }, - ) - - r.raise_for_status() - data = r.json() - except Exception as e: - # Handle connection error here - print(f"Connection error: {e}") - - if r is not None: - try: - res = r.json() - if "detail" in res: - return JSONResponse( - status_code=r.status_code, - content=res, - ) - except Exception: - pass - - else: - pass - - __event_emitter__ = get_event_emitter( - { - "chat_id": data["chat_id"], - "message_id": data["id"], - "session_id": data["session_id"], - } - ) - - __event_call__ = get_event_call( - { - "chat_id": data["chat_id"], - "message_id": data["id"], - "session_id": data["session_id"], - } - ) - - def get_priority(function_id): - function = Functions.get_function_by_id(function_id) - if function is not None and hasattr(function, "valves"): - # TODO: Fix FunctionModel to include vavles - return (function.valves if function.valves else {}).get("priority", 0) - return 0 - - filter_ids = [function.id for function in Functions.get_global_filter_functions()] - if "info" in model and "meta" in model["info"]: - filter_ids.extend(model["info"]["meta"].get("filterIds", [])) - filter_ids = list(set(filter_ids)) - - enabled_filter_ids = [ - function.id - for function in Functions.get_functions_by_type("filter", active_only=True) - ] - filter_ids = [ - filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids - ] - - # Sort filter_ids by priority, using the get_priority function - filter_ids.sort(key=get_priority) - - for filter_id in filter_ids: - filter = Functions.get_function_by_id(filter_id) - if not filter: - continue - - if filter_id in webui_app.state.FUNCTIONS: - function_module = webui_app.state.FUNCTIONS[filter_id] - else: - function_module, _, _ = load_function_module_by_id(filter_id) - webui_app.state.FUNCTIONS[filter_id] = function_module - - if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): - valves = Functions.get_function_valves_by_id(filter_id) - function_module.valves = function_module.Valves( - **(valves if valves else {}) - ) - - if not hasattr(function_module, "outlet"): - continue - try: - outlet = function_module.outlet - - # Get the signature of the function - sig = inspect.signature(outlet) - params = {"body": data} - - # Extra parameters to be passed to the function - extra_params = { - "__model__": model, - "__id__": filter_id, - "__event_emitter__": __event_emitter__, - "__event_call__": __event_call__, - } - - # Add extra params in contained in function signature - for key, value in extra_params.items(): - if key in sig.parameters: - params[key] = value - - if "__user__" in sig.parameters: - __user__ = { - "id": user.id, - "email": user.email, - "name": user.name, - "role": user.role, - } - - try: - if hasattr(function_module, "UserValves"): - __user__["valves"] = function_module.UserValves( - **Functions.get_user_valves_by_id_and_user_id( - filter_id, user.id - ) - ) - except Exception as e: - print(e) - - params = {**params, "__user__": __user__} - - if inspect.iscoroutinefunction(outlet): - data = await outlet(**params) - else: - data = outlet(**params) - - except Exception as e: - print(f"Error: {e}") - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content={"detail": str(e)}, - ) - - return data @app.post("/api/chat/actions/{action_id}") -async def chat_action(action_id: str, form_data: dict, user=Depends(get_verified_user)): - if "." in action_id: - action_id, sub_action_id = action_id.split(".") - else: - sub_action_id = None - - action = Functions.get_function_by_id(action_id) - if not action: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Action not found", - ) - - data = form_data - model_id = data["model"] - if model_id not in app.state.MODELS: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Model not found", - ) - model = app.state.MODELS[model_id] - - __event_emitter__ = get_event_emitter( - { - "chat_id": data["chat_id"], - "message_id": data["id"], - "session_id": data["session_id"], - } - ) - __event_call__ = get_event_call( - { - "chat_id": data["chat_id"], - "message_id": data["id"], - "session_id": data["session_id"], - } - ) - - if action_id in webui_app.state.FUNCTIONS: - function_module = webui_app.state.FUNCTIONS[action_id] - else: - function_module, _, _ = load_function_module_by_id(action_id) - webui_app.state.FUNCTIONS[action_id] = function_module - - if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): - valves = Functions.get_function_valves_by_id(action_id) - function_module.valves = function_module.Valves(**(valves if valves else {})) - - if hasattr(function_module, "action"): - try: - action = function_module.action - - # Get the signature of the function - sig = inspect.signature(action) - params = {"body": data} - - # Extra parameters to be passed to the function - extra_params = { - "__model__": model, - "__id__": sub_action_id if sub_action_id is not None else action_id, - "__event_emitter__": __event_emitter__, - "__event_call__": __event_call__, - } - - # Add extra params in contained in function signature - for key, value in extra_params.items(): - if key in sig.parameters: - params[key] = value - - if "__user__" in sig.parameters: - __user__ = { - "id": user.id, - "email": user.email, - "name": user.name, - "role": user.role, - } - - try: - if hasattr(function_module, "UserValves"): - __user__["valves"] = function_module.UserValves( - **Functions.get_user_valves_by_id_and_user_id( - action_id, user.id - ) - ) - except Exception as e: - print(e) - - params = {**params, "__user__": __user__} - - if inspect.iscoroutinefunction(action): - data = await action(**params) - else: - data = action(**params) - - except Exception as e: - print(f"Error: {e}") - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content={"detail": str(e)}, - ) - - return data - - -################################## -# -# Task Endpoints -# -################################## - - -# TODO: Refactor task API endpoints below into a separate file - - -@app.get("/api/task/config") -async def get_task_config(user=Depends(get_verified_user)): - return { - "TASK_MODEL": app.state.config.TASK_MODEL, - "TASK_MODEL_EXTERNAL": app.state.config.TASK_MODEL_EXTERNAL, - "TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE, - "TAGS_GENERATION_PROMPT_TEMPLATE": app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE, - "ENABLE_SEARCH_QUERY": app.state.config.ENABLE_SEARCH_QUERY, - "SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE, - "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, - } - - -class TaskConfigForm(BaseModel): - TASK_MODEL: Optional[str] - TASK_MODEL_EXTERNAL: Optional[str] - TITLE_GENERATION_PROMPT_TEMPLATE: str - TAGS_GENERATION_PROMPT_TEMPLATE: str - SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE: str - ENABLE_SEARCH_QUERY: bool - TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE: str - - -@app.post("/api/task/config/update") -async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_user)): - app.state.config.TASK_MODEL = form_data.TASK_MODEL - app.state.config.TASK_MODEL_EXTERNAL = form_data.TASK_MODEL_EXTERNAL - app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = ( - form_data.TITLE_GENERATION_PROMPT_TEMPLATE - ) - app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = ( - form_data.TAGS_GENERATION_PROMPT_TEMPLATE - ) - - app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = ( - form_data.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE - ) - app.state.config.ENABLE_SEARCH_QUERY = form_data.ENABLE_SEARCH_QUERY - app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = ( - form_data.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE - ) - - return { - "TASK_MODEL": app.state.config.TASK_MODEL, - "TASK_MODEL_EXTERNAL": app.state.config.TASK_MODEL_EXTERNAL, - "TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE, - "TAGS_GENERATION_PROMPT_TEMPLATE": app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE, - "SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE, - "ENABLE_SEARCH_QUERY": app.state.config.ENABLE_SEARCH_QUERY, - "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, - } - - -@app.post("/api/task/title/completions") -async def generate_title(form_data: dict, user=Depends(get_verified_user)): - print("generate_title") - - model_id = form_data["model"] - if model_id not in app.state.MODELS: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Model not found", - ) - - # Check if the user has a custom task model - # If the user has a custom task model, use that model - task_model_id = get_task_model_id(model_id) - print(task_model_id) - - model = app.state.MODELS[task_model_id] - - if app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE != "": - template = app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE - else: - template = """Create a concise, 3-5 word title with an emoji as a title for the chat history, in the given language. Suitable Emojis for the summary can be used to enhance understanding but avoid quotation marks or special formatting. RESPOND ONLY WITH THE TITLE TEXT. - -Examples of titles: -📉 Stock Market Trends -🍪 Perfect Chocolate Chip Recipe -Evolution of Music Streaming -Remote Work Productivity Tips -Artificial Intelligence in Healthcare -🎮 Video Game Development Insights - - -{{MESSAGES:END:2}} -""" - - content = title_generation_template( - template, - form_data["messages"], - { - "name": user.name, - "location": user.info.get("location") if user.info else None, - }, - ) - - payload = { - "model": task_model_id, - "messages": [{"role": "user", "content": content}], - "stream": False, - **( - {"max_tokens": 50} - if app.state.MODELS[task_model_id]["owned_by"] == "ollama" - else { - "max_completion_tokens": 50, - } - ), - "chat_id": form_data.get("chat_id", None), - "metadata": {"task": str(TASKS.TITLE_GENERATION), "task_body": form_data}, - } - log.debug(payload) - - # Handle pipeline filters +async def chat_action( + request: Request, action_id: str, form_data: dict, user=Depends(get_verified_user) +): try: - payload = filter_pipeline(payload, user) + return await chat_action_handler(request, action_id, form_data, user) except Exception as e: - if len(e.args) > 1: - return JSONResponse( - status_code=e.args[0], - content={"detail": e.args[1]}, - ) - else: - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content={"detail": str(e)}, - ) - if "chat_id" in payload: - del payload["chat_id"] - - return await generate_chat_completions(form_data=payload, user=user) - - -@app.post("/api/task/tags/completions") -async def generate_chat_tags(form_data: dict, user=Depends(get_verified_user)): - print("generate_chat_tags") - model_id = form_data["model"] - if model_id not in app.state.MODELS: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Model not found", - ) - - # Check if the user has a custom task model - # If the user has a custom task model, use that model - task_model_id = get_task_model_id(model_id) - print(task_model_id) - - if app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE != "": - template = app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE - else: - template = """### Task: -Generate 1-3 broad tags categorizing the main themes of the chat history, along with 1-3 more specific subtopic tags. - -### Guidelines: -- Start with high-level domains (e.g. Science, Technology, Philosophy, Arts, Politics, Business, Health, Sports, Entertainment, Education) -- Consider including relevant subfields/subdomains if they are strongly represented throughout the conversation -- If content is too short (less than 3 messages) or too diverse, use only ["General"] -- Use the chat's primary language; default to English if multilingual -- Prioritize accuracy over specificity - -### Output: -JSON format: { "tags": ["tag1", "tag2", "tag3"] } - -### Chat History: - -{{MESSAGES:END:6}} -""" - - content = tags_generation_template( - template, form_data["messages"], {"name": user.name} - ) - - print("content", content) - payload = { - "model": task_model_id, - "messages": [{"role": "user", "content": content}], - "stream": False, - "metadata": {"task": str(TASKS.TAGS_GENERATION), "task_body": form_data}, - } - log.debug(payload) - - # Handle pipeline filters - try: - payload = filter_pipeline(payload, user) - except Exception as e: - if len(e.args) > 1: - return JSONResponse( - status_code=e.args[0], - content={"detail": e.args[1]}, - ) - else: - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content={"detail": str(e)}, - ) - if "chat_id" in payload: - del payload["chat_id"] - - return await generate_chat_completions(form_data=payload, user=user) - - -@app.post("/api/task/query/completions") -async def generate_search_query(form_data: dict, user=Depends(get_verified_user)): - print("generate_search_query") - if not app.state.config.ENABLE_SEARCH_QUERY: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Search query generation is disabled", - ) - - model_id = form_data["model"] - if model_id not in app.state.MODELS: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Model not found", - ) - - # Check if the user has a custom task model - # If the user has a custom task model, use that model - task_model_id = get_task_model_id(model_id) - print(task_model_id) - - model = app.state.MODELS[task_model_id] - - if app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE != "": - template = app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE - else: - template = """Given the user's message and interaction history, decide if a web search is necessary. You must be concise and exclusively provide a search query if one is necessary. Refrain from verbose responses or any additional commentary. Prefer suggesting a search if uncertain to provide comprehensive or updated information. If a search isn't needed at all, respond with an empty string. Default to a search query when in doubt. Today's date is {{CURRENT_DATE}}. - -User Message: -{{prompt:end:4000}} - -Interaction History: -{{MESSAGES:END:6}} - -Search Query:""" - - content = search_query_generation_template( - template, form_data["messages"], {"name": user.name} - ) - - print("content", content) - - payload = { - "model": task_model_id, - "messages": [{"role": "user", "content": content}], - "stream": False, - **( - {"max_tokens": 30} - if app.state.MODELS[task_model_id]["owned_by"] == "ollama" - else { - "max_completion_tokens": 30, - } - ), - "metadata": {"task": str(TASKS.QUERY_GENERATION), "task_body": form_data}, - } - log.debug(payload) - - # Handle pipeline filters - try: - payload = filter_pipeline(payload, user) - except Exception as e: - if len(e.args) > 1: - return JSONResponse( - status_code=e.args[0], - content={"detail": e.args[1]}, - ) - else: - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content={"detail": str(e)}, - ) - if "chat_id" in payload: - del payload["chat_id"] - - return await generate_chat_completions(form_data=payload, user=user) - - -@app.post("/api/task/emoji/completions") -async def generate_emoji(form_data: dict, user=Depends(get_verified_user)): - print("generate_emoji") - - model_id = form_data["model"] - if model_id not in app.state.MODELS: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Model not found", - ) - - # Check if the user has a custom task model - # If the user has a custom task model, use that model - task_model_id = get_task_model_id(model_id) - print(task_model_id) - - model = app.state.MODELS[task_model_id] - - template = ''' -Your task is to reflect the speaker's likely facial expression through a fitting emoji. Interpret emotions from the message and reflect their facial expression using fitting, diverse emojis (e.g., 😊, 😢, 😡, 😱). - -Message: """{{prompt}}""" -''' - content = emoji_generation_template( - template, - form_data["prompt"], - { - "name": user.name, - "location": user.info.get("location") if user.info else None, - }, - ) - - payload = { - "model": task_model_id, - "messages": [{"role": "user", "content": content}], - "stream": False, - **( - {"max_tokens": 4} - if app.state.MODELS[task_model_id]["owned_by"] == "ollama" - else { - "max_completion_tokens": 4, - } - ), - "chat_id": form_data.get("chat_id", None), - "metadata": {"task": str(TASKS.EMOJI_GENERATION), "task_body": form_data}, - } - log.debug(payload) - - # Handle pipeline filters - try: - payload = filter_pipeline(payload, user) - except Exception as e: - if len(e.args) > 1: - return JSONResponse( - status_code=e.args[0], - content={"detail": e.args[1]}, - ) - else: - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content={"detail": str(e)}, - ) - if "chat_id" in payload: - del payload["chat_id"] - - return await generate_chat_completions(form_data=payload, user=user) - - -@app.post("/api/task/moa/completions") -async def generate_moa_response(form_data: dict, user=Depends(get_verified_user)): - print("generate_moa_response") - - model_id = form_data["model"] - if model_id not in app.state.MODELS: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Model not found", - ) - - # Check if the user has a custom task model - # If the user has a custom task model, use that model - task_model_id = get_task_model_id(model_id) - print(task_model_id) - - model = app.state.MODELS[task_model_id] - - template = """You have been provided with a set of responses from various models to the latest user query: "{{prompt}}" - -Your task is to synthesize these responses into a single, high-quality response. It is crucial to critically evaluate the information provided in these responses, recognizing that some of it may be biased or incorrect. Your response should not simply replicate the given answers but should offer a refined, accurate, and comprehensive reply to the instruction. Ensure your response is well-structured, coherent, and adheres to the highest standards of accuracy and reliability. - -Responses from models: {{responses}}""" - - content = moa_response_generation_template( - template, - form_data["prompt"], - form_data["responses"], - ) - - payload = { - "model": task_model_id, - "messages": [{"role": "user", "content": content}], - "stream": form_data.get("stream", False), - "chat_id": form_data.get("chat_id", None), - "metadata": { - "task": str(TASKS.MOA_RESPONSE_GENERATION), - "task_body": form_data, - }, - } - log.debug(payload) - - try: - payload = filter_pipeline(payload, user) - except Exception as e: - if len(e.args) > 1: - return JSONResponse( - status_code=e.args[0], - content={"detail": e.args[1]}, - ) - else: - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content={"detail": str(e)}, - ) - if "chat_id" in payload: - del payload["chat_id"] - - return await generate_chat_completions(form_data=payload, user=user) - - -################################## -# -# Pipelines Endpoints -# -################################## - - -# TODO: Refactor pipelines API endpoints below into a separate file - - -@app.get("/api/pipelines/list") -async def get_pipelines_list(user=Depends(get_admin_user)): - responses = await get_openai_models(raw=True) - - print(responses) - urlIdxs = [ - idx - for idx, response in enumerate(responses) - if response is not None and "pipelines" in response - ] - - return { - "data": [ - { - "url": openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx], - "idx": urlIdx, - } - for urlIdx in urlIdxs - ] - } - - -@app.post("/api/pipelines/upload") -async def upload_pipeline( - urlIdx: int = Form(...), file: UploadFile = File(...), user=Depends(get_admin_user) -): - print("upload_pipeline", urlIdx, file.filename) - # Check if the uploaded file is a python file - if not (file.filename and file.filename.endswith(".py")): - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Only Python (.py) files are allowed.", - ) - - upload_folder = f"{CACHE_DIR}/pipelines" - os.makedirs(upload_folder, exist_ok=True) - file_path = os.path.join(upload_folder, file.filename) - - r = None - try: - # Save the uploaded file - with open(file_path, "wb") as buffer: - shutil.copyfileobj(file.file, buffer) - - url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] - key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] - - headers = {"Authorization": f"Bearer {key}"} - - with open(file_path, "rb") as f: - files = {"file": f} - r = requests.post(f"{url}/pipelines/upload", headers=headers, files=files) - - r.raise_for_status() - data = r.json() - - return {**data} - except Exception as e: - # Handle connection error here - print(f"Connection error: {e}") - - detail = "Pipeline not found" - status_code = status.HTTP_404_NOT_FOUND - if r is not None: - status_code = r.status_code - try: - res = r.json() - if "detail" in res: - detail = res["detail"] - except Exception: - pass - - raise HTTPException( - status_code=status_code, - detail=detail, - ) - finally: - # Ensure the file is deleted after the upload is completed or on failure - if os.path.exists(file_path): - os.remove(file_path) - - -class AddPipelineForm(BaseModel): - url: str - urlIdx: int - - -@app.post("/api/pipelines/add") -async def add_pipeline(form_data: AddPipelineForm, user=Depends(get_admin_user)): - r = None - try: - urlIdx = form_data.urlIdx - - url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] - key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] - - headers = {"Authorization": f"Bearer {key}"} - r = requests.post( - f"{url}/pipelines/add", headers=headers, json={"url": form_data.url} - ) - - r.raise_for_status() - data = r.json() - - return {**data} - except Exception as e: - # Handle connection error here - print(f"Connection error: {e}") - - detail = "Pipeline not found" - if r is not None: - try: - res = r.json() - if "detail" in res: - detail = res["detail"] - except Exception: - pass - - raise HTTPException( - status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND), - detail=detail, - ) - - -class DeletePipelineForm(BaseModel): - id: str - urlIdx: int - - -@app.delete("/api/pipelines/delete") -async def delete_pipeline(form_data: DeletePipelineForm, user=Depends(get_admin_user)): - r = None - try: - urlIdx = form_data.urlIdx - - url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] - key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] - - headers = {"Authorization": f"Bearer {key}"} - r = requests.delete( - f"{url}/pipelines/delete", headers=headers, json={"id": form_data.id} - ) - - r.raise_for_status() - data = r.json() - - return {**data} - except Exception as e: - # Handle connection error here - print(f"Connection error: {e}") - - detail = "Pipeline not found" - if r is not None: - try: - res = r.json() - if "detail" in res: - detail = res["detail"] - except Exception: - pass - - raise HTTPException( - status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND), - detail=detail, - ) - - -@app.get("/api/pipelines") -async def get_pipelines(urlIdx: Optional[int] = None, user=Depends(get_admin_user)): - r = None - try: - url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] - key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] - - headers = {"Authorization": f"Bearer {key}"} - r = requests.get(f"{url}/pipelines", headers=headers) - - r.raise_for_status() - data = r.json() - - return {**data} - except Exception as e: - # Handle connection error here - print(f"Connection error: {e}") - - detail = "Pipeline not found" - if r is not None: - try: - res = r.json() - if "detail" in res: - detail = res["detail"] - except Exception: - pass - - raise HTTPException( - status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND), - detail=detail, - ) - - -@app.get("/api/pipelines/{pipeline_id}/valves") -async def get_pipeline_valves( - urlIdx: Optional[int], - pipeline_id: str, - user=Depends(get_admin_user), -): - r = None - try: - url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] - key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] - - headers = {"Authorization": f"Bearer {key}"} - r = requests.get(f"{url}/{pipeline_id}/valves", headers=headers) - - r.raise_for_status() - data = r.json() - - return {**data} - except Exception as e: - # Handle connection error here - print(f"Connection error: {e}") - - detail = "Pipeline not found" - - if r is not None: - try: - res = r.json() - if "detail" in res: - detail = res["detail"] - except Exception: - pass - - raise HTTPException( - status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND), - detail=detail, - ) - - -@app.get("/api/pipelines/{pipeline_id}/valves/spec") -async def get_pipeline_valves_spec( - urlIdx: Optional[int], - pipeline_id: str, - user=Depends(get_admin_user), -): - r = None - try: - url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] - key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] - - headers = {"Authorization": f"Bearer {key}"} - r = requests.get(f"{url}/{pipeline_id}/valves/spec", headers=headers) - - r.raise_for_status() - data = r.json() - - return {**data} - except Exception as e: - # Handle connection error here - print(f"Connection error: {e}") - - detail = "Pipeline not found" - if r is not None: - try: - res = r.json() - if "detail" in res: - detail = res["detail"] - except Exception: - pass - - raise HTTPException( - status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND), - detail=detail, - ) - - -@app.post("/api/pipelines/{pipeline_id}/valves/update") -async def update_pipeline_valves( - urlIdx: Optional[int], - pipeline_id: str, - form_data: dict, - user=Depends(get_admin_user), -): - r = None - try: - url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] - key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] - - headers = {"Authorization": f"Bearer {key}"} - r = requests.post( - f"{url}/{pipeline_id}/valves/update", - headers=headers, - json={**form_data}, - ) - - r.raise_for_status() - data = r.json() - - return {**data} - except Exception as e: - # Handle connection error here - print(f"Connection error: {e}") - - detail = "Pipeline not found" - - if r is not None: - try: - res = r.json() - if "detail" in res: - detail = res["detail"] - except Exception: - pass - - raise HTTPException( - status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND), - detail=detail, + detail=str(e), ) @@ -2215,11 +892,24 @@ async def get_app_config(request: Request): user = None if "token" in request.cookies: token = request.cookies.get("token") - data = decode_token(token) + try: + data = decode_token(token) + except Exception as e: + log.debug(e) + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid token", + ) if data is not None and "id" in data: user = Users.get_user_by_id(data["id"]) + onboarding = False + if user is None: + user_count = Users.get_num_users() + onboarding = user_count == 0 + return { + **({"onboarding": True} if onboarding else {}), "status": True, "name": WEBUI_NAME, "version": VERSION, @@ -2232,15 +922,17 @@ async def get_app_config(request: Request): }, "features": { "auth": WEBUI_AUTH, - "auth_trusted_header": bool(webui_app.state.AUTH_TRUSTED_EMAIL_HEADER), - "enable_signup": webui_app.state.config.ENABLE_SIGNUP, - "enable_login_form": webui_app.state.config.ENABLE_LOGIN_FORM, + "auth_trusted_header": bool(app.state.AUTH_TRUSTED_EMAIL_HEADER), + "enable_ldap": app.state.config.ENABLE_LDAP, + "enable_api_key": app.state.config.ENABLE_API_KEY, + "enable_signup": app.state.config.ENABLE_SIGNUP, + "enable_login_form": app.state.config.ENABLE_LOGIN_FORM, **( { - "enable_web_search": retrieval_app.state.config.ENABLE_RAG_WEB_SEARCH, - "enable_image_generation": images_app.state.config.ENABLED, - "enable_community_sharing": webui_app.state.config.ENABLE_COMMUNITY_SHARING, - "enable_message_rating": webui_app.state.config.ENABLE_MESSAGE_RATING, + "enable_web_search": app.state.config.ENABLE_RAG_WEB_SEARCH, + "enable_image_generation": app.state.config.ENABLE_IMAGE_GENERATION, + "enable_community_sharing": app.state.config.ENABLE_COMMUNITY_SHARING, + "enable_message_rating": app.state.config.ENABLE_MESSAGE_RATING, "enable_admin_export": ENABLE_ADMIN_EXPORT, "enable_admin_chat_access": ENABLE_ADMIN_CHAT_ACCESS, } @@ -2254,23 +946,23 @@ async def get_app_config(request: Request): }, **( { - "default_models": webui_app.state.config.DEFAULT_MODELS, - "default_prompt_suggestions": webui_app.state.config.DEFAULT_PROMPT_SUGGESTIONS, + "default_models": app.state.config.DEFAULT_MODELS, + "default_prompt_suggestions": app.state.config.DEFAULT_PROMPT_SUGGESTIONS, "audio": { "tts": { - "engine": audio_app.state.config.TTS_ENGINE, - "voice": audio_app.state.config.TTS_VOICE, - "split_on": audio_app.state.config.TTS_SPLIT_ON, + "engine": app.state.config.TTS_ENGINE, + "voice": app.state.config.TTS_VOICE, + "split_on": app.state.config.TTS_SPLIT_ON, }, "stt": { - "engine": audio_app.state.config.STT_ENGINE, + "engine": app.state.config.STT_ENGINE, }, }, "file": { - "max_size": retrieval_app.state.config.FILE_MAX_SIZE, - "max_count": retrieval_app.state.config.FILE_MAX_COUNT, + "max_size": app.state.config.FILE_MAX_SIZE, + "max_count": app.state.config.FILE_MAX_COUNT, }, - "permissions": {**webui_app.state.config.USER_PERMISSIONS}, + "permissions": {**app.state.config.USER_PERMISSIONS}, } if user is not None else {} @@ -2278,33 +970,8 @@ async def get_app_config(request: Request): } -@app.get("/api/config/model/filter") -async def get_model_filter_config(user=Depends(get_admin_user)): - return { - "enabled": app.state.config.ENABLE_MODEL_FILTER, - "models": app.state.config.MODEL_FILTER_LIST, - } - - -class ModelFilterConfigForm(BaseModel): - enabled: bool - models: list[str] - - -@app.post("/api/config/model/filter") -async def update_model_filter_config( - form_data: ModelFilterConfigForm, user=Depends(get_admin_user) -): - app.state.config.ENABLE_MODEL_FILTER = form_data.enabled - app.state.config.MODEL_FILTER_LIST = form_data.models - - return { - "enabled": app.state.config.ENABLE_MODEL_FILTER, - "models": app.state.config.MODEL_FILTER_LIST, - } - - -# TODO: webhook endpoint should be under config endpoints +class UrlForm(BaseModel): + url: str @app.get("/api/webhook") @@ -2314,14 +981,10 @@ async def get_webhook_url(user=Depends(get_admin_user)): } -class UrlForm(BaseModel): - url: str - - @app.post("/api/webhook") async def update_webhook_url(form_data: UrlForm, user=Depends(get_admin_user)): app.state.config.WEBHOOK_URL = form_data.url - webui_app.state.WEBHOOK_URL = app.state.config.WEBHOOK_URL + app.state.WEBHOOK_URL = app.state.config.WEBHOOK_URL return {"url": app.state.config.WEBHOOK_URL} @@ -2332,11 +995,6 @@ async def get_app_version(): } -@app.get("/api/changelog") -async def get_app_changelog(): - return {key: CHANGELOG[key] for idx, key in enumerate(CHANGELOG) if idx < 5} - - @app.get("/api/version/updates") async def get_app_latest_release_version(): if OFFLINE_MODE: @@ -2360,6 +1018,11 @@ async def get_app_latest_release_version(): return {"current": VERSION, "latest": VERSION} +@app.get("/api/changelog") +async def get_app_changelog(): + return {key: CHANGELOG[key] for idx, key in enumerate(CHANGELOG) if idx < 5} + + ############################ # OAuth Login & Callback ############################ @@ -2400,7 +1063,7 @@ async def get_manifest_json(): "start_url": "/", "display": "standalone", "background_color": "#343541", - "orientation": "any", + "orientation": "natural", "icons": [ { "src": "/static/logo.png", @@ -2447,7 +1110,6 @@ async def healthcheck_with_db(): app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static") app.mount("/cache", StaticFiles(directory=CACHE_DIR), name="cache") - if os.path.exists(FRONTEND_BUILD_DIR): mimetypes.add_type("text/javascript", ".js") app.mount( diff --git a/backend/open_webui/migrations/env.py b/backend/open_webui/migrations/env.py index 5e860c8a0..128881647 100644 --- a/backend/open_webui/migrations/env.py +++ b/backend/open_webui/migrations/env.py @@ -1,7 +1,7 @@ from logging.config import fileConfig from alembic import context -from open_webui.apps.webui.models.auths import Auth +from open_webui.models.auths import Auth from open_webui.env import DATABASE_URL from sqlalchemy import engine_from_config, pool diff --git a/backend/open_webui/migrations/script.py.mako b/backend/open_webui/migrations/script.py.mako index 01e730e77..bcf5567fd 100644 --- a/backend/open_webui/migrations/script.py.mako +++ b/backend/open_webui/migrations/script.py.mako @@ -9,7 +9,7 @@ from typing import Sequence, Union from alembic import op import sqlalchemy as sa -import open_webui.apps.webui.internal.db +import open_webui.internal.db ${imports if imports else ""} # revision identifiers, used by Alembic. diff --git a/backend/open_webui/migrations/versions/7e5b5dc7342b_init.py b/backend/open_webui/migrations/versions/7e5b5dc7342b_init.py index 607a7b2c9..9e56282ef 100644 --- a/backend/open_webui/migrations/versions/7e5b5dc7342b_init.py +++ b/backend/open_webui/migrations/versions/7e5b5dc7342b_init.py @@ -11,8 +11,8 @@ from typing import Sequence, Union import sqlalchemy as sa from alembic import op -import open_webui.apps.webui.internal.db -from open_webui.apps.webui.internal.db import JSONField +import open_webui.internal.db +from open_webui.internal.db import JSONField from open_webui.migrations.util import get_existing_tables # revision identifiers, used by Alembic. diff --git a/backend/open_webui/migrations/versions/922e7a387820_add_group_table.py b/backend/open_webui/migrations/versions/922e7a387820_add_group_table.py new file mode 100644 index 000000000..a75211584 --- /dev/null +++ b/backend/open_webui/migrations/versions/922e7a387820_add_group_table.py @@ -0,0 +1,85 @@ +"""Add group table + +Revision ID: 922e7a387820 +Revises: 4ace53fd72c8 +Create Date: 2024-11-14 03:00:00.000000 + +""" + +from alembic import op +import sqlalchemy as sa + +revision = "922e7a387820" +down_revision = "4ace53fd72c8" +branch_labels = None +depends_on = None + + +def upgrade(): + op.create_table( + "group", + sa.Column("id", sa.Text(), nullable=False, primary_key=True, unique=True), + sa.Column("user_id", sa.Text(), nullable=True), + sa.Column("name", sa.Text(), nullable=True), + sa.Column("description", sa.Text(), nullable=True), + sa.Column("data", sa.JSON(), nullable=True), + sa.Column("meta", sa.JSON(), nullable=True), + sa.Column("permissions", sa.JSON(), nullable=True), + sa.Column("user_ids", sa.JSON(), nullable=True), + sa.Column("created_at", sa.BigInteger(), nullable=True), + sa.Column("updated_at", sa.BigInteger(), nullable=True), + ) + + # Add 'access_control' column to 'model' table + op.add_column( + "model", + sa.Column("access_control", sa.JSON(), nullable=True), + ) + + # Add 'is_active' column to 'model' table + op.add_column( + "model", + sa.Column( + "is_active", + sa.Boolean(), + nullable=False, + server_default=sa.sql.expression.true(), + ), + ) + + # Add 'access_control' column to 'knowledge' table + op.add_column( + "knowledge", + sa.Column("access_control", sa.JSON(), nullable=True), + ) + + # Add 'access_control' column to 'prompt' table + op.add_column( + "prompt", + sa.Column("access_control", sa.JSON(), nullable=True), + ) + + # Add 'access_control' column to 'tools' table + op.add_column( + "tool", + sa.Column("access_control", sa.JSON(), nullable=True), + ) + + +def downgrade(): + op.drop_table("group") + + # Drop 'access_control' column from 'model' table + op.drop_column("model", "access_control") + + # Drop 'is_active' column from 'model' table + op.drop_column("model", "is_active") + + # Drop 'access_control' column from 'knowledge' table + op.drop_column("knowledge", "access_control") + + # Drop 'access_control' column from 'prompt' table + op.drop_column("prompt", "access_control") + + # Drop 'access_control' column from 'tools' table + op.drop_column("tool", "access_control") diff --git a/backend/open_webui/apps/webui/models/auths.py b/backend/open_webui/models/auths.py similarity index 95% rename from backend/open_webui/apps/webui/models/auths.py rename to backend/open_webui/models/auths.py index 167b9f6dc..f07c36c73 100644 --- a/backend/open_webui/apps/webui/models/auths.py +++ b/backend/open_webui/models/auths.py @@ -2,12 +2,12 @@ import logging import uuid from typing import Optional -from open_webui.apps.webui.internal.db import Base, get_db -from open_webui.apps.webui.models.users import UserModel, Users +from open_webui.internal.db import Base, get_db +from open_webui.models.users import UserModel, Users from open_webui.env import SRC_LOG_LEVELS from pydantic import BaseModel from sqlalchemy import Boolean, Column, String, Text -from open_webui.utils.utils import verify_password +from open_webui.utils.auth import verify_password log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MODELS"]) @@ -64,6 +64,11 @@ class SigninForm(BaseModel): password: str +class LdapForm(BaseModel): + user: str + password: str + + class ProfileImageUrlForm(BaseModel): profile_image_url: str diff --git a/backend/open_webui/apps/webui/models/chats.py b/backend/open_webui/models/chats.py similarity index 97% rename from backend/open_webui/apps/webui/models/chats.py rename to backend/open_webui/models/chats.py index f6a1e4548..3e621a150 100644 --- a/backend/open_webui/apps/webui/models/chats.py +++ b/backend/open_webui/models/chats.py @@ -3,8 +3,8 @@ import time import uuid from typing import Optional -from open_webui.apps.webui.internal.db import Base, get_db -from open_webui.apps.webui.models.tags import TagModel, Tag, Tags +from open_webui.internal.db import Base, get_db +from open_webui.models.tags import TagModel, Tag, Tags from pydantic import BaseModel, ConfigDict @@ -203,15 +203,22 @@ class ChatTable: def update_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]: try: with get_db() as db: - print("update_shared_chat_by_id") chat = db.get(Chat, chat_id) - print(chat) - chat.title = chat.title - chat.chat = chat.chat - db.commit() - db.refresh(chat) + shared_chat = ( + db.query(Chat).filter_by(user_id=f"shared-{chat_id}").first() + ) - return self.get_chat_by_id(chat.share_id) + if shared_chat is None: + return self.insert_shared_chat_by_chat_id(chat_id) + + shared_chat.title = chat.title + shared_chat.chat = chat.chat + + shared_chat.updated_at = int(time.time()) + db.commit() + db.refresh(shared_chat) + + return ChatModel.model_validate(shared_chat) except Exception: return None diff --git a/backend/open_webui/apps/webui/models/feedbacks.py b/backend/open_webui/models/feedbacks.py similarity index 98% rename from backend/open_webui/apps/webui/models/feedbacks.py rename to backend/open_webui/models/feedbacks.py index c2356dfd8..7ff5c4540 100644 --- a/backend/open_webui/apps/webui/models/feedbacks.py +++ b/backend/open_webui/models/feedbacks.py @@ -3,8 +3,8 @@ import time import uuid from typing import Optional -from open_webui.apps.webui.internal.db import Base, get_db -from open_webui.apps.webui.models.chats import Chats +from open_webui.internal.db import Base, get_db +from open_webui.models.chats import Chats from open_webui.env import SRC_LOG_LEVELS from pydantic import BaseModel, ConfigDict diff --git a/backend/open_webui/apps/webui/models/files.py b/backend/open_webui/models/files.py similarity index 98% rename from backend/open_webui/apps/webui/models/files.py rename to backend/open_webui/models/files.py index 31c9164b6..4050b0140 100644 --- a/backend/open_webui/apps/webui/models/files.py +++ b/backend/open_webui/models/files.py @@ -2,7 +2,7 @@ import logging import time from typing import Optional -from open_webui.apps.webui.internal.db import Base, JSONField, get_db +from open_webui.internal.db import Base, JSONField, get_db from open_webui.env import SRC_LOG_LEVELS from pydantic import BaseModel, ConfigDict from sqlalchemy import BigInteger, Column, String, Text, JSON diff --git a/backend/open_webui/apps/webui/models/folders.py b/backend/open_webui/models/folders.py similarity index 98% rename from backend/open_webui/apps/webui/models/folders.py rename to backend/open_webui/models/folders.py index 90e8880aa..040774196 100644 --- a/backend/open_webui/apps/webui/models/folders.py +++ b/backend/open_webui/models/folders.py @@ -3,8 +3,8 @@ import time import uuid from typing import Optional -from open_webui.apps.webui.internal.db import Base, get_db -from open_webui.apps.webui.models.chats import Chats +from open_webui.internal.db import Base, get_db +from open_webui.models.chats import Chats from open_webui.env import SRC_LOG_LEVELS from pydantic import BaseModel, ConfigDict diff --git a/backend/open_webui/apps/webui/models/functions.py b/backend/open_webui/models/functions.py similarity index 98% rename from backend/open_webui/apps/webui/models/functions.py rename to backend/open_webui/models/functions.py index fda155075..6c6aed862 100644 --- a/backend/open_webui/apps/webui/models/functions.py +++ b/backend/open_webui/models/functions.py @@ -2,8 +2,8 @@ import logging import time from typing import Optional -from open_webui.apps.webui.internal.db import Base, JSONField, get_db -from open_webui.apps.webui.models.users import Users +from open_webui.internal.db import Base, JSONField, get_db +from open_webui.models.users import Users from open_webui.env import SRC_LOG_LEVELS from pydantic import BaseModel, ConfigDict from sqlalchemy import BigInteger, Boolean, Column, String, Text diff --git a/backend/open_webui/models/groups.py b/backend/open_webui/models/groups.py new file mode 100644 index 000000000..8f0728411 --- /dev/null +++ b/backend/open_webui/models/groups.py @@ -0,0 +1,186 @@ +import json +import logging +import time +from typing import Optional +import uuid + +from open_webui.internal.db import Base, get_db +from open_webui.env import SRC_LOG_LEVELS + +from open_webui.models.files import FileMetadataResponse + + +from pydantic import BaseModel, ConfigDict +from sqlalchemy import BigInteger, Column, String, Text, JSON, func + + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["MODELS"]) + +#################### +# UserGroup DB Schema +#################### + + +class Group(Base): + __tablename__ = "group" + + id = Column(Text, unique=True, primary_key=True) + user_id = Column(Text) + + name = Column(Text) + description = Column(Text) + + data = Column(JSON, nullable=True) + meta = Column(JSON, nullable=True) + + permissions = Column(JSON, nullable=True) + user_ids = Column(JSON, nullable=True) + + created_at = Column(BigInteger) + updated_at = Column(BigInteger) + + +class GroupModel(BaseModel): + model_config = ConfigDict(from_attributes=True) + id: str + user_id: str + + name: str + description: str + + data: Optional[dict] = None + meta: Optional[dict] = None + + permissions: Optional[dict] = None + user_ids: list[str] = [] + + created_at: int # timestamp in epoch + updated_at: int # timestamp in epoch + + +#################### +# Forms +#################### + + +class GroupResponse(BaseModel): + id: str + user_id: str + name: str + description: str + permissions: Optional[dict] = None + data: Optional[dict] = None + meta: Optional[dict] = None + user_ids: list[str] = [] + created_at: int # timestamp in epoch + updated_at: int # timestamp in epoch + + +class GroupForm(BaseModel): + name: str + description: str + + +class GroupUpdateForm(GroupForm): + permissions: Optional[dict] = None + user_ids: Optional[list[str]] = None + admin_ids: Optional[list[str]] = None + + +class GroupTable: + def insert_new_group( + self, user_id: str, form_data: GroupForm + ) -> Optional[GroupModel]: + with get_db() as db: + group = GroupModel( + **{ + **form_data.model_dump(), + "id": str(uuid.uuid4()), + "user_id": user_id, + "created_at": int(time.time()), + "updated_at": int(time.time()), + } + ) + + try: + result = Group(**group.model_dump()) + db.add(result) + db.commit() + db.refresh(result) + if result: + return GroupModel.model_validate(result) + else: + return None + + except Exception: + return None + + def get_groups(self) -> list[GroupModel]: + with get_db() as db: + return [ + GroupModel.model_validate(group) + for group in db.query(Group).order_by(Group.updated_at.desc()).all() + ] + + def get_groups_by_member_id(self, user_id: str) -> list[GroupModel]: + with get_db() as db: + return [ + GroupModel.model_validate(group) + for group in db.query(Group) + .filter( + func.json_array_length(Group.user_ids) > 0 + ) # Ensure array exists + .filter( + Group.user_ids.cast(String).like(f'%"{user_id}"%') + ) # String-based check + .order_by(Group.updated_at.desc()) + .all() + ] + + def get_group_by_id(self, id: str) -> Optional[GroupModel]: + try: + with get_db() as db: + group = db.query(Group).filter_by(id=id).first() + return GroupModel.model_validate(group) if group else None + except Exception: + return None + + def update_group_by_id( + self, id: str, form_data: GroupUpdateForm, overwrite: bool = False + ) -> Optional[GroupModel]: + try: + with get_db() as db: + db.query(Group).filter_by(id=id).update( + { + **form_data.model_dump(exclude_none=True), + "updated_at": int(time.time()), + } + ) + db.commit() + return self.get_group_by_id(id=id) + except Exception as e: + log.exception(e) + return None + + def delete_group_by_id(self, id: str) -> bool: + try: + with get_db() as db: + db.query(Group).filter_by(id=id).delete() + db.commit() + return True + except Exception: + return False + + def delete_all_groups(self) -> bool: + with get_db() as db: + try: + db.query(Group).delete() + db.commit() + + return True + except Exception: + return False + + +Groups = GroupTable() diff --git a/backend/open_webui/apps/webui/models/knowledge.py b/backend/open_webui/models/knowledge.py similarity index 53% rename from backend/open_webui/apps/webui/models/knowledge.py rename to backend/open_webui/models/knowledge.py index 269ad8cc3..bed3d5542 100644 --- a/backend/open_webui/apps/webui/models/knowledge.py +++ b/backend/open_webui/models/knowledge.py @@ -4,15 +4,17 @@ import time from typing import Optional import uuid -from open_webui.apps.webui.internal.db import Base, get_db +from open_webui.internal.db import Base, get_db from open_webui.env import SRC_LOG_LEVELS -from open_webui.apps.webui.models.files import FileMetadataResponse +from open_webui.models.files import FileMetadataResponse +from open_webui.models.users import Users, UserResponse from pydantic import BaseModel, ConfigDict from sqlalchemy import BigInteger, Column, String, Text, JSON +from open_webui.utils.access_control import has_access log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MODELS"]) @@ -34,6 +36,23 @@ class Knowledge(Base): data = Column(JSON, nullable=True) meta = Column(JSON, nullable=True) + access_control = Column(JSON, nullable=True) # Controls data access levels. + # Defines access control rules for this entry. + # - `None`: Public access, available to all users with the "user" role. + # - `{}`: Private access, restricted exclusively to the owner. + # - Custom permissions: Specific access control for reading and writing; + # Can specify group or user-level restrictions: + # { + # "read": { + # "group_ids": ["group_id1", "group_id2"], + # "user_ids": ["user_id1", "user_id2"] + # }, + # "write": { + # "group_ids": ["group_id1", "group_id2"], + # "user_ids": ["user_id1", "user_id2"] + # } + # } + created_at = Column(BigInteger) updated_at = Column(BigInteger) @@ -50,6 +69,8 @@ class KnowledgeModel(BaseModel): data: Optional[dict] = None meta: Optional[dict] = None + access_control: Optional[dict] = None + created_at: int # timestamp in epoch updated_at: int # timestamp in epoch @@ -59,15 +80,15 @@ class KnowledgeModel(BaseModel): #################### -class KnowledgeResponse(BaseModel): - id: str - name: str - description: str - data: Optional[dict] = None - meta: Optional[dict] = None - created_at: int # timestamp in epoch - updated_at: int # timestamp in epoch +class KnowledgeUserModel(KnowledgeModel): + user: Optional[UserResponse] = None + +class KnowledgeResponse(KnowledgeModel): + files: Optional[list[FileMetadataResponse | dict]] = None + + +class KnowledgeUserResponse(KnowledgeUserModel): files: Optional[list[FileMetadataResponse | dict]] = None @@ -75,12 +96,7 @@ class KnowledgeForm(BaseModel): name: str description: str data: Optional[dict] = None - - -class KnowledgeUpdateForm(BaseModel): - name: Optional[str] = None - description: Optional[str] = None - data: Optional[dict] = None + access_control: Optional[dict] = None class KnowledgeTable: @@ -110,14 +126,33 @@ class KnowledgeTable: except Exception: return None - def get_knowledge_items(self) -> list[KnowledgeModel]: + def get_knowledge_bases(self) -> list[KnowledgeUserModel]: with get_db() as db: - return [ - KnowledgeModel.model_validate(knowledge) - for knowledge in db.query(Knowledge) - .order_by(Knowledge.updated_at.desc()) - .all() - ] + knowledge_bases = [] + for knowledge in ( + db.query(Knowledge).order_by(Knowledge.updated_at.desc()).all() + ): + user = Users.get_user_by_id(knowledge.user_id) + knowledge_bases.append( + KnowledgeUserModel.model_validate( + { + **KnowledgeModel.model_validate(knowledge).model_dump(), + "user": user.model_dump() if user else None, + } + ) + ) + return knowledge_bases + + def get_knowledge_bases_by_user_id( + self, user_id: str, permission: str = "write" + ) -> list[KnowledgeUserModel]: + knowledge_bases = self.get_knowledge_bases() + return [ + knowledge_base + for knowledge_base in knowledge_bases + if knowledge_base.user_id == user_id + or has_access(user_id, permission, knowledge_base.access_control) + ] def get_knowledge_by_id(self, id: str) -> Optional[KnowledgeModel]: try: @@ -128,14 +163,32 @@ class KnowledgeTable: return None def update_knowledge_by_id( - self, id: str, form_data: KnowledgeUpdateForm, overwrite: bool = False + self, id: str, form_data: KnowledgeForm, overwrite: bool = False ) -> Optional[KnowledgeModel]: try: with get_db() as db: knowledge = self.get_knowledge_by_id(id=id) db.query(Knowledge).filter_by(id=id).update( { - **form_data.model_dump(exclude_none=True), + **form_data.model_dump(), + "updated_at": int(time.time()), + } + ) + db.commit() + return self.get_knowledge_by_id(id=id) + except Exception as e: + log.exception(e) + return None + + def update_knowledge_data_by_id( + self, id: str, data: dict + ) -> Optional[KnowledgeModel]: + try: + with get_db() as db: + knowledge = self.get_knowledge_by_id(id=id) + db.query(Knowledge).filter_by(id=id).update( + { + "data": data, "updated_at": int(time.time()), } ) diff --git a/backend/open_webui/apps/webui/models/memories.py b/backend/open_webui/models/memories.py similarity index 98% rename from backend/open_webui/apps/webui/models/memories.py rename to backend/open_webui/models/memories.py index 6686058d3..c8dae9726 100644 --- a/backend/open_webui/apps/webui/models/memories.py +++ b/backend/open_webui/models/memories.py @@ -2,7 +2,7 @@ import time import uuid from typing import Optional -from open_webui.apps.webui.internal.db import Base, get_db +from open_webui.internal.db import Base, get_db from pydantic import BaseModel, ConfigDict from sqlalchemy import BigInteger, Column, String, Text diff --git a/backend/open_webui/apps/webui/models/models.py b/backend/open_webui/models/models.py similarity index 54% rename from backend/open_webui/apps/webui/models/models.py rename to backend/open_webui/models/models.py index 9bdffb9bc..f2f59d7c4 100644 --- a/backend/open_webui/apps/webui/models/models.py +++ b/backend/open_webui/models/models.py @@ -2,10 +2,21 @@ import logging import time from typing import Optional -from open_webui.apps.webui.internal.db import Base, JSONField, get_db +from open_webui.internal.db import Base, JSONField, get_db from open_webui.env import SRC_LOG_LEVELS + +from open_webui.models.users import Users, UserResponse + + from pydantic import BaseModel, ConfigDict -from sqlalchemy import BigInteger, Column, Text + +from sqlalchemy import or_, and_, func +from sqlalchemy.dialects import postgresql, sqlite +from sqlalchemy import BigInteger, Column, Text, JSON, Boolean + + +from open_webui.utils.access_control import has_access + log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MODELS"]) @@ -67,6 +78,25 @@ class Model(Base): Holds a JSON encoded blob of metadata, see `ModelMeta`. """ + access_control = Column(JSON, nullable=True) # Controls data access levels. + # Defines access control rules for this entry. + # - `None`: Public access, available to all users with the "user" role. + # - `{}`: Private access, restricted exclusively to the owner. + # - Custom permissions: Specific access control for reading and writing; + # Can specify group or user-level restrictions: + # { + # "read": { + # "group_ids": ["group_id1", "group_id2"], + # "user_ids": ["user_id1", "user_id2"] + # }, + # "write": { + # "group_ids": ["group_id1", "group_id2"], + # "user_ids": ["user_id1", "user_id2"] + # } + # } + + is_active = Column(Boolean, default=True) + updated_at = Column(BigInteger) created_at = Column(BigInteger) @@ -80,6 +110,9 @@ class ModelModel(BaseModel): params: ModelParams meta: ModelMeta + access_control: Optional[dict] = None + + is_active: bool updated_at: int # timestamp in epoch created_at: int # timestamp in epoch @@ -91,12 +124,12 @@ class ModelModel(BaseModel): #################### -class ModelResponse(BaseModel): - id: str - name: str - meta: ModelMeta - updated_at: int # timestamp in epoch - created_at: int # timestamp in epoch +class ModelUserResponse(ModelModel): + user: Optional[UserResponse] = None + + +class ModelResponse(ModelModel): + pass class ModelForm(BaseModel): @@ -105,6 +138,8 @@ class ModelForm(BaseModel): name: str meta: ModelMeta params: ModelParams + access_control: Optional[dict] = None + is_active: bool = True class ModelsTable: @@ -138,6 +173,39 @@ class ModelsTable: with get_db() as db: return [ModelModel.model_validate(model) for model in db.query(Model).all()] + def get_models(self) -> list[ModelUserResponse]: + with get_db() as db: + models = [] + for model in db.query(Model).filter(Model.base_model_id != None).all(): + user = Users.get_user_by_id(model.user_id) + models.append( + ModelUserResponse.model_validate( + { + **ModelModel.model_validate(model).model_dump(), + "user": user.model_dump() if user else None, + } + ) + ) + return models + + def get_base_models(self) -> list[ModelModel]: + with get_db() as db: + return [ + ModelModel.model_validate(model) + for model in db.query(Model).filter(Model.base_model_id == None).all() + ] + + def get_models_by_user_id( + self, user_id: str, permission: str = "write" + ) -> list[ModelUserResponse]: + models = self.get_models() + return [ + model + for model in models + if model.user_id == user_id + or has_access(user_id, permission, model.access_control) + ] + def get_model_by_id(self, id: str) -> Optional[ModelModel]: try: with get_db() as db: @@ -146,6 +214,23 @@ class ModelsTable: except Exception: return None + def toggle_model_by_id(self, id: str) -> Optional[ModelModel]: + with get_db() as db: + try: + is_active = db.query(Model).filter_by(id=id).first().is_active + + db.query(Model).filter_by(id=id).update( + { + "is_active": not is_active, + "updated_at": int(time.time()), + } + ) + db.commit() + + return self.get_model_by_id(id) + except Exception: + return None + def update_model_by_id(self, id: str, model: ModelForm) -> Optional[ModelModel]: try: with get_db() as db: @@ -153,7 +238,7 @@ class ModelsTable: result = ( db.query(Model) .filter_by(id=id) - .update(model.model_dump(exclude={"id"}, exclude_none=True)) + .update(model.model_dump(exclude={"id"})) ) db.commit() @@ -175,5 +260,15 @@ class ModelsTable: except Exception: return False + def delete_all_models(self) -> bool: + try: + with get_db() as db: + db.query(Model).delete() + db.commit() + + return True + except Exception: + return False + Models = ModelsTable() diff --git a/backend/open_webui/apps/webui/models/prompts.py b/backend/open_webui/models/prompts.py similarity index 53% rename from backend/open_webui/apps/webui/models/prompts.py rename to backend/open_webui/models/prompts.py index 6b98e5c53..8ef4cd2be 100644 --- a/backend/open_webui/apps/webui/models/prompts.py +++ b/backend/open_webui/models/prompts.py @@ -1,9 +1,13 @@ import time from typing import Optional -from open_webui.apps.webui.internal.db import Base, get_db +from open_webui.internal.db import Base, get_db +from open_webui.models.users import Users, UserResponse + from pydantic import BaseModel, ConfigDict -from sqlalchemy import BigInteger, Column, String, Text +from sqlalchemy import BigInteger, Column, String, Text, JSON + +from open_webui.utils.access_control import has_access #################### # Prompts DB Schema @@ -19,6 +23,23 @@ class Prompt(Base): content = Column(Text) timestamp = Column(BigInteger) + access_control = Column(JSON, nullable=True) # Controls data access levels. + # Defines access control rules for this entry. + # - `None`: Public access, available to all users with the "user" role. + # - `{}`: Private access, restricted exclusively to the owner. + # - Custom permissions: Specific access control for reading and writing; + # Can specify group or user-level restrictions: + # { + # "read": { + # "group_ids": ["group_id1", "group_id2"], + # "user_ids": ["user_id1", "user_id2"] + # }, + # "write": { + # "group_ids": ["group_id1", "group_id2"], + # "user_ids": ["user_id1", "user_id2"] + # } + # } + class PromptModel(BaseModel): command: str @@ -27,6 +48,7 @@ class PromptModel(BaseModel): content: str timestamp: int # timestamp in epoch + access_control: Optional[dict] = None model_config = ConfigDict(from_attributes=True) @@ -35,10 +57,15 @@ class PromptModel(BaseModel): #################### +class PromptUserResponse(PromptModel): + user: Optional[UserResponse] = None + + class PromptForm(BaseModel): command: str title: str content: str + access_control: Optional[dict] = None class PromptsTable: @@ -48,16 +75,14 @@ class PromptsTable: prompt = PromptModel( **{ "user_id": user_id, - "command": form_data.command, - "title": form_data.title, - "content": form_data.content, + **form_data.model_dump(), "timestamp": int(time.time()), } ) try: with get_db() as db: - result = Prompt(**prompt.dict()) + result = Prompt(**prompt.model_dump()) db.add(result) db.commit() db.refresh(result) @@ -76,11 +101,34 @@ class PromptsTable: except Exception: return None - def get_prompts(self) -> list[PromptModel]: + def get_prompts(self) -> list[PromptUserResponse]: with get_db() as db: - return [ - PromptModel.model_validate(prompt) for prompt in db.query(Prompt).all() - ] + prompts = [] + + for prompt in db.query(Prompt).order_by(Prompt.timestamp.desc()).all(): + user = Users.get_user_by_id(prompt.user_id) + prompts.append( + PromptUserResponse.model_validate( + { + **PromptModel.model_validate(prompt).model_dump(), + "user": user.model_dump() if user else None, + } + ) + ) + + return prompts + + def get_prompts_by_user_id( + self, user_id: str, permission: str = "write" + ) -> list[PromptUserResponse]: + prompts = self.get_prompts() + + return [ + prompt + for prompt in prompts + if prompt.user_id == user_id + or has_access(user_id, permission, prompt.access_control) + ] def update_prompt_by_command( self, command: str, form_data: PromptForm @@ -90,6 +138,7 @@ class PromptsTable: prompt = db.query(Prompt).filter_by(command=command).first() prompt.title = form_data.title prompt.content = form_data.content + prompt.access_control = form_data.access_control prompt.timestamp = int(time.time()) db.commit() return PromptModel.model_validate(prompt) diff --git a/backend/open_webui/apps/webui/models/tags.py b/backend/open_webui/models/tags.py similarity index 98% rename from backend/open_webui/apps/webui/models/tags.py rename to backend/open_webui/models/tags.py index 7424a2660..3e812db95 100644 --- a/backend/open_webui/apps/webui/models/tags.py +++ b/backend/open_webui/models/tags.py @@ -3,7 +3,7 @@ import time import uuid from typing import Optional -from open_webui.apps.webui.internal.db import Base, get_db +from open_webui.internal.db import Base, get_db from open_webui.env import SRC_LOG_LEVELS diff --git a/backend/open_webui/apps/webui/models/tools.py b/backend/open_webui/models/tools.py similarity index 72% rename from backend/open_webui/apps/webui/models/tools.py rename to backend/open_webui/models/tools.py index e06f83452..a5f13ebb7 100644 --- a/backend/open_webui/apps/webui/models/tools.py +++ b/backend/open_webui/models/tools.py @@ -2,11 +2,14 @@ import logging import time from typing import Optional -from open_webui.apps.webui.internal.db import Base, JSONField, get_db -from open_webui.apps.webui.models.users import Users +from open_webui.internal.db import Base, JSONField, get_db +from open_webui.models.users import Users, UserResponse from open_webui.env import SRC_LOG_LEVELS from pydantic import BaseModel, ConfigDict -from sqlalchemy import BigInteger, Column, String, Text +from sqlalchemy import BigInteger, Column, String, Text, JSON + +from open_webui.utils.access_control import has_access + log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MODELS"]) @@ -26,6 +29,24 @@ class Tool(Base): specs = Column(JSONField) meta = Column(JSONField) valves = Column(JSONField) + + access_control = Column(JSON, nullable=True) # Controls data access levels. + # Defines access control rules for this entry. + # - `None`: Public access, available to all users with the "user" role. + # - `{}`: Private access, restricted exclusively to the owner. + # - Custom permissions: Specific access control for reading and writing; + # Can specify group or user-level restrictions: + # { + # "read": { + # "group_ids": ["group_id1", "group_id2"], + # "user_ids": ["user_id1", "user_id2"] + # }, + # "write": { + # "group_ids": ["group_id1", "group_id2"], + # "user_ids": ["user_id1", "user_id2"] + # } + # } + updated_at = Column(BigInteger) created_at = Column(BigInteger) @@ -42,6 +63,8 @@ class ToolModel(BaseModel): content: str specs: list[dict] meta: ToolMeta + access_control: Optional[dict] = None + updated_at: int # timestamp in epoch created_at: int # timestamp in epoch @@ -53,20 +76,30 @@ class ToolModel(BaseModel): #################### +class ToolUserModel(ToolModel): + user: Optional[UserResponse] = None + + class ToolResponse(BaseModel): id: str user_id: str name: str meta: ToolMeta + access_control: Optional[dict] = None updated_at: int # timestamp in epoch created_at: int # timestamp in epoch +class ToolUserResponse(ToolResponse): + user: Optional[UserResponse] = None + + class ToolForm(BaseModel): id: str name: str content: str meta: ToolMeta + access_control: Optional[dict] = None class ToolValves(BaseModel): @@ -109,9 +142,32 @@ class ToolsTable: except Exception: return None - def get_tools(self) -> list[ToolModel]: + def get_tools(self) -> list[ToolUserModel]: with get_db() as db: - return [ToolModel.model_validate(tool) for tool in db.query(Tool).all()] + tools = [] + for tool in db.query(Tool).order_by(Tool.updated_at.desc()).all(): + user = Users.get_user_by_id(tool.user_id) + tools.append( + ToolUserModel.model_validate( + { + **ToolModel.model_validate(tool).model_dump(), + "user": user.model_dump() if user else None, + } + ) + ) + return tools + + def get_tools_by_user_id( + self, user_id: str, permission: str = "write" + ) -> list[ToolUserModel]: + tools = self.get_tools() + + return [ + tool + for tool in tools + if tool.user_id == user_id + or has_access(user_id, permission, tool.access_control) + ] def get_tool_valves_by_id(self, id: str) -> Optional[dict]: try: diff --git a/backend/open_webui/apps/webui/models/users.py b/backend/open_webui/models/users.py similarity index 97% rename from backend/open_webui/apps/webui/models/users.py rename to backend/open_webui/models/users.py index 328618a67..5b6c27214 100644 --- a/backend/open_webui/apps/webui/models/users.py +++ b/backend/open_webui/models/users.py @@ -1,8 +1,8 @@ import time from typing import Optional -from open_webui.apps.webui.internal.db import Base, JSONField, get_db -from open_webui.apps.webui.models.chats import Chats +from open_webui.internal.db import Base, JSONField, get_db +from open_webui.models.chats import Chats from pydantic import BaseModel, ConfigDict from sqlalchemy import BigInteger, Column, String, Text @@ -62,6 +62,14 @@ class UserModel(BaseModel): #################### +class UserResponse(BaseModel): + id: str + name: str + email: str + role: str + profile_image_url: str + + class UserRoleUpdateForm(BaseModel): id: str role: str diff --git a/backend/open_webui/apps/retrieval/loaders/main.py b/backend/open_webui/retrieval/loaders/main.py similarity index 95% rename from backend/open_webui/apps/retrieval/loaders/main.py rename to backend/open_webui/retrieval/loaders/main.py index f0e8f804e..a9372f65a 100644 --- a/backend/open_webui/apps/retrieval/loaders/main.py +++ b/backend/open_webui/retrieval/loaders/main.py @@ -1,6 +1,7 @@ import requests import logging import ftfy +import sys from langchain_community.document_loaders import ( BSHTMLLoader, @@ -18,8 +19,9 @@ from langchain_community.document_loaders import ( YoutubeLoader, ) from langchain_core.documents import Document -from open_webui.env import SRC_LOG_LEVELS +from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL +logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL) log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["RAG"]) @@ -106,7 +108,7 @@ class TikaLoader: if "Content-Type" in raw_metadata: headers["Content-Type"] = raw_metadata["Content-Type"] - log.info("Tika extracted text: %s", text) + log.debug("Tika extracted text: %s", text) return [Document(page_content=text, metadata=headers)] else: @@ -159,7 +161,7 @@ class Loader: elif file_ext in ["htm", "html"]: loader = BSHTMLLoader(file_path, open_encoding="unicode_escape") elif file_ext == "md": - loader = UnstructuredMarkdownLoader(file_path) + loader = TextLoader(file_path, autodetect_encoding=True) elif file_content_type == "application/epub+zip": loader = UnstructuredEPubLoader(file_path) elif ( diff --git a/backend/open_webui/retrieval/loaders/youtube.py b/backend/open_webui/retrieval/loaders/youtube.py new file mode 100644 index 000000000..8eb48488b --- /dev/null +++ b/backend/open_webui/retrieval/loaders/youtube.py @@ -0,0 +1,117 @@ +import logging + +from typing import Any, Dict, Generator, List, Optional, Sequence, Union +from urllib.parse import parse_qs, urlparse +from langchain_core.documents import Document +from open_webui.env import SRC_LOG_LEVELS + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["RAG"]) + +ALLOWED_SCHEMES = {"http", "https"} +ALLOWED_NETLOCS = { + "youtu.be", + "m.youtube.com", + "youtube.com", + "www.youtube.com", + "www.youtube-nocookie.com", + "vid.plus", +} + + +def _parse_video_id(url: str) -> Optional[str]: + """Parse a YouTube URL and return the video ID if valid, otherwise None.""" + parsed_url = urlparse(url) + + if parsed_url.scheme not in ALLOWED_SCHEMES: + return None + + if parsed_url.netloc not in ALLOWED_NETLOCS: + return None + + path = parsed_url.path + + if path.endswith("/watch"): + query = parsed_url.query + parsed_query = parse_qs(query) + if "v" in parsed_query: + ids = parsed_query["v"] + video_id = ids if isinstance(ids, str) else ids[0] + else: + return None + else: + path = parsed_url.path.lstrip("/") + video_id = path.split("/")[-1] + + if len(video_id) != 11: # Video IDs are 11 characters long + return None + + return video_id + + +class YoutubeLoader: + """Load `YouTube` video transcripts.""" + + def __init__( + self, + video_id: str, + language: Union[str, Sequence[str]] = "en", + proxy_url: Optional[str] = None, + ): + """Initialize with YouTube video ID.""" + _video_id = _parse_video_id(video_id) + self.video_id = _video_id if _video_id is not None else video_id + self._metadata = {"source": video_id} + self.language = language + self.proxy_url = proxy_url + if isinstance(language, str): + self.language = [language] + else: + self.language = language + + def load(self) -> List[Document]: + """Load YouTube transcripts into `Document` objects.""" + try: + from youtube_transcript_api import ( + NoTranscriptFound, + TranscriptsDisabled, + YouTubeTranscriptApi, + ) + except ImportError: + raise ImportError( + 'Could not import "youtube_transcript_api" Python package. ' + "Please install it with `pip install youtube-transcript-api`." + ) + + if self.proxy_url: + youtube_proxies = { + "http": self.proxy_url, + "https": self.proxy_url, + } + # Don't log complete URL because it might contain secrets + log.debug(f"Using proxy URL: {self.proxy_url[:14]}...") + else: + youtube_proxies = None + + try: + transcript_list = YouTubeTranscriptApi.list_transcripts( + self.video_id, proxies=youtube_proxies + ) + except Exception as e: + log.exception("Loading YouTube transcript failed") + return [] + + try: + transcript = transcript_list.find_transcript(self.language) + except NoTranscriptFound: + transcript = transcript_list.find_transcript(["en"]) + + transcript_pieces: List[Dict[str, Any]] = transcript.fetch() + + transcript = " ".join( + map( + lambda transcript_piece: transcript_piece["text"].strip(" "), + transcript_pieces, + ) + ) + return [Document(page_content=transcript, metadata=self._metadata)] diff --git a/backend/open_webui/apps/retrieval/models/colbert.py b/backend/open_webui/retrieval/models/colbert.py similarity index 100% rename from backend/open_webui/apps/retrieval/models/colbert.py rename to backend/open_webui/retrieval/models/colbert.py diff --git a/backend/open_webui/apps/retrieval/utils.py b/backend/open_webui/retrieval/utils.py similarity index 76% rename from backend/open_webui/apps/retrieval/utils.py rename to backend/open_webui/retrieval/utils.py index 153bd804f..9444ade95 100644 --- a/backend/open_webui/apps/retrieval/utils.py +++ b/backend/open_webui/retrieval/utils.py @@ -3,6 +3,7 @@ import os import uuid from typing import Optional, Union +import asyncio import requests from huggingface_hub import snapshot_download @@ -10,17 +11,10 @@ from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriev from langchain_community.retrievers import BM25Retriever from langchain_core.documents import Document - -from open_webui.apps.ollama.main import ( - GenerateEmbedForm, - generate_ollama_batch_embeddings, -) -from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT +from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT from open_webui.utils.misc import get_last_user_message from open_webui.env import SRC_LOG_LEVELS -from open_webui.config import DEFAULT_RAG_TEMPLATE - log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["RAG"]) @@ -76,7 +70,7 @@ def query_doc( limit=k, ) - log.info(f"query_doc:result {result}") + log.info(f"query_doc:result {result.ids} {result.metadatas}") return result except Exception as e: print(e) @@ -127,7 +121,10 @@ def query_doc_with_hybrid_search( "metadatas": [[d.metadata for d in result]], } - log.info(f"query_doc_with_hybrid_search:result {result}") + log.info( + "query_doc_with_hybrid_search:result " + + f'{result["metadatas"]} {result["distances"]}' + ) return result except Exception as e: raise e @@ -178,35 +175,34 @@ def merge_and_sort_query_results( def query_collection( collection_names: list[str], - query: str, + queries: list[str], embedding_function, k: int, ) -> dict: - results = [] - query_embedding = embedding_function(query) - - for collection_name in collection_names: - if collection_name: - try: - result = query_doc( - collection_name=collection_name, - k=k, - query_embedding=query_embedding, - ) - if result is not None: - results.append(result.model_dump()) - except Exception as e: - log.exception(f"Error when querying the collection: {e}") - else: - pass + for query in queries: + query_embedding = embedding_function(query) + for collection_name in collection_names: + if collection_name: + try: + result = query_doc( + collection_name=collection_name, + k=k, + query_embedding=query_embedding, + ) + if result is not None: + results.append(result.model_dump()) + except Exception as e: + log.exception(f"Error when querying the collection: {e}") + else: + pass return merge_and_sort_query_results(results, k=k) def query_collection_with_hybrid_search( collection_names: list[str], - query: str, + queries: list[str], embedding_function, k: int, reranking_function, @@ -216,15 +212,16 @@ def query_collection_with_hybrid_search( error = False for collection_name in collection_names: try: - result = query_doc_with_hybrid_search( - collection_name=collection_name, - query=query, - embedding_function=embedding_function, - k=k, - reranking_function=reranking_function, - r=r, - ) - results.append(result) + for query in queries: + result = query_doc_with_hybrid_search( + collection_name=collection_name, + query=query, + embedding_function=embedding_function, + k=k, + reranking_function=reranking_function, + r=r, + ) + results.append(result) except Exception as e: log.exception( "Error when querying the collection with " f"hybrid_search: {e}" @@ -239,50 +236,12 @@ def query_collection_with_hybrid_search( return merge_and_sort_query_results(results, k=k, reverse=True) -def rag_template(template: str, context: str, query: str): - if template == "": - template = DEFAULT_RAG_TEMPLATE - - if "[context]" not in template and "{{CONTEXT}}" not in template: - log.debug( - "WARNING: The RAG template does not contain the '[context]' or '{{CONTEXT}}' placeholder." - ) - - if "" in context and "" in context: - log.debug( - "WARNING: Potential prompt injection attack: the RAG " - "context contains '' and ''. This might be " - "nothing, or the user might be trying to hack something." - ) - - query_placeholders = [] - if "[query]" in context: - query_placeholder = "{{QUERY" + str(uuid.uuid4()) + "}}" - template = template.replace("[query]", query_placeholder) - query_placeholders.append(query_placeholder) - - if "{{QUERY}}" in context: - query_placeholder = "{{QUERY" + str(uuid.uuid4()) + "}}" - template = template.replace("{{QUERY}}", query_placeholder) - query_placeholders.append(query_placeholder) - - template = template.replace("[context]", context) - template = template.replace("{{CONTEXT}}", context) - template = template.replace("[query]", query) - template = template.replace("{{QUERY}}", query) - - for query_placeholder in query_placeholders: - template = template.replace(query_placeholder, query) - - return template - - def get_embedding_function( embedding_engine, embedding_model, embedding_function, - openai_key, - openai_url, + url, + key, embedding_batch_size, ): if embedding_engine == "": @@ -292,8 +251,8 @@ def get_embedding_function( engine=embedding_engine, model=embedding_model, text=query, - key=openai_key if embedding_engine == "openai" else "", - url=openai_url if embedding_engine == "openai" else "", + url=url, + key=key, ) def generate_multiple(query, func): @@ -308,17 +267,16 @@ def get_embedding_function( return lambda query: generate_multiple(query, func) -def get_rag_context( +def get_sources_from_files( files, - messages, + queries, embedding_function, k, reranking_function, r, hybrid_search, ): - log.debug(f"files: {files} {messages} {embedding_function} {reranking_function}") - query = get_last_user_message(messages) + log.debug(f"files: {files} {queries} {embedding_function} {reranking_function}") extracted_collections = [] relevant_contexts = [] @@ -360,7 +318,7 @@ def get_rag_context( try: context = query_collection_with_hybrid_search( collection_names=collection_names, - query=query, + queries=queries, embedding_function=embedding_function, k=k, reranking_function=reranking_function, @@ -375,7 +333,7 @@ def get_rag_context( if (not hybrid_search) or (context is None): context = query_collection( collection_names=collection_names, - query=query, + queries=queries, embedding_function=embedding_function, k=k, ) @@ -389,43 +347,24 @@ def get_rag_context( del file["data"] relevant_contexts.append({**context, "file": file}) - contexts = [] - citations = [] + sources = [] for context in relevant_contexts: try: if "documents" in context: - file_names = list( - set( - [ - metadata["name"] - for metadata in context["metadatas"][0] - if metadata is not None and "name" in metadata - ] - ) - ) - contexts.append( - ((", ".join(file_names) + ":\n\n") if file_names else "") - + "\n\n".join( - [text for text in context["documents"][0] if text is not None] - ) - ) - if "metadatas" in context: - citation = { + source = { "source": context["file"], "document": context["documents"][0], "metadata": context["metadatas"][0], } if "distances" in context and context["distances"]: - citation["distances"] = context["distances"][0] - citations.append(citation) + source["distances"] = context["distances"][0] + + sources.append(source) except Exception as e: log.exception(e) - print("contexts", contexts) - print("citations", citations) - - return contexts, citations + return sources def get_model_path(model: str, update_model: bool = False): @@ -467,7 +406,7 @@ def get_model_path(model: str, update_model: bool = False): def generate_openai_batch_embeddings( - model: str, texts: list[str], key: str, url: str = "https://api.openai.com/v1" + model: str, texts: list[str], url: str = "https://api.openai.com/v1", key: str = "" ) -> Optional[list[list[float]]]: try: r = requests.post( @@ -489,29 +428,49 @@ def generate_openai_batch_embeddings( return None +def generate_ollama_batch_embeddings( + model: str, texts: list[str], url: str, key: str = "" +) -> Optional[list[list[float]]]: + try: + r = requests.post( + f"{url}/api/embed", + headers={ + "Content-Type": "application/json", + "Authorization": f"Bearer {key}", + }, + json={"input": texts, "model": model}, + ) + r.raise_for_status() + data = r.json() + + if "embeddings" in data: + return data["embeddings"] + else: + raise "Something went wrong :/" + except Exception as e: + print(e) + return None + + def generate_embeddings(engine: str, model: str, text: Union[str, list[str]], **kwargs): + url = kwargs.get("url", "") + key = kwargs.get("key", "") + if engine == "ollama": if isinstance(text, list): embeddings = generate_ollama_batch_embeddings( - GenerateEmbedForm(**{"model": model, "input": text}) + **{"model": model, "texts": text, "url": url, "key": key} ) else: embeddings = generate_ollama_batch_embeddings( - GenerateEmbedForm(**{"model": model, "input": [text]}) + **{"model": model, "texts": [text], "url": url, "key": key} ) - return ( - embeddings["embeddings"][0] - if isinstance(text, str) - else embeddings["embeddings"] - ) + return embeddings[0] if isinstance(text, str) else embeddings elif engine == "openai": - key = kwargs.get("key", "") - url = kwargs.get("url", "https://api.openai.com/v1") - if isinstance(text, list): - embeddings = generate_openai_batch_embeddings(model, text, key, url) + embeddings = generate_openai_batch_embeddings(model, text, url, key) else: - embeddings = generate_openai_batch_embeddings(model, [text], key, url) + embeddings = generate_openai_batch_embeddings(model, [text], url, key) return embeddings[0] if isinstance(text, str) else embeddings diff --git a/backend/open_webui/retrieval/vector/connector.py b/backend/open_webui/retrieval/vector/connector.py new file mode 100644 index 000000000..bf97bc7b1 --- /dev/null +++ b/backend/open_webui/retrieval/vector/connector.py @@ -0,0 +1,22 @@ +from open_webui.config import VECTOR_DB + +if VECTOR_DB == "milvus": + from open_webui.retrieval.vector.dbs.milvus import MilvusClient + + VECTOR_DB_CLIENT = MilvusClient() +elif VECTOR_DB == "qdrant": + from open_webui.retrieval.vector.dbs.qdrant import QdrantClient + + VECTOR_DB_CLIENT = QdrantClient() +elif VECTOR_DB == "opensearch": + from open_webui.retrieval.vector.dbs.opensearch import OpenSearchClient + + VECTOR_DB_CLIENT = OpenSearchClient() +elif VECTOR_DB == "pgvector": + from open_webui.retrieval.vector.dbs.pgvector import PgvectorClient + + VECTOR_DB_CLIENT = PgvectorClient() +else: + from open_webui.retrieval.vector.dbs.chroma import ChromaClient + + VECTOR_DB_CLIENT = ChromaClient() diff --git a/backend/open_webui/apps/retrieval/vector/dbs/chroma.py b/backend/open_webui/retrieval/vector/dbs/chroma.py similarity index 89% rename from backend/open_webui/apps/retrieval/vector/dbs/chroma.py rename to backend/open_webui/retrieval/vector/dbs/chroma.py index 7782671a2..00d73a889 100644 --- a/backend/open_webui/apps/retrieval/vector/dbs/chroma.py +++ b/backend/open_webui/retrieval/vector/dbs/chroma.py @@ -4,7 +4,7 @@ from chromadb.utils.batch_utils import create_batches from typing import Optional -from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult +from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult from open_webui.config import ( CHROMA_DATA_PATH, CHROMA_HTTP_HOST, @@ -13,11 +13,24 @@ from open_webui.config import ( CHROMA_HTTP_SSL, CHROMA_TENANT, CHROMA_DATABASE, + CHROMA_CLIENT_AUTH_PROVIDER, + CHROMA_CLIENT_AUTH_CREDENTIALS, ) class ChromaClient: def __init__(self): + settings_dict = { + "allow_reset": True, + "anonymized_telemetry": False, + } + if CHROMA_CLIENT_AUTH_PROVIDER is not None: + settings_dict["chroma_client_auth_provider"] = CHROMA_CLIENT_AUTH_PROVIDER + if CHROMA_CLIENT_AUTH_CREDENTIALS is not None: + settings_dict["chroma_client_auth_credentials"] = ( + CHROMA_CLIENT_AUTH_CREDENTIALS + ) + if CHROMA_HTTP_HOST != "": self.client = chromadb.HttpClient( host=CHROMA_HTTP_HOST, @@ -26,12 +39,12 @@ class ChromaClient: ssl=CHROMA_HTTP_SSL, tenant=CHROMA_TENANT, database=CHROMA_DATABASE, - settings=Settings(allow_reset=True, anonymized_telemetry=False), + settings=Settings(**settings_dict), ) else: self.client = chromadb.PersistentClient( path=CHROMA_DATA_PATH, - settings=Settings(allow_reset=True, anonymized_telemetry=False), + settings=Settings(**settings_dict), tenant=CHROMA_TENANT, database=CHROMA_DATABASE, ) diff --git a/backend/open_webui/apps/retrieval/vector/dbs/milvus.py b/backend/open_webui/retrieval/vector/dbs/milvus.py similarity index 99% rename from backend/open_webui/apps/retrieval/vector/dbs/milvus.py rename to backend/open_webui/retrieval/vector/dbs/milvus.py index 5351f860e..31d890664 100644 --- a/backend/open_webui/apps/retrieval/vector/dbs/milvus.py +++ b/backend/open_webui/retrieval/vector/dbs/milvus.py @@ -4,7 +4,7 @@ import json from typing import Optional -from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult +from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult from open_webui.config import ( MILVUS_URI, ) diff --git a/backend/open_webui/retrieval/vector/dbs/opensearch.py b/backend/open_webui/retrieval/vector/dbs/opensearch.py new file mode 100644 index 000000000..b3d8b5eb8 --- /dev/null +++ b/backend/open_webui/retrieval/vector/dbs/opensearch.py @@ -0,0 +1,178 @@ +from opensearchpy import OpenSearch +from typing import Optional + +from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult +from open_webui.config import ( + OPENSEARCH_URI, + OPENSEARCH_SSL, + OPENSEARCH_CERT_VERIFY, + OPENSEARCH_USERNAME, + OPENSEARCH_PASSWORD, +) + + +class OpenSearchClient: + def __init__(self): + self.index_prefix = "open_webui" + self.client = OpenSearch( + hosts=[OPENSEARCH_URI], + use_ssl=OPENSEARCH_SSL, + verify_certs=OPENSEARCH_CERT_VERIFY, + http_auth=(OPENSEARCH_USERNAME, OPENSEARCH_PASSWORD), + ) + + def _result_to_get_result(self, result) -> GetResult: + ids = [] + documents = [] + metadatas = [] + + for hit in result["hits"]["hits"]: + ids.append(hit["_id"]) + documents.append(hit["_source"].get("text")) + metadatas.append(hit["_source"].get("metadata")) + + return GetResult(ids=ids, documents=documents, metadatas=metadatas) + + def _result_to_search_result(self, result) -> SearchResult: + ids = [] + distances = [] + documents = [] + metadatas = [] + + for hit in result["hits"]["hits"]: + ids.append(hit["_id"]) + distances.append(hit["_score"]) + documents.append(hit["_source"].get("text")) + metadatas.append(hit["_source"].get("metadata")) + + return SearchResult( + ids=ids, distances=distances, documents=documents, metadatas=metadatas + ) + + def _create_index(self, index_name: str, dimension: int): + body = { + "mappings": { + "properties": { + "id": {"type": "keyword"}, + "vector": { + "type": "dense_vector", + "dims": dimension, # Adjust based on your vector dimensions + "index": true, + "similarity": "faiss", + "method": { + "name": "hnsw", + "space_type": "ip", # Use inner product to approximate cosine similarity + "engine": "faiss", + "ef_construction": 128, + "m": 16, + }, + }, + "text": {"type": "text"}, + "metadata": {"type": "object"}, + } + } + } + self.client.indices.create(index=f"{self.index_prefix}_{index_name}", body=body) + + def _create_batches(self, items: list[VectorItem], batch_size=100): + for i in range(0, len(items), batch_size): + yield items[i : i + batch_size] + + def has_collection(self, index_name: str) -> bool: + # has_collection here means has index. + # We are simply adapting to the norms of the other DBs. + return self.client.indices.exists(index=f"{self.index_prefix}_{index_name}") + + def delete_colleciton(self, index_name: str): + # delete_collection here means delete index. + # We are simply adapting to the norms of the other DBs. + self.client.indices.delete(index=f"{self.index_prefix}_{index_name}") + + def search( + self, index_name: str, vectors: list[list[float]], limit: int + ) -> Optional[SearchResult]: + query = { + "size": limit, + "_source": ["text", "metadata"], + "query": { + "script_score": { + "query": {"match_all": {}}, + "script": { + "source": "cosineSimilarity(params.vector, 'vector') + 1.0", + "params": { + "vector": vectors[0] + }, # Assuming single query vector + }, + } + }, + } + + result = self.client.search( + index=f"{self.index_prefix}_{index_name}", body=query + ) + + return self._result_to_search_result(result) + + def get_or_create_index(self, index_name: str, dimension: int): + if not self.has_index(index_name): + self._create_index(index_name, dimension) + + def get(self, index_name: str) -> Optional[GetResult]: + query = {"query": {"match_all": {}}, "_source": ["text", "metadata"]} + + result = self.client.search( + index=f"{self.index_prefix}_{index_name}", body=query + ) + return self._result_to_get_result(result) + + def insert(self, index_name: str, items: list[VectorItem]): + if not self.has_index(index_name): + self._create_index(index_name, dimension=len(items[0]["vector"])) + + for batch in self._create_batches(items): + actions = [ + { + "index": { + "_id": item["id"], + "_source": { + "vector": item["vector"], + "text": item["text"], + "metadata": item["metadata"], + }, + } + } + for item in batch + ] + self.client.bulk(actions) + + def upsert(self, index_name: str, items: list[VectorItem]): + if not self.has_index(index_name): + self._create_index(index_name, dimension=len(items[0]["vector"])) + + for batch in self._create_batches(items): + actions = [ + { + "index": { + "_id": item["id"], + "_source": { + "vector": item["vector"], + "text": item["text"], + "metadata": item["metadata"], + }, + } + } + for item in batch + ] + self.client.bulk(actions) + + def delete(self, index_name: str, ids: list[str]): + actions = [ + {"delete": {"_index": f"{self.index_prefix}_{index_name}", "_id": id}} + for id in ids + ] + self.client.bulk(body=actions) + + def reset(self): + indices = self.client.indices.get(index=f"{self.index_prefix}_*") + for index in indices: + self.client.indices.delete(index=index) diff --git a/backend/open_webui/retrieval/vector/dbs/pgvector.py b/backend/open_webui/retrieval/vector/dbs/pgvector.py new file mode 100644 index 000000000..cb8c545e9 --- /dev/null +++ b/backend/open_webui/retrieval/vector/dbs/pgvector.py @@ -0,0 +1,354 @@ +from typing import Optional, List, Dict, Any +from sqlalchemy import ( + cast, + column, + create_engine, + Column, + Integer, + select, + text, + Text, + values, +) +from sqlalchemy.sql import true +from sqlalchemy.pool import NullPool + +from sqlalchemy.orm import declarative_base, scoped_session, sessionmaker +from sqlalchemy.dialects.postgresql import JSONB, array +from pgvector.sqlalchemy import Vector +from sqlalchemy.ext.mutable import MutableDict + +from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult +from open_webui.config import PGVECTOR_DB_URL + +VECTOR_LENGTH = 1536 +Base = declarative_base() + + +class DocumentChunk(Base): + __tablename__ = "document_chunk" + + id = Column(Text, primary_key=True) + vector = Column(Vector(dim=VECTOR_LENGTH), nullable=True) + collection_name = Column(Text, nullable=False) + text = Column(Text, nullable=True) + vmetadata = Column(MutableDict.as_mutable(JSONB), nullable=True) + + +class PgvectorClient: + def __init__(self) -> None: + + # if no pgvector uri, use the existing database connection + if not PGVECTOR_DB_URL: + from open_webui.internal.db import Session + + self.session = Session + else: + engine = create_engine( + PGVECTOR_DB_URL, pool_pre_ping=True, poolclass=NullPool + ) + SessionLocal = sessionmaker( + autocommit=False, autoflush=False, bind=engine, expire_on_commit=False + ) + self.session = scoped_session(SessionLocal) + + try: + # Ensure the pgvector extension is available + self.session.execute(text("CREATE EXTENSION IF NOT EXISTS vector;")) + + # Create the tables if they do not exist + # Base.metadata.create_all requires a bind (engine or connection) + # Get the connection from the session + connection = self.session.connection() + Base.metadata.create_all(bind=connection) + + # Create an index on the vector column if it doesn't exist + self.session.execute( + text( + "CREATE INDEX IF NOT EXISTS idx_document_chunk_vector " + "ON document_chunk USING ivfflat (vector vector_cosine_ops) WITH (lists = 100);" + ) + ) + self.session.execute( + text( + "CREATE INDEX IF NOT EXISTS idx_document_chunk_collection_name " + "ON document_chunk (collection_name);" + ) + ) + self.session.commit() + print("Initialization complete.") + except Exception as e: + self.session.rollback() + print(f"Error during initialization: {e}") + raise + + def adjust_vector_length(self, vector: List[float]) -> List[float]: + # Adjust vector to have length VECTOR_LENGTH + current_length = len(vector) + if current_length < VECTOR_LENGTH: + # Pad the vector with zeros + vector += [0.0] * (VECTOR_LENGTH - current_length) + elif current_length > VECTOR_LENGTH: + raise Exception( + f"Vector length {current_length} not supported. Max length must be <= {VECTOR_LENGTH}" + ) + return vector + + def insert(self, collection_name: str, items: List[VectorItem]) -> None: + try: + new_items = [] + for item in items: + vector = self.adjust_vector_length(item["vector"]) + new_chunk = DocumentChunk( + id=item["id"], + vector=vector, + collection_name=collection_name, + text=item["text"], + vmetadata=item["metadata"], + ) + new_items.append(new_chunk) + self.session.bulk_save_objects(new_items) + self.session.commit() + print( + f"Inserted {len(new_items)} items into collection '{collection_name}'." + ) + except Exception as e: + self.session.rollback() + print(f"Error during insert: {e}") + raise + + def upsert(self, collection_name: str, items: List[VectorItem]) -> None: + try: + for item in items: + vector = self.adjust_vector_length(item["vector"]) + existing = ( + self.session.query(DocumentChunk) + .filter(DocumentChunk.id == item["id"]) + .first() + ) + if existing: + existing.vector = vector + existing.text = item["text"] + existing.vmetadata = item["metadata"] + existing.collection_name = ( + collection_name # Update collection_name if necessary + ) + else: + new_chunk = DocumentChunk( + id=item["id"], + vector=vector, + collection_name=collection_name, + text=item["text"], + vmetadata=item["metadata"], + ) + self.session.add(new_chunk) + self.session.commit() + print(f"Upserted {len(items)} items into collection '{collection_name}'.") + except Exception as e: + self.session.rollback() + print(f"Error during upsert: {e}") + raise + + def search( + self, + collection_name: str, + vectors: List[List[float]], + limit: Optional[int] = None, + ) -> Optional[SearchResult]: + try: + if not vectors: + return None + + # Adjust query vectors to VECTOR_LENGTH + vectors = [self.adjust_vector_length(vector) for vector in vectors] + num_queries = len(vectors) + + def vector_expr(vector): + return cast(array(vector), Vector(VECTOR_LENGTH)) + + # Create the values for query vectors + qid_col = column("qid", Integer) + q_vector_col = column("q_vector", Vector(VECTOR_LENGTH)) + query_vectors = ( + values(qid_col, q_vector_col) + .data( + [(idx, vector_expr(vector)) for idx, vector in enumerate(vectors)] + ) + .alias("query_vectors") + ) + + # Build the lateral subquery for each query vector + subq = ( + select( + DocumentChunk.id, + DocumentChunk.text, + DocumentChunk.vmetadata, + ( + DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector) + ).label("distance"), + ) + .where(DocumentChunk.collection_name == collection_name) + .order_by( + (DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)) + ) + ) + if limit is not None: + subq = subq.limit(limit) + subq = subq.lateral("result") + + # Build the main query by joining query_vectors and the lateral subquery + stmt = ( + select( + query_vectors.c.qid, + subq.c.id, + subq.c.text, + subq.c.vmetadata, + subq.c.distance, + ) + .select_from(query_vectors) + .join(subq, true()) + .order_by(query_vectors.c.qid, subq.c.distance) + ) + + result_proxy = self.session.execute(stmt) + results = result_proxy.all() + + ids = [[] for _ in range(num_queries)] + distances = [[] for _ in range(num_queries)] + documents = [[] for _ in range(num_queries)] + metadatas = [[] for _ in range(num_queries)] + + if not results: + return SearchResult( + ids=ids, + distances=distances, + documents=documents, + metadatas=metadatas, + ) + + for row in results: + qid = int(row.qid) + ids[qid].append(row.id) + distances[qid].append(row.distance) + documents[qid].append(row.text) + metadatas[qid].append(row.vmetadata) + + return SearchResult( + ids=ids, distances=distances, documents=documents, metadatas=metadatas + ) + except Exception as e: + print(f"Error during search: {e}") + return None + + def query( + self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None + ) -> Optional[GetResult]: + try: + query = self.session.query(DocumentChunk).filter( + DocumentChunk.collection_name == collection_name + ) + + for key, value in filter.items(): + query = query.filter(DocumentChunk.vmetadata[key].astext == str(value)) + + if limit is not None: + query = query.limit(limit) + + results = query.all() + + if not results: + return None + + ids = [[result.id for result in results]] + documents = [[result.text for result in results]] + metadatas = [[result.vmetadata for result in results]] + + return GetResult( + ids=ids, + documents=documents, + metadatas=metadatas, + ) + except Exception as e: + print(f"Error during query: {e}") + return None + + def get( + self, collection_name: str, limit: Optional[int] = None + ) -> Optional[GetResult]: + try: + query = self.session.query(DocumentChunk).filter( + DocumentChunk.collection_name == collection_name + ) + if limit is not None: + query = query.limit(limit) + + results = query.all() + + if not results: + return None + + ids = [[result.id for result in results]] + documents = [[result.text for result in results]] + metadatas = [[result.vmetadata for result in results]] + + return GetResult(ids=ids, documents=documents, metadatas=metadatas) + except Exception as e: + print(f"Error during get: {e}") + return None + + def delete( + self, + collection_name: str, + ids: Optional[List[str]] = None, + filter: Optional[Dict[str, Any]] = None, + ) -> None: + try: + query = self.session.query(DocumentChunk).filter( + DocumentChunk.collection_name == collection_name + ) + if ids: + query = query.filter(DocumentChunk.id.in_(ids)) + if filter: + for key, value in filter.items(): + query = query.filter( + DocumentChunk.vmetadata[key].astext == str(value) + ) + deleted = query.delete(synchronize_session=False) + self.session.commit() + print(f"Deleted {deleted} items from collection '{collection_name}'.") + except Exception as e: + self.session.rollback() + print(f"Error during delete: {e}") + raise + + def reset(self) -> None: + try: + deleted = self.session.query(DocumentChunk).delete() + self.session.commit() + print( + f"Reset complete. Deleted {deleted} items from 'document_chunk' table." + ) + except Exception as e: + self.session.rollback() + print(f"Error during reset: {e}") + raise + + def close(self) -> None: + pass + + def has_collection(self, collection_name: str) -> bool: + try: + exists = ( + self.session.query(DocumentChunk) + .filter(DocumentChunk.collection_name == collection_name) + .first() + is not None + ) + return exists + except Exception as e: + print(f"Error checking collection existence: {e}") + return False + + def delete_collection(self, collection_name: str) -> None: + self.delete(collection_name) + print(f"Collection '{collection_name}' deleted.") diff --git a/backend/open_webui/apps/retrieval/vector/dbs/qdrant.py b/backend/open_webui/retrieval/vector/dbs/qdrant.py similarity index 95% rename from backend/open_webui/apps/retrieval/vector/dbs/qdrant.py rename to backend/open_webui/retrieval/vector/dbs/qdrant.py index c1e06872f..f077ae45a 100644 --- a/backend/open_webui/apps/retrieval/vector/dbs/qdrant.py +++ b/backend/open_webui/retrieval/vector/dbs/qdrant.py @@ -4,8 +4,8 @@ from qdrant_client import QdrantClient as Qclient from qdrant_client.http.models import PointStruct from qdrant_client.models import models -from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult -from open_webui.config import QDRANT_URI +from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult +from open_webui.config import QDRANT_URI, QDRANT_API_KEY NO_LIMIT = 999999999 @@ -14,7 +14,12 @@ class QdrantClient: def __init__(self): self.collection_prefix = "open-webui" self.QDRANT_URI = QDRANT_URI - self.client = Qclient(url=self.QDRANT_URI) if self.QDRANT_URI else None + self.QDRANT_API_KEY = QDRANT_API_KEY + self.client = ( + Qclient(url=self.QDRANT_URI, api_key=self.QDRANT_API_KEY) + if self.QDRANT_URI + else None + ) def _result_to_get_result(self, points) -> GetResult: ids = [] diff --git a/backend/open_webui/apps/retrieval/vector/main.py b/backend/open_webui/retrieval/vector/main.py similarity index 100% rename from backend/open_webui/apps/retrieval/vector/main.py rename to backend/open_webui/retrieval/vector/main.py diff --git a/backend/open_webui/retrieval/web/bing.py b/backend/open_webui/retrieval/web/bing.py new file mode 100644 index 000000000..09beb3460 --- /dev/null +++ b/backend/open_webui/retrieval/web/bing.py @@ -0,0 +1,73 @@ +import logging +import os +from pprint import pprint +from typing import Optional +import requests +from open_webui.retrieval.web.main import SearchResult, get_filtered_results +from open_webui.env import SRC_LOG_LEVELS +import argparse + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["RAG"]) +""" +Documentation: https://docs.microsoft.com/en-us/bing/search-apis/bing-web-search/overview +""" + + +def search_bing( + subscription_key: str, + endpoint: str, + locale: str, + query: str, + count: int, + filter_list: Optional[list[str]] = None, +) -> list[SearchResult]: + mkt = locale + params = {"q": query, "mkt": mkt, "answerCount": count} + headers = {"Ocp-Apim-Subscription-Key": subscription_key} + + try: + response = requests.get(endpoint, headers=headers, params=params) + response.raise_for_status() + json_response = response.json() + results = json_response.get("webPages", {}).get("value", []) + if filter_list: + results = get_filtered_results(results, filter_list) + return [ + SearchResult( + link=result["url"], + title=result.get("name"), + snippet=result.get("snippet"), + ) + for result in results + ] + except Exception as ex: + log.error(f"Error: {ex}") + raise ex + + +def main(): + parser = argparse.ArgumentParser(description="Search Bing from the command line.") + parser.add_argument( + "query", + type=str, + default="Top 10 international news today", + help="The search query.", + ) + parser.add_argument( + "--count", type=int, default=10, help="Number of search results to return." + ) + parser.add_argument( + "--filter", nargs="*", help="List of filters to apply to the search results." + ) + parser.add_argument( + "--locale", + type=str, + default="en-US", + help="The locale to use for the search, maps to market in api", + ) + + args = parser.parse_args() + + results = search_bing(args.locale, args.query, args.count, args.filter) + pprint(results) diff --git a/backend/open_webui/apps/retrieval/web/brave.py b/backend/open_webui/retrieval/web/brave.py similarity index 93% rename from backend/open_webui/apps/retrieval/web/brave.py rename to backend/open_webui/retrieval/web/brave.py index f988b3b08..3075db990 100644 --- a/backend/open_webui/apps/retrieval/web/brave.py +++ b/backend/open_webui/retrieval/web/brave.py @@ -2,7 +2,7 @@ import logging from typing import Optional import requests -from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results +from open_webui.retrieval.web.main import SearchResult, get_filtered_results from open_webui.env import SRC_LOG_LEVELS log = logging.getLogger(__name__) diff --git a/backend/open_webui/apps/retrieval/web/duckduckgo.py b/backend/open_webui/retrieval/web/duckduckgo.py similarity index 95% rename from backend/open_webui/apps/retrieval/web/duckduckgo.py rename to backend/open_webui/retrieval/web/duckduckgo.py index 11e512296..7c0c3f1c2 100644 --- a/backend/open_webui/apps/retrieval/web/duckduckgo.py +++ b/backend/open_webui/retrieval/web/duckduckgo.py @@ -1,7 +1,7 @@ import logging from typing import Optional -from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results +from open_webui.retrieval.web.main import SearchResult, get_filtered_results from duckduckgo_search import DDGS from open_webui.env import SRC_LOG_LEVELS diff --git a/backend/open_webui/apps/retrieval/web/google_pse.py b/backend/open_webui/retrieval/web/google_pse.py similarity index 94% rename from backend/open_webui/apps/retrieval/web/google_pse.py rename to backend/open_webui/retrieval/web/google_pse.py index 61b919583..2c51dd3c9 100644 --- a/backend/open_webui/apps/retrieval/web/google_pse.py +++ b/backend/open_webui/retrieval/web/google_pse.py @@ -2,7 +2,7 @@ import logging from typing import Optional import requests -from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results +from open_webui.retrieval.web.main import SearchResult, get_filtered_results from open_webui.env import SRC_LOG_LEVELS log = logging.getLogger(__name__) diff --git a/backend/open_webui/apps/retrieval/web/jina_search.py b/backend/open_webui/retrieval/web/jina_search.py similarity index 81% rename from backend/open_webui/apps/retrieval/web/jina_search.py rename to backend/open_webui/retrieval/web/jina_search.py index 487bbc948..3de6c1807 100644 --- a/backend/open_webui/apps/retrieval/web/jina_search.py +++ b/backend/open_webui/retrieval/web/jina_search.py @@ -1,7 +1,7 @@ import logging import requests -from open_webui.apps.retrieval.web.main import SearchResult +from open_webui.retrieval.web.main import SearchResult from open_webui.env import SRC_LOG_LEVELS from yarl import URL @@ -9,7 +9,7 @@ log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["RAG"]) -def search_jina(query: str, count: int) -> list[SearchResult]: +def search_jina(api_key: str, query: str, count: int) -> list[SearchResult]: """ Search using Jina's Search API and return the results as a list of SearchResult objects. Args: @@ -20,9 +20,7 @@ def search_jina(query: str, count: int) -> list[SearchResult]: list[SearchResult]: A list of search results """ jina_search_endpoint = "https://s.jina.ai/" - headers = { - "Accept": "application/json", - } + headers = {"Accept": "application/json", "Authorization": f"Bearer {api_key}"} url = str(URL(jina_search_endpoint + query)) response = requests.get(url, headers=headers) response.raise_for_status() diff --git a/backend/open_webui/retrieval/web/kagi.py b/backend/open_webui/retrieval/web/kagi.py new file mode 100644 index 000000000..0b69da8bc --- /dev/null +++ b/backend/open_webui/retrieval/web/kagi.py @@ -0,0 +1,48 @@ +import logging +from typing import Optional + +import requests +from open_webui.retrieval.web.main import SearchResult, get_filtered_results +from open_webui.env import SRC_LOG_LEVELS + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["RAG"]) + + +def search_kagi( + api_key: str, query: str, count: int, filter_list: Optional[list[str]] = None +) -> list[SearchResult]: + """Search using Kagi's Search API and return the results as a list of SearchResult objects. + + The Search API will inherit the settings in your account, including results personalization and snippet length. + + Args: + api_key (str): A Kagi Search API key + query (str): The query to search for + count (int): The number of results to return + """ + url = "https://kagi.com/api/v0/search" + headers = { + "Authorization": f"Bot {api_key}", + } + params = {"q": query, "limit": count} + + response = requests.get(url, headers=headers, params=params) + response.raise_for_status() + json_response = response.json() + search_results = json_response.get("data", []) + + results = [ + SearchResult( + link=result["url"], title=result["title"], snippet=result.get("snippet") + ) + for result in search_results + if result["t"] == 0 + ] + + print(results) + + if filter_list: + results = get_filtered_results(results, filter_list) + + return results diff --git a/backend/open_webui/apps/retrieval/web/main.py b/backend/open_webui/retrieval/web/main.py similarity index 100% rename from backend/open_webui/apps/retrieval/web/main.py rename to backend/open_webui/retrieval/web/main.py diff --git a/backend/open_webui/retrieval/web/mojeek.py b/backend/open_webui/retrieval/web/mojeek.py new file mode 100644 index 000000000..d298b0ee5 --- /dev/null +++ b/backend/open_webui/retrieval/web/mojeek.py @@ -0,0 +1,40 @@ +import logging +from typing import Optional + +import requests +from open_webui.retrieval.web.main import SearchResult, get_filtered_results +from open_webui.env import SRC_LOG_LEVELS + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["RAG"]) + + +def search_mojeek( + api_key: str, query: str, count: int, filter_list: Optional[list[str]] = None +) -> list[SearchResult]: + """Search using Mojeek's Search API and return the results as a list of SearchResult objects. + + Args: + api_key (str): A Mojeek Search API key + query (str): The query to search for + """ + url = "https://api.mojeek.com/search" + headers = { + "Accept": "application/json", + } + params = {"q": query, "api_key": api_key, "fmt": "json", "t": count} + + response = requests.get(url, headers=headers, params=params) + response.raise_for_status() + json_response = response.json() + results = json_response.get("response", {}).get("results", []) + print(results) + if filter_list: + results = get_filtered_results(results, filter_list) + + return [ + SearchResult( + link=result["url"], title=result.get("title"), snippet=result.get("desc") + ) + for result in results + ] diff --git a/backend/open_webui/apps/retrieval/web/searchapi.py b/backend/open_webui/retrieval/web/searchapi.py similarity index 93% rename from backend/open_webui/apps/retrieval/web/searchapi.py rename to backend/open_webui/retrieval/web/searchapi.py index 412dc6b69..38bc0b574 100644 --- a/backend/open_webui/apps/retrieval/web/searchapi.py +++ b/backend/open_webui/retrieval/web/searchapi.py @@ -3,7 +3,7 @@ from typing import Optional from urllib.parse import urlencode import requests -from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results +from open_webui.retrieval.web.main import SearchResult, get_filtered_results from open_webui.env import SRC_LOG_LEVELS log = logging.getLogger(__name__) diff --git a/backend/open_webui/apps/retrieval/web/searxng.py b/backend/open_webui/retrieval/web/searxng.py similarity index 97% rename from backend/open_webui/apps/retrieval/web/searxng.py rename to backend/open_webui/retrieval/web/searxng.py index cb1eaf91d..15e3c098a 100644 --- a/backend/open_webui/apps/retrieval/web/searxng.py +++ b/backend/open_webui/retrieval/web/searxng.py @@ -2,7 +2,7 @@ import logging from typing import Optional import requests -from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results +from open_webui.retrieval.web.main import SearchResult, get_filtered_results from open_webui.env import SRC_LOG_LEVELS log = logging.getLogger(__name__) diff --git a/backend/open_webui/apps/retrieval/web/serper.py b/backend/open_webui/retrieval/web/serper.py similarity index 93% rename from backend/open_webui/apps/retrieval/web/serper.py rename to backend/open_webui/retrieval/web/serper.py index 436fa167e..685e34375 100644 --- a/backend/open_webui/apps/retrieval/web/serper.py +++ b/backend/open_webui/retrieval/web/serper.py @@ -3,7 +3,7 @@ import logging from typing import Optional import requests -from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results +from open_webui.retrieval.web.main import SearchResult, get_filtered_results from open_webui.env import SRC_LOG_LEVELS log = logging.getLogger(__name__) diff --git a/backend/open_webui/apps/retrieval/web/serply.py b/backend/open_webui/retrieval/web/serply.py similarity index 95% rename from backend/open_webui/apps/retrieval/web/serply.py rename to backend/open_webui/retrieval/web/serply.py index 1c2521c47..a9b473eb0 100644 --- a/backend/open_webui/apps/retrieval/web/serply.py +++ b/backend/open_webui/retrieval/web/serply.py @@ -3,7 +3,7 @@ from typing import Optional from urllib.parse import urlencode import requests -from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results +from open_webui.retrieval.web.main import SearchResult, get_filtered_results from open_webui.env import SRC_LOG_LEVELS log = logging.getLogger(__name__) diff --git a/backend/open_webui/apps/retrieval/web/serpstack.py b/backend/open_webui/retrieval/web/serpstack.py similarity index 94% rename from backend/open_webui/apps/retrieval/web/serpstack.py rename to backend/open_webui/retrieval/web/serpstack.py index b655934de..d4dbda57c 100644 --- a/backend/open_webui/apps/retrieval/web/serpstack.py +++ b/backend/open_webui/retrieval/web/serpstack.py @@ -2,7 +2,7 @@ import logging from typing import Optional import requests -from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results +from open_webui.retrieval.web.main import SearchResult, get_filtered_results from open_webui.env import SRC_LOG_LEVELS log = logging.getLogger(__name__) diff --git a/backend/open_webui/apps/retrieval/web/tavily.py b/backend/open_webui/retrieval/web/tavily.py similarity index 94% rename from backend/open_webui/apps/retrieval/web/tavily.py rename to backend/open_webui/retrieval/web/tavily.py index 03b0be75a..cc468725d 100644 --- a/backend/open_webui/apps/retrieval/web/tavily.py +++ b/backend/open_webui/retrieval/web/tavily.py @@ -1,7 +1,7 @@ import logging import requests -from open_webui.apps.retrieval.web.main import SearchResult +from open_webui.retrieval.web.main import SearchResult from open_webui.env import SRC_LOG_LEVELS log = logging.getLogger(__name__) diff --git a/backend/open_webui/retrieval/web/testdata/bing.json b/backend/open_webui/retrieval/web/testdata/bing.json new file mode 100644 index 000000000..80324f3b4 --- /dev/null +++ b/backend/open_webui/retrieval/web/testdata/bing.json @@ -0,0 +1,58 @@ +{ + "_type": "SearchResponse", + "queryContext": { + "originalQuery": "Top 10 international results" + }, + "webPages": { + "webSearchUrl": "https://www.bing.com/search?q=Top+10+international+results", + "totalEstimatedMatches": 687, + "value": [ + { + "id": "https://api.bing.microsoft.com/api/v7/#WebPages.0", + "name": "2024 Mexican Grand Prix - F1 results and latest standings ... - PlanetF1", + "url": "https://www.planetf1.com/news/f1-results-2024-mexican-grand-prix-race-standings", + "datePublished": "2024-10-27T00:00:00.0000000", + "datePublishedFreshnessText": "1 day ago", + "isFamilyFriendly": true, + "displayUrl": "https://www.planetf1.com/news/f1-results-2024-mexican-grand-prix-race-standings", + "snippet": "Nico Hulkenberg and Pierre Gasly completed the top 10. A full report of the Mexican Grand Prix is available at the bottom of this article. F1 results – 2024 Mexican Grand Prix", + "dateLastCrawled": "2024-10-28T07:15:00.0000000Z", + "cachedPageUrl": "https://cc.bingj.com/cache.aspx?q=Top+10+international+results&d=916492551782&mkt=en-US&setlang=en-US&w=zBsfaAPyF2tUrHFHr_vFFdUm8sng4g34", + "language": "en", + "isNavigational": false, + "noCache": false + }, + { + "id": "https://api.bing.microsoft.com/api/v7/#WebPages.1", + "name": "F1 Results Today: HUGE Verstappen penalties cause major title change", + "url": "https://www.gpfans.com/en/f1-news/1033512/f1-results-today-mexican-grand-prix-huge-max-verstappen-penalties-cause-major-title-change/", + "datePublished": "2024-10-27T00:00:00.0000000", + "datePublishedFreshnessText": "1 day ago", + "isFamilyFriendly": true, + "displayUrl": "https://www.gpfans.com/en/f1-news/1033512/f1-results-today-mexican-grand-prix-huge-max...", + "snippet": "Elsewhere, Mercedes duo Lewis Hamilton and George Russell came home in P4 and P5 respectively. Meanwhile, the surprise package of the day were Haas, with both Kevin Magnussen and Nico Hulkenberg finishing inside the points.. READ MORE: RB star issues apology after red flag CRASH at Mexican GP Mexican Grand Prix 2024 results. 1. Carlos Sainz [Ferrari] 2. Lando Norris [McLaren] - +4.705", + "dateLastCrawled": "2024-10-28T06:06:00.0000000Z", + "cachedPageUrl": "https://cc.bingj.com/cache.aspx?q=Top+10+international+results&d=2840656522642&mkt=en-US&setlang=en-US&w=-Tbkwxnq52jZCvG7l3CtgcwT1vwAjIUD", + "language": "en", + "isNavigational": false, + "noCache": false + }, + { + "id": "https://api.bing.microsoft.com/api/v7/#WebPages.2", + "name": "International Power Rankings: England flying, Kangaroos cruising, Fiji rise", + "url": "https://www.loverugbyleague.com/post/international-power-rankings-england-flying-kangaroos-cruising-fiji-rise", + "datePublished": "2024-10-28T00:00:00.0000000", + "datePublishedFreshnessText": "7 hours ago", + "isFamilyFriendly": true, + "displayUrl": "https://www.loverugbyleague.com/post/international-power-rankings-england-flying...", + "snippet": "LRL RECOMMENDS: England player ratings from first Test against Samoa as omnificent George Williams scores perfect 10. 2. Australia (Men) – SAME. The Kangaroos remain 2nd in our Power Rankings after their 22-10 win against New Zealand in Christchurch on Sunday. As was the case in their win against Tonga last week, Mal Meninga’s side weren ...", + "dateLastCrawled": "2024-10-28T07:09:00.0000000Z", + "cachedPageUrl": "https://cc.bingj.com/cache.aspx?q=Top+10+international+results&d=1535008462672&mkt=en-US&setlang=en-US&w=82ujhH4Kp0iuhCS7wh1xLUFYUeetaVVm", + "language": "en", + "isNavigational": false, + "noCache": false + } + ], + "someResultsRemoved": true + } +} diff --git a/backend/open_webui/apps/retrieval/web/testdata/brave.json b/backend/open_webui/retrieval/web/testdata/brave.json similarity index 100% rename from backend/open_webui/apps/retrieval/web/testdata/brave.json rename to backend/open_webui/retrieval/web/testdata/brave.json diff --git a/backend/open_webui/apps/retrieval/web/testdata/google_pse.json b/backend/open_webui/retrieval/web/testdata/google_pse.json similarity index 100% rename from backend/open_webui/apps/retrieval/web/testdata/google_pse.json rename to backend/open_webui/retrieval/web/testdata/google_pse.json diff --git a/backend/open_webui/apps/retrieval/web/testdata/searchapi.json b/backend/open_webui/retrieval/web/testdata/searchapi.json similarity index 100% rename from backend/open_webui/apps/retrieval/web/testdata/searchapi.json rename to backend/open_webui/retrieval/web/testdata/searchapi.json diff --git a/backend/open_webui/apps/retrieval/web/testdata/searxng.json b/backend/open_webui/retrieval/web/testdata/searxng.json similarity index 100% rename from backend/open_webui/apps/retrieval/web/testdata/searxng.json rename to backend/open_webui/retrieval/web/testdata/searxng.json diff --git a/backend/open_webui/apps/retrieval/web/testdata/serper.json b/backend/open_webui/retrieval/web/testdata/serper.json similarity index 100% rename from backend/open_webui/apps/retrieval/web/testdata/serper.json rename to backend/open_webui/retrieval/web/testdata/serper.json diff --git a/backend/open_webui/apps/retrieval/web/testdata/serply.json b/backend/open_webui/retrieval/web/testdata/serply.json similarity index 100% rename from backend/open_webui/apps/retrieval/web/testdata/serply.json rename to backend/open_webui/retrieval/web/testdata/serply.json diff --git a/backend/open_webui/apps/retrieval/web/testdata/serpstack.json b/backend/open_webui/retrieval/web/testdata/serpstack.json similarity index 100% rename from backend/open_webui/apps/retrieval/web/testdata/serpstack.json rename to backend/open_webui/retrieval/web/testdata/serpstack.json diff --git a/backend/open_webui/apps/retrieval/web/utils.py b/backend/open_webui/retrieval/web/utils.py similarity index 100% rename from backend/open_webui/apps/retrieval/web/utils.py rename to backend/open_webui/retrieval/web/utils.py diff --git a/backend/open_webui/routers/audio.py b/backend/open_webui/routers/audio.py new file mode 100644 index 000000000..a26355945 --- /dev/null +++ b/backend/open_webui/routers/audio.py @@ -0,0 +1,703 @@ +import hashlib +import json +import logging +import os +import uuid +from functools import lru_cache +from pathlib import Path +from pydub import AudioSegment +from pydub.silence import split_on_silence + +import aiohttp +import aiofiles +import requests + +from fastapi import ( + Depends, + FastAPI, + File, + HTTPException, + Request, + UploadFile, + status, + APIRouter, +) +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import FileResponse +from pydantic import BaseModel + + +from open_webui.utils.auth import get_admin_user, get_verified_user +from open_webui.config import ( + WHISPER_MODEL_AUTO_UPDATE, + WHISPER_MODEL_DIR, + CACHE_DIR, +) + +from open_webui.constants import ERROR_MESSAGES +from open_webui.env import ( + ENV, + SRC_LOG_LEVELS, + DEVICE_TYPE, + ENABLE_FORWARD_USER_INFO_HEADERS, +) + + +router = APIRouter() + +# Constants +MAX_FILE_SIZE_MB = 25 +MAX_FILE_SIZE = MAX_FILE_SIZE_MB * 1024 * 1024 # Convert MB to bytes + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["AUDIO"]) + +SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/") +SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True) + + +########################################## +# +# Utility functions +# +########################################## + +from pydub import AudioSegment +from pydub.utils import mediainfo + + +def is_mp4_audio(file_path): + """Check if the given file is an MP4 audio file.""" + if not os.path.isfile(file_path): + print(f"File not found: {file_path}") + return False + + info = mediainfo(file_path) + if ( + info.get("codec_name") == "aac" + and info.get("codec_type") == "audio" + and info.get("codec_tag_string") == "mp4a" + ): + return True + return False + + +def convert_mp4_to_wav(file_path, output_path): + """Convert MP4 audio file to WAV format.""" + audio = AudioSegment.from_file(file_path, format="mp4") + audio.export(output_path, format="wav") + print(f"Converted {file_path} to {output_path}") + + +def set_faster_whisper_model(model: str, auto_update: bool = False): + whisper_model = None + if model: + from faster_whisper import WhisperModel + + faster_whisper_kwargs = { + "model_size_or_path": model, + "device": DEVICE_TYPE if DEVICE_TYPE and DEVICE_TYPE == "cuda" else "cpu", + "compute_type": "int8", + "download_root": WHISPER_MODEL_DIR, + "local_files_only": not auto_update, + } + + try: + whisper_model = WhisperModel(**faster_whisper_kwargs) + except Exception: + log.warning( + "WhisperModel initialization failed, attempting download with local_files_only=False" + ) + faster_whisper_kwargs["local_files_only"] = False + whisper_model = WhisperModel(**faster_whisper_kwargs) + return whisper_model + + +########################################## +# +# Audio API +# +########################################## + + +class TTSConfigForm(BaseModel): + OPENAI_API_BASE_URL: str + OPENAI_API_KEY: str + API_KEY: str + ENGINE: str + MODEL: str + VOICE: str + SPLIT_ON: str + AZURE_SPEECH_REGION: str + AZURE_SPEECH_OUTPUT_FORMAT: str + + +class STTConfigForm(BaseModel): + OPENAI_API_BASE_URL: str + OPENAI_API_KEY: str + ENGINE: str + MODEL: str + WHISPER_MODEL: str + + +class AudioConfigUpdateForm(BaseModel): + tts: TTSConfigForm + stt: STTConfigForm + + +@router.get("/config") +async def get_audio_config(request: Request, user=Depends(get_admin_user)): + return { + "tts": { + "OPENAI_API_BASE_URL": request.app.state.config.TTS_OPENAI_API_BASE_URL, + "OPENAI_API_KEY": request.app.state.config.TTS_OPENAI_API_KEY, + "API_KEY": request.app.state.config.TTS_API_KEY, + "ENGINE": request.app.state.config.TTS_ENGINE, + "MODEL": request.app.state.config.TTS_MODEL, + "VOICE": request.app.state.config.TTS_VOICE, + "SPLIT_ON": request.app.state.config.TTS_SPLIT_ON, + "AZURE_SPEECH_REGION": request.app.state.config.TTS_AZURE_SPEECH_REGION, + "AZURE_SPEECH_OUTPUT_FORMAT": request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT, + }, + "stt": { + "OPENAI_API_BASE_URL": request.app.state.config.STT_OPENAI_API_BASE_URL, + "OPENAI_API_KEY": request.app.state.config.STT_OPENAI_API_KEY, + "ENGINE": request.app.state.config.STT_ENGINE, + "MODEL": request.app.state.config.STT_MODEL, + "WHISPER_MODEL": request.app.state.config.WHISPER_MODEL, + }, + } + + +@router.post("/config/update") +async def update_audio_config( + request: Request, form_data: AudioConfigUpdateForm, user=Depends(get_admin_user) +): + request.app.state.config.TTS_OPENAI_API_BASE_URL = form_data.tts.OPENAI_API_BASE_URL + request.app.state.config.TTS_OPENAI_API_KEY = form_data.tts.OPENAI_API_KEY + request.app.state.config.TTS_API_KEY = form_data.tts.API_KEY + request.app.state.config.TTS_ENGINE = form_data.tts.ENGINE + request.app.state.config.TTS_MODEL = form_data.tts.MODEL + request.app.state.config.TTS_VOICE = form_data.tts.VOICE + request.app.state.config.TTS_SPLIT_ON = form_data.tts.SPLIT_ON + request.app.state.config.TTS_AZURE_SPEECH_REGION = form_data.tts.AZURE_SPEECH_REGION + request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT = ( + form_data.tts.AZURE_SPEECH_OUTPUT_FORMAT + ) + + request.app.state.config.STT_OPENAI_API_BASE_URL = form_data.stt.OPENAI_API_BASE_URL + request.app.state.config.STT_OPENAI_API_KEY = form_data.stt.OPENAI_API_KEY + request.app.state.config.STT_ENGINE = form_data.stt.ENGINE + request.app.state.config.STT_MODEL = form_data.stt.MODEL + request.app.state.config.WHISPER_MODEL = form_data.stt.WHISPER_MODEL + + if request.app.state.config.STT_ENGINE == "": + request.app.state.faster_whisper_model = set_faster_whisper_model( + form_data.stt.WHISPER_MODEL, WHISPER_MODEL_AUTO_UPDATE + ) + + return { + "tts": { + "OPENAI_API_BASE_URL": request.app.state.config.TTS_OPENAI_API_BASE_URL, + "OPENAI_API_KEY": request.app.state.config.TTS_OPENAI_API_KEY, + "API_KEY": request.app.state.config.TTS_API_KEY, + "ENGINE": request.app.state.config.TTS_ENGINE, + "MODEL": request.app.state.config.TTS_MODEL, + "VOICE": request.app.state.config.TTS_VOICE, + "SPLIT_ON": request.app.state.config.TTS_SPLIT_ON, + "AZURE_SPEECH_REGION": request.app.state.config.TTS_AZURE_SPEECH_REGION, + "AZURE_SPEECH_OUTPUT_FORMAT": request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT, + }, + "stt": { + "OPENAI_API_BASE_URL": request.app.state.config.STT_OPENAI_API_BASE_URL, + "OPENAI_API_KEY": request.app.state.config.STT_OPENAI_API_KEY, + "ENGINE": request.app.state.config.STT_ENGINE, + "MODEL": request.app.state.config.STT_MODEL, + "WHISPER_MODEL": request.app.state.config.WHISPER_MODEL, + }, + } + + +def load_speech_pipeline(): + from transformers import pipeline + from datasets import load_dataset + + if request.app.state.speech_synthesiser is None: + request.app.state.speech_synthesiser = pipeline( + "text-to-speech", "microsoft/speecht5_tts" + ) + + if request.app.state.speech_speaker_embeddings_dataset is None: + request.app.state.speech_speaker_embeddings_dataset = load_dataset( + "Matthijs/cmu-arctic-xvectors", split="validation" + ) + + +@router.post("/speech") +async def speech(request: Request, user=Depends(get_verified_user)): + body = await request.body() + name = hashlib.sha256(body).hexdigest() + + file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3") + file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json") + + # Check if the file already exists in the cache + if file_path.is_file(): + return FileResponse(file_path) + + payload = None + try: + payload = json.loads(body.decode("utf-8")) + except Exception as e: + log.exception(e) + raise HTTPException(status_code=400, detail="Invalid JSON payload") + + if request.app.state.config.TTS_ENGINE == "openai": + payload["model"] = request.app.state.config.TTS_MODEL + + try: + async with aiohttp.ClientSession() as session: + async with session.post( + url=f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech", + data=payload, + headers={ + "Content-Type": "application/json", + "Authorization": f"Bearer {request.app.state.config.TTS_OPENAI_API_KEY}", + **( + { + "X-OpenWebUI-User-Name": user.name, + "X-OpenWebUI-User-Id": user.id, + "X-OpenWebUI-User-Email": user.email, + "X-OpenWebUI-User-Role": user.role, + } + if ENABLE_FORWARD_USER_INFO_HEADERS + else {} + ), + }, + ) as r: + r.raise_for_status() + + async with aiofiles.open(file_path, "wb") as f: + await f.write(await r.read()) + + async with aiofiles.open(file_body_path, "w") as f: + await f.write(json.dumps(json.loads(body.decode("utf-8")))) + + return FileResponse(file_path) + + except Exception as e: + log.exception(e) + detail = None + + try: + if r.status != 200: + res = await r.json() + if "error" in res: + detail = f"External: {res['error'].get('message', '')}" + except Exception: + detail = f"External: {e}" + + raise HTTPException( + status_code=getattr(r, "status", 500), + detail=detail if detail else "Open WebUI: Server Connection Error", + ) + + elif request.app.state.config.TTS_ENGINE == "elevenlabs": + voice_id = payload.get("voice", "") + + if voice_id not in get_available_voices(): + raise HTTPException( + status_code=400, + detail="Invalid voice id", + ) + + try: + async with aiohttp.ClientSession() as session: + async with session.post( + f"https://api.elevenlabs.io/v1/text-to-speech/{voice_id}", + json={ + "text": payload["input"], + "model_id": request.app.state.config.TTS_MODEL, + "voice_settings": {"stability": 0.5, "similarity_boost": 0.5}, + }, + headers={ + "Accept": "audio/mpeg", + "Content-Type": "application/json", + "xi-api-key": request.app.state.config.TTS_API_KEY, + }, + ) as r: + r.raise_for_status() + + async with aiofiles.open(file_path, "wb") as f: + await f.write(await r.read()) + + async with aiofiles.open(file_body_path, "w") as f: + await f.write(json.dumps(json.loads(body.decode("utf-8")))) + + return FileResponse(file_path) + + except Exception as e: + log.exception(e) + detail = None + + try: + if r.status != 200: + res = await r.json() + if "error" in res: + detail = f"External: {res['error'].get('message', '')}" + except Exception: + detail = f"External: {e}" + + raise HTTPException( + status_code=getattr(r, "status", 500), + detail=detail if detail else "Open WebUI: Server Connection Error", + ) + + elif request.app.state.config.TTS_ENGINE == "azure": + try: + payload = json.loads(body.decode("utf-8")) + except Exception as e: + log.exception(e) + raise HTTPException(status_code=400, detail="Invalid JSON payload") + + region = request.app.state.config.TTS_AZURE_SPEECH_REGION + language = request.app.state.config.TTS_VOICE + locale = "-".join(request.app.state.config.TTS_VOICE.split("-")[:1]) + output_format = request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT + + try: + data = f""" + {payload["input"]} + """ + async with aiohttp.ClientSession() as session: + async with session.post( + f"https://{region}.tts.speech.microsoft.com/cognitiveservices/v1", + headers={ + "Ocp-Apim-Subscription-Key": request.app.state.config.TTS_API_KEY, + "Content-Type": "application/ssml+xml", + "X-Microsoft-OutputFormat": output_format, + }, + data=data, + ) as r: + r.raise_for_status() + + async with aiofiles.open(file_path, "wb") as f: + await f.write(await r.read()) + + return FileResponse(file_path) + + except Exception as e: + log.exception(e) + detail = None + + try: + if r.status != 200: + res = await r.json() + if "error" in res: + detail = f"External: {res['error'].get('message', '')}" + except Exception: + detail = f"External: {e}" + + raise HTTPException( + status_code=getattr(r, "status", 500), + detail=detail if detail else "Open WebUI: Server Connection Error", + ) + + elif request.app.state.config.TTS_ENGINE == "transformers": + payload = None + try: + payload = json.loads(body.decode("utf-8")) + except Exception as e: + log.exception(e) + raise HTTPException(status_code=400, detail="Invalid JSON payload") + + import torch + import soundfile as sf + + load_speech_pipeline() + + embeddings_dataset = request.app.state.speech_speaker_embeddings_dataset + + speaker_index = 6799 + try: + speaker_index = embeddings_dataset["filename"].index( + request.app.state.config.TTS_MODEL + ) + except Exception: + pass + + speaker_embedding = torch.tensor( + embeddings_dataset[speaker_index]["xvector"] + ).unsqueeze(0) + + speech = request.app.state.speech_synthesiser( + payload["input"], + forward_params={"speaker_embeddings": speaker_embedding}, + ) + + sf.write(file_path, speech["audio"], samplerate=speech["sampling_rate"]) + with open(file_body_path, "w") as f: + json.dump(json.loads(body.decode("utf-8")), f) + + return FileResponse(file_path) + + +def transcribe(request: Request, file_path): + print("transcribe", file_path) + filename = os.path.basename(file_path) + file_dir = os.path.dirname(file_path) + id = filename.split(".")[0] + + if request.app.state.config.STT_ENGINE == "": + if request.app.state.faster_whisper_model is None: + request.app.state.faster_whisper_model = set_faster_whisper_model( + request.app.state.config.WHISPER_MODEL + ) + + model = request.app.state.faster_whisper_model + segments, info = model.transcribe(file_path, beam_size=5) + log.info( + "Detected language '%s' with probability %f" + % (info.language, info.language_probability) + ) + + transcript = "".join([segment.text for segment in list(segments)]) + data = {"text": transcript.strip()} + + # save the transcript to a json file + transcript_file = f"{file_dir}/{id}.json" + with open(transcript_file, "w") as f: + json.dump(data, f) + + log.debug(data) + return data + elif request.app.state.config.STT_ENGINE == "openai": + if is_mp4_audio(file_path): + os.rename(file_path, file_path.replace(".wav", ".mp4")) + # Convert MP4 audio file to WAV format + convert_mp4_to_wav(file_path.replace(".wav", ".mp4"), file_path) + + r = None + try: + r = requests.post( + url=f"{request.app.state.config.STT_OPENAI_API_BASE_URL}/audio/transcriptions", + headers={ + "Authorization": f"Bearer {request.app.state.config.STT_OPENAI_API_KEY}" + }, + files={"file": (filename, open(file_path, "rb"))}, + data={"model": request.app.state.config.STT_MODEL}, + ) + + r.raise_for_status() + data = r.json() + + # save the transcript to a json file + transcript_file = f"{file_dir}/{id}.json" + with open(transcript_file, "w") as f: + json.dump(data, f) + + return data + except Exception as e: + log.exception(e) + + detail = None + if r is not None: + try: + res = r.json() + if "error" in res: + detail = f"External: {res['error'].get('message', '')}" + except Exception: + detail = f"External: {e}" + + raise Exception(detail if detail else "Open WebUI: Server Connection Error") + + +def compress_audio(file_path): + if os.path.getsize(file_path) > MAX_FILE_SIZE: + file_dir = os.path.dirname(file_path) + audio = AudioSegment.from_file(file_path) + audio = audio.set_frame_rate(16000).set_channels(1) # Compress audio + compressed_path = f"{file_dir}/{id}_compressed.opus" + audio.export(compressed_path, format="opus", bitrate="32k") + log.debug(f"Compressed audio to {compressed_path}") + + if ( + os.path.getsize(compressed_path) > MAX_FILE_SIZE + ): # Still larger than MAX_FILE_SIZE after compression + raise Exception(ERROR_MESSAGES.FILE_TOO_LARGE(size=f"{MAX_FILE_SIZE_MB}MB")) + return compressed_path + else: + return file_path + + +@router.post("/transcriptions") +def transcription( + request: Request, + file: UploadFile = File(...), + user=Depends(get_verified_user), +): + log.info(f"file.content_type: {file.content_type}") + + if file.content_type not in ["audio/mpeg", "audio/wav", "audio/ogg", "audio/x-m4a"]: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.FILE_NOT_SUPPORTED, + ) + + try: + ext = file.filename.split(".")[-1] + id = uuid.uuid4() + + filename = f"{id}.{ext}" + contents = file.file.read() + + file_dir = f"{CACHE_DIR}/audio/transcriptions" + os.makedirs(file_dir, exist_ok=True) + file_path = f"{file_dir}/{filename}" + + with open(file_path, "wb") as f: + f.write(contents) + + try: + try: + file_path = compress_audio(file_path) + except Exception as e: + log.exception(e) + + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT(e), + ) + + data = transcribe(request, file_path) + file_path = file_path.split("/")[-1] + return {**data, "filename": file_path} + except Exception as e: + log.exception(e) + + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT(e), + ) + + except Exception as e: + log.exception(e) + + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT(e), + ) + + +def get_available_models(request: Request) -> list[dict]: + available_models = [] + if request.app.state.config.TTS_ENGINE == "openai": + available_models = [{"id": "tts-1"}, {"id": "tts-1-hd"}] + elif request.app.state.config.TTS_ENGINE == "elevenlabs": + try: + response = requests.get( + "https://api.elevenlabs.io/v1/models", + headers={ + "xi-api-key": request.app.state.config.TTS_API_KEY, + "Content-Type": "application/json", + }, + timeout=5, + ) + response.raise_for_status() + models = response.json() + + available_models = [ + {"name": model["name"], "id": model["model_id"]} for model in models + ] + except requests.RequestException as e: + log.error(f"Error fetching voices: {str(e)}") + return available_models + + +@router.get("/models") +async def get_models(request: Request, user=Depends(get_verified_user)): + return {"models": get_available_models(request)} + + +def get_available_voices(request) -> dict: + """Returns {voice_id: voice_name} dict""" + available_voices = {} + if request.app.state.config.TTS_ENGINE == "openai": + available_voices = { + "alloy": "alloy", + "echo": "echo", + "fable": "fable", + "onyx": "onyx", + "nova": "nova", + "shimmer": "shimmer", + } + elif request.app.state.config.TTS_ENGINE == "elevenlabs": + try: + available_voices = get_elevenlabs_voices( + api_key=request.app.state.config.TTS_API_KEY + ) + except Exception: + # Avoided @lru_cache with exception + pass + elif request.app.state.config.TTS_ENGINE == "azure": + try: + region = request.app.state.config.TTS_AZURE_SPEECH_REGION + url = f"https://{region}.tts.speech.microsoft.com/cognitiveservices/voices/list" + headers = { + "Ocp-Apim-Subscription-Key": request.app.state.config.TTS_API_KEY + } + + response = requests.get(url, headers=headers) + response.raise_for_status() + voices = response.json() + + for voice in voices: + available_voices[voice["ShortName"]] = ( + f"{voice['DisplayName']} ({voice['ShortName']})" + ) + except requests.RequestException as e: + log.error(f"Error fetching voices: {str(e)}") + + return available_voices + + +@lru_cache +def get_elevenlabs_voices(api_key: str) -> dict: + """ + Note, set the following in your .env file to use Elevenlabs: + AUDIO_TTS_ENGINE=elevenlabs + AUDIO_TTS_API_KEY=sk_... # Your Elevenlabs API key + AUDIO_TTS_VOICE=EXAVITQu4vr4xnSDxMaL # From https://api.elevenlabs.io/v1/voices + AUDIO_TTS_MODEL=eleven_multilingual_v2 + """ + + try: + # TODO: Add retries + response = requests.get( + "https://api.elevenlabs.io/v1/voices", + headers={ + "xi-api-key": api_key, + "Content-Type": "application/json", + }, + ) + response.raise_for_status() + voices_data = response.json() + + voices = {} + for voice in voices_data.get("voices", []): + voices[voice["voice_id"]] = voice["name"] + except requests.RequestException as e: + # Avoid @lru_cache with exception + log.error(f"Error fetching voices: {str(e)}") + raise RuntimeError(f"Error fetching voices: {str(e)}") + + return voices + + +@router.get("/voices") +async def get_voices(request: Request, user=Depends(get_verified_user)): + return { + "voices": [ + {"id": k, "name": v} for k, v in get_available_voices(request).items() + ] + } diff --git a/backend/open_webui/apps/webui/routers/auths.py b/backend/open_webui/routers/auths.py similarity index 55% rename from backend/open_webui/apps/webui/routers/auths.py rename to backend/open_webui/routers/auths.py index ef0a0d445..0b1f42edf 100644 --- a/backend/open_webui/apps/webui/routers/auths.py +++ b/backend/open_webui/routers/auths.py @@ -2,12 +2,15 @@ import re import uuid import time import datetime +import logging +from aiohttp import ClientSession -from open_webui.apps.webui.models.auths import ( +from open_webui.models.auths import ( AddUserForm, ApiKey, Auths, Token, + LdapForm, SigninForm, SigninResponse, SignupForm, @@ -15,20 +18,26 @@ from open_webui.apps.webui.models.auths import ( UpdateProfileForm, UserResponse, ) -from open_webui.apps.webui.models.users import Users -from open_webui.config import WEBUI_AUTH +from open_webui.models.users import Users + from open_webui.constants import ERROR_MESSAGES, WEBHOOK_MESSAGES from open_webui.env import ( + WEBUI_AUTH, WEBUI_AUTH_TRUSTED_EMAIL_HEADER, WEBUI_AUTH_TRUSTED_NAME_HEADER, WEBUI_SESSION_COOKIE_SAME_SITE, WEBUI_SESSION_COOKIE_SECURE, + SRC_LOG_LEVELS, ) from fastapi import APIRouter, Depends, HTTPException, Request, status -from fastapi.responses import Response +from fastapi.responses import RedirectResponse, Response +from open_webui.config import ( + OPENID_PROVIDER_URL, + ENABLE_OAUTH_SIGNUP, +) from pydantic import BaseModel from open_webui.utils.misc import parse_duration, validate_email_format -from open_webui.utils.utils import ( +from open_webui.utils.auth import ( create_api_key, create_token, get_admin_user, @@ -37,10 +46,19 @@ from open_webui.utils.utils import ( get_password_hash, ) from open_webui.utils.webhook import post_webhook -from typing import Optional +from open_webui.utils.access_control import get_permissions + +from typing import Optional, List + +from ssl import CERT_REQUIRED, PROTOCOL_TLS +from ldap3 import Server, Connection, ALL, Tls +from ldap3.utils.conv import escape_filter_chars router = APIRouter() +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["MAIN"]) + ############################ # GetSessionUser ############################ @@ -48,6 +66,7 @@ router = APIRouter() class SessionUserResponse(Token, UserResponse): expires_at: Optional[int] = None + permissions: Optional[dict] = None @router.get("/", response_model=SessionUserResponse) @@ -80,6 +99,10 @@ async def get_session_user( secure=WEBUI_SESSION_COOKIE_SECURE, ) + user_permissions = get_permissions( + user.id, request.app.state.config.USER_PERMISSIONS + ) + return { "token": token, "token_type": "Bearer", @@ -89,6 +112,7 @@ async def get_session_user( "name": user.name, "role": user.role, "profile_image_url": user.profile_image_url, + "permissions": user_permissions, } @@ -137,6 +161,146 @@ async def update_password( raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) +############################ +# LDAP Authentication +############################ +@router.post("/ldap", response_model=SigninResponse) +async def ldap_auth(request: Request, response: Response, form_data: LdapForm): + ENABLE_LDAP = request.app.state.config.ENABLE_LDAP + LDAP_SERVER_LABEL = request.app.state.config.LDAP_SERVER_LABEL + LDAP_SERVER_HOST = request.app.state.config.LDAP_SERVER_HOST + LDAP_SERVER_PORT = request.app.state.config.LDAP_SERVER_PORT + LDAP_ATTRIBUTE_FOR_USERNAME = request.app.state.config.LDAP_ATTRIBUTE_FOR_USERNAME + LDAP_SEARCH_BASE = request.app.state.config.LDAP_SEARCH_BASE + LDAP_SEARCH_FILTERS = request.app.state.config.LDAP_SEARCH_FILTERS + LDAP_APP_DN = request.app.state.config.LDAP_APP_DN + LDAP_APP_PASSWORD = request.app.state.config.LDAP_APP_PASSWORD + LDAP_USE_TLS = request.app.state.config.LDAP_USE_TLS + LDAP_CA_CERT_FILE = request.app.state.config.LDAP_CA_CERT_FILE + LDAP_CIPHERS = ( + request.app.state.config.LDAP_CIPHERS + if request.app.state.config.LDAP_CIPHERS + else "ALL" + ) + + if not ENABLE_LDAP: + raise HTTPException(400, detail="LDAP authentication is not enabled") + + try: + tls = Tls( + validate=CERT_REQUIRED, + version=PROTOCOL_TLS, + ca_certs_file=LDAP_CA_CERT_FILE, + ciphers=LDAP_CIPHERS, + ) + except Exception as e: + log.error(f"An error occurred on TLS: {str(e)}") + raise HTTPException(400, detail=str(e)) + + try: + server = Server( + host=LDAP_SERVER_HOST, + port=LDAP_SERVER_PORT, + get_info=ALL, + use_ssl=LDAP_USE_TLS, + tls=tls, + ) + connection_app = Connection( + server, + LDAP_APP_DN, + LDAP_APP_PASSWORD, + auto_bind="NONE", + authentication="SIMPLE", + ) + if not connection_app.bind(): + raise HTTPException(400, detail="Application account bind failed") + + search_success = connection_app.search( + search_base=LDAP_SEARCH_BASE, + search_filter=f"(&({LDAP_ATTRIBUTE_FOR_USERNAME}={escape_filter_chars(form_data.user.lower())}){LDAP_SEARCH_FILTERS})", + attributes=[f"{LDAP_ATTRIBUTE_FOR_USERNAME}", "mail", "cn"], + ) + + if not search_success: + raise HTTPException(400, detail="User not found in the LDAP server") + + entry = connection_app.entries[0] + username = str(entry[f"{LDAP_ATTRIBUTE_FOR_USERNAME}"]).lower() + mail = str(entry["mail"]) + cn = str(entry["cn"]) + user_dn = entry.entry_dn + + if username == form_data.user.lower(): + connection_user = Connection( + server, + user_dn, + form_data.password, + auto_bind="NONE", + authentication="SIMPLE", + ) + if not connection_user.bind(): + raise HTTPException(400, f"Authentication failed for {form_data.user}") + + user = Users.get_user_by_email(mail) + if not user: + try: + role = ( + "admin" + if Users.get_num_users() == 0 + else request.app.state.config.DEFAULT_USER_ROLE + ) + + user = Auths.insert_new_auth( + email=mail, password=str(uuid.uuid4()), name=cn, role=role + ) + + if not user: + raise HTTPException( + 500, detail=ERROR_MESSAGES.CREATE_USER_ERROR + ) + + except HTTPException: + raise + except Exception as err: + raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT(err)) + + user = Auths.authenticate_user_by_trusted_header(mail) + + if user: + token = create_token( + data={"id": user.id}, + expires_delta=parse_duration( + request.app.state.config.JWT_EXPIRES_IN + ), + ) + + # Set the cookie token + response.set_cookie( + key="token", + value=token, + httponly=True, # Ensures the cookie is not accessible via JavaScript + ) + + return { + "token": token, + "token_type": "Bearer", + "id": user.id, + "email": user.email, + "name": user.name, + "role": user.role, + "profile_image_url": user.profile_image_url, + } + else: + raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) + else: + raise HTTPException( + 400, + f"User {form_data.user} does not match the record. Search result: {str(entry[f'{LDAP_ATTRIBUTE_FOR_USERNAME}'])}", + ) + except Exception as e: + raise HTTPException(400, detail=str(e)) + + ############################ # SignIn ############################ @@ -211,6 +375,10 @@ async def signin(request: Request, response: Response, form_data: SigninForm): secure=WEBUI_SESSION_COOKIE_SECURE, ) + user_permissions = get_permissions( + user.id, request.app.state.config.USER_PERMISSIONS + ) + return { "token": token, "token_type": "Bearer", @@ -220,6 +388,7 @@ async def signin(request: Request, response: Response, form_data: SigninForm): "name": user.name, "role": user.role, "profile_image_url": user.profile_image_url, + "permissions": user_permissions, } else: raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) @@ -260,6 +429,11 @@ async def signup(request: Request, response: Response, form_data: SignupForm): if Users.get_num_users() == 0 else request.app.state.config.DEFAULT_USER_ROLE ) + + if Users.get_num_users() == 0: + # Disable signup after the first user is created + request.app.state.config.ENABLE_SIGNUP = False + hashed = get_password_hash(form_data.password) user = Auths.insert_new_auth( form_data.email.lower(), @@ -307,6 +481,10 @@ async def signup(request: Request, response: Response, form_data: SignupForm): }, ) + user_permissions = get_permissions( + user.id, request.app.state.config.USER_PERMISSIONS + ) + return { "token": token, "token_type": "Bearer", @@ -316,6 +494,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm): "name": user.name, "role": user.role, "profile_image_url": user.profile_image_url, + "permissions": user_permissions, } else: raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_USER_ERROR) @@ -324,8 +503,31 @@ async def signup(request: Request, response: Response, form_data: SignupForm): @router.get("/signout") -async def signout(response: Response): +async def signout(request: Request, response: Response): response.delete_cookie("token") + + if ENABLE_OAUTH_SIGNUP.value: + oauth_id_token = request.cookies.get("oauth_id_token") + if oauth_id_token: + try: + async with ClientSession() as session: + async with session.get(OPENID_PROVIDER_URL.value) as resp: + if resp.status == 200: + openid_data = await resp.json() + logout_url = openid_data.get("end_session_endpoint") + if logout_url: + response.delete_cookie("oauth_id_token") + return RedirectResponse( + url=f"{logout_url}?id_token_hint={oauth_id_token}" + ) + else: + raise HTTPException( + status_code=resp.status, + detail="Failed to fetch OpenID configuration", + ) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + return {"status": True} @@ -413,6 +615,7 @@ async def get_admin_config(request: Request, user=Depends(get_admin_user)): return { "SHOW_ADMIN_DETAILS": request.app.state.config.SHOW_ADMIN_DETAILS, "ENABLE_SIGNUP": request.app.state.config.ENABLE_SIGNUP, + "ENABLE_API_KEY": request.app.state.config.ENABLE_API_KEY, "DEFAULT_USER_ROLE": request.app.state.config.DEFAULT_USER_ROLE, "JWT_EXPIRES_IN": request.app.state.config.JWT_EXPIRES_IN, "ENABLE_COMMUNITY_SHARING": request.app.state.config.ENABLE_COMMUNITY_SHARING, @@ -423,6 +626,7 @@ async def get_admin_config(request: Request, user=Depends(get_admin_user)): class AdminConfig(BaseModel): SHOW_ADMIN_DETAILS: bool ENABLE_SIGNUP: bool + ENABLE_API_KEY: bool DEFAULT_USER_ROLE: str JWT_EXPIRES_IN: str ENABLE_COMMUNITY_SHARING: bool @@ -435,6 +639,7 @@ async def update_admin_config( ): request.app.state.config.SHOW_ADMIN_DETAILS = form_data.SHOW_ADMIN_DETAILS request.app.state.config.ENABLE_SIGNUP = form_data.ENABLE_SIGNUP + request.app.state.config.ENABLE_API_KEY = form_data.ENABLE_API_KEY if form_data.DEFAULT_USER_ROLE in ["pending", "user", "admin"]: request.app.state.config.DEFAULT_USER_ROLE = form_data.DEFAULT_USER_ROLE @@ -453,6 +658,7 @@ async def update_admin_config( return { "SHOW_ADMIN_DETAILS": request.app.state.config.SHOW_ADMIN_DETAILS, "ENABLE_SIGNUP": request.app.state.config.ENABLE_SIGNUP, + "ENABLE_API_KEY": request.app.state.config.ENABLE_API_KEY, "DEFAULT_USER_ROLE": request.app.state.config.DEFAULT_USER_ROLE, "JWT_EXPIRES_IN": request.app.state.config.JWT_EXPIRES_IN, "ENABLE_COMMUNITY_SHARING": request.app.state.config.ENABLE_COMMUNITY_SHARING, @@ -460,6 +666,105 @@ async def update_admin_config( } +class LdapServerConfig(BaseModel): + label: str + host: str + port: Optional[int] = None + attribute_for_username: str = "uid" + app_dn: str + app_dn_password: str + search_base: str + search_filters: str = "" + use_tls: bool = True + certificate_path: Optional[str] = None + ciphers: Optional[str] = "ALL" + + +@router.get("/admin/config/ldap/server", response_model=LdapServerConfig) +async def get_ldap_server(request: Request, user=Depends(get_admin_user)): + return { + "label": request.app.state.config.LDAP_SERVER_LABEL, + "host": request.app.state.config.LDAP_SERVER_HOST, + "port": request.app.state.config.LDAP_SERVER_PORT, + "attribute_for_username": request.app.state.config.LDAP_ATTRIBUTE_FOR_USERNAME, + "app_dn": request.app.state.config.LDAP_APP_DN, + "app_dn_password": request.app.state.config.LDAP_APP_PASSWORD, + "search_base": request.app.state.config.LDAP_SEARCH_BASE, + "search_filters": request.app.state.config.LDAP_SEARCH_FILTERS, + "use_tls": request.app.state.config.LDAP_USE_TLS, + "certificate_path": request.app.state.config.LDAP_CA_CERT_FILE, + "ciphers": request.app.state.config.LDAP_CIPHERS, + } + + +@router.post("/admin/config/ldap/server") +async def update_ldap_server( + request: Request, form_data: LdapServerConfig, user=Depends(get_admin_user) +): + required_fields = [ + "label", + "host", + "attribute_for_username", + "app_dn", + "app_dn_password", + "search_base", + ] + for key in required_fields: + value = getattr(form_data, key) + if not value: + raise HTTPException(400, detail=f"Required field {key} is empty") + + if form_data.use_tls and not form_data.certificate_path: + raise HTTPException( + 400, detail="TLS is enabled but certificate file path is missing" + ) + + request.app.state.config.LDAP_SERVER_LABEL = form_data.label + request.app.state.config.LDAP_SERVER_HOST = form_data.host + request.app.state.config.LDAP_SERVER_PORT = form_data.port + request.app.state.config.LDAP_ATTRIBUTE_FOR_USERNAME = ( + form_data.attribute_for_username + ) + request.app.state.config.LDAP_APP_DN = form_data.app_dn + request.app.state.config.LDAP_APP_PASSWORD = form_data.app_dn_password + request.app.state.config.LDAP_SEARCH_BASE = form_data.search_base + request.app.state.config.LDAP_SEARCH_FILTERS = form_data.search_filters + request.app.state.config.LDAP_USE_TLS = form_data.use_tls + request.app.state.config.LDAP_CA_CERT_FILE = form_data.certificate_path + request.app.state.config.LDAP_CIPHERS = form_data.ciphers + + return { + "label": request.app.state.config.LDAP_SERVER_LABEL, + "host": request.app.state.config.LDAP_SERVER_HOST, + "port": request.app.state.config.LDAP_SERVER_PORT, + "attribute_for_username": request.app.state.config.LDAP_ATTRIBUTE_FOR_USERNAME, + "app_dn": request.app.state.config.LDAP_APP_DN, + "app_dn_password": request.app.state.config.LDAP_APP_PASSWORD, + "search_base": request.app.state.config.LDAP_SEARCH_BASE, + "search_filters": request.app.state.config.LDAP_SEARCH_FILTERS, + "use_tls": request.app.state.config.LDAP_USE_TLS, + "certificate_path": request.app.state.config.LDAP_CA_CERT_FILE, + "ciphers": request.app.state.config.LDAP_CIPHERS, + } + + +@router.get("/admin/config/ldap") +async def get_ldap_config(request: Request, user=Depends(get_admin_user)): + return {"ENABLE_LDAP": request.app.state.config.ENABLE_LDAP} + + +class LdapConfigForm(BaseModel): + enable_ldap: Optional[bool] = None + + +@router.post("/admin/config/ldap") +async def update_ldap_config( + request: Request, form_data: LdapConfigForm, user=Depends(get_admin_user) +): + request.app.state.config.ENABLE_LDAP = form_data.enable_ldap + return {"ENABLE_LDAP": request.app.state.config.ENABLE_LDAP} + + ############################ # API Key ############################ @@ -467,9 +772,16 @@ async def update_admin_config( # create api key @router.post("/api_key", response_model=ApiKey) -async def create_api_key_(user=Depends(get_current_user)): +async def generate_api_key(request: Request, user=Depends(get_current_user)): + if not request.app.state.config.ENABLE_API_KEY: + raise HTTPException( + status.HTTP_403_FORBIDDEN, + detail=ERROR_MESSAGES.API_KEY_CREATION_NOT_ALLOWED, + ) + api_key = create_api_key() success = Users.update_user_api_key_by_id(user.id, api_key) + if success: return { "api_key": api_key, diff --git a/backend/open_webui/apps/webui/routers/chats.py b/backend/open_webui/routers/chats.py similarity index 97% rename from backend/open_webui/apps/webui/routers/chats.py rename to backend/open_webui/routers/chats.py index b149b2eb4..5e0e75e24 100644 --- a/backend/open_webui/apps/webui/routers/chats.py +++ b/backend/open_webui/routers/chats.py @@ -2,22 +2,25 @@ import json import logging from typing import Optional -from open_webui.apps.webui.models.chats import ( +from open_webui.models.chats import ( ChatForm, ChatImportForm, ChatResponse, Chats, ChatTitleIdResponse, ) -from open_webui.apps.webui.models.tags import TagModel, Tags -from open_webui.apps.webui.models.folders import Folders +from open_webui.models.tags import TagModel, Tags +from open_webui.models.folders import Folders from open_webui.config import ENABLE_ADMIN_CHAT_ACCESS, ENABLE_ADMIN_EXPORT from open_webui.constants import ERROR_MESSAGES from open_webui.env import SRC_LOG_LEVELS from fastapi import APIRouter, Depends, HTTPException, Request, status from pydantic import BaseModel -from open_webui.utils.utils import get_admin_user, get_verified_user + + +from open_webui.utils.auth import get_admin_user, get_verified_user +from open_webui.utils.access_control import has_permission log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MODELS"]) @@ -50,9 +53,10 @@ async def get_session_user_chat_list( @router.delete("/", response_model=bool) async def delete_all_user_chats(request: Request, user=Depends(get_verified_user)): - if user.role == "user" and not request.app.state.config.USER_PERMISSIONS.get( - "chat", {} - ).get("deletion", {}): + + if user.role == "user" and not has_permission( + user.id, "chat.delete", request.app.state.config.USER_PERMISSIONS + ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED, @@ -385,8 +389,8 @@ async def delete_chat_by_id(request: Request, id: str, user=Depends(get_verified return result else: - if not request.app.state.config.USER_PERMISSIONS.get("chat", {}).get( - "deletion", {} + if not has_permission( + user.id, "chat.delete", request.app.state.config.USER_PERMISSIONS ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -603,7 +607,6 @@ async def add_tag_by_id_and_tag_name( detail=ERROR_MESSAGES.DEFAULT("Tag name cannot be 'None'"), ) - print(tags, tag_id) if tag_id not in tags: Chats.add_chat_tag_by_id_and_user_id_and_tag_name( id, user.id, form_data.name diff --git a/backend/open_webui/apps/webui/routers/configs.py b/backend/open_webui/routers/configs.py similarity index 63% rename from backend/open_webui/apps/webui/routers/configs.py rename to backend/open_webui/routers/configs.py index 1c30b0b3b..ef6c4d8c1 100644 --- a/backend/open_webui/apps/webui/routers/configs.py +++ b/backend/open_webui/routers/configs.py @@ -1,10 +1,12 @@ -from open_webui.config import BannerModel from fastapi import APIRouter, Depends, Request from pydantic import BaseModel -from open_webui.utils.utils import get_admin_user, get_verified_user +from typing import Optional +from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.config import get_config, save_config +from open_webui.config import BannerModel + router = APIRouter() @@ -34,8 +36,32 @@ async def export_config(user=Depends(get_admin_user)): return get_config() -class SetDefaultModelsForm(BaseModel): - models: str +############################ +# SetDefaultModels +############################ +class ModelsConfigForm(BaseModel): + DEFAULT_MODELS: Optional[str] + MODEL_ORDER_LIST: Optional[list[str]] + + +@router.get("/models", response_model=ModelsConfigForm) +async def get_models_config(request: Request, user=Depends(get_admin_user)): + return { + "DEFAULT_MODELS": request.app.state.config.DEFAULT_MODELS, + "MODEL_ORDER_LIST": request.app.state.config.MODEL_ORDER_LIST, + } + + +@router.post("/models", response_model=ModelsConfigForm) +async def set_models_config( + request: Request, form_data: ModelsConfigForm, user=Depends(get_admin_user) +): + request.app.state.config.DEFAULT_MODELS = form_data.DEFAULT_MODELS + request.app.state.config.MODEL_ORDER_LIST = form_data.MODEL_ORDER_LIST + return { + "DEFAULT_MODELS": request.app.state.config.DEFAULT_MODELS, + "MODEL_ORDER_LIST": request.app.state.config.MODEL_ORDER_LIST, + } class PromptSuggestion(BaseModel): @@ -47,21 +73,8 @@ class SetDefaultSuggestionsForm(BaseModel): suggestions: list[PromptSuggestion] -############################ -# SetDefaultModels -############################ - - -@router.post("/default/models", response_model=str) -async def set_global_default_models( - request: Request, form_data: SetDefaultModelsForm, user=Depends(get_admin_user) -): - request.app.state.config.DEFAULT_MODELS = form_data.models - return request.app.state.config.DEFAULT_MODELS - - -@router.post("/default/suggestions", response_model=list[PromptSuggestion]) -async def set_global_default_suggestions( +@router.post("/suggestions", response_model=list[PromptSuggestion]) +async def set_default_suggestions( request: Request, form_data: SetDefaultSuggestionsForm, user=Depends(get_admin_user), diff --git a/backend/open_webui/apps/webui/routers/evaluations.py b/backend/open_webui/routers/evaluations.py similarity index 96% rename from backend/open_webui/apps/webui/routers/evaluations.py rename to backend/open_webui/routers/evaluations.py index b9e3bff29..f0c4a6b06 100644 --- a/backend/open_webui/apps/webui/routers/evaluations.py +++ b/backend/open_webui/routers/evaluations.py @@ -2,8 +2,8 @@ from typing import Optional from fastapi import APIRouter, Depends, HTTPException, status, Request from pydantic import BaseModel -from open_webui.apps.webui.models.users import Users, UserModel -from open_webui.apps.webui.models.feedbacks import ( +from open_webui.models.users import Users, UserModel +from open_webui.models.feedbacks import ( FeedbackModel, FeedbackResponse, FeedbackForm, @@ -11,7 +11,7 @@ from open_webui.apps.webui.models.feedbacks import ( ) from open_webui.constants import ERROR_MESSAGES -from open_webui.utils.utils import get_admin_user, get_verified_user +from open_webui.utils.auth import get_admin_user, get_verified_user router = APIRouter() diff --git a/backend/open_webui/apps/webui/routers/files.py b/backend/open_webui/routers/files.py similarity index 89% rename from backend/open_webui/apps/webui/routers/files.py rename to backend/open_webui/routers/files.py index b8695eb67..fa36a03ea 100644 --- a/backend/open_webui/apps/webui/routers/files.py +++ b/backend/open_webui/routers/files.py @@ -5,27 +5,28 @@ from pathlib import Path from typing import Optional from pydantic import BaseModel import mimetypes +from urllib.parse import quote from open_webui.storage.provider import Storage -from open_webui.apps.webui.models.files import ( +from open_webui.models.files import ( FileForm, FileModel, FileModelResponse, Files, ) -from open_webui.apps.retrieval.main import process_file, ProcessFileForm +from open_webui.routers.retrieval import process_file, ProcessFileForm from open_webui.config import UPLOAD_DIR from open_webui.env import SRC_LOG_LEVELS from open_webui.constants import ERROR_MESSAGES -from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status +from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status, Request from fastapi.responses import FileResponse, StreamingResponse -from open_webui.utils.utils import get_admin_user, get_verified_user +from open_webui.utils.auth import get_admin_user, get_verified_user log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MODELS"]) @@ -39,7 +40,9 @@ router = APIRouter() @router.post("/", response_model=FileModelResponse) -def upload_file(file: UploadFile = File(...), user=Depends(get_verified_user)): +def upload_file( + request: Request, file: UploadFile = File(...), user=Depends(get_verified_user) +): log.info(f"file.content_type: {file.content_type}") try: unsanitized_filename = file.filename @@ -56,7 +59,7 @@ def upload_file(file: UploadFile = File(...), user=Depends(get_verified_user)): FileForm( **{ "id": id, - "filename": filename, + "filename": name, "path": file_path, "meta": { "name": name, @@ -68,7 +71,7 @@ def upload_file(file: UploadFile = File(...), user=Depends(get_verified_user)): ) try: - process_file(ProcessFileForm(file_id=id)) + process_file(request, ProcessFileForm(file_id=id)) file_item = Files.get_file_by_id(id=id) except Exception as e: log.exception(e) @@ -183,13 +186,15 @@ class ContentForm(BaseModel): @router.post("/{id}/data/content/update") async def update_file_data_content_by_id( - id: str, form_data: ContentForm, user=Depends(get_verified_user) + request: Request, id: str, form_data: ContentForm, user=Depends(get_verified_user) ): file = Files.get_file_by_id(id) if file and (file.user_id == user.id or user.role == "admin"): try: - process_file(ProcessFileForm(file_id=id, content=form_data.content)) + process_file( + request, ProcessFileForm(file_id=id, content=form_data.content) + ) file = Files.get_file_by_id(id=id) except Exception as e: log.exception(e) @@ -218,11 +223,15 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)): # Check if the file already exists in the cache if file_path.is_file(): - print(f"file_path: {file_path}") + # Handle Unicode filenames + filename = file.meta.get("name", file.filename) + encoded_filename = quote(filename) # RFC5987 encoding headers = { - "Content-Disposition": f'attachment; filename="{file.meta.get("name", file.filename)}"' + "Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}" } + return FileResponse(file_path, headers=headers) + else: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, @@ -279,16 +288,20 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)): if file and (file.user_id == user.id or user.role == "admin"): file_path = file.path + + # Handle Unicode filenames + filename = file.meta.get("name", file.filename) + encoded_filename = quote(filename) # RFC5987 encoding + headers = { + "Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}" + } + if file_path: file_path = Storage.get_file(file_path) file_path = Path(file_path) # Check if the file already exists in the cache if file_path.is_file(): - print(f"file_path: {file_path}") - headers = { - "Content-Disposition": f'attachment; filename="{file.meta.get("name", file.filename)}"' - } return FileResponse(file_path, headers=headers) else: raise HTTPException( @@ -307,7 +320,7 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)): return StreamingResponse( generator(), media_type="text/plain", - headers={"Content-Disposition": f"attachment; filename={file_name}"}, + headers=headers, ) else: raise HTTPException( diff --git a/backend/open_webui/apps/webui/routers/folders.py b/backend/open_webui/routers/folders.py similarity index 97% rename from backend/open_webui/apps/webui/routers/folders.py rename to backend/open_webui/routers/folders.py index 36075c357..ca2fbd213 100644 --- a/backend/open_webui/apps/webui/routers/folders.py +++ b/backend/open_webui/routers/folders.py @@ -8,12 +8,12 @@ from pydantic import BaseModel import mimetypes -from open_webui.apps.webui.models.folders import ( +from open_webui.models.folders import ( FolderForm, FolderModel, Folders, ) -from open_webui.apps.webui.models.chats import Chats +from open_webui.models.chats import Chats from open_webui.config import UPLOAD_DIR from open_webui.env import SRC_LOG_LEVELS @@ -24,7 +24,7 @@ from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status from fastapi.responses import FileResponse, StreamingResponse -from open_webui.utils.utils import get_admin_user, get_verified_user +from open_webui.utils.auth import get_admin_user, get_verified_user log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MODELS"]) diff --git a/backend/open_webui/apps/webui/routers/functions.py b/backend/open_webui/routers/functions.py similarity index 98% rename from backend/open_webui/apps/webui/routers/functions.py rename to backend/open_webui/routers/functions.py index aeaceecfb..7f3305f25 100644 --- a/backend/open_webui/apps/webui/routers/functions.py +++ b/backend/open_webui/routers/functions.py @@ -2,17 +2,17 @@ import os from pathlib import Path from typing import Optional -from open_webui.apps.webui.models.functions import ( +from open_webui.models.functions import ( FunctionForm, FunctionModel, FunctionResponse, Functions, ) -from open_webui.apps.webui.utils import load_function_module_by_id, replace_imports +from open_webui.utils.plugin import load_function_module_by_id, replace_imports from open_webui.config import CACHE_DIR from open_webui.constants import ERROR_MESSAGES from fastapi import APIRouter, Depends, HTTPException, Request, status -from open_webui.utils.utils import get_admin_user, get_verified_user +from open_webui.utils.auth import get_admin_user, get_verified_user router = APIRouter() diff --git a/backend/open_webui/routers/groups.py b/backend/open_webui/routers/groups.py new file mode 100644 index 000000000..e8f8994a4 --- /dev/null +++ b/backend/open_webui/routers/groups.py @@ -0,0 +1,120 @@ +import os +from pathlib import Path +from typing import Optional + +from open_webui.models.groups import ( + Groups, + GroupForm, + GroupUpdateForm, + GroupResponse, +) + +from open_webui.config import CACHE_DIR +from open_webui.constants import ERROR_MESSAGES +from fastapi import APIRouter, Depends, HTTPException, Request, status +from open_webui.utils.auth import get_admin_user, get_verified_user + +router = APIRouter() + +############################ +# GetFunctions +############################ + + +@router.get("/", response_model=list[GroupResponse]) +async def get_groups(user=Depends(get_verified_user)): + if user.role == "admin": + return Groups.get_groups() + else: + return Groups.get_groups_by_member_id(user.id) + + +############################ +# CreateNewGroup +############################ + + +@router.post("/create", response_model=Optional[GroupResponse]) +async def create_new_function(form_data: GroupForm, user=Depends(get_admin_user)): + try: + group = Groups.insert_new_group(user.id, form_data) + if group: + return group + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT("Error creating group"), + ) + except Exception as e: + print(e) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT(e), + ) + + +############################ +# GetGroupById +############################ + + +@router.get("/id/{id}", response_model=Optional[GroupResponse]) +async def get_group_by_id(id: str, user=Depends(get_admin_user)): + group = Groups.get_group_by_id(id) + if group: + return group + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + +############################ +# UpdateGroupById +############################ + + +@router.post("/id/{id}/update", response_model=Optional[GroupResponse]) +async def update_group_by_id( + id: str, form_data: GroupUpdateForm, user=Depends(get_admin_user) +): + try: + group = Groups.update_group_by_id(id, form_data) + if group: + return group + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT("Error updating group"), + ) + except Exception as e: + print(e) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT(e), + ) + + +############################ +# DeleteGroupById +############################ + + +@router.delete("/id/{id}/delete", response_model=bool) +async def delete_group_by_id(id: str, user=Depends(get_admin_user)): + try: + result = Groups.delete_group_by_id(id) + if result: + return result + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT("Error deleting group"), + ) + except Exception as e: + print(e) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT(e), + ) diff --git a/backend/open_webui/apps/images/main.py b/backend/open_webui/routers/images.py similarity index 55% rename from backend/open_webui/apps/images/main.py rename to backend/open_webui/routers/images.py index 56b251ef6..3f51fbdb4 100644 --- a/backend/open_webui/apps/images/main.py +++ b/backend/open_webui/routers/images.py @@ -9,37 +9,24 @@ from pathlib import Path from typing import Optional import requests -from open_webui.apps.images.utils.comfyui import ( + + +from fastapi import Depends, FastAPI, HTTPException, Request, APIRouter +from fastapi.middleware.cors import CORSMiddleware +from pydantic import BaseModel + + +from open_webui.config import CACHE_DIR +from open_webui.constants import ERROR_MESSAGES +from open_webui.env import ENV, SRC_LOG_LEVELS, ENABLE_FORWARD_USER_INFO_HEADERS + +from open_webui.utils.auth import get_admin_user, get_verified_user +from open_webui.utils.images.comfyui import ( ComfyUIGenerateImageForm, ComfyUIWorkflow, comfyui_generate_image, ) -from open_webui.config import ( - AUTOMATIC1111_API_AUTH, - AUTOMATIC1111_BASE_URL, - AUTOMATIC1111_CFG_SCALE, - AUTOMATIC1111_SAMPLER, - AUTOMATIC1111_SCHEDULER, - CACHE_DIR, - COMFYUI_BASE_URL, - COMFYUI_WORKFLOW, - COMFYUI_WORKFLOW_NODES, - CORS_ALLOW_ORIGIN, - ENABLE_IMAGE_GENERATION, - IMAGE_GENERATION_ENGINE, - IMAGE_GENERATION_MODEL, - IMAGE_SIZE, - IMAGE_STEPS, - IMAGES_OPENAI_API_BASE_URL, - IMAGES_OPENAI_API_KEY, - AppConfig, -) -from open_webui.constants import ERROR_MESSAGES -from open_webui.env import ENV, SRC_LOG_LEVELS -from fastapi import Depends, FastAPI, HTTPException, Request -from fastapi.middleware.cors import CORSMiddleware -from pydantic import BaseModel -from open_webui.utils.utils import get_admin_user, get_verified_user + log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["IMAGES"]) @@ -47,59 +34,30 @@ log.setLevel(SRC_LOG_LEVELS["IMAGES"]) IMAGE_CACHE_DIR = Path(CACHE_DIR).joinpath("./image/generations/") IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True) -app = FastAPI(docs_url="/docs" if ENV == "dev" else None, openapi_url="/openapi.json" if ENV == "dev" else None, redoc_url=None) -app.add_middleware( - CORSMiddleware, - allow_origins=CORS_ALLOW_ORIGIN, - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) - -app.state.config = AppConfig() - -app.state.config.ENGINE = IMAGE_GENERATION_ENGINE -app.state.config.ENABLED = ENABLE_IMAGE_GENERATION - -app.state.config.OPENAI_API_BASE_URL = IMAGES_OPENAI_API_BASE_URL -app.state.config.OPENAI_API_KEY = IMAGES_OPENAI_API_KEY - -app.state.config.MODEL = IMAGE_GENERATION_MODEL - -app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL -app.state.config.AUTOMATIC1111_API_AUTH = AUTOMATIC1111_API_AUTH -app.state.config.AUTOMATIC1111_CFG_SCALE = AUTOMATIC1111_CFG_SCALE -app.state.config.AUTOMATIC1111_SAMPLER = AUTOMATIC1111_SAMPLER -app.state.config.AUTOMATIC1111_SCHEDULER = AUTOMATIC1111_SCHEDULER -app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL -app.state.config.COMFYUI_WORKFLOW = COMFYUI_WORKFLOW -app.state.config.COMFYUI_WORKFLOW_NODES = COMFYUI_WORKFLOW_NODES - -app.state.config.IMAGE_SIZE = IMAGE_SIZE -app.state.config.IMAGE_STEPS = IMAGE_STEPS +router = APIRouter() -@app.get("/config") +@router.get("/config") async def get_config(request: Request, user=Depends(get_admin_user)): return { - "enabled": app.state.config.ENABLED, - "engine": app.state.config.ENGINE, + "enabled": request.app.state.config.ENABLE_IMAGE_GENERATION, + "engine": request.app.state.config.IMAGE_GENERATION_ENGINE, "openai": { - "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL, - "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY, + "OPENAI_API_BASE_URL": request.app.state.config.IMAGES_OPENAI_API_BASE_URL, + "OPENAI_API_KEY": request.app.state.config.IMAGES_OPENAI_API_KEY, }, "automatic1111": { - "AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL, - "AUTOMATIC1111_API_AUTH": app.state.config.AUTOMATIC1111_API_AUTH, - "AUTOMATIC1111_CFG_SCALE": app.state.config.AUTOMATIC1111_CFG_SCALE, - "AUTOMATIC1111_SAMPLER": app.state.config.AUTOMATIC1111_SAMPLER, - "AUTOMATIC1111_SCHEDULER": app.state.config.AUTOMATIC1111_SCHEDULER, + "AUTOMATIC1111_BASE_URL": request.app.state.config.AUTOMATIC1111_BASE_URL, + "AUTOMATIC1111_API_AUTH": request.app.state.config.AUTOMATIC1111_API_AUTH, + "AUTOMATIC1111_CFG_SCALE": request.app.state.config.AUTOMATIC1111_CFG_SCALE, + "AUTOMATIC1111_SAMPLER": request.app.state.config.AUTOMATIC1111_SAMPLER, + "AUTOMATIC1111_SCHEDULER": request.app.state.config.AUTOMATIC1111_SCHEDULER, }, "comfyui": { - "COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL, - "COMFYUI_WORKFLOW": app.state.config.COMFYUI_WORKFLOW, - "COMFYUI_WORKFLOW_NODES": app.state.config.COMFYUI_WORKFLOW_NODES, + "COMFYUI_BASE_URL": request.app.state.config.COMFYUI_BASE_URL, + "COMFYUI_WORKFLOW": request.app.state.config.COMFYUI_WORKFLOW, + "COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES, }, } @@ -112,7 +70,7 @@ class OpenAIConfigForm(BaseModel): class Automatic1111ConfigForm(BaseModel): AUTOMATIC1111_BASE_URL: str AUTOMATIC1111_API_AUTH: str - AUTOMATIC1111_CFG_SCALE: Optional[str] + AUTOMATIC1111_CFG_SCALE: Optional[str | float | int] AUTOMATIC1111_SAMPLER: Optional[str] AUTOMATIC1111_SCHEDULER: Optional[str] @@ -131,133 +89,156 @@ class ConfigForm(BaseModel): comfyui: ComfyUIConfigForm -@app.post("/config/update") -async def update_config(form_data: ConfigForm, user=Depends(get_admin_user)): - app.state.config.ENGINE = form_data.engine - app.state.config.ENABLED = form_data.enabled +@router.post("/config/update") +async def update_config( + request: Request, form_data: ConfigForm, user=Depends(get_admin_user) +): + request.app.state.config.IMAGE_GENERATION_ENGINE = form_data.engine + request.app.state.config.ENABLE_IMAGE_GENERATION = form_data.enabled - app.state.config.OPENAI_API_BASE_URL = form_data.openai.OPENAI_API_BASE_URL - app.state.config.OPENAI_API_KEY = form_data.openai.OPENAI_API_KEY + request.app.state.config.IMAGES_OPENAI_API_BASE_URL = ( + form_data.openai.OPENAI_API_BASE_URL + ) + request.app.state.config.IMAGES_OPENAI_API_KEY = form_data.openai.OPENAI_API_KEY - app.state.config.AUTOMATIC1111_BASE_URL = ( + request.app.state.config.AUTOMATIC1111_BASE_URL = ( form_data.automatic1111.AUTOMATIC1111_BASE_URL ) - app.state.config.AUTOMATIC1111_API_AUTH = ( + request.app.state.config.AUTOMATIC1111_API_AUTH = ( form_data.automatic1111.AUTOMATIC1111_API_AUTH ) - app.state.config.AUTOMATIC1111_CFG_SCALE = ( + request.app.state.config.AUTOMATIC1111_CFG_SCALE = ( float(form_data.automatic1111.AUTOMATIC1111_CFG_SCALE) if form_data.automatic1111.AUTOMATIC1111_CFG_SCALE else None ) - app.state.config.AUTOMATIC1111_SAMPLER = ( + request.app.state.config.AUTOMATIC1111_SAMPLER = ( form_data.automatic1111.AUTOMATIC1111_SAMPLER if form_data.automatic1111.AUTOMATIC1111_SAMPLER else None ) - app.state.config.AUTOMATIC1111_SCHEDULER = ( + request.app.state.config.AUTOMATIC1111_SCHEDULER = ( form_data.automatic1111.AUTOMATIC1111_SCHEDULER if form_data.automatic1111.AUTOMATIC1111_SCHEDULER else None ) - app.state.config.COMFYUI_BASE_URL = form_data.comfyui.COMFYUI_BASE_URL.strip("/") - app.state.config.COMFYUI_WORKFLOW = form_data.comfyui.COMFYUI_WORKFLOW - app.state.config.COMFYUI_WORKFLOW_NODES = form_data.comfyui.COMFYUI_WORKFLOW_NODES + request.app.state.config.COMFYUI_BASE_URL = ( + form_data.comfyui.COMFYUI_BASE_URL.strip("/") + ) + request.app.state.config.COMFYUI_WORKFLOW = form_data.comfyui.COMFYUI_WORKFLOW + request.app.state.config.COMFYUI_WORKFLOW_NODES = ( + form_data.comfyui.COMFYUI_WORKFLOW_NODES + ) return { - "enabled": app.state.config.ENABLED, - "engine": app.state.config.ENGINE, + "enabled": request.app.state.config.ENABLE_IMAGE_GENERATION, + "engine": request.app.state.config.IMAGE_GENERATION_ENGINE, "openai": { - "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL, - "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY, + "OPENAI_API_BASE_URL": request.app.state.config.IMAGES_OPENAI_API_BASE_URL, + "OPENAI_API_KEY": request.app.state.config.IMAGES_OPENAI_API_KEY, }, "automatic1111": { - "AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL, - "AUTOMATIC1111_API_AUTH": app.state.config.AUTOMATIC1111_API_AUTH, - "AUTOMATIC1111_CFG_SCALE": app.state.config.AUTOMATIC1111_CFG_SCALE, - "AUTOMATIC1111_SAMPLER": app.state.config.AUTOMATIC1111_SAMPLER, - "AUTOMATIC1111_SCHEDULER": app.state.config.AUTOMATIC1111_SCHEDULER, + "AUTOMATIC1111_BASE_URL": request.app.state.config.AUTOMATIC1111_BASE_URL, + "AUTOMATIC1111_API_AUTH": request.app.state.config.AUTOMATIC1111_API_AUTH, + "AUTOMATIC1111_CFG_SCALE": request.app.state.config.AUTOMATIC1111_CFG_SCALE, + "AUTOMATIC1111_SAMPLER": request.app.state.config.AUTOMATIC1111_SAMPLER, + "AUTOMATIC1111_SCHEDULER": request.app.state.config.AUTOMATIC1111_SCHEDULER, }, "comfyui": { - "COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL, - "COMFYUI_WORKFLOW": app.state.config.COMFYUI_WORKFLOW, - "COMFYUI_WORKFLOW_NODES": app.state.config.COMFYUI_WORKFLOW_NODES, + "COMFYUI_BASE_URL": request.app.state.config.COMFYUI_BASE_URL, + "COMFYUI_WORKFLOW": request.app.state.config.COMFYUI_WORKFLOW, + "COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES, }, } -def get_automatic1111_api_auth(): - if app.state.config.AUTOMATIC1111_API_AUTH is None: +def get_automatic1111_api_auth(request: Request): + if request.app.state.config.AUTOMATIC1111_API_AUTH is None: return "" else: - auth1111_byte_string = app.state.config.AUTOMATIC1111_API_AUTH.encode("utf-8") + auth1111_byte_string = request.app.state.config.AUTOMATIC1111_API_AUTH.encode( + "utf-8" + ) auth1111_base64_encoded_bytes = base64.b64encode(auth1111_byte_string) auth1111_base64_encoded_string = auth1111_base64_encoded_bytes.decode("utf-8") return f"Basic {auth1111_base64_encoded_string}" -@app.get("/config/url/verify") -async def verify_url(user=Depends(get_admin_user)): - if app.state.config.ENGINE == "automatic1111": +@router.get("/config/url/verify") +async def verify_url(request: Request, user=Depends(get_admin_user)): + if request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111": try: r = requests.get( - url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", - headers={"authorization": get_automatic1111_api_auth()}, + url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", + headers={"authorization": get_automatic1111_api_auth(request)}, ) r.raise_for_status() return True except Exception: - app.state.config.ENABLED = False + request.app.state.config.ENABLE_IMAGE_GENERATION = False raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL) - elif app.state.config.ENGINE == "comfyui": + elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui": try: - r = requests.get(url=f"{app.state.config.COMFYUI_BASE_URL}/object_info") + r = requests.get( + url=f"{request.app.state.config.COMFYUI_BASE_URL}/object_info" + ) r.raise_for_status() return True except Exception: - app.state.config.ENABLED = False + request.app.state.config.ENABLE_IMAGE_GENERATION = False raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL) else: return True -def set_image_model(model: str): +def set_image_model(request: Request, model: str): log.info(f"Setting image model to {model}") - app.state.config.MODEL = model - if app.state.config.ENGINE in ["", "automatic1111"]: + request.app.state.config.IMAGE_GENERATION_MODEL = model + if request.app.state.config.IMAGE_GENERATION_ENGINE in ["", "automatic1111"]: api_auth = get_automatic1111_api_auth() r = requests.get( - url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", + url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", headers={"authorization": api_auth}, ) options = r.json() if model != options["sd_model_checkpoint"]: options["sd_model_checkpoint"] = model r = requests.post( - url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", + url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", json=options, headers={"authorization": api_auth}, ) - return app.state.config.MODEL + return request.app.state.config.IMAGE_GENERATION_MODEL -def get_image_model(): - if app.state.config.ENGINE == "openai": - return app.state.config.MODEL if app.state.config.MODEL else "dall-e-2" - elif app.state.config.ENGINE == "comfyui": - return app.state.config.MODEL if app.state.config.MODEL else "" - elif app.state.config.ENGINE == "automatic1111" or app.state.config.ENGINE == "": +def get_image_model(request): + if request.app.state.config.IMAGE_GENERATION_ENGINE == "openai": + return ( + request.app.state.config.IMAGE_GENERATION_MODEL + if request.app.state.config.IMAGE_GENERATION_MODEL + else "dall-e-2" + ) + elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui": + return ( + request.app.state.config.IMAGE_GENERATION_MODEL + if request.app.state.config.IMAGE_GENERATION_MODEL + else "" + ) + elif ( + request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111" + or request.app.state.config.IMAGE_GENERATION_ENGINE == "" + ): try: r = requests.get( - url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", + url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", headers={"authorization": get_automatic1111_api_auth()}, ) options = r.json() return options["sd_model_checkpoint"] except Exception as e: - app.state.config.ENABLED = False + request.app.state.config.ENABLE_IMAGE_GENERATION = False raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) @@ -267,23 +248,25 @@ class ImageConfigForm(BaseModel): IMAGE_STEPS: int -@app.get("/image/config") -async def get_image_config(user=Depends(get_admin_user)): +@router.get("/image/config") +async def get_image_config(request: Request, user=Depends(get_admin_user)): return { - "MODEL": app.state.config.MODEL, - "IMAGE_SIZE": app.state.config.IMAGE_SIZE, - "IMAGE_STEPS": app.state.config.IMAGE_STEPS, + "MODEL": request.app.state.config.IMAGE_GENERATION_MODEL, + "IMAGE_SIZE": request.app.state.config.IMAGE_SIZE, + "IMAGE_STEPS": request.app.state.config.IMAGE_STEPS, } -@app.post("/image/config/update") -async def update_image_config(form_data: ImageConfigForm, user=Depends(get_admin_user)): +@router.post("/image/config/update") +async def update_image_config( + request: Request, form_data: ImageConfigForm, user=Depends(get_admin_user) +): - set_image_model(form_data.MODEL) + set_image_model(request, form_data.MODEL) pattern = r"^\d+x\d+$" if re.match(pattern, form_data.IMAGE_SIZE): - app.state.config.IMAGE_SIZE = form_data.IMAGE_SIZE + request.app.state.config.IMAGE_SIZE = form_data.IMAGE_SIZE else: raise HTTPException( status_code=400, @@ -291,7 +274,7 @@ async def update_image_config(form_data: ImageConfigForm, user=Depends(get_admin ) if form_data.IMAGE_STEPS >= 0: - app.state.config.IMAGE_STEPS = form_data.IMAGE_STEPS + request.app.state.config.IMAGE_STEPS = form_data.IMAGE_STEPS else: raise HTTPException( status_code=400, @@ -299,29 +282,31 @@ async def update_image_config(form_data: ImageConfigForm, user=Depends(get_admin ) return { - "MODEL": app.state.config.MODEL, - "IMAGE_SIZE": app.state.config.IMAGE_SIZE, - "IMAGE_STEPS": app.state.config.IMAGE_STEPS, + "MODEL": request.app.state.config.IMAGE_GENERATION_MODEL, + "IMAGE_SIZE": request.app.state.config.IMAGE_SIZE, + "IMAGE_STEPS": request.app.state.config.IMAGE_STEPS, } -@app.get("/models") -def get_models(user=Depends(get_verified_user)): +@router.get("/models") +def get_models(request: Request, user=Depends(get_verified_user)): try: - if app.state.config.ENGINE == "openai": + if request.app.state.config.IMAGE_GENERATION_ENGINE == "openai": return [ {"id": "dall-e-2", "name": "DALL·E 2"}, {"id": "dall-e-3", "name": "DALL·E 3"}, ] - elif app.state.config.ENGINE == "comfyui": + elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui": # TODO - get models from comfyui - r = requests.get(url=f"{app.state.config.COMFYUI_BASE_URL}/object_info") + r = requests.get( + url=f"{request.app.state.config.COMFYUI_BASE_URL}/object_info" + ) info = r.json() - workflow = json.loads(app.state.config.COMFYUI_WORKFLOW) + workflow = json.loads(request.app.state.config.COMFYUI_WORKFLOW) model_node_id = None - for node in app.state.config.COMFYUI_WORKFLOW_NODES: + for node in request.app.state.config.COMFYUI_WORKFLOW_NODES: if node["type"] == "model": if node["node_ids"]: model_node_id = node["node_ids"][0] @@ -357,10 +342,11 @@ def get_models(user=Depends(get_verified_user)): ) ) elif ( - app.state.config.ENGINE == "automatic1111" or app.state.config.ENGINE == "" + request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111" + or request.app.state.config.IMAGE_GENERATION_ENGINE == "" ): r = requests.get( - url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models", + url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models", headers={"authorization": get_automatic1111_api_auth()}, ) models = r.json() @@ -371,7 +357,7 @@ def get_models(user=Depends(get_verified_user)): ) ) except Exception as e: - app.state.config.ENABLED = False + request.app.state.config.ENABLE_IMAGE_GENERATION = False raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) @@ -443,30 +429,41 @@ def save_url_image(url): return None -@app.post("/generations") +@router.post("/generations") async def image_generations( + request: Request, form_data: GenerateImageForm, user=Depends(get_verified_user), ): - width, height = tuple(map(int, app.state.config.IMAGE_SIZE.split("x"))) + width, height = tuple(map(int, request.app.state.config.IMAGE_SIZE.split("x"))) r = None try: - if app.state.config.ENGINE == "openai": + if request.app.state.config.IMAGE_GENERATION_ENGINE == "openai": headers = {} - headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEY}" + headers["Authorization"] = ( + f"Bearer {request.app.state.config.IMAGES_OPENAI_API_KEY}" + ) headers["Content-Type"] = "application/json" + if ENABLE_FORWARD_USER_INFO_HEADERS: + headers["X-OpenWebUI-User-Name"] = user.name + headers["X-OpenWebUI-User-Id"] = user.id + headers["X-OpenWebUI-User-Email"] = user.email + headers["X-OpenWebUI-User-Role"] = user.role + data = { "model": ( - app.state.config.MODEL - if app.state.config.MODEL != "" + request.app.state.config.IMAGE_GENERATION_MODEL + if request.app.state.config.IMAGE_GENERATION_MODEL != "" else "dall-e-2" ), "prompt": form_data.prompt, "n": form_data.n, "size": ( - form_data.size if form_data.size else app.state.config.IMAGE_SIZE + form_data.size + if form_data.size + else request.app.state.config.IMAGE_SIZE ), "response_format": "b64_json", } @@ -474,7 +471,7 @@ async def image_generations( # Use asyncio.to_thread for the requests.post call r = await asyncio.to_thread( requests.post, - url=f"{app.state.config.OPENAI_API_BASE_URL}/images/generations", + url=f"{request.app.state.config.IMAGES_OPENAI_API_BASE_URL}/images/generations", json=data, headers=headers, ) @@ -494,7 +491,7 @@ async def image_generations( return images - elif app.state.config.ENGINE == "comfyui": + elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui": data = { "prompt": form_data.prompt, "width": width, @@ -502,8 +499,8 @@ async def image_generations( "n": form_data.n, } - if app.state.config.IMAGE_STEPS is not None: - data["steps"] = app.state.config.IMAGE_STEPS + if request.app.state.config.IMAGE_STEPS is not None: + data["steps"] = request.app.state.config.IMAGE_STEPS if form_data.negative_prompt is not None: data["negative_prompt"] = form_data.negative_prompt @@ -512,18 +509,18 @@ async def image_generations( **{ "workflow": ComfyUIWorkflow( **{ - "workflow": app.state.config.COMFYUI_WORKFLOW, - "nodes": app.state.config.COMFYUI_WORKFLOW_NODES, + "workflow": request.app.state.config.COMFYUI_WORKFLOW, + "nodes": request.app.state.config.COMFYUI_WORKFLOW_NODES, } ), **data, } ) res = await comfyui_generate_image( - app.state.config.MODEL, + request.app.state.config.IMAGE_GENERATION_MODEL, form_data, user.id, - app.state.config.COMFYUI_BASE_URL, + request.app.state.config.COMFYUI_BASE_URL, ) log.debug(f"res: {res}") @@ -540,7 +537,8 @@ async def image_generations( log.debug(f"images: {images}") return images elif ( - app.state.config.ENGINE == "automatic1111" or app.state.config.ENGINE == "" + request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111" + or request.app.state.config.IMAGE_GENERATION_ENGINE == "" ): if form_data.model: set_image_model(form_data.model) @@ -552,25 +550,25 @@ async def image_generations( "height": height, } - if app.state.config.IMAGE_STEPS is not None: - data["steps"] = app.state.config.IMAGE_STEPS + if request.app.state.config.IMAGE_STEPS is not None: + data["steps"] = request.app.state.config.IMAGE_STEPS if form_data.negative_prompt is not None: data["negative_prompt"] = form_data.negative_prompt - if app.state.config.AUTOMATIC1111_CFG_SCALE: - data["cfg_scale"] = app.state.config.AUTOMATIC1111_CFG_SCALE + if request.app.state.config.AUTOMATIC1111_CFG_SCALE: + data["cfg_scale"] = request.app.state.config.AUTOMATIC1111_CFG_SCALE - if app.state.config.AUTOMATIC1111_SAMPLER: - data["sampler_name"] = app.state.config.AUTOMATIC1111_SAMPLER + if request.app.state.config.AUTOMATIC1111_SAMPLER: + data["sampler_name"] = request.app.state.config.AUTOMATIC1111_SAMPLER - if app.state.config.AUTOMATIC1111_SCHEDULER: - data["scheduler"] = app.state.config.AUTOMATIC1111_SCHEDULER + if request.app.state.config.AUTOMATIC1111_SCHEDULER: + data["scheduler"] = request.app.state.config.AUTOMATIC1111_SCHEDULER # Use asyncio.to_thread for the requests.post call r = await asyncio.to_thread( requests.post, - url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img", + url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img", json=data, headers={"authorization": get_automatic1111_api_auth()}, ) diff --git a/backend/open_webui/apps/webui/routers/knowledge.py b/backend/open_webui/routers/knowledge.py similarity index 50% rename from backend/open_webui/apps/webui/routers/knowledge.py rename to backend/open_webui/routers/knowledge.py index 1b5381a74..7f9947d7a 100644 --- a/backend/open_webui/apps/webui/routers/knowledge.py +++ b/backend/open_webui/routers/knowledge.py @@ -1,22 +1,25 @@ import json from typing import Optional, Union from pydantic import BaseModel -from fastapi import APIRouter, Depends, HTTPException, status +from fastapi import APIRouter, Depends, HTTPException, status, Request import logging -from open_webui.apps.webui.models.knowledge import ( +from open_webui.models.knowledge import ( Knowledges, - KnowledgeUpdateForm, KnowledgeForm, KnowledgeResponse, + KnowledgeUserResponse, ) -from open_webui.apps.webui.models.files import Files, FileModel -from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT -from open_webui.apps.retrieval.main import process_file, ProcessFileForm +from open_webui.models.files import Files, FileModel +from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT +from open_webui.routers.retrieval import process_file, ProcessFileForm from open_webui.constants import ERROR_MESSAGES -from open_webui.utils.utils import get_admin_user, get_verified_user +from open_webui.utils.auth import get_admin_user, get_verified_user +from open_webui.utils.access_control import has_access, has_permission + + from open_webui.env import SRC_LOG_LEVELS @@ -26,64 +29,103 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"]) router = APIRouter() ############################ -# GetKnowledgeItems +# getKnowledgeBases ############################ -@router.get( - "/", response_model=Optional[Union[list[KnowledgeResponse], KnowledgeResponse]] -) -async def get_knowledge_items( - id: Optional[str] = None, user=Depends(get_verified_user) -): - if id: - knowledge = Knowledges.get_knowledge_by_id(id=id) +@router.get("/", response_model=list[KnowledgeUserResponse]) +async def get_knowledge(user=Depends(get_verified_user)): + knowledge_bases = [] - if knowledge: - return knowledge - else: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail=ERROR_MESSAGES.NOT_FOUND, - ) + if user.role == "admin": + knowledge_bases = Knowledges.get_knowledge_bases() else: - knowledge_bases = [] + knowledge_bases = Knowledges.get_knowledge_bases_by_user_id(user.id, "read") - for knowledge in Knowledges.get_knowledge_items(): - - files = [] - if knowledge.data: - files = Files.get_file_metadatas_by_ids( - knowledge.data.get("file_ids", []) - ) - - # Check if all files exist - if len(files) != len(knowledge.data.get("file_ids", [])): - missing_files = list( - set(knowledge.data.get("file_ids", [])) - - set([file.id for file in files]) - ) - if missing_files: - data = knowledge.data or {} - file_ids = data.get("file_ids", []) - - for missing_file in missing_files: - file_ids.remove(missing_file) - - data["file_ids"] = file_ids - Knowledges.update_knowledge_by_id( - id=knowledge.id, form_data=KnowledgeUpdateForm(data=data) - ) - - files = Files.get_file_metadatas_by_ids(file_ids) - - knowledge_bases.append( - KnowledgeResponse( - **knowledge.model_dump(), - files=files, - ) + # Get files for each knowledge base + knowledge_with_files = [] + for knowledge_base in knowledge_bases: + files = [] + if knowledge_base.data: + files = Files.get_file_metadatas_by_ids( + knowledge_base.data.get("file_ids", []) ) - return knowledge_bases + + # Check if all files exist + if len(files) != len(knowledge_base.data.get("file_ids", [])): + missing_files = list( + set(knowledge_base.data.get("file_ids", [])) + - set([file.id for file in files]) + ) + if missing_files: + data = knowledge_base.data or {} + file_ids = data.get("file_ids", []) + + for missing_file in missing_files: + file_ids.remove(missing_file) + + data["file_ids"] = file_ids + Knowledges.update_knowledge_data_by_id( + id=knowledge_base.id, data=data + ) + + files = Files.get_file_metadatas_by_ids(file_ids) + + knowledge_with_files.append( + KnowledgeUserResponse( + **knowledge_base.model_dump(), + files=files, + ) + ) + + return knowledge_with_files + + +@router.get("/list", response_model=list[KnowledgeUserResponse]) +async def get_knowledge_list(user=Depends(get_verified_user)): + knowledge_bases = [] + + if user.role == "admin": + knowledge_bases = Knowledges.get_knowledge_bases() + else: + knowledge_bases = Knowledges.get_knowledge_bases_by_user_id(user.id, "write") + + # Get files for each knowledge base + knowledge_with_files = [] + for knowledge_base in knowledge_bases: + files = [] + if knowledge_base.data: + files = Files.get_file_metadatas_by_ids( + knowledge_base.data.get("file_ids", []) + ) + + # Check if all files exist + if len(files) != len(knowledge_base.data.get("file_ids", [])): + missing_files = list( + set(knowledge_base.data.get("file_ids", [])) + - set([file.id for file in files]) + ) + if missing_files: + data = knowledge_base.data or {} + file_ids = data.get("file_ids", []) + + for missing_file in missing_files: + file_ids.remove(missing_file) + + data["file_ids"] = file_ids + Knowledges.update_knowledge_data_by_id( + id=knowledge_base.id, data=data + ) + + files = Files.get_file_metadatas_by_ids(file_ids) + + knowledge_with_files.append( + KnowledgeUserResponse( + **knowledge_base.model_dump(), + files=files, + ) + ) + return knowledge_with_files ############################ @@ -92,7 +134,17 @@ async def get_knowledge_items( @router.post("/create", response_model=Optional[KnowledgeResponse]) -async def create_new_knowledge(form_data: KnowledgeForm, user=Depends(get_admin_user)): +async def create_new_knowledge( + request: Request, form_data: KnowledgeForm, user=Depends(get_verified_user) +): + if user.role != "admin" and not has_permission( + user.id, "workspace.knowledge", request.app.state.config.USER_PERMISSIONS + ): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.UNAUTHORIZED, + ) + knowledge = Knowledges.insert_new_knowledge(user.id, form_data) if knowledge: @@ -118,13 +170,20 @@ async def get_knowledge_by_id(id: str, user=Depends(get_verified_user)): knowledge = Knowledges.get_knowledge_by_id(id=id) if knowledge: - file_ids = knowledge.data.get("file_ids", []) if knowledge.data else [] - files = Files.get_files_by_ids(file_ids) - return KnowledgeFilesResponse( - **knowledge.model_dump(), - files=files, - ) + if ( + user.role == "admin" + or knowledge.user_id == user.id + or has_access(user.id, "read", knowledge.access_control) + ): + + file_ids = knowledge.data.get("file_ids", []) if knowledge.data else [] + files = Files.get_files_by_ids(file_ids) + + return KnowledgeFilesResponse( + **knowledge.model_dump(), + files=files, + ) else: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -140,11 +199,23 @@ async def get_knowledge_by_id(id: str, user=Depends(get_verified_user)): @router.post("/{id}/update", response_model=Optional[KnowledgeFilesResponse]) async def update_knowledge_by_id( id: str, - form_data: KnowledgeUpdateForm, - user=Depends(get_admin_user), + form_data: KnowledgeForm, + user=Depends(get_verified_user), ): - knowledge = Knowledges.update_knowledge_by_id(id=id, form_data=form_data) + knowledge = Knowledges.get_knowledge_by_id(id=id) + if not knowledge: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + if knowledge.user_id != user.id and user.role != "admin": + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) + + knowledge = Knowledges.update_knowledge_by_id(id=id, form_data=form_data) if knowledge: file_ids = knowledge.data.get("file_ids", []) if knowledge.data else [] files = Files.get_files_by_ids(file_ids) @@ -171,11 +242,25 @@ class KnowledgeFileIdForm(BaseModel): @router.post("/{id}/file/add", response_model=Optional[KnowledgeFilesResponse]) def add_file_to_knowledge_by_id( + request: Request, id: str, form_data: KnowledgeFileIdForm, - user=Depends(get_admin_user), + user=Depends(get_verified_user), ): knowledge = Knowledges.get_knowledge_by_id(id=id) + + if not knowledge: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + if knowledge.user_id != user.id and user.role != "admin": + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) + file = Files.get_file_by_id(form_data.file_id) if not file: raise HTTPException( @@ -190,7 +275,9 @@ def add_file_to_knowledge_by_id( # Add content to the vector database try: - process_file(ProcessFileForm(file_id=form_data.file_id, collection_name=id)) + process_file( + request, ProcessFileForm(file_id=form_data.file_id, collection_name=id) + ) except Exception as e: log.debug(e) raise HTTPException( @@ -206,9 +293,7 @@ def add_file_to_knowledge_by_id( file_ids.append(form_data.file_id) data["file_ids"] = file_ids - knowledge = Knowledges.update_knowledge_by_id( - id=id, form_data=KnowledgeUpdateForm(data=data) - ) + knowledge = Knowledges.update_knowledge_data_by_id(id=id, data=data) if knowledge: files = Files.get_files_by_ids(file_ids) @@ -236,11 +321,24 @@ def add_file_to_knowledge_by_id( @router.post("/{id}/file/update", response_model=Optional[KnowledgeFilesResponse]) def update_file_from_knowledge_by_id( + request: Request, id: str, form_data: KnowledgeFileIdForm, - user=Depends(get_admin_user), + user=Depends(get_verified_user), ): knowledge = Knowledges.get_knowledge_by_id(id=id) + if not knowledge: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + if knowledge.user_id != user.id and user.role != "admin": + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) + file = Files.get_file_by_id(form_data.file_id) if not file: raise HTTPException( @@ -255,7 +353,9 @@ def update_file_from_knowledge_by_id( # Add content to the vector database try: - process_file(ProcessFileForm(file_id=form_data.file_id, collection_name=id)) + process_file( + request, ProcessFileForm(file_id=form_data.file_id, collection_name=id) + ) except Exception as e: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -288,9 +388,21 @@ def update_file_from_knowledge_by_id( def remove_file_from_knowledge_by_id( id: str, form_data: KnowledgeFileIdForm, - user=Depends(get_admin_user), + user=Depends(get_verified_user), ): knowledge = Knowledges.get_knowledge_by_id(id=id) + if not knowledge: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + if knowledge.user_id != user.id and user.role != "admin": + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) + file = Files.get_file_by_id(form_data.file_id) if not file: raise HTTPException( @@ -318,9 +430,7 @@ def remove_file_from_knowledge_by_id( file_ids.remove(form_data.file_id) data["file_ids"] = file_ids - knowledge = Knowledges.update_knowledge_by_id( - id=id, form_data=KnowledgeUpdateForm(data=data) - ) + knowledge = Knowledges.update_knowledge_data_by_id(id=id, data=data) if knowledge: files = Files.get_files_by_ids(file_ids) @@ -346,32 +456,26 @@ def remove_file_from_knowledge_by_id( ) -############################ -# ResetKnowledgeById -############################ - - -@router.post("/{id}/reset", response_model=Optional[KnowledgeResponse]) -async def reset_knowledge_by_id(id: str, user=Depends(get_admin_user)): - try: - VECTOR_DB_CLIENT.delete_collection(collection_name=id) - except Exception as e: - log.debug(e) - pass - - knowledge = Knowledges.update_knowledge_by_id( - id=id, form_data=KnowledgeUpdateForm(data={"file_ids": []}) - ) - return knowledge - - ############################ # DeleteKnowledgeById ############################ @router.delete("/{id}/delete", response_model=bool) -async def delete_knowledge_by_id(id: str, user=Depends(get_admin_user)): +async def delete_knowledge_by_id(id: str, user=Depends(get_verified_user)): + knowledge = Knowledges.get_knowledge_by_id(id=id) + if not knowledge: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + if knowledge.user_id != user.id and user.role != "admin": + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) + try: VECTOR_DB_CLIENT.delete_collection(collection_name=id) except Exception as e: @@ -379,3 +483,34 @@ async def delete_knowledge_by_id(id: str, user=Depends(get_admin_user)): pass result = Knowledges.delete_knowledge_by_id(id=id) return result + + +############################ +# ResetKnowledgeById +############################ + + +@router.post("/{id}/reset", response_model=Optional[KnowledgeResponse]) +async def reset_knowledge_by_id(id: str, user=Depends(get_verified_user)): + knowledge = Knowledges.get_knowledge_by_id(id=id) + if not knowledge: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + if knowledge.user_id != user.id and user.role != "admin": + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) + + try: + VECTOR_DB_CLIENT.delete_collection(collection_name=id) + except Exception as e: + log.debug(e) + pass + + knowledge = Knowledges.update_knowledge_data_by_id(id=id, data={"file_ids": []}) + + return knowledge diff --git a/backend/open_webui/apps/webui/routers/memories.py b/backend/open_webui/routers/memories.py similarity index 96% rename from backend/open_webui/apps/webui/routers/memories.py rename to backend/open_webui/routers/memories.py index ccf84a9d4..e72cf1445 100644 --- a/backend/open_webui/apps/webui/routers/memories.py +++ b/backend/open_webui/routers/memories.py @@ -3,9 +3,9 @@ from pydantic import BaseModel import logging from typing import Optional -from open_webui.apps.webui.models.memories import Memories, MemoryModel -from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT -from open_webui.utils.utils import get_verified_user +from open_webui.models.memories import Memories, MemoryModel +from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT +from open_webui.utils.auth import get_verified_user from open_webui.env import SRC_LOG_LEVELS diff --git a/backend/open_webui/routers/models.py b/backend/open_webui/routers/models.py new file mode 100644 index 000000000..db981a913 --- /dev/null +++ b/backend/open_webui/routers/models.py @@ -0,0 +1,189 @@ +from typing import Optional + +from open_webui.models.models import ( + ModelForm, + ModelModel, + ModelResponse, + ModelUserResponse, + Models, +) +from open_webui.constants import ERROR_MESSAGES +from fastapi import APIRouter, Depends, HTTPException, Request, status + + +from open_webui.utils.auth import get_admin_user, get_verified_user +from open_webui.utils.access_control import has_access, has_permission + + +router = APIRouter() + + +########################### +# GetModels +########################### + + +@router.get("/", response_model=list[ModelUserResponse]) +async def get_models(id: Optional[str] = None, user=Depends(get_verified_user)): + if user.role == "admin": + return Models.get_models() + else: + return Models.get_models_by_user_id(user.id) + + +########################### +# GetBaseModels +########################### + + +@router.get("/base", response_model=list[ModelResponse]) +async def get_base_models(user=Depends(get_admin_user)): + return Models.get_base_models() + + +############################ +# CreateNewModel +############################ + + +@router.post("/create", response_model=Optional[ModelModel]) +async def create_new_model( + request: Request, + form_data: ModelForm, + user=Depends(get_verified_user), +): + if user.role != "admin" and not has_permission( + user.id, "workspace.models", request.app.state.config.USER_PERMISSIONS + ): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.UNAUTHORIZED, + ) + + model = Models.get_model_by_id(form_data.id) + if model: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.MODEL_ID_TAKEN, + ) + + else: + model = Models.insert_new_model(form_data, user.id) + if model: + return model + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.DEFAULT(), + ) + + +########################### +# GetModelById +########################### + + +# Note: We're not using the typical url path param here, but instead using a query parameter to allow '/' in the id +@router.get("/model", response_model=Optional[ModelResponse]) +async def get_model_by_id(id: str, user=Depends(get_verified_user)): + model = Models.get_model_by_id(id) + if model: + if ( + user.role == "admin" + or model.user_id == user.id + or has_access(user.id, "read", model.access_control) + ): + return model + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + +############################ +# ToggelModelById +############################ + + +@router.post("/model/toggle", response_model=Optional[ModelResponse]) +async def toggle_model_by_id(id: str, user=Depends(get_verified_user)): + model = Models.get_model_by_id(id) + if model: + if ( + user.role == "admin" + or model.user_id == user.id + or has_access(user.id, "write", model.access_control) + ): + model = Models.toggle_model_by_id(id) + + if model: + return model + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT("Error updating function"), + ) + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.UNAUTHORIZED, + ) + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + +############################ +# UpdateModelById +############################ + + +@router.post("/model/update", response_model=Optional[ModelModel]) +async def update_model_by_id( + id: str, + form_data: ModelForm, + user=Depends(get_verified_user), +): + model = Models.get_model_by_id(id) + + if not model: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + model = Models.update_model_by_id(id, form_data) + return model + + +############################ +# DeleteModelById +############################ + + +@router.delete("/model/delete", response_model=bool) +async def delete_model_by_id(id: str, user=Depends(get_verified_user)): + model = Models.get_model_by_id(id) + if not model: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + if model.user_id != user.id and user.role != "admin": + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.UNAUTHORIZED, + ) + + result = Models.delete_model_by_id(id) + return result + + +@router.delete("/delete/all", response_model=bool) +async def delete_all_models(user=Depends(get_admin_user)): + result = Models.delete_all_models() + return result diff --git a/backend/open_webui/routers/ollama.py b/backend/open_webui/routers/ollama.py new file mode 100644 index 000000000..233e30ce5 --- /dev/null +++ b/backend/open_webui/routers/ollama.py @@ -0,0 +1,1440 @@ +# TODO: Implement a more intelligent load balancing mechanism for distributing requests among multiple backend instances. +# Current implementation uses a simple round-robin approach (random.choice). Consider incorporating algorithms like weighted round-robin, +# least connections, or least response time for better resource utilization and performance optimization. + +import asyncio +import json +import logging +import os +import random +import re +import time +from typing import Optional, Union +from urllib.parse import urlparse + +import aiohttp +from aiocache import cached + +import requests + +from fastapi import ( + Depends, + FastAPI, + File, + HTTPException, + Request, + UploadFile, + APIRouter, +) +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import StreamingResponse +from pydantic import BaseModel, ConfigDict +from starlette.background import BackgroundTask + + +from open_webui.models.models import Models +from open_webui.utils.misc import ( + calculate_sha256, +) +from open_webui.utils.payload import ( + apply_model_params_to_body_ollama, + apply_model_params_to_body_openai, + apply_model_system_prompt_to_body, +) +from open_webui.utils.auth import get_admin_user, get_verified_user +from open_webui.utils.access_control import has_access + + +from open_webui.config import ( + UPLOAD_DIR, +) +from open_webui.env import ( + ENV, + SRC_LOG_LEVELS, + AIOHTTP_CLIENT_TIMEOUT, + AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST, + BYPASS_MODEL_ACCESS_CONTROL, +) +from open_webui.constants import ERROR_MESSAGES + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["OLLAMA"]) + + +########################################## +# +# Utility functions +# +########################################## + + +async def send_get_request(url, key=None): + timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST) + try: + async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: + async with session.get( + url, headers={**({"Authorization": f"Bearer {key}"} if key else {})} + ) as response: + return await response.json() + except Exception as e: + # Handle connection error here + log.error(f"Connection error: {e}") + return None + + +async def send_post_request( + url: str, + payload: Union[str, bytes], + stream: bool = True, + key: Optional[str] = None, + content_type: Optional[str] = None, +): + async def cleanup_response( + response: Optional[aiohttp.ClientResponse], + session: Optional[aiohttp.ClientSession], + ): + if response: + response.close() + if session: + await session.close() + + r = None + try: + session = aiohttp.ClientSession( + trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT) + ) + + r = await session.post( + url, + data=payload, + headers={ + "Content-Type": "application/json", + **({"Authorization": f"Bearer {key}"} if key else {}), + }, + ) + r.raise_for_status() + + if stream: + response_headers = dict(r.headers) + + if content_type: + response_headers["Content-Type"] = content_type + + return StreamingResponse( + r.content, + status_code=r.status, + headers=response_headers, + background=BackgroundTask( + cleanup_response, response=r, session=session + ), + ) + else: + res = await r.json() + await cleanup_response(r, session) + return res + + except Exception as e: + detail = None + + if r is not None: + try: + res = await r.json() + if "error" in res: + detail = f"Ollama: {res.get('error', 'Unknown error')}" + except Exception: + detail = f"Ollama: {e}" + + raise HTTPException( + status_code=r.status if r else 500, + detail=detail if detail else "Open WebUI: Server Connection Error", + ) + + +def get_api_key(url, configs): + parsed_url = urlparse(url) + base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" + return configs.get(base_url, {}).get("key", None) + + +########################################## +# +# API routes +# +########################################## + +router = APIRouter() + + +@router.head("/") +@router.get("/") +async def get_status(): + return {"status": True} + + +class ConnectionVerificationForm(BaseModel): + url: str + key: Optional[str] = None + + +@router.post("/verify") +async def verify_connection( + form_data: ConnectionVerificationForm, user=Depends(get_admin_user) +): + url = form_data.url + key = form_data.key + + async with aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST) + ) as session: + try: + async with session.get( + f"{url}/api/version", + headers={**({"Authorization": f"Bearer {key}"} if key else {})}, + ) as r: + if r.status != 200: + detail = f"HTTP Error: {r.status}" + res = await r.json() + + if "error" in res: + detail = f"External Error: {res['error']}" + raise Exception(detail) + + data = await r.json() + return data + except aiohttp.ClientError as e: + log.exception(f"Client error: {str(e)}") + raise HTTPException( + status_code=500, detail="Open WebUI: Server Connection Error" + ) + except Exception as e: + log.exception(f"Unexpected error: {e}") + error_detail = f"Unexpected error: {str(e)}" + raise HTTPException(status_code=500, detail=error_detail) + + +@router.get("/config") +async def get_config(request: Request, user=Depends(get_admin_user)): + return { + "ENABLE_OLLAMA_API": request.app.state.config.ENABLE_OLLAMA_API, + "OLLAMA_BASE_URLS": request.app.state.config.OLLAMA_BASE_URLS, + "OLLAMA_API_CONFIGS": request.app.state.config.OLLAMA_API_CONFIGS, + } + + +class OllamaConfigForm(BaseModel): + ENABLE_OLLAMA_API: Optional[bool] = None + OLLAMA_BASE_URLS: list[str] + OLLAMA_API_CONFIGS: dict + + +@router.post("/config/update") +async def update_config( + request: Request, form_data: OllamaConfigForm, user=Depends(get_admin_user) +): + request.app.state.config.ENABLE_OLLAMA_API = form_data.ENABLE_OLLAMA_API + + request.app.state.config.OLLAMA_BASE_URLS = form_data.OLLAMA_BASE_URLS + request.app.state.config.OLLAMA_API_CONFIGS = form_data.OLLAMA_API_CONFIGS + + # Remove any extra configs + config_urls = request.app.state.config.OLLAMA_API_CONFIGS.keys() + for url in list(request.app.state.config.OLLAMA_BASE_URLS): + if url not in config_urls: + request.app.state.config.OLLAMA_API_CONFIGS.pop(url, None) + + return { + "ENABLE_OLLAMA_API": request.app.state.config.ENABLE_OLLAMA_API, + "OLLAMA_BASE_URLS": request.app.state.config.OLLAMA_BASE_URLS, + "OLLAMA_API_CONFIGS": request.app.state.config.OLLAMA_API_CONFIGS, + } + + +@cached(ttl=3) +async def get_all_models(request: Request): + log.info("get_all_models()") + if request.app.state.config.ENABLE_OLLAMA_API: + request_tasks = [] + + for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS): + if url not in request.app.state.config.OLLAMA_API_CONFIGS: + request_tasks.append(send_get_request(f"{url}/api/tags")) + else: + api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}) + enable = api_config.get("enable", True) + key = api_config.get("key", None) + + if enable: + request_tasks.append(send_get_request(f"{url}/api/tags", key)) + else: + request_tasks.append(asyncio.ensure_future(asyncio.sleep(0, None))) + + responses = await asyncio.gather(*request_tasks) + + for idx, response in enumerate(responses): + if response: + url = request.app.state.config.OLLAMA_BASE_URLS[idx] + api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}) + + prefix_id = api_config.get("prefix_id", None) + model_ids = api_config.get("model_ids", []) + + if len(model_ids) != 0 and "models" in response: + response["models"] = list( + filter( + lambda model: model["model"] in model_ids, + response["models"], + ) + ) + + if prefix_id: + for model in response.get("models", []): + model["model"] = f"{prefix_id}.{model['model']}" + + def merge_models_lists(model_lists): + merged_models = {} + + for idx, model_list in enumerate(model_lists): + if model_list is not None: + for model in model_list: + id = model["model"] + if id not in merged_models: + model["urls"] = [idx] + merged_models[id] = model + else: + merged_models[id]["urls"].append(idx) + + return list(merged_models.values()) + + models = { + "models": merge_models_lists( + map( + lambda response: response.get("models", []) if response else None, + responses, + ) + ) + } + + else: + models = {"models": []} + + request.app.state.OLLAMA_MODELS = { + model["model"]: model for model in models["models"] + } + return models + + +async def get_filtered_models(models, user): + # Filter models based on user access control + filtered_models = [] + for model in models.get("models", []): + model_info = Models.get_model_by_id(model["model"]) + if model_info: + if user.id == model_info.user_id or has_access( + user.id, type="read", access_control=model_info.access_control + ): + filtered_models.append(model) + return filtered_models + + +@router.get("/api/tags") +@router.get("/api/tags/{url_idx}") +async def get_ollama_tags( + request: Request, url_idx: Optional[int] = None, user=Depends(get_verified_user) +): + models = [] + + if url_idx is None: + models = await get_all_models(request) + else: + url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] + key = get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS) + + r = None + try: + r = requests.request( + method="GET", + url=f"{url}/api/tags", + headers={**({"Authorization": f"Bearer {key}"} if key else {})}, + ) + r.raise_for_status() + + models = r.json() + except Exception as e: + log.exception(e) + + detail = None + if r is not None: + try: + res = r.json() + if "error" in res: + detail = f"Ollama: {res['error']}" + except Exception: + detail = f"Ollama: {e}" + + raise HTTPException( + status_code=r.status_code if r else 500, + detail=detail if detail else "Open WebUI: Server Connection Error", + ) + + if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL: + models["models"] = get_filtered_models(models, user) + + return models + + +@router.get("/api/version") +@router.get("/api/version/{url_idx}") +async def get_ollama_versions(request: Request, url_idx: Optional[int] = None): + if request.app.state.config.ENABLE_OLLAMA_API: + if url_idx is None: + # returns lowest version + request_tasks = [ + send_get_request( + f"{url}/api/version", + request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}).get( + "key", None + ), + ) + for url in request.app.state.config.OLLAMA_BASE_URLS + ] + responses = await asyncio.gather(*request_tasks) + responses = list(filter(lambda x: x is not None, responses)) + + if len(responses) > 0: + lowest_version = min( + responses, + key=lambda x: tuple( + map(int, re.sub(r"^v|-.*", "", x["version"]).split(".")) + ), + ) + + return {"version": lowest_version["version"]} + else: + raise HTTPException( + status_code=500, + detail=ERROR_MESSAGES.OLLAMA_NOT_FOUND, + ) + else: + url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] + + r = None + try: + r = requests.request(method="GET", url=f"{url}/api/version") + r.raise_for_status() + + return r.json() + except Exception as e: + log.exception(e) + + detail = None + if r is not None: + try: + res = r.json() + if "error" in res: + detail = f"Ollama: {res['error']}" + except Exception: + detail = f"Ollama: {e}" + + raise HTTPException( + status_code=r.status_code if r else 500, + detail=detail if detail else "Open WebUI: Server Connection Error", + ) + else: + return {"version": False} + + +@router.get("/api/ps") +async def get_ollama_loaded_models(request: Request, user=Depends(get_verified_user)): + """ + List models that are currently loaded into Ollama memory, and which node they are loaded on. + """ + if request.app.state.config.ENABLE_OLLAMA_API: + request_tasks = [ + send_get_request( + f"{url}/api/ps", + request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}).get( + "key", None + ), + ) + for url in request.app.state.config.OLLAMA_BASE_URLS + ] + responses = await asyncio.gather(*request_tasks) + + return dict(zip(request.app.state.config.OLLAMA_BASE_URLS, responses)) + else: + return {} + + +class ModelNameForm(BaseModel): + name: str + + +@router.post("/api/pull") +@router.post("/api/pull/{url_idx}") +async def pull_model( + request: Request, + form_data: ModelNameForm, + url_idx: int = 0, + user=Depends(get_admin_user), +): + url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] + log.info(f"url: {url}") + + # Admin should be able to pull models from any source + payload = {**form_data.model_dump(exclude_none=True), "insecure": True} + + return await send_post_request( + url=f"{url}/api/pull", + payload=json.dumps(payload), + key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS), + ) + + +class PushModelForm(BaseModel): + name: str + insecure: Optional[bool] = None + stream: Optional[bool] = None + + +@router.delete("/api/push") +@router.delete("/api/push/{url_idx}") +async def push_model( + request: Request, + form_data: PushModelForm, + url_idx: Optional[int] = None, + user=Depends(get_admin_user), +): + if url_idx is None: + await get_all_models(request) + models = request.app.state.OLLAMA_MODELS + + if form_data.name in models: + url_idx = models[form_data.name]["urls"][0] + else: + raise HTTPException( + status_code=400, + detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name), + ) + + url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] + log.debug(f"url: {url}") + + return await send_post_request( + url=f"{url}/api/push", + payload=form_data.model_dump_json(exclude_none=True).encode(), + key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS), + ) + + +class CreateModelForm(BaseModel): + name: str + modelfile: Optional[str] = None + stream: Optional[bool] = None + path: Optional[str] = None + + +@router.post("/api/create") +@router.post("/api/create/{url_idx}") +async def create_model( + request: Request, + form_data: CreateModelForm, + url_idx: int = 0, + user=Depends(get_admin_user), +): + log.debug(f"form_data: {form_data}") + url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] + + return await send_post_request( + url=f"{url}/api/create", + payload=form_data.model_dump_json(exclude_none=True).encode(), + key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS), + ) + + +class CopyModelForm(BaseModel): + source: str + destination: str + + +@router.post("/api/copy") +@router.post("/api/copy/{url_idx}") +async def copy_model( + request: Request, + form_data: CopyModelForm, + url_idx: Optional[int] = None, + user=Depends(get_admin_user), +): + if url_idx is None: + await get_all_models(request) + models = request.app.state.OLLAMA_MODELS + + if form_data.source in models: + url_idx = models[form_data.source]["urls"][0] + else: + raise HTTPException( + status_code=400, + detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.source), + ) + + url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] + key = get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS) + + try: + r = requests.request( + method="POST", + url=f"{url}/api/copy", + headers={ + "Content-Type": "application/json", + **({"Authorization": f"Bearer {key}"} if key else {}), + }, + data=form_data.model_dump_json(exclude_none=True).encode(), + ) + r.raise_for_status() + + log.debug(f"r.text: {r.text}") + return True + except Exception as e: + log.exception(e) + + detail = None + if r is not None: + try: + res = r.json() + if "error" in res: + detail = f"Ollama: {res['error']}" + except Exception: + detail = f"Ollama: {e}" + + raise HTTPException( + status_code=r.status_code if r else 500, + detail=detail if detail else "Open WebUI: Server Connection Error", + ) + + +@router.delete("/api/delete") +@router.delete("/api/delete/{url_idx}") +async def delete_model( + request: Request, + form_data: ModelNameForm, + url_idx: Optional[int] = None, + user=Depends(get_admin_user), +): + if url_idx is None: + await get_all_models(request) + models = request.app.state.OLLAMA_MODELS + + if form_data.name in models: + url_idx = models[form_data.name]["urls"][0] + else: + raise HTTPException( + status_code=400, + detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name), + ) + + url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] + key = get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS) + + try: + r = requests.request( + method="DELETE", + url=f"{url}/api/delete", + data=form_data.model_dump_json(exclude_none=True).encode(), + headers={ + "Content-Type": "application/json", + **({"Authorization": f"Bearer {key}"} if key else {}), + }, + ) + r.raise_for_status() + + log.debug(f"r.text: {r.text}") + return True + except Exception as e: + log.exception(e) + + detail = None + if r is not None: + try: + res = r.json() + if "error" in res: + detail = f"Ollama: {res['error']}" + except Exception: + detail = f"Ollama: {e}" + + raise HTTPException( + status_code=r.status_code if r else 500, + detail=detail if detail else "Open WebUI: Server Connection Error", + ) + + +@router.post("/api/show") +async def show_model_info( + request: Request, form_data: ModelNameForm, user=Depends(get_verified_user) +): + await get_all_models(request) + models = request.app.state.OLLAMA_MODELS + + if form_data.name not in models: + raise HTTPException( + status_code=400, + detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name), + ) + + url_idx = random.choice(models[form_data.name]["urls"]) + + url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] + key = get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS) + + try: + r = requests.request( + method="POST", + url=f"{url}/api/show", + headers={ + "Content-Type": "application/json", + **({"Authorization": f"Bearer {key}"} if key else {}), + }, + data=form_data.model_dump_json(exclude_none=True).encode(), + ) + r.raise_for_status() + + return r.json() + except Exception as e: + log.exception(e) + + detail = None + if r is not None: + try: + res = r.json() + if "error" in res: + detail = f"Ollama: {res['error']}" + except Exception: + detail = f"Ollama: {e}" + + raise HTTPException( + status_code=r.status_code if r else 500, + detail=detail if detail else "Open WebUI: Server Connection Error", + ) + + +class GenerateEmbedForm(BaseModel): + model: str + input: list[str] | str + truncate: Optional[bool] = None + options: Optional[dict] = None + keep_alive: Optional[Union[int, str]] = None + + +@router.post("/api/embed") +@router.post("/api/embed/{url_idx}") +async def embed( + request: Request, + form_data: GenerateEmbedForm, + url_idx: Optional[int] = None, + user=Depends(get_verified_user), +): + log.info(f"generate_ollama_batch_embeddings {form_data}") + + if url_idx is None: + await get_all_models(request) + models = request.app.state.OLLAMA_MODELS + + model = form_data.model + + if ":" not in model: + model = f"{model}:latest" + + if model in models: + url_idx = random.choice(models[model]["urls"]) + else: + raise HTTPException( + status_code=400, + detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), + ) + + url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] + key = get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS) + + try: + r = requests.request( + method="POST", + url=f"{url}/api/embed", + headers={ + "Content-Type": "application/json", + **({"Authorization": f"Bearer {key}"} if key else {}), + }, + data=form_data.model_dump_json(exclude_none=True).encode(), + ) + r.raise_for_status() + + data = r.json() + return data + except Exception as e: + log.exception(e) + + detail = None + if r is not None: + try: + res = r.json() + if "error" in res: + detail = f"Ollama: {res['error']}" + except Exception: + detail = f"Ollama: {e}" + + raise HTTPException( + status_code=r.status_code if r else 500, + detail=detail if detail else "Open WebUI: Server Connection Error", + ) + + +class GenerateEmbeddingsForm(BaseModel): + model: str + prompt: str + options: Optional[dict] = None + keep_alive: Optional[Union[int, str]] = None + + +@router.post("/api/embeddings") +@router.post("/api/embeddings/{url_idx}") +async def embeddings( + request: Request, + form_data: GenerateEmbeddingsForm, + url_idx: Optional[int] = None, + user=Depends(get_verified_user), +): + log.info(f"generate_ollama_embeddings {form_data}") + + if url_idx is None: + await get_all_models(request) + models = request.app.state.OLLAMA_MODELS + + model = form_data.model + + if ":" not in model: + model = f"{model}:latest" + + if model in models: + url_idx = random.choice(models[model]["urls"]) + else: + raise HTTPException( + status_code=400, + detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), + ) + + url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] + key = get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS) + + try: + r = requests.request( + method="POST", + url=f"{url}/api/embeddings", + headers={ + "Content-Type": "application/json", + **({"Authorization": f"Bearer {key}"} if key else {}), + }, + data=form_data.model_dump_json(exclude_none=True).encode(), + ) + r.raise_for_status() + + data = r.json() + return data + except Exception as e: + log.exception(e) + + detail = None + if r is not None: + try: + res = r.json() + if "error" in res: + detail = f"Ollama: {res['error']}" + except Exception: + detail = f"Ollama: {e}" + + raise HTTPException( + status_code=r.status_code if r else 500, + detail=detail if detail else "Open WebUI: Server Connection Error", + ) + + +class GenerateCompletionForm(BaseModel): + model: str + prompt: str + suffix: Optional[str] = None + images: Optional[list[str]] = None + format: Optional[str] = None + options: Optional[dict] = None + system: Optional[str] = None + template: Optional[str] = None + context: Optional[list[int]] = None + stream: Optional[bool] = True + raw: Optional[bool] = None + keep_alive: Optional[Union[int, str]] = None + + +@router.post("/api/generate") +@router.post("/api/generate/{url_idx}") +async def generate_completion( + request: Request, + form_data: GenerateCompletionForm, + url_idx: Optional[int] = None, + user=Depends(get_verified_user), +): + if url_idx is None: + await get_all_models(request) + models = request.app.state.OLLAMA_MODELS + + model = form_data.model + + if ":" not in model: + model = f"{model}:latest" + + if model in models: + url_idx = random.choice(models[model]["urls"]) + else: + raise HTTPException( + status_code=400, + detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), + ) + + url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] + api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}) + + prefix_id = api_config.get("prefix_id", None) + if prefix_id: + form_data.model = form_data.model.replace(f"{prefix_id}.", "") + + return await send_post_request( + url=f"{url}/api/generate", + payload=form_data.model_dump_json(exclude_none=True).encode(), + key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS), + ) + + +class ChatMessage(BaseModel): + role: str + content: str + images: Optional[list[str]] = None + + +class GenerateChatCompletionForm(BaseModel): + model: str + messages: list[ChatMessage] + format: Optional[str] = None + options: Optional[dict] = None + template: Optional[str] = None + stream: Optional[bool] = True + keep_alive: Optional[Union[int, str]] = None + + +async def get_ollama_url(request: Request, model: str, url_idx: Optional[int] = None): + if url_idx is None: + models = request.app.state.OLLAMA_MODELS + if model not in models: + raise HTTPException( + status_code=400, + detail=ERROR_MESSAGES.MODEL_NOT_FOUND(model), + ) + url_idx = random.choice(models[model].get("urls", [])) + url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] + return url + + +@router.post("/api/chat") +@router.post("/api/chat/{url_idx}") +async def generate_chat_completion( + request: Request, + form_data: dict, + url_idx: Optional[int] = None, + user=Depends(get_verified_user), + bypass_filter: Optional[bool] = False, +): + if BYPASS_MODEL_ACCESS_CONTROL: + bypass_filter = True + + try: + form_data = GenerateChatCompletionForm(**form_data) + except Exception as e: + log.exception(e) + raise HTTPException( + status_code=400, + detail=str(e), + ) + + payload = {**form_data.model_dump(exclude_none=True)} + if "metadata" in payload: + del payload["metadata"] + + model_id = payload["model"] + model_info = Models.get_model_by_id(model_id) + + if model_info: + if model_info.base_model_id: + payload["model"] = model_info.base_model_id + + params = model_info.params.model_dump() + + if params: + if payload.get("options") is None: + payload["options"] = {} + + payload["options"] = apply_model_params_to_body_ollama( + params, payload["options"] + ) + payload = apply_model_system_prompt_to_body(params, payload, user) + + # Check if user has access to the model + if not bypass_filter and user.role == "user": + if not ( + user.id == model_info.user_id + or has_access( + user.id, type="read", access_control=model_info.access_control + ) + ): + raise HTTPException( + status_code=403, + detail="Model not found", + ) + elif not bypass_filter: + if user.role != "admin": + raise HTTPException( + status_code=403, + detail="Model not found", + ) + + if ":" not in payload["model"]: + payload["model"] = f"{payload['model']}:latest" + + url = await get_ollama_url(request, payload["model"], url_idx) + api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}) + + prefix_id = api_config.get("prefix_id", None) + if prefix_id: + payload["model"] = payload["model"].replace(f"{prefix_id}.", "") + + return await send_post_request( + url=f"{url}/api/chat", + payload=json.dumps(payload), + stream=form_data.stream, + key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS), + content_type="application/x-ndjson", + ) + + +# TODO: we should update this part once Ollama supports other types +class OpenAIChatMessageContent(BaseModel): + type: str + model_config = ConfigDict(extra="allow") + + +class OpenAIChatMessage(BaseModel): + role: str + content: Union[str, list[OpenAIChatMessageContent]] + + model_config = ConfigDict(extra="allow") + + +class OpenAIChatCompletionForm(BaseModel): + model: str + messages: list[OpenAIChatMessage] + + model_config = ConfigDict(extra="allow") + + +class OpenAICompletionForm(BaseModel): + model: str + prompt: str + + model_config = ConfigDict(extra="allow") + + +@router.post("/v1/completions") +@router.post("/v1/completions/{url_idx}") +async def generate_openai_completion( + request: Request, + form_data: dict, + url_idx: Optional[int] = None, + user=Depends(get_verified_user), +): + try: + form_data = OpenAICompletionForm(**form_data) + except Exception as e: + log.exception(e) + raise HTTPException( + status_code=400, + detail=str(e), + ) + + payload = {**form_data.model_dump(exclude_none=True, exclude=["metadata"])} + if "metadata" in payload: + del payload["metadata"] + + model_id = form_data.model + if ":" not in model_id: + model_id = f"{model_id}:latest" + + model_info = Models.get_model_by_id(model_id) + if model_info: + if model_info.base_model_id: + payload["model"] = model_info.base_model_id + params = model_info.params.model_dump() + + if params: + payload = apply_model_params_to_body_openai(params, payload) + + # Check if user has access to the model + if user.role == "user": + if not ( + user.id == model_info.user_id + or has_access( + user.id, type="read", access_control=model_info.access_control + ) + ): + raise HTTPException( + status_code=403, + detail="Model not found", + ) + else: + if user.role != "admin": + raise HTTPException( + status_code=403, + detail="Model not found", + ) + + if ":" not in payload["model"]: + payload["model"] = f"{payload['model']}:latest" + + url = await get_ollama_url(request, payload["model"], url_idx) + api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}) + + prefix_id = api_config.get("prefix_id", None) + + if prefix_id: + payload["model"] = payload["model"].replace(f"{prefix_id}.", "") + + return await send_post_request( + url=f"{url}/v1/completions", + payload=json.dumps(payload), + stream=payload.get("stream", False), + key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS), + ) + + +@router.post("/v1/chat/completions") +@router.post("/v1/chat/completions/{url_idx}") +async def generate_openai_chat_completion( + request: Request, + form_data: dict, + url_idx: Optional[int] = None, + user=Depends(get_verified_user), +): + try: + completion_form = OpenAIChatCompletionForm(**form_data) + except Exception as e: + log.exception(e) + raise HTTPException( + status_code=400, + detail=str(e), + ) + + payload = {**completion_form.model_dump(exclude_none=True, exclude=["metadata"])} + if "metadata" in payload: + del payload["metadata"] + + model_id = completion_form.model + if ":" not in model_id: + model_id = f"{model_id}:latest" + + model_info = Models.get_model_by_id(model_id) + if model_info: + if model_info.base_model_id: + payload["model"] = model_info.base_model_id + + params = model_info.params.model_dump() + + if params: + payload = apply_model_params_to_body_openai(params, payload) + payload = apply_model_system_prompt_to_body(params, payload, user) + + # Check if user has access to the model + if user.role == "user": + if not ( + user.id == model_info.user_id + or has_access( + user.id, type="read", access_control=model_info.access_control + ) + ): + raise HTTPException( + status_code=403, + detail="Model not found", + ) + else: + if user.role != "admin": + raise HTTPException( + status_code=403, + detail="Model not found", + ) + + if ":" not in payload["model"]: + payload["model"] = f"{payload['model']}:latest" + + url = await get_ollama_url(request, payload["model"], url_idx) + api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}) + + prefix_id = api_config.get("prefix_id", None) + if prefix_id: + payload["model"] = payload["model"].replace(f"{prefix_id}.", "") + + return await send_post_request( + url=f"{url}/v1/chat/completions", + payload=json.dumps(payload), + stream=payload.get("stream", False), + key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS), + ) + + +@router.get("/v1/models") +@router.get("/v1/models/{url_idx}") +async def get_openai_models( + request: Request, + url_idx: Optional[int] = None, + user=Depends(get_verified_user), +): + + models = [] + if url_idx is None: + model_list = await get_all_models(request) + models = [ + { + "id": model["model"], + "object": "model", + "created": int(time.time()), + "owned_by": "openai", + } + for model in model_list["models"] + ] + + else: + url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] + try: + r = requests.request(method="GET", url=f"{url}/api/tags") + r.raise_for_status() + + model_list = r.json() + + models = [ + { + "id": model["model"], + "object": "model", + "created": int(time.time()), + "owned_by": "openai", + } + for model in models["models"] + ] + except Exception as e: + log.exception(e) + error_detail = "Open WebUI: Server Connection Error" + if r is not None: + try: + res = r.json() + if "error" in res: + error_detail = f"Ollama: {res['error']}" + except Exception: + error_detail = f"Ollama: {e}" + + raise HTTPException( + status_code=r.status_code if r else 500, + detail=error_detail, + ) + + if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL: + # Filter models based on user access control + filtered_models = [] + for model in models: + model_info = Models.get_model_by_id(model["id"]) + if model_info: + if user.id == model_info.user_id or has_access( + user.id, type="read", access_control=model_info.access_control + ): + filtered_models.append(model) + models = filtered_models + + return { + "data": models, + "object": "list", + } + + +class UrlForm(BaseModel): + url: str + + +class UploadBlobForm(BaseModel): + filename: str + + +def parse_huggingface_url(hf_url): + try: + # Parse the URL + parsed_url = urlparse(hf_url) + + # Get the path and split it into components + path_components = parsed_url.path.split("/") + + # Extract the desired output + model_file = path_components[-1] + + return model_file + except ValueError: + return None + + +async def download_file_stream( + ollama_url, file_url, file_path, file_name, chunk_size=1024 * 1024 +): + done = False + + if os.path.exists(file_path): + current_size = os.path.getsize(file_path) + else: + current_size = 0 + + headers = {"Range": f"bytes={current_size}-"} if current_size > 0 else {} + + timeout = aiohttp.ClientTimeout(total=600) # Set the timeout + + async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: + async with session.get(file_url, headers=headers) as response: + total_size = int(response.headers.get("content-length", 0)) + current_size + + with open(file_path, "ab+") as file: + async for data in response.content.iter_chunked(chunk_size): + current_size += len(data) + file.write(data) + + done = current_size == total_size + progress = round((current_size / total_size) * 100, 2) + + yield f'data: {{"progress": {progress}, "completed": {current_size}, "total": {total_size}}}\n\n' + + if done: + file.seek(0) + hashed = calculate_sha256(file) + file.seek(0) + + url = f"{ollama_url}/api/blobs/sha256:{hashed}" + response = requests.post(url, data=file) + + if response.ok: + res = { + "done": done, + "blob": f"sha256:{hashed}", + "name": file_name, + } + os.remove(file_path) + + yield f"data: {json.dumps(res)}\n\n" + else: + raise "Ollama: Could not create blob, Please try again." + + +# url = "https://huggingface.co/TheBloke/stablelm-zephyr-3b-GGUF/resolve/main/stablelm-zephyr-3b.Q2_K.gguf" +@router.post("/models/download") +@router.post("/models/download/{url_idx}") +async def download_model( + request: Request, + form_data: UrlForm, + url_idx: Optional[int] = None, + user=Depends(get_admin_user), +): + allowed_hosts = ["https://huggingface.co/", "https://github.com/"] + + if not any(form_data.url.startswith(host) for host in allowed_hosts): + raise HTTPException( + status_code=400, + detail="Invalid file_url. Only URLs from allowed hosts are permitted.", + ) + + if url_idx is None: + url_idx = 0 + url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] + + file_name = parse_huggingface_url(form_data.url) + + if file_name: + file_path = f"{UPLOAD_DIR}/{file_name}" + + return StreamingResponse( + download_file_stream(url, form_data.url, file_path, file_name), + ) + else: + return None + + +@router.post("/models/upload") +@router.post("/models/upload/{url_idx}") +def upload_model( + request: Request, + file: UploadFile = File(...), + url_idx: Optional[int] = None, + user=Depends(get_admin_user), +): + if url_idx is None: + url_idx = 0 + ollama_url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] + + file_path = f"{UPLOAD_DIR}/{file.filename}" + + # Save file in chunks + with open(file_path, "wb+") as f: + for chunk in file.file: + f.write(chunk) + + def file_process_stream(): + nonlocal ollama_url + total_size = os.path.getsize(file_path) + chunk_size = 1024 * 1024 + try: + with open(file_path, "rb") as f: + total = 0 + done = False + + while not done: + chunk = f.read(chunk_size) + if not chunk: + done = True + continue + + total += len(chunk) + progress = round((total / total_size) * 100, 2) + + res = { + "progress": progress, + "total": total_size, + "completed": total, + } + yield f"data: {json.dumps(res)}\n\n" + + if done: + f.seek(0) + hashed = calculate_sha256(f) + f.seek(0) + + url = f"{ollama_url}/api/blobs/sha256:{hashed}" + response = requests.post(url, data=f) + + if response.ok: + res = { + "done": done, + "blob": f"sha256:{hashed}", + "name": file.filename, + } + os.remove(file_path) + yield f"data: {json.dumps(res)}\n\n" + else: + raise Exception( + "Ollama: Could not create blob, Please try again." + ) + + except Exception as e: + res = {"error": str(e)} + yield f"data: {json.dumps(res)}\n\n" + + return StreamingResponse(file_process_stream(), media_type="text/event-stream") diff --git a/backend/open_webui/routers/openai.py b/backend/open_webui/routers/openai.py new file mode 100644 index 000000000..f7f78be85 --- /dev/null +++ b/backend/open_webui/routers/openai.py @@ -0,0 +1,772 @@ +import asyncio +import hashlib +import json +import logging +from pathlib import Path +from typing import Literal, Optional, overload + +import aiohttp +from aiocache import cached +import requests + + +from fastapi import Depends, FastAPI, HTTPException, Request, APIRouter +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import FileResponse, StreamingResponse +from pydantic import BaseModel +from starlette.background import BackgroundTask + +from open_webui.models.models import Models +from open_webui.config import ( + CACHE_DIR, +) +from open_webui.env import ( + AIOHTTP_CLIENT_TIMEOUT, + AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST, + ENABLE_FORWARD_USER_INFO_HEADERS, + BYPASS_MODEL_ACCESS_CONTROL, +) + +from open_webui.constants import ERROR_MESSAGES +from open_webui.env import ENV, SRC_LOG_LEVELS + + +from open_webui.utils.payload import ( + apply_model_params_to_body_openai, + apply_model_system_prompt_to_body, +) + +from open_webui.utils.auth import get_admin_user, get_verified_user +from open_webui.utils.access_control import has_access + + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["OPENAI"]) + + +########################################## +# +# Utility functions +# +########################################## + + +async def send_get_request(url, key=None): + timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST) + try: + async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: + async with session.get( + url, headers={**({"Authorization": f"Bearer {key}"} if key else {})} + ) as response: + return await response.json() + except Exception as e: + # Handle connection error here + log.error(f"Connection error: {e}") + return None + + +async def cleanup_response( + response: Optional[aiohttp.ClientResponse], + session: Optional[aiohttp.ClientSession], +): + if response: + response.close() + if session: + await session.close() + + +def openai_o1_handler(payload): + """ + Handle O1 specific parameters + """ + if "max_tokens" in payload: + # Remove "max_tokens" from the payload + payload["max_completion_tokens"] = payload["max_tokens"] + del payload["max_tokens"] + + # Fix: O1 does not support the "system" parameter, Modify "system" to "user" + if payload["messages"][0]["role"] == "system": + payload["messages"][0]["role"] = "user" + + return payload + + +########################################## +# +# API routes +# +########################################## + +router = APIRouter() + + +@router.get("/config") +async def get_config(request: Request, user=Depends(get_admin_user)): + return { + "ENABLE_OPENAI_API": request.app.state.config.ENABLE_OPENAI_API, + "OPENAI_API_BASE_URLS": request.app.state.config.OPENAI_API_BASE_URLS, + "OPENAI_API_KEYS": request.app.state.config.OPENAI_API_KEYS, + "OPENAI_API_CONFIGS": request.app.state.config.OPENAI_API_CONFIGS, + } + + +class OpenAIConfigForm(BaseModel): + ENABLE_OPENAI_API: Optional[bool] = None + OPENAI_API_BASE_URLS: list[str] + OPENAI_API_KEYS: list[str] + OPENAI_API_CONFIGS: dict + + +@router.post("/config/update") +async def update_config( + request: Request, form_data: OpenAIConfigForm, user=Depends(get_admin_user) +): + request.app.state.config.ENABLE_OPENAI_API = form_data.ENABLE_OPENAI_API + request.app.state.config.OPENAI_API_BASE_URLS = form_data.OPENAI_API_BASE_URLS + request.app.state.config.OPENAI_API_KEYS = form_data.OPENAI_API_KEYS + + # Check if API KEYS length is same than API URLS length + if len(request.app.state.config.OPENAI_API_KEYS) != len( + request.app.state.config.OPENAI_API_BASE_URLS + ): + if len(request.app.state.config.OPENAI_API_KEYS) > len( + request.app.state.config.OPENAI_API_BASE_URLS + ): + request.app.state.config.OPENAI_API_KEYS = ( + request.app.state.config.OPENAI_API_KEYS[ + : len(request.app.state.config.OPENAI_API_BASE_URLS) + ] + ) + else: + request.app.state.config.OPENAI_API_KEYS += [""] * ( + len(request.app.state.config.OPENAI_API_BASE_URLS) + - len(request.app.state.config.OPENAI_API_KEYS) + ) + + request.app.state.config.OPENAI_API_CONFIGS = form_data.OPENAI_API_CONFIGS + + # Remove any extra configs + config_urls = request.app.state.config.OPENAI_API_CONFIGS.keys() + for idx, url in enumerate(request.app.state.config.OPENAI_API_BASE_URLS): + if url not in config_urls: + request.app.state.config.OPENAI_API_CONFIGS.pop(url, None) + + return { + "ENABLE_OPENAI_API": request.app.state.config.ENABLE_OPENAI_API, + "OPENAI_API_BASE_URLS": request.app.state.config.OPENAI_API_BASE_URLS, + "OPENAI_API_KEYS": request.app.state.config.OPENAI_API_KEYS, + "OPENAI_API_CONFIGS": request.app.state.config.OPENAI_API_CONFIGS, + } + + +@router.post("/audio/speech") +async def speech(request: Request, user=Depends(get_verified_user)): + idx = None + try: + idx = request.app.state.config.OPENAI_API_BASE_URLS.index( + "https://api.openai.com/v1" + ) + + body = await request.body() + name = hashlib.sha256(body).hexdigest() + + SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/") + SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True) + file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3") + file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json") + + # Check if the file already exists in the cache + if file_path.is_file(): + return FileResponse(file_path) + + url = request.app.state.config.OPENAI_API_BASE_URLS[idx] + + r = None + try: + r = requests.post( + url=f"{url}/audio/speech", + data=body, + headers={ + "Content-Type": "application/json", + "Authorization": f"Bearer {request.app.state.config.OPENAI_API_KEYS[idx]}", + **( + { + "HTTP-Referer": "https://openwebui.com/", + "X-Title": "Open WebUI", + } + if "openrouter.ai" in url + else {} + ), + **( + { + "X-OpenWebUI-User-Name": user.name, + "X-OpenWebUI-User-Id": user.id, + "X-OpenWebUI-User-Email": user.email, + "X-OpenWebUI-User-Role": user.role, + } + if ENABLE_FORWARD_USER_INFO_HEADERS + else {} + ), + }, + stream=True, + ) + + r.raise_for_status() + + # Save the streaming content to a file + with open(file_path, "wb") as f: + for chunk in r.iter_content(chunk_size=8192): + f.write(chunk) + + with open(file_body_path, "w") as f: + json.dump(json.loads(body.decode("utf-8")), f) + + # Return the saved file + return FileResponse(file_path) + + except Exception as e: + log.exception(e) + + detail = None + if r is not None: + try: + res = r.json() + if "error" in res: + detail = f"External: {res['error']}" + except Exception: + detail = f"External: {e}" + + raise HTTPException( + status_code=r.status_code if r else 500, + detail=detail if detail else "Open WebUI: Server Connection Error", + ) + + except ValueError: + raise HTTPException(status_code=401, detail=ERROR_MESSAGES.OPENAI_NOT_FOUND) + + +async def get_all_models_responses(request: Request) -> list: + if not request.app.state.config.ENABLE_OPENAI_API: + return [] + + # Check if API KEYS length is same than API URLS length + num_urls = len(request.app.state.config.OPENAI_API_BASE_URLS) + num_keys = len(request.app.state.config.OPENAI_API_KEYS) + + if num_keys != num_urls: + # if there are more keys than urls, remove the extra keys + if num_keys > num_urls: + new_keys = request.app.state.config.OPENAI_API_KEYS[:num_urls] + request.app.state.config.OPENAI_API_KEYS = new_keys + # if there are more urls than keys, add empty keys + else: + request.app.state.config.OPENAI_API_KEYS += [""] * (num_urls - num_keys) + + request_tasks = [] + for idx, url in enumerate(request.app.state.config.OPENAI_API_BASE_URLS): + if url not in request.app.state.config.OPENAI_API_CONFIGS: + request_tasks.append( + send_get_request( + f"{url}/models", request.app.state.config.OPENAI_API_KEYS[idx] + ) + ) + else: + api_config = request.app.state.config.OPENAI_API_CONFIGS.get(url, {}) + + enable = api_config.get("enable", True) + model_ids = api_config.get("model_ids", []) + + if enable: + if len(model_ids) == 0: + request_tasks.append( + send_get_request( + f"{url}/models", + request.app.state.config.OPENAI_API_KEYS[idx], + ) + ) + else: + model_list = { + "object": "list", + "data": [ + { + "id": model_id, + "name": model_id, + "owned_by": "openai", + "openai": {"id": model_id}, + "urlIdx": idx, + } + for model_id in model_ids + ], + } + + request_tasks.append( + asyncio.ensure_future(asyncio.sleep(0, model_list)) + ) + else: + request_tasks.append(asyncio.ensure_future(asyncio.sleep(0, None))) + + responses = await asyncio.gather(*request_tasks) + + for idx, response in enumerate(responses): + if response: + url = request.app.state.config.OPENAI_API_BASE_URLS[idx] + api_config = request.app.state.config.OPENAI_API_CONFIGS.get(url, {}) + + prefix_id = api_config.get("prefix_id", None) + + if prefix_id: + for model in ( + response if isinstance(response, list) else response.get("data", []) + ): + model["id"] = f"{prefix_id}.{model['id']}" + + log.debug(f"get_all_models:responses() {responses}") + return responses + + +async def get_filtered_models(models, user): + # Filter models based on user access control + filtered_models = [] + for model in models.get("data", []): + model_info = Models.get_model_by_id(model["id"]) + if model_info: + if user.id == model_info.user_id or has_access( + user.id, type="read", access_control=model_info.access_control + ): + filtered_models.append(model) + return filtered_models + + +@cached(ttl=3) +async def get_all_models(request: Request) -> dict[str, list]: + log.info("get_all_models()") + + if not request.app.state.config.ENABLE_OPENAI_API: + return {"data": []} + + responses = await get_all_models_responses(request) + + def extract_data(response): + if response and "data" in response: + return response["data"] + if isinstance(response, list): + return response + return None + + def merge_models_lists(model_lists): + log.debug(f"merge_models_lists {model_lists}") + merged_list = [] + + for idx, models in enumerate(model_lists): + if models is not None and "error" not in models: + merged_list.extend( + [ + { + **model, + "name": model.get("name", model["id"]), + "owned_by": "openai", + "openai": model, + "urlIdx": idx, + } + for model in models + if "api.openai.com" + not in request.app.state.config.OPENAI_API_BASE_URLS[idx] + or not any( + name in model["id"] + for name in [ + "babbage", + "dall-e", + "davinci", + "embedding", + "tts", + "whisper", + ] + ) + ] + ) + + return merged_list + + models = {"data": merge_models_lists(map(extract_data, responses))} + log.debug(f"models: {models}") + + request.app.state.OPENAI_MODELS = {model["id"]: model for model in models["data"]} + return models + + +@router.get("/models") +@router.get("/models/{url_idx}") +async def get_models( + request: Request, url_idx: Optional[int] = None, user=Depends(get_verified_user) +): + models = { + "data": [], + } + + if url_idx is None: + models = await get_all_models(request) + else: + url = request.app.state.config.OPENAI_API_BASE_URLS[url_idx] + key = request.app.state.config.OPENAI_API_KEYS[url_idx] + + r = None + async with aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout( + total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST + ) + ) as session: + try: + async with session.get( + f"{url}/models", + headers={ + "Authorization": f"Bearer {key}", + "Content-Type": "application/json", + **( + { + "X-OpenWebUI-User-Name": user.name, + "X-OpenWebUI-User-Id": user.id, + "X-OpenWebUI-User-Email": user.email, + "X-OpenWebUI-User-Role": user.role, + } + if ENABLE_FORWARD_USER_INFO_HEADERS + else {} + ), + }, + ) as r: + if r.status != 200: + # Extract response error details if available + error_detail = f"HTTP Error: {r.status}" + res = await r.json() + if "error" in res: + error_detail = f"External Error: {res['error']}" + raise Exception(error_detail) + + response_data = await r.json() + + # Check if we're calling OpenAI API based on the URL + if "api.openai.com" in url: + # Filter models according to the specified conditions + response_data["data"] = [ + model + for model in response_data.get("data", []) + if not any( + name in model["id"] + for name in [ + "babbage", + "dall-e", + "davinci", + "embedding", + "tts", + "whisper", + ] + ) + ] + + models = response_data + except aiohttp.ClientError as e: + # ClientError covers all aiohttp requests issues + log.exception(f"Client error: {str(e)}") + raise HTTPException( + status_code=500, detail="Open WebUI: Server Connection Error" + ) + except Exception as e: + log.exception(f"Unexpected error: {e}") + error_detail = f"Unexpected error: {str(e)}" + raise HTTPException(status_code=500, detail=error_detail) + + if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL: + models["data"] = get_filtered_models(models, user) + + return models + + +class ConnectionVerificationForm(BaseModel): + url: str + key: str + + +@router.post("/verify") +async def verify_connection( + form_data: ConnectionVerificationForm, user=Depends(get_admin_user) +): + url = form_data.url + key = form_data.key + + async with aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST) + ) as session: + try: + async with session.get( + f"{url}/models", + headers={ + "Authorization": f"Bearer {key}", + "Content-Type": "application/json", + }, + ) as r: + if r.status != 200: + # Extract response error details if available + error_detail = f"HTTP Error: {r.status}" + res = await r.json() + if "error" in res: + error_detail = f"External Error: {res['error']}" + raise Exception(error_detail) + + response_data = await r.json() + return response_data + + except aiohttp.ClientError as e: + # ClientError covers all aiohttp requests issues + log.exception(f"Client error: {str(e)}") + raise HTTPException( + status_code=500, detail="Open WebUI: Server Connection Error" + ) + except Exception as e: + log.exception(f"Unexpected error: {e}") + error_detail = f"Unexpected error: {str(e)}" + raise HTTPException(status_code=500, detail=error_detail) + + +@router.post("/chat/completions") +async def generate_chat_completion( + request: Request, + form_data: dict, + user=Depends(get_verified_user), + bypass_filter: Optional[bool] = False, +): + idx = 0 + payload = {**form_data} + if "metadata" in payload: + del payload["metadata"] + + model_id = form_data.get("model") + model_info = Models.get_model_by_id(model_id) + + # Check model info and override the payload + if model_info: + if model_info.base_model_id: + payload["model"] = model_info.base_model_id + + params = model_info.params.model_dump() + payload = apply_model_params_to_body_openai(params, payload) + payload = apply_model_system_prompt_to_body(params, payload, user) + + # Check if user has access to the model + if not bypass_filter and user.role == "user": + if not ( + user.id == model_info.user_id + or has_access( + user.id, type="read", access_control=model_info.access_control + ) + ): + raise HTTPException( + status_code=403, + detail="Model not found", + ) + elif not bypass_filter: + if user.role != "admin": + raise HTTPException( + status_code=403, + detail="Model not found", + ) + + model = request.app.state.OPENAI_MODELS.get(model_id) + if model: + idx = model["urlIdx"] + else: + raise HTTPException( + status_code=404, + detail="Model not found", + ) + + # Get the API config for the model + api_config = request.app.state.config.OPENAI_API_CONFIGS.get( + request.app.state.config.OPENAI_API_BASE_URLS[idx], {} + ) + + prefix_id = api_config.get("prefix_id", None) + if prefix_id: + payload["model"] = payload["model"].replace(f"{prefix_id}.", "") + + # Add user info to the payload if the model is a pipeline + if "pipeline" in model and model.get("pipeline"): + payload["user"] = { + "name": user.name, + "id": user.id, + "email": user.email, + "role": user.role, + } + + url = request.app.state.config.OPENAI_API_BASE_URLS[idx] + key = request.app.state.config.OPENAI_API_KEYS[idx] + + # Fix: O1 does not support the "max_tokens" parameter, Modify "max_tokens" to "max_completion_tokens" + is_o1 = payload["model"].lower().startswith("o1-") + if is_o1: + payload = openai_o1_handler(payload) + elif "api.openai.com" not in url: + # Remove "max_tokens" from the payload for backward compatibility + if "max_tokens" in payload: + payload["max_completion_tokens"] = payload["max_tokens"] + del payload["max_tokens"] + + # TODO: check if below is needed + # if "max_tokens" in payload and "max_completion_tokens" in payload: + # del payload["max_tokens"] + + # Convert the modified body back to JSON + payload = json.dumps(payload) + + r = None + session = None + streaming = False + response = None + + try: + session = aiohttp.ClientSession( + trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT) + ) + + r = await session.request( + method="POST", + url=f"{url}/chat/completions", + data=payload, + headers={ + "Authorization": f"Bearer {key}", + "Content-Type": "application/json", + **( + { + "HTTP-Referer": "https://openwebui.com/", + "X-Title": "Open WebUI", + } + if "openrouter.ai" in url + else {} + ), + **( + { + "X-OpenWebUI-User-Name": user.name, + "X-OpenWebUI-User-Id": user.id, + "X-OpenWebUI-User-Email": user.email, + "X-OpenWebUI-User-Role": user.role, + } + if ENABLE_FORWARD_USER_INFO_HEADERS + else {} + ), + }, + ) + + # Check if response is SSE + if "text/event-stream" in r.headers.get("Content-Type", ""): + streaming = True + return StreamingResponse( + r.content, + status_code=r.status, + headers=dict(r.headers), + background=BackgroundTask( + cleanup_response, response=r, session=session + ), + ) + else: + try: + response = await r.json() + except Exception as e: + log.error(e) + response = await r.text() + + r.raise_for_status() + return response + except Exception as e: + log.exception(e) + + detail = None + if isinstance(response, dict): + if "error" in response: + detail = f"{response['error']['message'] if 'message' in response['error'] else response['error']}" + elif isinstance(response, str): + detail = response + + raise HTTPException( + status_code=r.status if r else 500, + detail=detail if detail else "Open WebUI: Server Connection Error", + ) + finally: + if not streaming and session: + if r: + r.close() + await session.close() + + +@router.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) +async def proxy(path: str, request: Request, user=Depends(get_verified_user)): + """ + Deprecated: proxy all requests to OpenAI API + """ + + body = await request.body() + + idx = 0 + url = request.app.state.config.OPENAI_API_BASE_URLS[idx] + key = request.app.state.config.OPENAI_API_KEYS[idx] + + r = None + session = None + streaming = False + + try: + session = aiohttp.ClientSession(trust_env=True) + r = await session.request( + method=request.method, + url=f"{url}/{path}", + data=body, + headers={ + "Authorization": f"Bearer {key}", + "Content-Type": "application/json", + **( + { + "X-OpenWebUI-User-Name": user.name, + "X-OpenWebUI-User-Id": user.id, + "X-OpenWebUI-User-Email": user.email, + "X-OpenWebUI-User-Role": user.role, + } + if ENABLE_FORWARD_USER_INFO_HEADERS + else {} + ), + }, + ) + r.raise_for_status() + + # Check if response is SSE + if "text/event-stream" in r.headers.get("Content-Type", ""): + streaming = True + return StreamingResponse( + r.content, + status_code=r.status, + headers=dict(r.headers), + background=BackgroundTask( + cleanup_response, response=r, session=session + ), + ) + else: + response_data = await r.json() + return response_data + + except Exception as e: + log.exception(e) + + detail = None + if r is not None: + try: + res = await r.json() + print(res) + if "error" in res: + detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}" + except Exception: + detail = f"External: {e}" + raise HTTPException( + status_code=r.status if r else 500, + detail=detail if detail else "Open WebUI: Server Connection Error", + ) + finally: + if not streaming and session: + if r: + r.close() + await session.close() diff --git a/backend/open_webui/routers/pipelines.py b/backend/open_webui/routers/pipelines.py new file mode 100644 index 000000000..258c10ee6 --- /dev/null +++ b/backend/open_webui/routers/pipelines.py @@ -0,0 +1,496 @@ +from fastapi import ( + Depends, + FastAPI, + File, + Form, + HTTPException, + Request, + UploadFile, + status, + APIRouter, +) +import os +import logging +import shutil +import requests +from pydantic import BaseModel +from starlette.responses import FileResponse +from typing import Optional + +from open_webui.env import SRC_LOG_LEVELS +from open_webui.config import CACHE_DIR +from open_webui.constants import ERROR_MESSAGES + + +from open_webui.routers.openai import get_all_models_responses + +from open_webui.utils.auth import get_admin_user + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["MAIN"]) + + +################################## +# +# Pipeline Middleware +# +################################## + + +def get_sorted_filters(model_id, models): + filters = [ + model + for model in models.values() + if "pipeline" in model + and "type" in model["pipeline"] + and model["pipeline"]["type"] == "filter" + and ( + model["pipeline"]["pipelines"] == ["*"] + or any( + model_id == target_model_id + for target_model_id in model["pipeline"]["pipelines"] + ) + ) + ] + sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"]) + return sorted_filters + + +def process_pipeline_inlet_filter(request, payload, user, models): + user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role} + model_id = payload["model"] + + sorted_filters = get_sorted_filters(model_id, models) + model = models[model_id] + + if "pipeline" in model: + sorted_filters.append(model) + + for filter in sorted_filters: + r = None + try: + urlIdx = filter["urlIdx"] + + url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx] + key = request.app.state.config.OPENAI_API_KEYS[urlIdx] + + if key == "": + continue + + headers = {"Authorization": f"Bearer {key}"} + r = requests.post( + f"{url}/{filter['id']}/filter/inlet", + headers=headers, + json={ + "user": user, + "body": payload, + }, + ) + + r.raise_for_status() + payload = r.json() + except Exception as e: + # Handle connection error here + print(f"Connection error: {e}") + + if r is not None: + res = r.json() + if "detail" in res: + raise Exception(r.status_code, res["detail"]) + + return payload + + +def process_pipeline_outlet_filter(request, payload, user, models): + user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role} + model_id = payload["model"] + + sorted_filters = get_sorted_filters(model_id, models) + model = models[model_id] + + if "pipeline" in model: + sorted_filters = [model] + sorted_filters + + for filter in sorted_filters: + r = None + try: + urlIdx = filter["urlIdx"] + + url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx] + key = request.app.state.config.OPENAI_API_KEYS[urlIdx] + + if key != "": + r = requests.post( + f"{url}/{filter['id']}/filter/outlet", + headers={"Authorization": f"Bearer {key}"}, + json={ + "user": { + "id": user.id, + "name": user.name, + "email": user.email, + "role": user.role, + }, + "body": data, + }, + ) + + r.raise_for_status() + data = r.json() + except Exception as e: + # Handle connection error here + print(f"Connection error: {e}") + + if r is not None: + try: + res = r.json() + if "detail" in res: + return Exception(r.status_code, res) + except Exception: + pass + + else: + pass + + return payload + + +################################## +# +# Pipelines Endpoints +# +################################## + +router = APIRouter() + + +@router.get("/list") +async def get_pipelines_list(request: Request, user=Depends(get_admin_user)): + responses = await get_all_models_responses(request) + log.debug(f"get_pipelines_list: get_openai_models_responses returned {responses}") + + urlIdxs = [ + idx + for idx, response in enumerate(responses) + if response is not None and "pipelines" in response + ] + + return { + "data": [ + { + "url": request.app.state.config.OPENAI_API_BASE_URLS[urlIdx], + "idx": urlIdx, + } + for urlIdx in urlIdxs + ] + } + + +@router.post("/upload") +async def upload_pipeline( + request: Request, + urlIdx: int = Form(...), + file: UploadFile = File(...), + user=Depends(get_admin_user), +): + print("upload_pipeline", urlIdx, file.filename) + # Check if the uploaded file is a python file + if not (file.filename and file.filename.endswith(".py")): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Only Python (.py) files are allowed.", + ) + + upload_folder = f"{CACHE_DIR}/pipelines" + os.makedirs(upload_folder, exist_ok=True) + file_path = os.path.join(upload_folder, file.filename) + + r = None + try: + # Save the uploaded file + with open(file_path, "wb") as buffer: + shutil.copyfileobj(file.file, buffer) + + url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx] + key = request.app.state.config.OPENAI_API_KEYS[urlIdx] + + with open(file_path, "rb") as f: + files = {"file": f} + r = requests.post( + f"{url}/pipelines/upload", + headers={"Authorization": f"Bearer {key}"}, + files=files, + ) + + r.raise_for_status() + data = r.json() + + return {**data} + except Exception as e: + # Handle connection error here + print(f"Connection error: {e}") + + detail = None + status_code = status.HTTP_404_NOT_FOUND + if r is not None: + status_code = r.status_code + try: + res = r.json() + if "detail" in res: + detail = res["detail"] + except Exception: + pass + + raise HTTPException( + status_code=status_code, + detail=detail if detail else "Pipeline not found", + ) + finally: + # Ensure the file is deleted after the upload is completed or on failure + if os.path.exists(file_path): + os.remove(file_path) + + +class AddPipelineForm(BaseModel): + url: str + urlIdx: int + + +@router.post("/add") +async def add_pipeline( + request: Request, form_data: AddPipelineForm, user=Depends(get_admin_user) +): + r = None + try: + urlIdx = form_data.urlIdx + + url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx] + key = request.app.state.config.OPENAI_API_KEYS[urlIdx] + + r = requests.post( + f"{url}/pipelines/add", + headers={"Authorization": f"Bearer {key}"}, + json={"url": form_data.url}, + ) + + r.raise_for_status() + data = r.json() + + return {**data} + except Exception as e: + # Handle connection error here + print(f"Connection error: {e}") + + detail = None + if r is not None: + try: + res = r.json() + if "detail" in res: + detail = res["detail"] + except Exception: + pass + + raise HTTPException( + status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND), + detail=detail if detail else "Pipeline not found", + ) + + +class DeletePipelineForm(BaseModel): + id: str + urlIdx: int + + +@router.delete("/delete") +async def delete_pipeline( + request: Request, form_data: DeletePipelineForm, user=Depends(get_admin_user) +): + r = None + try: + urlIdx = form_data.urlIdx + + url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx] + key = request.app.state.config.OPENAI_API_KEYS[urlIdx] + + r = requests.delete( + f"{url}/pipelines/delete", + headers={"Authorization": f"Bearer {key}"}, + json={"id": form_data.id}, + ) + + r.raise_for_status() + data = r.json() + + return {**data} + except Exception as e: + # Handle connection error here + print(f"Connection error: {e}") + + detail = None + if r is not None: + try: + res = r.json() + if "detail" in res: + detail = res["detail"] + except Exception: + pass + + raise HTTPException( + status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND), + detail=detail if detail else "Pipeline not found", + ) + + +@router.get("/") +async def get_pipelines( + request: Request, urlIdx: Optional[int] = None, user=Depends(get_admin_user) +): + r = None + try: + url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx] + key = request.app.state.config.OPENAI_API_KEYS[urlIdx] + + r = requests.get(f"{url}/pipelines", headers={"Authorization": f"Bearer {key}"}) + + r.raise_for_status() + data = r.json() + + return {**data} + except Exception as e: + # Handle connection error here + print(f"Connection error: {e}") + + detail = None + if r is not None: + try: + res = r.json() + if "detail" in res: + detail = res["detail"] + except Exception: + pass + + raise HTTPException( + status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND), + detail=detail if detail else "Pipeline not found", + ) + + +@router.get("/{pipeline_id}/valves") +async def get_pipeline_valves( + request: Request, + urlIdx: Optional[int], + pipeline_id: str, + user=Depends(get_admin_user), +): + r = None + try: + url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx] + key = request.app.state.config.OPENAI_API_KEYS[urlIdx] + + r = requests.get( + f"{url}/{pipeline_id}/valves", headers={"Authorization": f"Bearer {key}"} + ) + + r.raise_for_status() + data = r.json() + + return {**data} + except Exception as e: + # Handle connection error here + print(f"Connection error: {e}") + + detail = None + if r is not None: + try: + res = r.json() + if "detail" in res: + detail = res["detail"] + except Exception: + pass + + raise HTTPException( + status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND), + detail=detail if detail else "Pipeline not found", + ) + + +@router.get("/{pipeline_id}/valves/spec") +async def get_pipeline_valves_spec( + request: Request, + urlIdx: Optional[int], + pipeline_id: str, + user=Depends(get_admin_user), +): + r = None + try: + url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx] + key = request.app.state.config.OPENAI_API_KEYS[urlIdx] + + r = requests.get( + f"{url}/{pipeline_id}/valves/spec", + headers={"Authorization": f"Bearer {key}"}, + ) + + r.raise_for_status() + data = r.json() + + return {**data} + except Exception as e: + # Handle connection error here + print(f"Connection error: {e}") + + detail = None + if r is not None: + try: + res = r.json() + if "detail" in res: + detail = res["detail"] + except Exception: + pass + + raise HTTPException( + status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND), + detail=detail if detail else "Pipeline not found", + ) + + +@router.post("/{pipeline_id}/valves/update") +async def update_pipeline_valves( + request: Request, + urlIdx: Optional[int], + pipeline_id: str, + form_data: dict, + user=Depends(get_admin_user), +): + r = None + try: + url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx] + key = request.app.state.config.OPENAI_API_KEYS[urlIdx] + + r = requests.post( + f"{url}/{pipeline_id}/valves/update", + headers={"Authorization": f"Bearer {key}"}, + json={**form_data}, + ) + + r.raise_for_status() + data = r.json() + + return {**data} + except Exception as e: + # Handle connection error here + print(f"Connection error: {e}") + + detail = None + + if r is not None: + try: + res = r.json() + if "detail" in res: + detail = res["detail"] + except Exception: + pass + + raise HTTPException( + status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND), + detail=detail if detail else "Pipeline not found", + ) diff --git a/backend/open_webui/routers/prompts.py b/backend/open_webui/routers/prompts.py new file mode 100644 index 000000000..4f1c48482 --- /dev/null +++ b/backend/open_webui/routers/prompts.py @@ -0,0 +1,152 @@ +from typing import Optional + +from open_webui.models.prompts import ( + PromptForm, + PromptUserResponse, + PromptModel, + Prompts, +) +from open_webui.constants import ERROR_MESSAGES +from fastapi import APIRouter, Depends, HTTPException, status, Request +from open_webui.utils.auth import get_admin_user, get_verified_user +from open_webui.utils.access_control import has_access, has_permission + +router = APIRouter() + +############################ +# GetPrompts +############################ + + +@router.get("/", response_model=list[PromptModel]) +async def get_prompts(user=Depends(get_verified_user)): + if user.role == "admin": + prompts = Prompts.get_prompts() + else: + prompts = Prompts.get_prompts_by_user_id(user.id, "read") + + return prompts + + +@router.get("/list", response_model=list[PromptUserResponse]) +async def get_prompt_list(user=Depends(get_verified_user)): + if user.role == "admin": + prompts = Prompts.get_prompts() + else: + prompts = Prompts.get_prompts_by_user_id(user.id, "write") + + return prompts + + +############################ +# CreateNewPrompt +############################ + + +@router.post("/create", response_model=Optional[PromptModel]) +async def create_new_prompt( + request: Request, form_data: PromptForm, user=Depends(get_verified_user) +): + if user.role != "admin" and not has_permission( + user.id, "workspace.prompts", request.app.state.config.USER_PERMISSIONS + ): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.UNAUTHORIZED, + ) + + prompt = Prompts.get_prompt_by_command(form_data.command) + if prompt is None: + prompt = Prompts.insert_new_prompt(user.id, form_data) + + if prompt: + return prompt + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT(), + ) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.COMMAND_TAKEN, + ) + + +############################ +# GetPromptByCommand +############################ + + +@router.get("/command/{command}", response_model=Optional[PromptModel]) +async def get_prompt_by_command(command: str, user=Depends(get_verified_user)): + prompt = Prompts.get_prompt_by_command(f"/{command}") + + if prompt: + if ( + user.role == "admin" + or prompt.user_id == user.id + or has_access(user.id, "read", prompt.access_control) + ): + return prompt + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + +############################ +# UpdatePromptByCommand +############################ + + +@router.post("/command/{command}/update", response_model=Optional[PromptModel]) +async def update_prompt_by_command( + command: str, + form_data: PromptForm, + user=Depends(get_verified_user), +): + prompt = Prompts.get_prompt_by_command(f"/{command}") + if not prompt: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + if prompt.user_id != user.id and user.role != "admin": + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) + + prompt = Prompts.update_prompt_by_command(f"/{command}", form_data) + if prompt: + return prompt + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) + + +############################ +# DeletePromptByCommand +############################ + + +@router.delete("/command/{command}/delete", response_model=bool) +async def delete_prompt_by_command(command: str, user=Depends(get_verified_user)): + prompt = Prompts.get_prompt_by_command(f"/{command}") + if not prompt: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + if prompt.user_id != user.id and user.role != "admin": + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) + + result = Prompts.delete_prompt_by_command(f"/{command}") + return result diff --git a/backend/open_webui/routers/retrieval.py b/backend/open_webui/routers/retrieval.py new file mode 100644 index 000000000..e577f70f1 --- /dev/null +++ b/backend/open_webui/routers/retrieval.py @@ -0,0 +1,1430 @@ +import json +import logging +import mimetypes +import os +import shutil + +import uuid +from datetime import datetime +from pathlib import Path +from typing import Iterator, Optional, Sequence, Union + +from fastapi import ( + Depends, + FastAPI, + File, + Form, + HTTPException, + UploadFile, + Request, + status, + APIRouter, +) +from fastapi.middleware.cors import CORSMiddleware +from pydantic import BaseModel +import tiktoken + + +from langchain.text_splitter import RecursiveCharacterTextSplitter, TokenTextSplitter +from langchain_core.documents import Document + +from open_webui.models.files import Files +from open_webui.models.knowledge import Knowledges +from open_webui.storage.provider import Storage + + +from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT + +# Document loaders +from open_webui.retrieval.loaders.main import Loader +from open_webui.retrieval.loaders.youtube import YoutubeLoader + +# Web search engines +from open_webui.retrieval.web.main import SearchResult +from open_webui.retrieval.web.utils import get_web_loader +from open_webui.retrieval.web.brave import search_brave +from open_webui.retrieval.web.kagi import search_kagi +from open_webui.retrieval.web.mojeek import search_mojeek +from open_webui.retrieval.web.duckduckgo import search_duckduckgo +from open_webui.retrieval.web.google_pse import search_google_pse +from open_webui.retrieval.web.jina_search import search_jina +from open_webui.retrieval.web.searchapi import search_searchapi +from open_webui.retrieval.web.searxng import search_searxng +from open_webui.retrieval.web.serper import search_serper +from open_webui.retrieval.web.serply import search_serply +from open_webui.retrieval.web.serpstack import search_serpstack +from open_webui.retrieval.web.tavily import search_tavily +from open_webui.retrieval.web.bing import search_bing + + +from open_webui.retrieval.utils import ( + get_embedding_function, + get_model_path, + query_collection, + query_collection_with_hybrid_search, + query_doc, + query_doc_with_hybrid_search, +) +from open_webui.utils.misc import ( + calculate_sha256_string, +) +from open_webui.utils.auth import get_admin_user, get_verified_user + + +from open_webui.config import ( + ENV, + RAG_EMBEDDING_MODEL_AUTO_UPDATE, + RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE, + RAG_RERANKING_MODEL_AUTO_UPDATE, + RAG_RERANKING_MODEL_TRUST_REMOTE_CODE, + UPLOAD_DIR, + DEFAULT_LOCALE, +) +from open_webui.env import ( + SRC_LOG_LEVELS, + DEVICE_TYPE, + DOCKER, +) +from open_webui.constants import ERROR_MESSAGES + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["RAG"]) + +########################################## +# +# Utility functions +# +########################################## + + +def get_ef( + engine: str, + embedding_model: str, + auto_update: bool = False, +): + ef = None + if embedding_model and engine == "": + from sentence_transformers import SentenceTransformer + + try: + ef = SentenceTransformer( + get_model_path(embedding_model, auto_update), + device=DEVICE_TYPE, + trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE, + ) + except Exception as e: + log.debug(f"Error loading SentenceTransformer: {e}") + + return ef + + +def get_rf( + reranking_model: str, + auto_update: bool = False, +): + rf = None + if reranking_model: + if any(model in reranking_model for model in ["jinaai/jina-colbert-v2"]): + try: + from open_webui.retrieval.models.colbert import ColBERT + + rf = ColBERT( + get_model_path(reranking_model, auto_update), + env="docker" if DOCKER else None, + ) + + except Exception as e: + log.error(f"ColBERT: {e}") + raise Exception(ERROR_MESSAGES.DEFAULT(e)) + else: + import sentence_transformers + + try: + rf = sentence_transformers.CrossEncoder( + get_model_path(reranking_model, auto_update), + device=DEVICE_TYPE, + trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE, + ) + except: + log.error("CrossEncoder error") + raise Exception(ERROR_MESSAGES.DEFAULT("CrossEncoder error")) + return rf + + +########################################## +# +# API routes +# +########################################## + + +router = APIRouter() + + +class CollectionNameForm(BaseModel): + collection_name: Optional[str] = None + + +class ProcessUrlForm(CollectionNameForm): + url: str + + +class SearchForm(CollectionNameForm): + query: str + + +@router.get("/") +async def get_status(request: Request): + return { + "status": True, + "chunk_size": request.app.state.config.CHUNK_SIZE, + "chunk_overlap": request.app.state.config.CHUNK_OVERLAP, + "template": request.app.state.config.RAG_TEMPLATE, + "embedding_engine": request.app.state.config.RAG_EMBEDDING_ENGINE, + "embedding_model": request.app.state.config.RAG_EMBEDDING_MODEL, + "reranking_model": request.app.state.config.RAG_RERANKING_MODEL, + "embedding_batch_size": request.app.state.config.RAG_EMBEDDING_BATCH_SIZE, + } + + +@router.get("/embedding") +async def get_embedding_config(request: Request, user=Depends(get_admin_user)): + return { + "status": True, + "embedding_engine": request.app.state.config.RAG_EMBEDDING_ENGINE, + "embedding_model": request.app.state.config.RAG_EMBEDDING_MODEL, + "embedding_batch_size": request.app.state.config.RAG_EMBEDDING_BATCH_SIZE, + "openai_config": { + "url": request.app.state.config.RAG_OPENAI_API_BASE_URL, + "key": request.app.state.config.RAG_OPENAI_API_KEY, + }, + "ollama_config": { + "url": request.app.state.config.RAG_OLLAMA_BASE_URL, + "key": request.app.state.config.RAG_OLLAMA_API_KEY, + }, + } + + +@router.get("/reranking") +async def get_reraanking_config(request: Request, user=Depends(get_admin_user)): + return { + "status": True, + "reranking_model": request.app.state.config.RAG_RERANKING_MODEL, + } + + +class OpenAIConfigForm(BaseModel): + url: str + key: str + + +class OllamaConfigForm(BaseModel): + url: str + key: str + + +class EmbeddingModelUpdateForm(BaseModel): + openai_config: Optional[OpenAIConfigForm] = None + ollama_config: Optional[OllamaConfigForm] = None + embedding_engine: str + embedding_model: str + embedding_batch_size: Optional[int] = 1 + + +@router.post("/embedding/update") +async def update_embedding_config( + request: Request, form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user) +): + log.info( + f"Updating embedding model: {request.app.state.config.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}" + ) + try: + request.app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine + request.app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model + + if request.app.state.config.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]: + if form_data.openai_config is not None: + request.app.state.config.RAG_OPENAI_API_BASE_URL = ( + form_data.openai_config.url + ) + request.app.state.config.RAG_OPENAI_API_KEY = ( + form_data.openai_config.key + ) + + if form_data.ollama_config is not None: + request.app.state.config.RAG_OLLAMA_BASE_URL = ( + form_data.ollama_config.url + ) + request.app.state.config.RAG_OLLAMA_API_KEY = ( + form_data.ollama_config.key + ) + + request.app.state.config.RAG_EMBEDDING_BATCH_SIZE = ( + form_data.embedding_batch_size + ) + + request.app.state.ef = get_ef( + request.app.state.config.RAG_EMBEDDING_ENGINE, + request.app.state.config.RAG_EMBEDDING_MODEL, + ) + + request.app.state.EMBEDDING_FUNCTION = get_embedding_function( + request.app.state.config.RAG_EMBEDDING_ENGINE, + request.app.state.config.RAG_EMBEDDING_MODEL, + request.app.state.ef, + ( + request.app.state.config.RAG_OPENAI_API_BASE_URL + if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai" + else request.app.state.config.RAG_OLLAMA_BASE_URL + ), + ( + request.app.state.config.RAG_OPENAI_API_KEY + if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai" + else request.app.state.config.RAG_OLLAMA_API_KEY + ), + request.app.state.config.RAG_EMBEDDING_BATCH_SIZE, + ) + + return { + "status": True, + "embedding_engine": request.app.state.config.RAG_EMBEDDING_ENGINE, + "embedding_model": request.app.state.config.RAG_EMBEDDING_MODEL, + "embedding_batch_size": request.app.state.config.RAG_EMBEDDING_BATCH_SIZE, + "openai_config": { + "url": request.app.state.config.RAG_OPENAI_API_BASE_URL, + "key": request.app.state.config.RAG_OPENAI_API_KEY, + }, + "ollama_config": { + "url": request.app.state.config.RAG_OLLAMA_BASE_URL, + "key": request.app.state.config.RAG_OLLAMA_API_KEY, + }, + } + except Exception as e: + log.exception(f"Problem updating embedding model: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=ERROR_MESSAGES.DEFAULT(e), + ) + + +class RerankingModelUpdateForm(BaseModel): + reranking_model: str + + +@router.post("/reranking/update") +async def update_reranking_config( + request: Request, form_data: RerankingModelUpdateForm, user=Depends(get_admin_user) +): + log.info( + f"Updating reranking model: {request.app.state.config.RAG_RERANKING_MODEL} to {form_data.reranking_model}" + ) + try: + request.app.state.config.RAG_RERANKING_MODEL = form_data.reranking_model + + try: + request.app.state.rf = get_rf( + request.app.state.config.RAG_RERANKING_MODEL, + True, + ) + except Exception as e: + log.error(f"Error loading reranking model: {e}") + request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = False + + return { + "status": True, + "reranking_model": request.app.state.config.RAG_RERANKING_MODEL, + } + except Exception as e: + log.exception(f"Problem updating reranking model: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=ERROR_MESSAGES.DEFAULT(e), + ) + + +@router.get("/config") +async def get_rag_config(request: Request, user=Depends(get_admin_user)): + return { + "status": True, + "pdf_extract_images": request.app.state.config.PDF_EXTRACT_IMAGES, + "content_extraction": { + "engine": request.app.state.config.CONTENT_EXTRACTION_ENGINE, + "tika_server_url": request.app.state.config.TIKA_SERVER_URL, + }, + "chunk": { + "text_splitter": request.app.state.config.TEXT_SPLITTER, + "chunk_size": request.app.state.config.CHUNK_SIZE, + "chunk_overlap": request.app.state.config.CHUNK_OVERLAP, + }, + "file": { + "max_size": request.app.state.config.FILE_MAX_SIZE, + "max_count": request.app.state.config.FILE_MAX_COUNT, + }, + "youtube": { + "language": request.app.state.config.YOUTUBE_LOADER_LANGUAGE, + "translation": request.app.state.YOUTUBE_LOADER_TRANSLATION, + "proxy_url": request.app.state.config.YOUTUBE_LOADER_PROXY_URL, + }, + "web": { + "web_loader_ssl_verification": request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, + "search": { + "enabled": request.app.state.config.ENABLE_RAG_WEB_SEARCH, + "engine": request.app.state.config.RAG_WEB_SEARCH_ENGINE, + "searxng_query_url": request.app.state.config.SEARXNG_QUERY_URL, + "google_pse_api_key": request.app.state.config.GOOGLE_PSE_API_KEY, + "google_pse_engine_id": request.app.state.config.GOOGLE_PSE_ENGINE_ID, + "brave_search_api_key": request.app.state.config.BRAVE_SEARCH_API_KEY, + "kagi_search_api_key": request.app.state.config.KAGI_SEARCH_API_KEY, + "mojeek_search_api_key": request.app.state.config.MOJEEK_SEARCH_API_KEY, + "serpstack_api_key": request.app.state.config.SERPSTACK_API_KEY, + "serpstack_https": request.app.state.config.SERPSTACK_HTTPS, + "serper_api_key": request.app.state.config.SERPER_API_KEY, + "serply_api_key": request.app.state.config.SERPLY_API_KEY, + "tavily_api_key": request.app.state.config.TAVILY_API_KEY, + "searchapi_api_key": request.app.state.config.SEARCHAPI_API_KEY, + "seaarchapi_engine": request.app.state.config.SEARCHAPI_ENGINE, + "jina_api_key": request.app.state.config.JINA_API_KEY, + "bing_search_v7_endpoint": request.app.state.config.BING_SEARCH_V7_ENDPOINT, + "bing_search_v7_subscription_key": request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY, + "result_count": request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, + "concurrent_requests": request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS, + }, + }, + } + + +class FileConfig(BaseModel): + max_size: Optional[int] = None + max_count: Optional[int] = None + + +class ContentExtractionConfig(BaseModel): + engine: str = "" + tika_server_url: Optional[str] = None + + +class ChunkParamUpdateForm(BaseModel): + text_splitter: Optional[str] = None + chunk_size: int + chunk_overlap: int + + +class YoutubeLoaderConfig(BaseModel): + language: list[str] + translation: Optional[str] = None + proxy_url: str = "" + + +class WebSearchConfig(BaseModel): + enabled: bool + engine: Optional[str] = None + searxng_query_url: Optional[str] = None + google_pse_api_key: Optional[str] = None + google_pse_engine_id: Optional[str] = None + brave_search_api_key: Optional[str] = None + kagi_search_api_key: Optional[str] = None + mojeek_search_api_key: Optional[str] = None + serpstack_api_key: Optional[str] = None + serpstack_https: Optional[bool] = None + serper_api_key: Optional[str] = None + serply_api_key: Optional[str] = None + tavily_api_key: Optional[str] = None + searchapi_api_key: Optional[str] = None + searchapi_engine: Optional[str] = None + jina_api_key: Optional[str] = None + bing_search_v7_endpoint: Optional[str] = None + bing_search_v7_subscription_key: Optional[str] = None + result_count: Optional[int] = None + concurrent_requests: Optional[int] = None + + +class WebConfig(BaseModel): + search: WebSearchConfig + web_loader_ssl_verification: Optional[bool] = None + + +class ConfigUpdateForm(BaseModel): + pdf_extract_images: Optional[bool] = None + file: Optional[FileConfig] = None + content_extraction: Optional[ContentExtractionConfig] = None + chunk: Optional[ChunkParamUpdateForm] = None + youtube: Optional[YoutubeLoaderConfig] = None + web: Optional[WebConfig] = None + + +@router.post("/config/update") +async def update_rag_config( + request: Request, form_data: ConfigUpdateForm, user=Depends(get_admin_user) +): + request.app.state.config.PDF_EXTRACT_IMAGES = ( + form_data.pdf_extract_images + if form_data.pdf_extract_images is not None + else request.app.state.config.PDF_EXTRACT_IMAGES + ) + + if form_data.file is not None: + request.app.state.config.FILE_MAX_SIZE = form_data.file.max_size + request.app.state.config.FILE_MAX_COUNT = form_data.file.max_count + + if form_data.content_extraction is not None: + log.info(f"Updating text settings: {form_data.content_extraction}") + request.app.state.config.CONTENT_EXTRACTION_ENGINE = ( + form_data.content_extraction.engine + ) + request.app.state.config.TIKA_SERVER_URL = ( + form_data.content_extraction.tika_server_url + ) + + if form_data.chunk is not None: + request.app.state.config.TEXT_SPLITTER = form_data.chunk.text_splitter + request.app.state.config.CHUNK_SIZE = form_data.chunk.chunk_size + request.app.state.config.CHUNK_OVERLAP = form_data.chunk.chunk_overlap + + if form_data.youtube is not None: + request.app.state.config.YOUTUBE_LOADER_LANGUAGE = form_data.youtube.language + request.app.state.config.YOUTUBE_LOADER_PROXY_URL = form_data.youtube.proxy_url + request.app.state.YOUTUBE_LOADER_TRANSLATION = form_data.youtube.translation + + if form_data.web is not None: + request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = ( + # Note: When UI "Bypass SSL verification for Websites"=True then ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION=False + form_data.web.web_loader_ssl_verification + ) + + request.app.state.config.ENABLE_RAG_WEB_SEARCH = form_data.web.search.enabled + request.app.state.config.RAG_WEB_SEARCH_ENGINE = form_data.web.search.engine + request.app.state.config.SEARXNG_QUERY_URL = ( + form_data.web.search.searxng_query_url + ) + request.app.state.config.GOOGLE_PSE_API_KEY = ( + form_data.web.search.google_pse_api_key + ) + request.app.state.config.GOOGLE_PSE_ENGINE_ID = ( + form_data.web.search.google_pse_engine_id + ) + request.app.state.config.BRAVE_SEARCH_API_KEY = ( + form_data.web.search.brave_search_api_key + ) + request.app.state.config.KAGI_SEARCH_API_KEY = ( + form_data.web.search.kagi_search_api_key + ) + request.app.state.config.MOJEEK_SEARCH_API_KEY = ( + form_data.web.search.mojeek_search_api_key + ) + request.app.state.config.SERPSTACK_API_KEY = ( + form_data.web.search.serpstack_api_key + ) + request.app.state.config.SERPSTACK_HTTPS = form_data.web.search.serpstack_https + request.app.state.config.SERPER_API_KEY = form_data.web.search.serper_api_key + request.app.state.config.SERPLY_API_KEY = form_data.web.search.serply_api_key + request.app.state.config.TAVILY_API_KEY = form_data.web.search.tavily_api_key + request.app.state.config.SEARCHAPI_API_KEY = ( + form_data.web.search.searchapi_api_key + ) + request.app.state.config.SEARCHAPI_ENGINE = ( + form_data.web.search.searchapi_engine + ) + + request.app.state.config.JINA_API_KEY = form_data.web.search.jina_api_key + request.app.state.config.BING_SEARCH_V7_ENDPOINT = ( + form_data.web.search.bing_search_v7_endpoint + ) + request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY = ( + form_data.web.search.bing_search_v7_subscription_key + ) + + request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = ( + form_data.web.search.result_count + ) + request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = ( + form_data.web.search.concurrent_requests + ) + + return { + "status": True, + "pdf_extract_images": request.app.state.config.PDF_EXTRACT_IMAGES, + "file": { + "max_size": request.app.state.config.FILE_MAX_SIZE, + "max_count": request.app.state.config.FILE_MAX_COUNT, + }, + "content_extraction": { + "engine": request.app.state.config.CONTENT_EXTRACTION_ENGINE, + "tika_server_url": request.app.state.config.TIKA_SERVER_URL, + }, + "chunk": { + "text_splitter": request.app.state.config.TEXT_SPLITTER, + "chunk_size": request.app.state.config.CHUNK_SIZE, + "chunk_overlap": request.app.state.config.CHUNK_OVERLAP, + }, + "youtube": { + "language": request.app.state.config.YOUTUBE_LOADER_LANGUAGE, + "proxy_url": request.app.state.config.YOUTUBE_LOADER_PROXY_URL, + "translation": request.app.state.YOUTUBE_LOADER_TRANSLATION, + }, + "web": { + "web_loader_ssl_verification": request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, + "search": { + "enabled": request.app.state.config.ENABLE_RAG_WEB_SEARCH, + "engine": request.app.state.config.RAG_WEB_SEARCH_ENGINE, + "searxng_query_url": request.app.state.config.SEARXNG_QUERY_URL, + "google_pse_api_key": request.app.state.config.GOOGLE_PSE_API_KEY, + "google_pse_engine_id": request.app.state.config.GOOGLE_PSE_ENGINE_ID, + "brave_search_api_key": request.app.state.config.BRAVE_SEARCH_API_KEY, + "kagi_search_api_key": request.app.state.config.KAGI_SEARCH_API_KEY, + "mojeek_search_api_key": request.app.state.config.MOJEEK_SEARCH_API_KEY, + "serpstack_api_key": request.app.state.config.SERPSTACK_API_KEY, + "serpstack_https": request.app.state.config.SERPSTACK_HTTPS, + "serper_api_key": request.app.state.config.SERPER_API_KEY, + "serply_api_key": request.app.state.config.SERPLY_API_KEY, + "serachapi_api_key": request.app.state.config.SEARCHAPI_API_KEY, + "searchapi_engine": request.app.state.config.SEARCHAPI_ENGINE, + "tavily_api_key": request.app.state.config.TAVILY_API_KEY, + "jina_api_key": request.app.state.config.JINA_API_KEY, + "bing_search_v7_endpoint": request.app.state.config.BING_SEARCH_V7_ENDPOINT, + "bing_search_v7_subscription_key": request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY, + "result_count": request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, + "concurrent_requests": request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS, + }, + }, + } + + +@router.get("/template") +async def get_rag_template(request: Request, user=Depends(get_verified_user)): + return { + "status": True, + "template": request.app.state.config.RAG_TEMPLATE, + } + + +@router.get("/query/settings") +async def get_query_settings(request: Request, user=Depends(get_admin_user)): + return { + "status": True, + "template": request.app.state.config.RAG_TEMPLATE, + "k": request.app.state.config.TOP_K, + "r": request.app.state.config.RELEVANCE_THRESHOLD, + "hybrid": request.app.state.config.ENABLE_RAG_HYBRID_SEARCH, + } + + +class QuerySettingsForm(BaseModel): + k: Optional[int] = None + r: Optional[float] = None + template: Optional[str] = None + hybrid: Optional[bool] = None + + +@router.post("/query/settings/update") +async def update_query_settings( + request: Request, form_data: QuerySettingsForm, user=Depends(get_admin_user) +): + request.app.state.config.RAG_TEMPLATE = form_data.template + request.app.state.config.TOP_K = form_data.k if form_data.k else 4 + request.app.state.config.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0 + + request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = ( + form_data.hybrid if form_data.hybrid else False + ) + + return { + "status": True, + "template": request.app.state.config.RAG_TEMPLATE, + "k": request.app.state.config.TOP_K, + "r": request.app.state.config.RELEVANCE_THRESHOLD, + "hybrid": request.app.state.config.ENABLE_RAG_HYBRID_SEARCH, + } + + +#################################### +# +# Document process and retrieval +# +#################################### + + +def save_docs_to_vector_db( + request: Request, + docs, + collection_name, + metadata: Optional[dict] = None, + overwrite: bool = False, + split: bool = True, + add: bool = False, +) -> bool: + def _get_docs_info(docs: list[Document]) -> str: + docs_info = set() + + # Trying to select relevant metadata identifying the document. + for doc in docs: + metadata = getattr(doc, "metadata", {}) + doc_name = metadata.get("name", "") + if not doc_name: + doc_name = metadata.get("title", "") + if not doc_name: + doc_name = metadata.get("source", "") + if doc_name: + docs_info.add(doc_name) + + return ", ".join(docs_info) + + log.info( + f"save_docs_to_vector_db: document {_get_docs_info(docs)} {collection_name}" + ) + + # Check if entries with the same hash (metadata.hash) already exist + if metadata and "hash" in metadata: + result = VECTOR_DB_CLIENT.query( + collection_name=collection_name, + filter={"hash": metadata["hash"]}, + ) + + if result is not None: + existing_doc_ids = result.ids[0] + if existing_doc_ids: + log.info(f"Document with hash {metadata['hash']} already exists") + raise ValueError(ERROR_MESSAGES.DUPLICATE_CONTENT) + + if split: + if request.app.state.config.TEXT_SPLITTER in ["", "character"]: + text_splitter = RecursiveCharacterTextSplitter( + chunk_size=request.app.state.config.CHUNK_SIZE, + chunk_overlap=request.app.state.config.CHUNK_OVERLAP, + add_start_index=True, + ) + elif request.app.state.config.TEXT_SPLITTER == "token": + log.info( + f"Using token text splitter: {request.app.state.config.TIKTOKEN_ENCODING_NAME}" + ) + + tiktoken.get_encoding(str(request.app.state.config.TIKTOKEN_ENCODING_NAME)) + text_splitter = TokenTextSplitter( + encoding_name=str(request.app.state.config.TIKTOKEN_ENCODING_NAME), + chunk_size=request.app.state.config.CHUNK_SIZE, + chunk_overlap=request.app.state.config.CHUNK_OVERLAP, + add_start_index=True, + ) + else: + raise ValueError(ERROR_MESSAGES.DEFAULT("Invalid text splitter")) + + docs = text_splitter.split_documents(docs) + + if len(docs) == 0: + raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT) + + texts = [doc.page_content for doc in docs] + metadatas = [ + { + **doc.metadata, + **(metadata if metadata else {}), + "embedding_config": json.dumps( + { + "engine": request.app.state.config.RAG_EMBEDDING_ENGINE, + "model": request.app.state.config.RAG_EMBEDDING_MODEL, + } + ), + } + for doc in docs + ] + + # ChromaDB does not like datetime formats + # for meta-data so convert them to string. + for metadata in metadatas: + for key, value in metadata.items(): + if isinstance(value, datetime): + metadata[key] = str(value) + + try: + if VECTOR_DB_CLIENT.has_collection(collection_name=collection_name): + log.info(f"collection {collection_name} already exists") + + if overwrite: + VECTOR_DB_CLIENT.delete_collection(collection_name=collection_name) + log.info(f"deleting existing collection {collection_name}") + elif add is False: + log.info( + f"collection {collection_name} already exists, overwrite is False and add is False" + ) + return True + + log.info(f"adding to collection {collection_name}") + embedding_function = get_embedding_function( + request.app.state.config.RAG_EMBEDDING_ENGINE, + request.app.state.config.RAG_EMBEDDING_MODEL, + request.app.state.ef, + ( + request.app.state.config.RAG_OPENAI_API_BASE_URL + if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai" + else request.app.state.config.RAG_OLLAMA_BASE_URL + ), + ( + request.app.state.config.RAG_OPENAI_API_KEY + if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai" + else request.app.state.config.RAG_OLLAMA_API_KEY + ), + request.app.state.config.RAG_EMBEDDING_BATCH_SIZE, + ) + + embeddings = embedding_function( + list(map(lambda x: x.replace("\n", " "), texts)) + ) + + items = [ + { + "id": str(uuid.uuid4()), + "text": text, + "vector": embeddings[idx], + "metadata": metadatas[idx], + } + for idx, text in enumerate(texts) + ] + + VECTOR_DB_CLIENT.insert( + collection_name=collection_name, + items=items, + ) + + return True + except Exception as e: + log.exception(e) + raise e + + +class ProcessFileForm(BaseModel): + file_id: str + content: Optional[str] = None + collection_name: Optional[str] = None + + +@router.post("/process/file") +def process_file( + request: Request, + form_data: ProcessFileForm, + user=Depends(get_verified_user), +): + try: + file = Files.get_file_by_id(form_data.file_id) + + collection_name = form_data.collection_name + + if collection_name is None: + collection_name = f"file-{file.id}" + + if form_data.content: + # Update the content in the file + # Usage: /files/{file_id}/data/content/update + + VECTOR_DB_CLIENT.delete_collection(collection_name=f"file-{file.id}") + + docs = [ + Document( + page_content=form_data.content.replace("
", "\n"), + metadata={ + **file.meta, + "name": file.filename, + "created_by": file.user_id, + "file_id": file.id, + "source": file.filename, + }, + ) + ] + + text_content = form_data.content + elif form_data.collection_name: + # Check if the file has already been processed and save the content + # Usage: /knowledge/{id}/file/add, /knowledge/{id}/file/update + + result = VECTOR_DB_CLIENT.query( + collection_name=f"file-{file.id}", filter={"file_id": file.id} + ) + + if result is not None and len(result.ids[0]) > 0: + docs = [ + Document( + page_content=result.documents[0][idx], + metadata=result.metadatas[0][idx], + ) + for idx, id in enumerate(result.ids[0]) + ] + else: + docs = [ + Document( + page_content=file.data.get("content", ""), + metadata={ + **file.meta, + "name": file.filename, + "created_by": file.user_id, + "file_id": file.id, + "source": file.filename, + }, + ) + ] + + text_content = file.data.get("content", "") + else: + # Process the file and save the content + # Usage: /files/ + file_path = file.path + if file_path: + file_path = Storage.get_file(file_path) + loader = Loader( + engine=request.app.state.config.CONTENT_EXTRACTION_ENGINE, + TIKA_SERVER_URL=request.app.state.config.TIKA_SERVER_URL, + PDF_EXTRACT_IMAGES=request.app.state.config.PDF_EXTRACT_IMAGES, + ) + docs = loader.load( + file.filename, file.meta.get("content_type"), file_path + ) + + docs = [ + Document( + page_content=doc.page_content, + metadata={ + **doc.metadata, + "name": file.filename, + "created_by": file.user_id, + "file_id": file.id, + "source": file.filename, + }, + ) + for doc in docs + ] + else: + docs = [ + Document( + page_content=file.data.get("content", ""), + metadata={ + **file.meta, + "name": file.filename, + "created_by": file.user_id, + "file_id": file.id, + "source": file.filename, + }, + ) + ] + text_content = " ".join([doc.page_content for doc in docs]) + + log.debug(f"text_content: {text_content}") + Files.update_file_data_by_id( + file.id, + {"content": text_content}, + ) + + hash = calculate_sha256_string(text_content) + Files.update_file_hash_by_id(file.id, hash) + + try: + result = save_docs_to_vector_db( + request, + docs=docs, + collection_name=collection_name, + metadata={ + "file_id": file.id, + "name": file.filename, + "hash": hash, + }, + add=(True if form_data.collection_name else False), + ) + + if result: + Files.update_file_metadata_by_id( + file.id, + { + "collection_name": collection_name, + }, + ) + + return { + "status": True, + "collection_name": collection_name, + "filename": file.filename, + "content": text_content, + } + except Exception as e: + raise e + except Exception as e: + log.exception(e) + if "No pandoc was found" in str(e): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.PANDOC_NOT_INSTALLED, + ) + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=str(e), + ) + + +class ProcessTextForm(BaseModel): + name: str + content: str + collection_name: Optional[str] = None + + +@router.post("/process/text") +def process_text( + request: Request, + form_data: ProcessTextForm, + user=Depends(get_verified_user), +): + collection_name = form_data.collection_name + if collection_name is None: + collection_name = calculate_sha256_string(form_data.content) + + docs = [ + Document( + page_content=form_data.content, + metadata={"name": form_data.name, "created_by": user.id}, + ) + ] + text_content = form_data.content + log.debug(f"text_content: {text_content}") + + result = save_docs_to_vector_db(request, docs, collection_name) + if result: + return { + "status": True, + "collection_name": collection_name, + "content": text_content, + } + else: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=ERROR_MESSAGES.DEFAULT(), + ) + + +@router.post("/process/youtube") +def process_youtube_video( + request: Request, form_data: ProcessUrlForm, user=Depends(get_verified_user) +): + try: + collection_name = form_data.collection_name + if not collection_name: + collection_name = calculate_sha256_string(form_data.url)[:63] + + loader = YoutubeLoader( + form_data.url, + language=request.app.state.config.YOUTUBE_LOADER_LANGUAGE, + proxy_url=request.app.state.config.YOUTUBE_LOADER_PROXY_URL, + ) + + docs = loader.load() + content = " ".join([doc.page_content for doc in docs]) + log.debug(f"text_content: {content}") + + save_docs_to_vector_db(request, docs, collection_name, overwrite=True) + + return { + "status": True, + "collection_name": collection_name, + "filename": form_data.url, + "file": { + "data": { + "content": content, + }, + "meta": { + "name": form_data.url, + }, + }, + } + except Exception as e: + log.exception(e) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT(e), + ) + + +@router.post("/process/web") +def process_web( + request: Request, form_data: ProcessUrlForm, user=Depends(get_verified_user) +): + try: + collection_name = form_data.collection_name + if not collection_name: + collection_name = calculate_sha256_string(form_data.url)[:63] + + loader = get_web_loader( + form_data.url, + verify_ssl=request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, + requests_per_second=request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS, + ) + docs = loader.load() + content = " ".join([doc.page_content for doc in docs]) + + log.debug(f"text_content: {content}") + save_docs_to_vector_db(request, docs, collection_name, overwrite=True) + + return { + "status": True, + "collection_name": collection_name, + "filename": form_data.url, + "file": { + "data": { + "content": content, + }, + "meta": { + "name": form_data.url, + }, + }, + } + except Exception as e: + log.exception(e) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT(e), + ) + + +def search_web(request: Request, engine: str, query: str) -> list[SearchResult]: + """Search the web using a search engine and return the results as a list of SearchResult objects. + Will look for a search engine API key in environment variables in the following order: + - SEARXNG_QUERY_URL + - GOOGLE_PSE_API_KEY + GOOGLE_PSE_ENGINE_ID + - BRAVE_SEARCH_API_KEY + - KAGI_SEARCH_API_KEY + - MOJEEK_SEARCH_API_KEY + - SERPSTACK_API_KEY + - SERPER_API_KEY + - SERPLY_API_KEY + - TAVILY_API_KEY + - SEARCHAPI_API_KEY + SEARCHAPI_ENGINE (by default `google`) + Args: + query (str): The query to search for + """ + + # TODO: add playwright to search the web + if engine == "searxng": + if request.app.state.config.SEARXNG_QUERY_URL: + return search_searxng( + request.app.state.config.SEARXNG_QUERY_URL, + query, + request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, + request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, + ) + else: + raise Exception("No SEARXNG_QUERY_URL found in environment variables") + elif engine == "google_pse": + if ( + request.app.state.config.GOOGLE_PSE_API_KEY + and request.app.state.config.GOOGLE_PSE_ENGINE_ID + ): + return search_google_pse( + request.app.state.config.GOOGLE_PSE_API_KEY, + request.app.state.config.GOOGLE_PSE_ENGINE_ID, + query, + request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, + request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, + ) + else: + raise Exception( + "No GOOGLE_PSE_API_KEY or GOOGLE_PSE_ENGINE_ID found in environment variables" + ) + elif engine == "brave": + if request.app.state.config.BRAVE_SEARCH_API_KEY: + return search_brave( + request.app.state.config.BRAVE_SEARCH_API_KEY, + query, + request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, + request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, + ) + else: + raise Exception("No BRAVE_SEARCH_API_KEY found in environment variables") + elif engine == "kagi": + if request.app.state.config.KAGI_SEARCH_API_KEY: + return search_kagi( + request.app.state.config.KAGI_SEARCH_API_KEY, + query, + request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, + request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, + ) + else: + raise Exception("No KAGI_SEARCH_API_KEY found in environment variables") + elif engine == "mojeek": + if request.app.state.config.MOJEEK_SEARCH_API_KEY: + return search_mojeek( + request.app.state.config.MOJEEK_SEARCH_API_KEY, + query, + request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, + request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, + ) + else: + raise Exception("No MOJEEK_SEARCH_API_KEY found in environment variables") + elif engine == "serpstack": + if request.app.state.config.SERPSTACK_API_KEY: + return search_serpstack( + request.app.state.config.SERPSTACK_API_KEY, + query, + request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, + request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, + https_enabled=request.app.state.config.SERPSTACK_HTTPS, + ) + else: + raise Exception("No SERPSTACK_API_KEY found in environment variables") + elif engine == "serper": + if request.app.state.config.SERPER_API_KEY: + return search_serper( + request.app.state.config.SERPER_API_KEY, + query, + request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, + request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, + ) + else: + raise Exception("No SERPER_API_KEY found in environment variables") + elif engine == "serply": + if request.app.state.config.SERPLY_API_KEY: + return search_serply( + request.app.state.config.SERPLY_API_KEY, + query, + request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, + request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, + ) + else: + raise Exception("No SERPLY_API_KEY found in environment variables") + elif engine == "duckduckgo": + return search_duckduckgo( + query, + request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, + request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, + ) + elif engine == "tavily": + if request.app.state.config.TAVILY_API_KEY: + return search_tavily( + request.app.state.config.TAVILY_API_KEY, + query, + request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, + ) + else: + raise Exception("No TAVILY_API_KEY found in environment variables") + elif engine == "searchapi": + if request.app.state.config.SEARCHAPI_API_KEY: + return search_searchapi( + request.app.state.config.SEARCHAPI_API_KEY, + request.app.state.config.SEARCHAPI_ENGINE, + query, + request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, + request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, + ) + else: + raise Exception("No SEARCHAPI_API_KEY found in environment variables") + elif engine == "jina": + return search_jina( + request.app.state.config.JINA_API_KEY, + query, + request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, + ) + elif engine == "bing": + return search_bing( + request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY, + request.app.state.config.BING_SEARCH_V7_ENDPOINT, + str(DEFAULT_LOCALE), + query, + request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, + request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, + ) + else: + raise Exception("No search engine API key found in environment variables") + + +@router.post("/process/web/search") +def process_web_search( + request: Request, form_data: SearchForm, user=Depends(get_verified_user) +): + try: + logging.info( + f"trying to web search with {request.app.state.config.RAG_WEB_SEARCH_ENGINE, form_data.query}" + ) + web_results = search_web( + request, request.app.state.config.RAG_WEB_SEARCH_ENGINE, form_data.query + ) + except Exception as e: + log.exception(e) + + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.WEB_SEARCH_ERROR(e), + ) + + try: + collection_name = form_data.collection_name + if collection_name == "": + collection_name = f"web-search-{calculate_sha256_string(form_data.query)}"[ + :63 + ] + + urls = [result.link for result in web_results] + loader = get_web_loader( + urls=urls, + verify_ssl=request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, + requests_per_second=request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS, + ) + docs = loader.aload() + + save_docs_to_vector_db(request, docs, collection_name, overwrite=True) + + return { + "status": True, + "collection_name": collection_name, + "filenames": urls, + } + except Exception as e: + log.exception(e) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT(e), + ) + + +class QueryDocForm(BaseModel): + collection_name: str + query: str + k: Optional[int] = None + r: Optional[float] = None + hybrid: Optional[bool] = None + + +@router.post("/query/doc") +def query_doc_handler( + request: Request, + form_data: QueryDocForm, + user=Depends(get_verified_user), +): + try: + if request.app.state.config.ENABLE_RAG_HYBRID_SEARCH: + return query_doc_with_hybrid_search( + collection_name=form_data.collection_name, + query=form_data.query, + embedding_function=request.app.state.EMBEDDING_FUNCTION, + k=form_data.k if form_data.k else request.app.state.config.TOP_K, + reranking_function=request.app.state.rf, + r=( + form_data.r + if form_data.r + else request.app.state.config.RELEVANCE_THRESHOLD + ), + ) + else: + return query_doc( + collection_name=form_data.collection_name, + query_embedding=request.app.state.EMBEDDING_FUNCTION(form_data.query), + k=form_data.k if form_data.k else request.app.state.config.TOP_K, + ) + except Exception as e: + log.exception(e) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT(e), + ) + + +class QueryCollectionsForm(BaseModel): + collection_names: list[str] + query: str + k: Optional[int] = None + r: Optional[float] = None + hybrid: Optional[bool] = None + + +@router.post("/query/collection") +def query_collection_handler( + request: Request, + form_data: QueryCollectionsForm, + user=Depends(get_verified_user), +): + try: + if request.app.state.config.ENABLE_RAG_HYBRID_SEARCH: + return query_collection_with_hybrid_search( + collection_names=form_data.collection_names, + queries=[form_data.query], + embedding_function=request.app.state.EMBEDDING_FUNCTION, + k=form_data.k if form_data.k else request.app.state.config.TOP_K, + reranking_function=request.app.state.rf, + r=( + form_data.r + if form_data.r + else request.app.state.config.RELEVANCE_THRESHOLD + ), + ) + else: + return query_collection( + collection_names=form_data.collection_names, + queries=[form_data.query], + embedding_function=request.app.state.EMBEDDING_FUNCTION, + k=form_data.k if form_data.k else request.app.state.config.TOP_K, + ) + + except Exception as e: + log.exception(e) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT(e), + ) + + +#################################### +# +# Vector DB operations +# +#################################### + + +class DeleteForm(BaseModel): + collection_name: str + file_id: str + + +@router.post("/delete") +def delete_entries_from_collection(form_data: DeleteForm, user=Depends(get_admin_user)): + try: + if VECTOR_DB_CLIENT.has_collection(collection_name=form_data.collection_name): + file = Files.get_file_by_id(form_data.file_id) + hash = file.hash + + VECTOR_DB_CLIENT.delete( + collection_name=form_data.collection_name, + metadata={"hash": hash}, + ) + return {"status": True} + else: + return {"status": False} + except Exception as e: + log.exception(e) + return {"status": False} + + +@router.post("/reset/db") +def reset_vector_db(user=Depends(get_admin_user)): + VECTOR_DB_CLIENT.reset() + Knowledges.delete_all_knowledge() + + +@router.post("/reset/uploads") +def reset_upload_dir(user=Depends(get_admin_user)) -> bool: + folder = f"{UPLOAD_DIR}" + try: + # Check if the directory exists + if os.path.exists(folder): + # Iterate over all the files and directories in the specified directory + for filename in os.listdir(folder): + file_path = os.path.join(folder, filename) + try: + if os.path.isfile(file_path) or os.path.islink(file_path): + os.unlink(file_path) # Remove the file or link + elif os.path.isdir(file_path): + shutil.rmtree(file_path) # Remove the directory + except Exception as e: + print(f"Failed to delete {file_path}. Reason: {e}") + else: + print(f"The directory {folder} does not exist") + except Exception as e: + print(f"Failed to process the directory {folder}. Reason: {e}") + return True + + +if ENV == "dev": + + @router.get("/ef/{text}") + async def get_embeddings(request: Request, text: Optional[str] = "Hello World!"): + return {"result": request.app.state.EMBEDDING_FUNCTION(text)} diff --git a/backend/open_webui/routers/tasks.py b/backend/open_webui/routers/tasks.py new file mode 100644 index 000000000..4990b4e08 --- /dev/null +++ b/backend/open_webui/routers/tasks.py @@ -0,0 +1,514 @@ +from fastapi import APIRouter, Depends, HTTPException, Response, status, Request +from fastapi.responses import JSONResponse, RedirectResponse + +from pydantic import BaseModel +from typing import Optional +import logging + +from open_webui.utils.chat import generate_chat_completion +from open_webui.utils.task import ( + title_generation_template, + query_generation_template, + autocomplete_generation_template, + tags_generation_template, + emoji_generation_template, + moa_response_generation_template, +) +from open_webui.utils.auth import get_admin_user, get_verified_user +from open_webui.constants import TASKS + +from open_webui.routers.pipelines import process_pipeline_inlet_filter +from open_webui.utils.task import get_task_model_id + +from open_webui.config import ( + DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE, + DEFAULT_TAGS_GENERATION_PROMPT_TEMPLATE, + DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE, + DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE, + DEFAULT_EMOJI_GENERATION_PROMPT_TEMPLATE, + DEFAULT_MOA_GENERATION_PROMPT_TEMPLATE, +) +from open_webui.env import SRC_LOG_LEVELS + + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["MODELS"]) + +router = APIRouter() + + +################################## +# +# Task Endpoints +# +################################## + + +@router.get("/config") +async def get_task_config(request: Request, user=Depends(get_verified_user)): + return { + "TASK_MODEL": request.app.state.config.TASK_MODEL, + "TASK_MODEL_EXTERNAL": request.app.state.config.TASK_MODEL_EXTERNAL, + "TITLE_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE, + "ENABLE_AUTOCOMPLETE_GENERATION": request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION, + "AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH, + "TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE, + "ENABLE_TAGS_GENERATION": request.app.state.config.ENABLE_TAGS_GENERATION, + "ENABLE_SEARCH_QUERY_GENERATION": request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION, + "ENABLE_RETRIEVAL_QUERY_GENERATION": request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION, + "QUERY_GENERATION_PROMPT_TEMPLATE": request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE, + "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, + } + + +class TaskConfigForm(BaseModel): + TASK_MODEL: Optional[str] + TASK_MODEL_EXTERNAL: Optional[str] + TITLE_GENERATION_PROMPT_TEMPLATE: str + ENABLE_AUTOCOMPLETE_GENERATION: bool + AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH: int + TAGS_GENERATION_PROMPT_TEMPLATE: str + ENABLE_TAGS_GENERATION: bool + ENABLE_SEARCH_QUERY_GENERATION: bool + ENABLE_RETRIEVAL_QUERY_GENERATION: bool + QUERY_GENERATION_PROMPT_TEMPLATE: str + TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE: str + + +@router.post("/config/update") +async def update_task_config( + request: Request, form_data: TaskConfigForm, user=Depends(get_admin_user) +): + request.app.state.config.TASK_MODEL = form_data.TASK_MODEL + request.app.state.config.TASK_MODEL_EXTERNAL = form_data.TASK_MODEL_EXTERNAL + request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = ( + form_data.TITLE_GENERATION_PROMPT_TEMPLATE + ) + + request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION = ( + form_data.ENABLE_AUTOCOMPLETE_GENERATION + ) + request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH = ( + form_data.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH + ) + + request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = ( + form_data.TAGS_GENERATION_PROMPT_TEMPLATE + ) + request.app.state.config.ENABLE_TAGS_GENERATION = form_data.ENABLE_TAGS_GENERATION + request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION = ( + form_data.ENABLE_SEARCH_QUERY_GENERATION + ) + request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION = ( + form_data.ENABLE_RETRIEVAL_QUERY_GENERATION + ) + + request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE = ( + form_data.QUERY_GENERATION_PROMPT_TEMPLATE + ) + request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = ( + form_data.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE + ) + + return { + "TASK_MODEL": request.app.state.config.TASK_MODEL, + "TASK_MODEL_EXTERNAL": request.app.state.config.TASK_MODEL_EXTERNAL, + "TITLE_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE, + "ENABLE_AUTOCOMPLETE_GENERATION": request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION, + "AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH, + "TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE, + "ENABLE_TAGS_GENERATION": request.app.state.config.ENABLE_TAGS_GENERATION, + "ENABLE_SEARCH_QUERY_GENERATION": request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION, + "ENABLE_RETRIEVAL_QUERY_GENERATION": request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION, + "QUERY_GENERATION_PROMPT_TEMPLATE": request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE, + "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, + } + + +@router.post("/title/completions") +async def generate_title( + request: Request, form_data: dict, user=Depends(get_verified_user) +): + models = request.app.state.MODELS + + model_id = form_data["model"] + if model_id not in models: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Model not found", + ) + + # Check if the user has a custom task model + # If the user has a custom task model, use that model + task_model_id = get_task_model_id( + model_id, + request.app.state.config.TASK_MODEL, + request.app.state.config.TASK_MODEL_EXTERNAL, + models, + ) + + log.debug( + f"generating chat title using model {task_model_id} for user {user.email} " + ) + + if request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE != "": + template = request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE + else: + template = DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE + + content = title_generation_template( + template, + form_data["messages"], + { + "name": user.name, + "location": user.info.get("location") if user.info else None, + }, + ) + + payload = { + "model": task_model_id, + "messages": [{"role": "user", "content": content}], + "stream": False, + **( + {"max_tokens": 50} + if models[task_model_id]["owned_by"] == "ollama" + else { + "max_completion_tokens": 50, + } + ), + "metadata": { + "task": str(TASKS.TITLE_GENERATION), + "task_body": form_data, + "chat_id": form_data.get("chat_id", None), + }, + } + + try: + return await generate_chat_completion(request, form_data=payload, user=user) + except Exception as e: + log.error("Exception occurred", exc_info=True) + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"detail": "An internal error has occurred."}, + ) + + +@router.post("/tags/completions") +async def generate_chat_tags( + request: Request, form_data: dict, user=Depends(get_verified_user) +): + + if not request.app.state.config.ENABLE_TAGS_GENERATION: + return JSONResponse( + status_code=status.HTTP_200_OK, + content={"detail": "Tags generation is disabled"}, + ) + + models = request.app.state.MODELS + + model_id = form_data["model"] + if model_id not in models: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Model not found", + ) + + # Check if the user has a custom task model + # If the user has a custom task model, use that model + task_model_id = get_task_model_id( + model_id, + request.app.state.config.TASK_MODEL, + request.app.state.config.TASK_MODEL_EXTERNAL, + models, + ) + + log.debug( + f"generating chat tags using model {task_model_id} for user {user.email} " + ) + + if request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE != "": + template = request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE + else: + template = DEFAULT_TAGS_GENERATION_PROMPT_TEMPLATE + + content = tags_generation_template( + template, form_data["messages"], {"name": user.name} + ) + + payload = { + "model": task_model_id, + "messages": [{"role": "user", "content": content}], + "stream": False, + "metadata": { + "task": str(TASKS.TAGS_GENERATION), + "task_body": form_data, + "chat_id": form_data.get("chat_id", None), + }, + } + + try: + return await generate_chat_completion(request, form_data=payload, user=user) + except Exception as e: + log.error(f"Error generating chat completion: {e}") + return JSONResponse( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + content={"detail": "An internal error has occurred."}, + ) + + +@router.post("/queries/completions") +async def generate_queries( + request: Request, form_data: dict, user=Depends(get_verified_user) +): + + type = form_data.get("type") + if type == "web_search": + if not request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Search query generation is disabled", + ) + elif type == "retrieval": + if not request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Query generation is disabled", + ) + + models = request.app.state.MODELS + + model_id = form_data["model"] + if model_id not in models: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Model not found", + ) + + # Check if the user has a custom task model + # If the user has a custom task model, use that model + task_model_id = get_task_model_id( + model_id, + request.app.state.config.TASK_MODEL, + request.app.state.config.TASK_MODEL_EXTERNAL, + models, + ) + + log.debug( + f"generating {type} queries using model {task_model_id} for user {user.email}" + ) + + if (request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE).strip() != "": + template = request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE + else: + template = DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE + + content = query_generation_template( + template, form_data["messages"], {"name": user.name} + ) + + payload = { + "model": task_model_id, + "messages": [{"role": "user", "content": content}], + "stream": False, + "metadata": { + "task": str(TASKS.QUERY_GENERATION), + "task_body": form_data, + "chat_id": form_data.get("chat_id", None), + }, + } + + try: + return await generate_chat_completion(request, form_data=payload, user=user) + except Exception as e: + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"detail": str(e)}, + ) + + +@router.post("/auto/completions") +async def generate_autocompletion( + request: Request, form_data: dict, user=Depends(get_verified_user) +): + if not request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Autocompletion generation is disabled", + ) + + type = form_data.get("type") + prompt = form_data.get("prompt") + messages = form_data.get("messages") + + if request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH > 0: + if ( + len(prompt) + > request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH + ): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Input prompt exceeds maximum length of {request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH}", + ) + + models = request.app.state.MODELS + + model_id = form_data["model"] + if model_id not in models: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Model not found", + ) + + # Check if the user has a custom task model + # If the user has a custom task model, use that model + task_model_id = get_task_model_id( + model_id, + request.app.state.config.TASK_MODEL, + request.app.state.config.TASK_MODEL_EXTERNAL, + models, + ) + + log.debug( + f"generating autocompletion using model {task_model_id} for user {user.email}" + ) + + if (request.app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE).strip() != "": + template = request.app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE + else: + template = DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE + + content = autocomplete_generation_template( + template, prompt, messages, type, {"name": user.name} + ) + + payload = { + "model": task_model_id, + "messages": [{"role": "user", "content": content}], + "stream": False, + "metadata": { + "task": str(TASKS.AUTOCOMPLETE_GENERATION), + "task_body": form_data, + "chat_id": form_data.get("chat_id", None), + }, + } + + try: + return await generate_chat_completion(request, form_data=payload, user=user) + except Exception as e: + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"detail": str(e)}, + ) + + +@router.post("/emoji/completions") +async def generate_emoji( + request: Request, form_data: dict, user=Depends(get_verified_user) +): + + models = request.app.state.MODELS + + model_id = form_data["model"] + if model_id not in models: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Model not found", + ) + + # Check if the user has a custom task model + # If the user has a custom task model, use that model + task_model_id = get_task_model_id( + model_id, + request.app.state.config.TASK_MODEL, + request.app.state.config.TASK_MODEL_EXTERNAL, + models, + ) + + log.debug(f"generating emoji using model {task_model_id} for user {user.email} ") + + template = DEFAULT_EMOJI_GENERATION_PROMPT_TEMPLATE + + content = emoji_generation_template( + template, + form_data["prompt"], + { + "name": user.name, + "location": user.info.get("location") if user.info else None, + }, + ) + + payload = { + "model": task_model_id, + "messages": [{"role": "user", "content": content}], + "stream": False, + **( + {"max_tokens": 4} + if models[task_model_id]["owned_by"] == "ollama" + else { + "max_completion_tokens": 4, + } + ), + "chat_id": form_data.get("chat_id", None), + "metadata": {"task": str(TASKS.EMOJI_GENERATION), "task_body": form_data}, + } + + try: + return await generate_chat_completion(request, form_data=payload, user=user) + except Exception as e: + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"detail": str(e)}, + ) + + +@router.post("/moa/completions") +async def generate_moa_response( + request: Request, form_data: dict, user=Depends(get_verified_user) +): + + models = request.app.state.MODELS + model_id = form_data["model"] + + if model_id not in models: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Model not found", + ) + + # Check if the user has a custom task model + # If the user has a custom task model, use that model + task_model_id = get_task_model_id( + model_id, + request.app.state.config.TASK_MODEL, + request.app.state.config.TASK_MODEL_EXTERNAL, + models, + ) + + log.debug(f"generating MOA model {task_model_id} for user {user.email} ") + + template = DEFAULT_MOA_GENERATION_PROMPT_TEMPLATE + + content = moa_response_generation_template( + template, + form_data["prompt"], + form_data["responses"], + ) + + payload = { + "model": task_model_id, + "messages": [{"role": "user", "content": content}], + "stream": form_data.get("stream", False), + "chat_id": form_data.get("chat_id", None), + "metadata": { + "task": str(TASKS.MOA_RESPONSE_GENERATION), + "task_body": form_data, + }, + } + + try: + return await generate_chat_completion(request, form_data=payload, user=user) + except Exception as e: + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"detail": str(e)}, + ) diff --git a/backend/open_webui/apps/webui/routers/tools.py b/backend/open_webui/routers/tools.py similarity index 54% rename from backend/open_webui/apps/webui/routers/tools.py rename to backend/open_webui/routers/tools.py index d1ad89dea..9e95ebe5a 100644 --- a/backend/open_webui/apps/webui/routers/tools.py +++ b/backend/open_webui/routers/tools.py @@ -1,51 +1,82 @@ -import os from pathlib import Path from typing import Optional -from open_webui.apps.webui.models.tools import ToolForm, ToolModel, ToolResponse, Tools -from open_webui.apps.webui.utils import load_toolkit_module_by_id, replace_imports -from open_webui.config import CACHE_DIR, DATA_DIR +from open_webui.models.tools import ( + ToolForm, + ToolModel, + ToolResponse, + ToolUserResponse, + Tools, +) +from open_webui.utils.plugin import load_tools_module_by_id, replace_imports +from open_webui.config import CACHE_DIR from open_webui.constants import ERROR_MESSAGES from fastapi import APIRouter, Depends, HTTPException, Request, status from open_webui.utils.tools import get_tools_specs -from open_webui.utils.utils import get_admin_user, get_verified_user +from open_webui.utils.auth import get_admin_user, get_verified_user +from open_webui.utils.access_control import has_access, has_permission router = APIRouter() ############################ -# GetToolkits +# GetTools ############################ -@router.get("/", response_model=list[ToolResponse]) -async def get_toolkits(user=Depends(get_verified_user)): - toolkits = [toolkit for toolkit in Tools.get_tools()] - return toolkits +@router.get("/", response_model=list[ToolUserResponse]) +async def get_tools(user=Depends(get_verified_user)): + if user.role == "admin": + tools = Tools.get_tools() + else: + tools = Tools.get_tools_by_user_id(user.id, "read") + return tools ############################ -# ExportToolKits +# GetToolList +############################ + + +@router.get("/list", response_model=list[ToolUserResponse]) +async def get_tool_list(user=Depends(get_verified_user)): + if user.role == "admin": + tools = Tools.get_tools() + else: + tools = Tools.get_tools_by_user_id(user.id, "write") + return tools + + +############################ +# ExportTools ############################ @router.get("/export", response_model=list[ToolModel]) -async def get_toolkits(user=Depends(get_admin_user)): - toolkits = [toolkit for toolkit in Tools.get_tools()] - return toolkits +async def export_tools(user=Depends(get_admin_user)): + tools = Tools.get_tools() + return tools ############################ -# CreateNewToolKit +# CreateNewTools ############################ @router.post("/create", response_model=Optional[ToolResponse]) -async def create_new_toolkit( +async def create_new_tools( request: Request, form_data: ToolForm, - user=Depends(get_admin_user), + user=Depends(get_verified_user), ): + if user.role != "admin" and not has_permission( + user.id, "workspace.knowledge", request.app.state.config.USER_PERMISSIONS + ): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.UNAUTHORIZED, + ) + if not form_data.id.isidentifier(): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -54,30 +85,30 @@ async def create_new_toolkit( form_data.id = form_data.id.lower() - toolkit = Tools.get_tool_by_id(form_data.id) - if toolkit is None: + tools = Tools.get_tool_by_id(form_data.id) + if tools is None: try: form_data.content = replace_imports(form_data.content) - toolkit_module, frontmatter = load_toolkit_module_by_id( + tools_module, frontmatter = load_tools_module_by_id( form_data.id, content=form_data.content ) form_data.meta.manifest = frontmatter TOOLS = request.app.state.TOOLS - TOOLS[form_data.id] = toolkit_module + TOOLS[form_data.id] = tools_module specs = get_tools_specs(TOOLS[form_data.id]) - toolkit = Tools.insert_new_tool(user.id, form_data, specs) + tools = Tools.insert_new_tool(user.id, form_data, specs) tool_cache_dir = Path(CACHE_DIR) / "tools" / form_data.id tool_cache_dir.mkdir(parents=True, exist_ok=True) - if toolkit: - return toolkit + if tools: + return tools else: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.DEFAULT("Error creating toolkit"), + detail=ERROR_MESSAGES.DEFAULT("Error creating tools"), ) except Exception as e: print(e) @@ -93,16 +124,21 @@ async def create_new_toolkit( ############################ -# GetToolkitById +# GetToolsById ############################ @router.get("/id/{id}", response_model=Optional[ToolModel]) -async def get_toolkit_by_id(id: str, user=Depends(get_admin_user)): - toolkit = Tools.get_tool_by_id(id) +async def get_tools_by_id(id: str, user=Depends(get_verified_user)): + tools = Tools.get_tool_by_id(id) - if toolkit: - return toolkit + if tools: + if ( + user.role == "admin" + or tools.user_id == user.id + or has_access(user.id, "read", tools.access_control) + ): + return tools else: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -111,26 +147,39 @@ async def get_toolkit_by_id(id: str, user=Depends(get_admin_user)): ############################ -# UpdateToolkitById +# UpdateToolsById ############################ @router.post("/id/{id}/update", response_model=Optional[ToolModel]) -async def update_toolkit_by_id( +async def update_tools_by_id( request: Request, id: str, form_data: ToolForm, - user=Depends(get_admin_user), + user=Depends(get_verified_user), ): + tools = Tools.get_tool_by_id(id) + if not tools: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + if tools.user_id != user.id and user.role != "admin": + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.UNAUTHORIZED, + ) + try: form_data.content = replace_imports(form_data.content) - toolkit_module, frontmatter = load_toolkit_module_by_id( + tools_module, frontmatter = load_tools_module_by_id( id, content=form_data.content ) form_data.meta.manifest = frontmatter TOOLS = request.app.state.TOOLS - TOOLS[id] = toolkit_module + TOOLS[id] = tools_module specs = get_tools_specs(TOOLS[id]) @@ -140,14 +189,14 @@ async def update_toolkit_by_id( } print(updated) - toolkit = Tools.update_tool_by_id(id, updated) + tools = Tools.update_tool_by_id(id, updated) - if toolkit: - return toolkit + if tools: + return tools else: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.DEFAULT("Error updating toolkit"), + detail=ERROR_MESSAGES.DEFAULT("Error updating tools"), ) except Exception as e: @@ -158,14 +207,28 @@ async def update_toolkit_by_id( ############################ -# DeleteToolkitById +# DeleteToolsById ############################ @router.delete("/id/{id}/delete", response_model=bool) -async def delete_toolkit_by_id(request: Request, id: str, user=Depends(get_admin_user)): - result = Tools.delete_tool_by_id(id) +async def delete_tools_by_id( + request: Request, id: str, user=Depends(get_verified_user) +): + tools = Tools.get_tool_by_id(id) + if not tools: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + if tools.user_id != user.id and user.role != "admin": + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.UNAUTHORIZED, + ) + + result = Tools.delete_tool_by_id(id) if result: TOOLS = request.app.state.TOOLS if id in TOOLS: @@ -180,9 +243,9 @@ async def delete_toolkit_by_id(request: Request, id: str, user=Depends(get_admin @router.get("/id/{id}/valves", response_model=Optional[dict]) -async def get_toolkit_valves_by_id(id: str, user=Depends(get_admin_user)): - toolkit = Tools.get_tool_by_id(id) - if toolkit: +async def get_tools_valves_by_id(id: str, user=Depends(get_verified_user)): + tools = Tools.get_tool_by_id(id) + if tools: try: valves = Tools.get_tool_valves_by_id(id) return valves @@ -204,19 +267,19 @@ async def get_toolkit_valves_by_id(id: str, user=Depends(get_admin_user)): @router.get("/id/{id}/valves/spec", response_model=Optional[dict]) -async def get_toolkit_valves_spec_by_id( - request: Request, id: str, user=Depends(get_admin_user) +async def get_tools_valves_spec_by_id( + request: Request, id: str, user=Depends(get_verified_user) ): - toolkit = Tools.get_tool_by_id(id) - if toolkit: + tools = Tools.get_tool_by_id(id) + if tools: if id in request.app.state.TOOLS: - toolkit_module = request.app.state.TOOLS[id] + tools_module = request.app.state.TOOLS[id] else: - toolkit_module, _ = load_toolkit_module_by_id(id) - request.app.state.TOOLS[id] = toolkit_module + tools_module, _ = load_tools_module_by_id(id) + request.app.state.TOOLS[id] = tools_module - if hasattr(toolkit_module, "Valves"): - Valves = toolkit_module.Valves + if hasattr(tools_module, "Valves"): + Valves = tools_module.Valves return Valves.schema() return None else: @@ -232,42 +295,39 @@ async def get_toolkit_valves_spec_by_id( @router.post("/id/{id}/valves/update", response_model=Optional[dict]) -async def update_toolkit_valves_by_id( - request: Request, id: str, form_data: dict, user=Depends(get_admin_user) +async def update_tools_valves_by_id( + request: Request, id: str, form_data: dict, user=Depends(get_verified_user) ): - toolkit = Tools.get_tool_by_id(id) - if toolkit: - if id in request.app.state.TOOLS: - toolkit_module = request.app.state.TOOLS[id] - else: - toolkit_module, _ = load_toolkit_module_by_id(id) - request.app.state.TOOLS[id] = toolkit_module - - if hasattr(toolkit_module, "Valves"): - Valves = toolkit_module.Valves - - try: - form_data = {k: v for k, v in form_data.items() if v is not None} - valves = Valves(**form_data) - Tools.update_tool_valves_by_id(id, valves.model_dump()) - return valves.model_dump() - except Exception as e: - print(e) - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.DEFAULT(str(e)), - ) - else: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail=ERROR_MESSAGES.NOT_FOUND, - ) - - else: + tools = Tools.get_tool_by_id(id) + if not tools: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND, ) + if id in request.app.state.TOOLS: + tools_module = request.app.state.TOOLS[id] + else: + tools_module, _ = load_tools_module_by_id(id) + request.app.state.TOOLS[id] = tools_module + + if not hasattr(tools_module, "Valves"): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + Valves = tools_module.Valves + + try: + form_data = {k: v for k, v in form_data.items() if v is not None} + valves = Valves(**form_data) + Tools.update_tool_valves_by_id(id, valves.model_dump()) + return valves.model_dump() + except Exception as e: + print(e) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT(str(e)), + ) ############################ @@ -276,9 +336,9 @@ async def update_toolkit_valves_by_id( @router.get("/id/{id}/valves/user", response_model=Optional[dict]) -async def get_toolkit_user_valves_by_id(id: str, user=Depends(get_verified_user)): - toolkit = Tools.get_tool_by_id(id) - if toolkit: +async def get_tools_user_valves_by_id(id: str, user=Depends(get_verified_user)): + tools = Tools.get_tool_by_id(id) + if tools: try: user_valves = Tools.get_user_valves_by_id_and_user_id(id, user.id) return user_valves @@ -295,19 +355,19 @@ async def get_toolkit_user_valves_by_id(id: str, user=Depends(get_verified_user) @router.get("/id/{id}/valves/user/spec", response_model=Optional[dict]) -async def get_toolkit_user_valves_spec_by_id( +async def get_tools_user_valves_spec_by_id( request: Request, id: str, user=Depends(get_verified_user) ): - toolkit = Tools.get_tool_by_id(id) - if toolkit: + tools = Tools.get_tool_by_id(id) + if tools: if id in request.app.state.TOOLS: - toolkit_module = request.app.state.TOOLS[id] + tools_module = request.app.state.TOOLS[id] else: - toolkit_module, _ = load_toolkit_module_by_id(id) - request.app.state.TOOLS[id] = toolkit_module + tools_module, _ = load_tools_module_by_id(id) + request.app.state.TOOLS[id] = tools_module - if hasattr(toolkit_module, "UserValves"): - UserValves = toolkit_module.UserValves + if hasattr(tools_module, "UserValves"): + UserValves = tools_module.UserValves return UserValves.schema() return None else: @@ -318,20 +378,20 @@ async def get_toolkit_user_valves_spec_by_id( @router.post("/id/{id}/valves/user/update", response_model=Optional[dict]) -async def update_toolkit_user_valves_by_id( +async def update_tools_user_valves_by_id( request: Request, id: str, form_data: dict, user=Depends(get_verified_user) ): - toolkit = Tools.get_tool_by_id(id) + tools = Tools.get_tool_by_id(id) - if toolkit: + if tools: if id in request.app.state.TOOLS: - toolkit_module = request.app.state.TOOLS[id] + tools_module = request.app.state.TOOLS[id] else: - toolkit_module, _ = load_toolkit_module_by_id(id) - request.app.state.TOOLS[id] = toolkit_module + tools_module, _ = load_tools_module_by_id(id) + request.app.state.TOOLS[id] = tools_module - if hasattr(toolkit_module, "UserValves"): - UserValves = toolkit_module.UserValves + if hasattr(tools_module, "UserValves"): + UserValves = tools_module.UserValves try: form_data = {k: v for k, v in form_data.items() if v is not None} diff --git a/backend/open_webui/apps/webui/routers/users.py b/backend/open_webui/routers/users.py similarity index 85% rename from backend/open_webui/apps/webui/routers/users.py rename to backend/open_webui/routers/users.py index abc540efa..1206d56f2 100644 --- a/backend/open_webui/apps/webui/routers/users.py +++ b/backend/open_webui/routers/users.py @@ -1,9 +1,9 @@ import logging from typing import Optional -from open_webui.apps.webui.models.auths import Auths -from open_webui.apps.webui.models.chats import Chats -from open_webui.apps.webui.models.users import ( +from open_webui.models.auths import Auths +from open_webui.models.chats import Chats +from open_webui.models.users import ( UserModel, UserRoleUpdateForm, Users, @@ -14,7 +14,7 @@ from open_webui.constants import ERROR_MESSAGES from open_webui.env import SRC_LOG_LEVELS from fastapi import APIRouter, Depends, HTTPException, Request, status from pydantic import BaseModel -from open_webui.utils.utils import get_admin_user, get_password_hash, get_verified_user +from open_webui.utils.auth import get_admin_user, get_password_hash, get_verified_user log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MODELS"]) @@ -31,21 +31,58 @@ async def get_users(skip: int = 0, limit: int = 50, user=Depends(get_admin_user) return Users.get_users(skip, limit) +############################ +# User Groups +############################ + + +@router.get("/groups") +async def get_user_groups(user=Depends(get_verified_user)): + return Users.get_user_groups(user.id) + + ############################ # User Permissions ############################ -@router.get("/permissions/user") +@router.get("/permissions") +async def get_user_permissisions(user=Depends(get_verified_user)): + return Users.get_user_groups(user.id) + + +############################ +# User Default Permissions +############################ +class WorkspacePermissions(BaseModel): + models: bool + knowledge: bool + prompts: bool + tools: bool + + +class ChatPermissions(BaseModel): + file_upload: bool + delete: bool + edit: bool + temporary: bool + + +class UserPermissions(BaseModel): + workspace: WorkspacePermissions + chat: ChatPermissions + + +@router.get("/default/permissions") async def get_user_permissions(request: Request, user=Depends(get_admin_user)): return request.app.state.config.USER_PERMISSIONS -@router.post("/permissions/user") +@router.post("/default/permissions") async def update_user_permissions( - request: Request, form_data: dict, user=Depends(get_admin_user) + request: Request, form_data: UserPermissions, user=Depends(get_admin_user) ): - request.app.state.config.USER_PERMISSIONS = form_data + request.app.state.config.USER_PERMISSIONS = form_data.model_dump() return request.app.state.config.USER_PERMISSIONS diff --git a/backend/open_webui/apps/webui/routers/utils.py b/backend/open_webui/routers/utils.py similarity index 93% rename from backend/open_webui/apps/webui/routers/utils.py rename to backend/open_webui/routers/utils.py index 0ab0f6b15..ea73e9759 100644 --- a/backend/open_webui/apps/webui/routers/utils.py +++ b/backend/open_webui/routers/utils.py @@ -1,7 +1,7 @@ import black import markdown -from open_webui.apps.webui.models.chats import ChatTitleMessagesForm +from open_webui.models.chats import ChatTitleMessagesForm from open_webui.config import DATA_DIR, ENABLE_ADMIN_EXPORT from open_webui.constants import ERROR_MESSAGES from fastapi import APIRouter, Depends, HTTPException, Response, status @@ -9,7 +9,7 @@ from pydantic import BaseModel from starlette.responses import FileResponse from open_webui.utils.misc import get_gravatar_url from open_webui.utils.pdf_generator import PDFGenerator -from open_webui.utils.utils import get_admin_user +from open_webui.utils.auth import get_admin_user router = APIRouter() @@ -76,7 +76,7 @@ async def download_db(user=Depends(get_admin_user)): status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED, ) - from open_webui.apps.webui.internal.db import engine + from open_webui.internal.db import engine if engine.name != "sqlite": raise HTTPException( diff --git a/backend/open_webui/apps/socket/main.py b/backend/open_webui/socket/main.py similarity index 96% rename from backend/open_webui/apps/socket/main.py rename to backend/open_webui/socket/main.py index fca268a6b..8343be666 100644 --- a/backend/open_webui/apps/socket/main.py +++ b/backend/open_webui/socket/main.py @@ -4,14 +4,14 @@ import logging import sys import time -from open_webui.apps.webui.models.users import Users +from open_webui.models.users import Users from open_webui.env import ( ENABLE_WEBSOCKET_SUPPORT, WEBSOCKET_MANAGER, WEBSOCKET_REDIS_URL, ) -from open_webui.utils.utils import decode_token -from open_webui.apps.socket.utils import RedisDict +from open_webui.utils.auth import decode_token +from open_webui.socket.utils import RedisDict from open_webui.env import ( GLOBAL_LOG_LEVEL, @@ -171,6 +171,11 @@ async def user_count(sid): await sio.emit("user-count", {"count": len(USER_POOL.items())}) +@sio.on("chat") +async def chat(sid, data): + print("chat", sid, SESSION_POOL[sid], data) + + @sio.event async def disconnect(sid): if sid in SESSION_POOL: diff --git a/backend/open_webui/apps/socket/utils.py b/backend/open_webui/socket/utils.py similarity index 100% rename from backend/open_webui/apps/socket/utils.py rename to backend/open_webui/socket/utils.py diff --git a/backend/open_webui/static/assets/pdf-style.css b/backend/open_webui/static/assets/pdf-style.css index db9ac83dd..85c36271c 100644 --- a/backend/open_webui/static/assets/pdf-style.css +++ b/backend/open_webui/static/assets/pdf-style.css @@ -26,7 +26,7 @@ html { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', 'NotoSans', 'NotoSansJP', 'NotoSansKR', - 'NotoSansSC', 'STSong-Light', 'MSung-Light', 'HeiseiMin-W3', 'HYSMyeongJo-Medium', Roboto, + 'NotoSansSC', 'Twemoji', 'STSong-Light', 'MSung-Light', 'HeiseiMin-W3', 'HYSMyeongJo-Medium', Roboto, 'Helvetica Neue', Arial, sans-serif; font-size: 14px; /* Default font size */ line-height: 1.5; @@ -40,7 +40,7 @@ html { body { margin: 0; - color: #212529; + padding: 0; background-color: #fff; width: auto; } diff --git a/backend/open_webui/static/fonts/Twemoji.ttf b/backend/open_webui/static/fonts/Twemoji.ttf new file mode 100644 index 000000000..281d356d9 Binary files /dev/null and b/backend/open_webui/static/fonts/Twemoji.ttf differ diff --git a/backend/open_webui/storage/provider.py b/backend/open_webui/storage/provider.py index a3168ead8..76e4fc48f 100644 --- a/backend/open_webui/storage/provider.py +++ b/backend/open_webui/storage/provider.py @@ -51,7 +51,10 @@ class StorageProvider: try: self.s3_client.upload_file(file_path, self.bucket_name, filename) - return open(file_path, "rb").read(), file_path + return ( + open(file_path, "rb").read(), + "s3://" + self.bucket_name + "/" + filename, + ) except ClientError as e: raise RuntimeError(f"Error uploading file to S3: {e}") diff --git a/backend/open_webui/test/apps/webui/routers/test_auths.py b/backend/open_webui/test/apps/webui/routers/test_auths.py index bc14fb8dd..f0f69e26d 100644 --- a/backend/open_webui/test/apps/webui/routers/test_auths.py +++ b/backend/open_webui/test/apps/webui/routers/test_auths.py @@ -7,8 +7,8 @@ class TestAuths(AbstractPostgresTest): def setup_class(cls): super().setup_class() - from open_webui.apps.webui.models.auths import Auths - from open_webui.apps.webui.models.users import Users + from open_webui.models.auths import Auths + from open_webui.models.users import Users cls.users = Users cls.auths = Auths @@ -26,7 +26,7 @@ class TestAuths(AbstractPostgresTest): } def test_update_profile(self): - from open_webui.utils.utils import get_password_hash + from open_webui.utils.auth import get_password_hash user = self.auths.insert_new_auth( email="john.doe@openwebui.com", @@ -47,7 +47,7 @@ class TestAuths(AbstractPostgresTest): assert db_user.profile_image_url == "/user2.png" def test_update_password(self): - from open_webui.utils.utils import get_password_hash + from open_webui.utils.auth import get_password_hash user = self.auths.insert_new_auth( email="john.doe@openwebui.com", @@ -74,7 +74,7 @@ class TestAuths(AbstractPostgresTest): assert new_auth is not None def test_signin(self): - from open_webui.utils.utils import get_password_hash + from open_webui.utils.auth import get_password_hash user = self.auths.insert_new_auth( email="john.doe@openwebui.com", diff --git a/backend/open_webui/test/apps/webui/routers/test_chats.py b/backend/open_webui/test/apps/webui/routers/test_chats.py index 935316fd8..a36a01fb1 100644 --- a/backend/open_webui/test/apps/webui/routers/test_chats.py +++ b/backend/open_webui/test/apps/webui/routers/test_chats.py @@ -12,7 +12,7 @@ class TestChats(AbstractPostgresTest): def setup_method(self): super().setup_method() - from open_webui.apps.webui.models.chats import ChatForm, Chats + from open_webui.models.chats import ChatForm, Chats self.chats = Chats self.chats.insert_new_chat( @@ -88,7 +88,7 @@ class TestChats(AbstractPostgresTest): def test_get_user_archived_chats(self): self.chats.archive_all_chats_by_user_id("2") - from open_webui.apps.webui.internal.db import Session + from open_webui.internal.db import Session Session.commit() with mock_webui_user(id="2"): diff --git a/backend/open_webui/test/apps/webui/routers/test_models.py b/backend/open_webui/test/apps/webui/routers/test_models.py index 1d52658b8..c16ca9d07 100644 --- a/backend/open_webui/test/apps/webui/routers/test_models.py +++ b/backend/open_webui/test/apps/webui/routers/test_models.py @@ -7,7 +7,7 @@ class TestModels(AbstractPostgresTest): def setup_class(cls): super().setup_class() - from open_webui.apps.webui.models.models import Model + from open_webui.models.models import Model cls.models = Model diff --git a/backend/open_webui/test/apps/webui/routers/test_users.py b/backend/open_webui/test/apps/webui/routers/test_users.py index 6facf7055..1a58ab147 100644 --- a/backend/open_webui/test/apps/webui/routers/test_users.py +++ b/backend/open_webui/test/apps/webui/routers/test_users.py @@ -25,7 +25,7 @@ class TestUsers(AbstractPostgresTest): def setup_class(cls): super().setup_class() - from open_webui.apps.webui.models.users import Users + from open_webui.models.users import Users cls.users = Users diff --git a/backend/open_webui/test/util/abstract_integration_test.py b/backend/open_webui/test/util/abstract_integration_test.py index 2814731e0..e8492befb 100644 --- a/backend/open_webui/test/util/abstract_integration_test.py +++ b/backend/open_webui/test/util/abstract_integration_test.py @@ -115,7 +115,7 @@ class AbstractPostgresTest(AbstractIntegrationTest): pytest.fail(f"Could not setup test environment: {ex}") def _check_db_connection(self): - from open_webui.apps.webui.internal.db import Session + from open_webui.internal.db import Session retries = 10 while retries > 0: @@ -139,7 +139,7 @@ class AbstractPostgresTest(AbstractIntegrationTest): cls.docker_client.containers.get(cls.DOCKER_CONTAINER_NAME).remove(force=True) def teardown_method(self): - from open_webui.apps.webui.internal.db import Session + from open_webui.internal.db import Session # rollback everything not yet committed Session.commit() diff --git a/backend/open_webui/test/util/mock_user.py b/backend/open_webui/test/util/mock_user.py index 96456a2c8..7ce64dffa 100644 --- a/backend/open_webui/test/util/mock_user.py +++ b/backend/open_webui/test/util/mock_user.py @@ -5,7 +5,7 @@ from fastapi import FastAPI @contextmanager def mock_webui_user(**kwargs): - from open_webui.apps.webui.main import app + from open_webui.routers.webui import app with mock_user(app, **kwargs): yield @@ -13,13 +13,13 @@ def mock_webui_user(**kwargs): @contextmanager def mock_user(app: FastAPI, **kwargs): - from open_webui.utils.utils import ( + from open_webui.utils.auth import ( get_current_user, get_verified_user, get_admin_user, get_current_user_by_api_key, ) - from open_webui.apps.webui.models.users import User + from open_webui.models.users import User def create_user(): user_parameters = { diff --git a/backend/open_webui/utils/access_control.py b/backend/open_webui/utils/access_control.py new file mode 100644 index 000000000..3b3e75a8b --- /dev/null +++ b/backend/open_webui/utils/access_control.py @@ -0,0 +1,95 @@ +from typing import Optional, Union, List, Dict, Any +from open_webui.models.groups import Groups +import json + + +def get_permissions( + user_id: str, + default_permissions: Dict[str, Any], +) -> Dict[str, Any]: + """ + Get all permissions for a user by combining the permissions of all groups the user is a member of. + If a permission is defined in multiple groups, the most permissive value is used (True > False). + Permissions are nested in a dict with the permission key as the key and a boolean as the value. + """ + + def combine_permissions( + permissions: Dict[str, Any], group_permissions: Dict[str, Any] + ) -> Dict[str, Any]: + """Combine permissions from multiple groups by taking the most permissive value.""" + for key, value in group_permissions.items(): + if isinstance(value, dict): + if key not in permissions: + permissions[key] = {} + permissions[key] = combine_permissions(permissions[key], value) + else: + if key not in permissions: + permissions[key] = value + else: + permissions[key] = permissions[key] or value + return permissions + + user_groups = Groups.get_groups_by_member_id(user_id) + + # deep copy default permissions to avoid modifying the original dict + permissions = json.loads(json.dumps(default_permissions)) + + for group in user_groups: + group_permissions = group.permissions + permissions = combine_permissions(permissions, group_permissions) + + return permissions + + +def has_permission( + user_id: str, + permission_key: str, + default_permissions: Dict[str, bool] = {}, +) -> bool: + """ + Check if a user has a specific permission by checking the group permissions + and falls back to default permissions if not found in any group. + + Permission keys can be hierarchical and separated by dots ('.'). + """ + + def get_permission(permissions: Dict[str, bool], keys: List[str]) -> bool: + """Traverse permissions dict using a list of keys (from dot-split permission_key).""" + for key in keys: + if key not in permissions: + return False # If any part of the hierarchy is missing, deny access + permissions = permissions[key] # Go one level deeper + + return bool(permissions) # Return the boolean at the final level + + permission_hierarchy = permission_key.split(".") + + # Retrieve user group permissions + user_groups = Groups.get_groups_by_member_id(user_id) + + for group in user_groups: + group_permissions = group.permissions + if get_permission(group_permissions, permission_hierarchy): + return True + + # Check default permissions afterwards if the group permissions don't allow it + return get_permission(default_permissions, permission_hierarchy) + + +def has_access( + user_id: str, + type: str = "write", + access_control: Optional[dict] = None, +) -> bool: + if access_control is None: + return type == "read" + + user_groups = Groups.get_groups_by_member_id(user_id) + user_group_ids = [group.id for group in user_groups] + permission_access = access_control.get(type, {}) + permitted_group_ids = permission_access.get("group_ids", []) + permitted_user_ids = permission_access.get("user_ids", []) + + return user_id in permitted_user_ids or any( + group_id in permitted_group_ids for group_id in user_group_ids + ) diff --git a/backend/open_webui/utils/utils.py b/backend/open_webui/utils/auth.py similarity index 88% rename from backend/open_webui/utils/utils.py rename to backend/open_webui/utils/auth.py index 79faa1831..e1a0ca671 100644 --- a/backend/open_webui/utils/utils.py +++ b/backend/open_webui/utils/auth.py @@ -1,12 +1,15 @@ import logging import uuid -from datetime import UTC, datetime, timedelta -from typing import Optional, Union - import jwt -from open_webui.apps.webui.models.users import Users + +from datetime import UTC, datetime, timedelta +from typing import Optional, Union, List, Dict + +from open_webui.models.users import Users + from open_webui.constants import ERROR_MESSAGES from open_webui.env import WEBUI_SECRET_KEY + from fastapi import Depends, HTTPException, Request, Response, status from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from passlib.context import CryptContext @@ -88,10 +91,21 @@ def get_current_user( # auth by api key if token.startswith("sk-"): + if not request.state.enable_api_key: + raise HTTPException( + status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.API_KEY_NOT_ALLOWED + ) return get_current_user_by_api_key(token) # auth by jwt token - data = decode_token(token) + try: + data = decode_token(token) + except Exception as e: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid token", + ) + if data is not None and "id" in data: user = Users.get_user_by_id(data["id"]) if user is None: diff --git a/backend/open_webui/utils/chat.py b/backend/open_webui/utils/chat.py new file mode 100644 index 000000000..56904d1d8 --- /dev/null +++ b/backend/open_webui/utils/chat.py @@ -0,0 +1,374 @@ +import time +import logging +import sys + +from aiocache import cached +from typing import Any, Optional +import random +import json +import inspect + +from fastapi import Request +from starlette.responses import Response, StreamingResponse + + +from open_webui.models.users import UserModel + +from open_webui.socket.main import ( + get_event_call, + get_event_emitter, +) +from open_webui.functions import generate_function_chat_completion + +from open_webui.routers.openai import ( + generate_chat_completion as generate_openai_chat_completion, +) + +from open_webui.routers.ollama import ( + generate_chat_completion as generate_ollama_chat_completion, +) + +from open_webui.routers.pipelines import ( + process_pipeline_inlet_filter, + process_pipeline_outlet_filter, +) + +from open_webui.models.functions import Functions +from open_webui.models.models import Models + + +from open_webui.utils.plugin import load_function_module_by_id +from open_webui.utils.models import get_all_models, check_model_access +from open_webui.utils.payload import convert_payload_openai_to_ollama +from open_webui.utils.response import ( + convert_response_ollama_to_openai, + convert_streaming_response_ollama_to_openai, +) + +from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL, BYPASS_MODEL_ACCESS_CONTROL + + +logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL) +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["MAIN"]) + + +async def generate_chat_completion( + request: Request, + form_data: dict, + user: Any, + bypass_filter: bool = False, +): + if BYPASS_MODEL_ACCESS_CONTROL: + bypass_filter = True + + models = request.app.state.MODELS + + model_id = form_data["model"] + if model_id not in models: + raise Exception("Model not found") + + # Process the form_data through the pipeline + try: + form_data = process_pipeline_inlet_filter(request, form_data, user, models) + except Exception as e: + raise e + + model = models[model_id] + + # Check if user has access to the model + if not bypass_filter and user.role == "user": + try: + check_model_access(user, model) + except Exception as e: + raise e + + if model["owned_by"] == "arena": + model_ids = model.get("info", {}).get("meta", {}).get("model_ids") + filter_mode = model.get("info", {}).get("meta", {}).get("filter_mode") + if model_ids and filter_mode == "exclude": + model_ids = [ + model["id"] + for model in await get_all_models(request) + if model.get("owned_by") != "arena" and model["id"] not in model_ids + ] + + selected_model_id = None + if isinstance(model_ids, list) and model_ids: + selected_model_id = random.choice(model_ids) + else: + model_ids = [ + model["id"] + for model in await get_all_models(request) + if model.get("owned_by") != "arena" + ] + selected_model_id = random.choice(model_ids) + + form_data["model"] = selected_model_id + + if form_data.get("stream") == True: + + async def stream_wrapper(stream): + yield f"data: {json.dumps({'selected_model_id': selected_model_id})}\n\n" + async for chunk in stream: + yield chunk + + response = await generate_chat_completion( + form_data, user, bypass_filter=True + ) + return StreamingResponse( + stream_wrapper(response.body_iterator), media_type="text/event-stream" + ) + else: + return { + **(await generate_chat_completion(form_data, user, bypass_filter=True)), + "selected_model_id": selected_model_id, + } + + if model.get("pipe"): + # Below does not require bypass_filter because this is the only route the uses this function and it is already bypassing the filter + return await generate_function_chat_completion( + form_data, user=user, models=models + ) + if model["owned_by"] == "ollama": + # Using /ollama/api/chat endpoint + form_data = convert_payload_openai_to_ollama(form_data) + response = await generate_ollama_chat_completion( + request=request, form_data=form_data, user=user, bypass_filter=bypass_filter + ) + if form_data.get("stream"): + response.headers["content-type"] = "text/event-stream" + return StreamingResponse( + convert_streaming_response_ollama_to_openai(response), + headers=dict(response.headers), + ) + else: + return convert_response_ollama_to_openai(response) + else: + return await generate_openai_chat_completion( + request=request, form_data=form_data, user=user, bypass_filter=bypass_filter + ) + + +async def chat_completed(request: Request, form_data: dict, user: Any): + await get_all_models(request) + models = request.app.state.MODELS + + data = form_data + model_id = data["model"] + if model_id not in models: + raise Exception("Model not found") + + model = models[model_id] + + try: + data = process_pipeline_outlet_filter(request, data, user, models) + except Exception as e: + return Exception(f"Error: {e}") + + __event_emitter__ = get_event_emitter( + { + "chat_id": data["chat_id"], + "message_id": data["id"], + "session_id": data["session_id"], + } + ) + + __event_call__ = get_event_call( + { + "chat_id": data["chat_id"], + "message_id": data["id"], + "session_id": data["session_id"], + } + ) + + def get_priority(function_id): + function = Functions.get_function_by_id(function_id) + if function is not None and hasattr(function, "valves"): + # TODO: Fix FunctionModel to include vavles + return (function.valves if function.valves else {}).get("priority", 0) + return 0 + + filter_ids = [function.id for function in Functions.get_global_filter_functions()] + if "info" in model and "meta" in model["info"]: + filter_ids.extend(model["info"]["meta"].get("filterIds", [])) + filter_ids = list(set(filter_ids)) + + enabled_filter_ids = [ + function.id + for function in Functions.get_functions_by_type("filter", active_only=True) + ] + filter_ids = [ + filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids + ] + + # Sort filter_ids by priority, using the get_priority function + filter_ids.sort(key=get_priority) + + for filter_id in filter_ids: + filter = Functions.get_function_by_id(filter_id) + if not filter: + continue + + if filter_id in request.app.state.FUNCTIONS: + function_module = request.app.state.FUNCTIONS[filter_id] + else: + function_module, _, _ = load_function_module_by_id(filter_id) + request.app.state.FUNCTIONS[filter_id] = function_module + + if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): + valves = Functions.get_function_valves_by_id(filter_id) + function_module.valves = function_module.Valves( + **(valves if valves else {}) + ) + + if not hasattr(function_module, "outlet"): + continue + try: + outlet = function_module.outlet + + # Get the signature of the function + sig = inspect.signature(outlet) + params = {"body": data} + + # Extra parameters to be passed to the function + extra_params = { + "__model__": model, + "__id__": filter_id, + "__event_emitter__": __event_emitter__, + "__event_call__": __event_call__, + "__request__": request, + } + + # Add extra params in contained in function signature + for key, value in extra_params.items(): + if key in sig.parameters: + params[key] = value + + if "__user__" in sig.parameters: + __user__ = { + "id": user.id, + "email": user.email, + "name": user.name, + "role": user.role, + } + + try: + if hasattr(function_module, "UserValves"): + __user__["valves"] = function_module.UserValves( + **Functions.get_user_valves_by_id_and_user_id( + filter_id, user.id + ) + ) + except Exception as e: + print(e) + + params = {**params, "__user__": __user__} + + if inspect.iscoroutinefunction(outlet): + data = await outlet(**params) + else: + data = outlet(**params) + + except Exception as e: + return Exception(f"Error: {e}") + + return data + + +async def chat_action(request: Request, action_id: str, form_data: dict, user: Any): + if "." in action_id: + action_id, sub_action_id = action_id.split(".") + else: + sub_action_id = None + + action = Functions.get_function_by_id(action_id) + if not action: + raise Exception(f"Action not found: {action_id}") + + await get_all_models(request) + models = request.app.state.MODELS + + data = form_data + model_id = data["model"] + + if model_id not in models: + raise Exception("Model not found") + model = models[model_id] + + __event_emitter__ = get_event_emitter( + { + "chat_id": data["chat_id"], + "message_id": data["id"], + "session_id": data["session_id"], + } + ) + __event_call__ = get_event_call( + { + "chat_id": data["chat_id"], + "message_id": data["id"], + "session_id": data["session_id"], + } + ) + + if action_id in request.app.state.FUNCTIONS: + function_module = request.app.state.FUNCTIONS[action_id] + else: + function_module, _, _ = load_function_module_by_id(action_id) + request.app.state.FUNCTIONS[action_id] = function_module + + if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): + valves = Functions.get_function_valves_by_id(action_id) + function_module.valves = function_module.Valves(**(valves if valves else {})) + + if hasattr(function_module, "action"): + try: + action = function_module.action + + # Get the signature of the function + sig = inspect.signature(action) + params = {"body": data} + + # Extra parameters to be passed to the function + extra_params = { + "__model__": model, + "__id__": sub_action_id if sub_action_id is not None else action_id, + "__event_emitter__": __event_emitter__, + "__event_call__": __event_call__, + "__request__": request, + } + + # Add extra params in contained in function signature + for key, value in extra_params.items(): + if key in sig.parameters: + params[key] = value + + if "__user__" in sig.parameters: + __user__ = { + "id": user.id, + "email": user.email, + "name": user.name, + "role": user.role, + } + + try: + if hasattr(function_module, "UserValves"): + __user__["valves"] = function_module.UserValves( + **Functions.get_user_valves_by_id_and_user_id( + action_id, user.id + ) + ) + except Exception as e: + print(e) + + params = {**params, "__user__": __user__} + + if inspect.iscoroutinefunction(action): + data = await action(**params) + else: + data = action(**params) + + except Exception as e: + return Exception(f"Error: {e}") + + return data diff --git a/backend/open_webui/apps/images/utils/comfyui.py b/backend/open_webui/utils/images/comfyui.py similarity index 100% rename from backend/open_webui/apps/images/utils/comfyui.py rename to backend/open_webui/utils/images/comfyui.py diff --git a/backend/open_webui/utils/logo.png b/backend/open_webui/utils/logo.png deleted file mode 100644 index 519af1db6..000000000 Binary files a/backend/open_webui/utils/logo.png and /dev/null differ diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py new file mode 100644 index 000000000..1d2bc2b99 --- /dev/null +++ b/backend/open_webui/utils/middleware.py @@ -0,0 +1,508 @@ +import time +import logging +import sys + +from aiocache import cached +from typing import Any, Optional +import random +import json +import inspect + +from fastapi import Request +from starlette.responses import Response, StreamingResponse + + +from open_webui.socket.main import ( + get_event_call, + get_event_emitter, +) +from open_webui.routers.tasks import generate_queries + + +from open_webui.models.users import UserModel +from open_webui.models.functions import Functions +from open_webui.models.models import Models + +from open_webui.retrieval.utils import get_sources_from_files + + +from open_webui.utils.chat import generate_chat_completion +from open_webui.utils.task import ( + get_task_model_id, + rag_template, + tools_function_calling_generation_template, +) +from open_webui.utils.misc import ( + add_or_update_system_message, + get_last_user_message, + prepend_to_first_user_message_content, +) +from open_webui.utils.tools import get_tools +from open_webui.utils.plugin import load_function_module_by_id + + +from open_webui.config import DEFAULT_TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE +from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL, BYPASS_MODEL_ACCESS_CONTROL +from open_webui.constants import TASKS + + +logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL) +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["MAIN"]) + + +async def chat_completion_filter_functions_handler(request, body, model, extra_params): + skip_files = None + + def get_filter_function_ids(model): + def get_priority(function_id): + function = Functions.get_function_by_id(function_id) + if function is not None and hasattr(function, "valves"): + # TODO: Fix FunctionModel + return (function.valves if function.valves else {}).get("priority", 0) + return 0 + + filter_ids = [ + function.id for function in Functions.get_global_filter_functions() + ] + if "info" in model and "meta" in model["info"]: + filter_ids.extend(model["info"]["meta"].get("filterIds", [])) + filter_ids = list(set(filter_ids)) + + enabled_filter_ids = [ + function.id + for function in Functions.get_functions_by_type("filter", active_only=True) + ] + + filter_ids = [ + filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids + ] + + filter_ids.sort(key=get_priority) + return filter_ids + + filter_ids = get_filter_function_ids(model) + for filter_id in filter_ids: + filter = Functions.get_function_by_id(filter_id) + if not filter: + continue + + if filter_id in request.app.state.FUNCTIONS: + function_module = request.app.state.FUNCTIONS[filter_id] + else: + function_module, _, _ = load_function_module_by_id(filter_id) + request.app.state.FUNCTIONS[filter_id] = function_module + + # Check if the function has a file_handler variable + if hasattr(function_module, "file_handler"): + skip_files = function_module.file_handler + + # Apply valves to the function + if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): + valves = Functions.get_function_valves_by_id(filter_id) + function_module.valves = function_module.Valves( + **(valves if valves else {}) + ) + + if hasattr(function_module, "inlet"): + try: + inlet = function_module.inlet + + # Create a dictionary of parameters to be passed to the function + params = {"body": body} | { + k: v + for k, v in { + **extra_params, + "__model__": model, + "__id__": filter_id, + }.items() + if k in inspect.signature(inlet).parameters + } + + if "__user__" in params and hasattr(function_module, "UserValves"): + try: + params["__user__"]["valves"] = function_module.UserValves( + **Functions.get_user_valves_by_id_and_user_id( + filter_id, params["__user__"]["id"] + ) + ) + except Exception as e: + print(e) + + if inspect.iscoroutinefunction(inlet): + body = await inlet(**params) + else: + body = inlet(**params) + + except Exception as e: + print(f"Error: {e}") + raise e + + if skip_files and "files" in body.get("metadata", {}): + del body["metadata"]["files"] + + return body, {} + + +async def chat_completion_tools_handler( + request: Request, body: dict, user: UserModel, models, extra_params: dict +) -> tuple[dict, dict]: + async def get_content_from_response(response) -> Optional[str]: + content = None + if hasattr(response, "body_iterator"): + async for chunk in response.body_iterator: + data = json.loads(chunk.decode("utf-8")) + content = data["choices"][0]["message"]["content"] + + # Cleanup any remaining background tasks if necessary + if response.background is not None: + await response.background() + else: + content = response["choices"][0]["message"]["content"] + return content + + def get_tools_function_calling_payload(messages, task_model_id, content): + user_message = get_last_user_message(messages) + history = "\n".join( + f"{message['role'].upper()}: \"\"\"{message['content']}\"\"\"" + for message in messages[::-1][:4] + ) + + prompt = f"History:\n{history}\nQuery: {user_message}" + + return { + "model": task_model_id, + "messages": [ + {"role": "system", "content": content}, + {"role": "user", "content": f"Query: {prompt}"}, + ], + "stream": False, + "metadata": {"task": str(TASKS.FUNCTION_CALLING)}, + } + + # If tool_ids field is present, call the functions + metadata = body.get("metadata", {}) + + tool_ids = metadata.get("tool_ids", None) + log.debug(f"{tool_ids=}") + if not tool_ids: + return body, {} + + skip_files = False + sources = [] + + task_model_id = get_task_model_id( + body["model"], + request.app.state.config.TASK_MODEL, + request.app.state.config.TASK_MODEL_EXTERNAL, + models, + ) + tools = get_tools( + request, + tool_ids, + user, + { + **extra_params, + "__model__": models[task_model_id], + "__messages__": body["messages"], + "__files__": metadata.get("files", []), + }, + ) + log.info(f"{tools=}") + + specs = [tool["spec"] for tool in tools.values()] + tools_specs = json.dumps(specs) + + if request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE != "": + template = request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE + else: + template = DEFAULT_TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE + + tools_function_calling_prompt = tools_function_calling_generation_template( + template, tools_specs + ) + log.info(f"{tools_function_calling_prompt=}") + payload = get_tools_function_calling_payload( + body["messages"], task_model_id, tools_function_calling_prompt + ) + + try: + response = await generate_chat_completion(request, form_data=payload, user=user) + log.debug(f"{response=}") + content = await get_content_from_response(response) + log.debug(f"{content=}") + + if not content: + return body, {} + + try: + content = content[content.find("{") : content.rfind("}") + 1] + if not content: + raise Exception("No JSON object found in the response") + + result = json.loads(content) + + tool_function_name = result.get("name", None) + if tool_function_name not in tools: + return body, {} + + tool_function_params = result.get("parameters", {}) + + try: + required_params = ( + tools[tool_function_name] + .get("spec", {}) + .get("parameters", {}) + .get("required", []) + ) + tool_function = tools[tool_function_name]["callable"] + tool_function_params = { + k: v + for k, v in tool_function_params.items() + if k in required_params + } + tool_output = await tool_function(**tool_function_params) + + except Exception as e: + tool_output = str(e) + + if isinstance(tool_output, str): + if tools[tool_function_name]["citation"]: + sources.append( + { + "source": { + "name": f"TOOL:{tools[tool_function_name]['toolkit_id']}/{tool_function_name}" + }, + "document": [tool_output], + "metadata": [ + { + "source": f"TOOL:{tools[tool_function_name]['toolkit_id']}/{tool_function_name}" + } + ], + } + ) + else: + sources.append( + { + "source": {}, + "document": [tool_output], + "metadata": [ + { + "source": f"TOOL:{tools[tool_function_name]['toolkit_id']}/{tool_function_name}" + } + ], + } + ) + + if tools[tool_function_name]["file_handler"]: + skip_files = True + + except Exception as e: + log.exception(f"Error: {e}") + content = None + except Exception as e: + log.exception(f"Error: {e}") + content = None + + log.debug(f"tool_contexts: {sources}") + + if skip_files and "files" in body.get("metadata", {}): + del body["metadata"]["files"] + + return body, {"sources": sources} + + +async def chat_completion_files_handler( + request: Request, body: dict, user: UserModel +) -> tuple[dict, dict[str, list]]: + sources = [] + + if files := body.get("metadata", {}).get("files", None): + try: + queries_response = await generate_queries( + { + "model": body["model"], + "messages": body["messages"], + "type": "retrieval", + }, + user, + ) + queries_response = queries_response["choices"][0]["message"]["content"] + + try: + bracket_start = queries_response.find("{") + bracket_end = queries_response.rfind("}") + 1 + + if bracket_start == -1 or bracket_end == -1: + raise Exception("No JSON object found in the response") + + queries_response = queries_response[bracket_start:bracket_end] + queries_response = json.loads(queries_response) + except Exception as e: + queries_response = {"queries": [queries_response]} + + queries = queries_response.get("queries", []) + except Exception as e: + queries = [] + + if len(queries) == 0: + queries = [get_last_user_message(body["messages"])] + + sources = get_sources_from_files( + files=files, + queries=queries, + embedding_function=request.app.state.EMBEDDING_FUNCTION, + k=request.app.state.config.TOP_K, + reranking_function=request.app.state.rf, + r=request.app.state.config.RELEVANCE_THRESHOLD, + hybrid_search=request.app.state.config.ENABLE_RAG_HYBRID_SEARCH, + ) + + log.debug(f"rag_contexts:sources: {sources}") + return body, {"sources": sources} + + +async def process_chat_payload(request, form_data, user, model): + metadata = { + "chat_id": form_data.pop("chat_id", None), + "message_id": form_data.pop("id", None), + "session_id": form_data.pop("session_id", None), + "tool_ids": form_data.get("tool_ids", None), + "files": form_data.get("files", None), + } + form_data["metadata"] = metadata + + extra_params = { + "__event_emitter__": get_event_emitter(metadata), + "__event_call__": get_event_call(metadata), + "__user__": { + "id": user.id, + "email": user.email, + "name": user.name, + "role": user.role, + }, + "__metadata__": metadata, + "__request__": request, + } + + # Initialize events to store additional event to be sent to the client + # Initialize contexts and citation + models = request.app.state.MODELS + events = [] + sources = [] + + try: + form_data, flags = await chat_completion_filter_functions_handler( + request, form_data, model, extra_params + ) + except Exception as e: + return Exception(f"Error: {e}") + + tool_ids = form_data.pop("tool_ids", None) + files = form_data.pop("files", None) + + metadata = { + **metadata, + "tool_ids": tool_ids, + "files": files, + } + form_data["metadata"] = metadata + + try: + form_data, flags = await chat_completion_tools_handler( + request, form_data, user, models, extra_params + ) + sources.extend(flags.get("sources", [])) + except Exception as e: + log.exception(e) + + try: + form_data, flags = await chat_completion_files_handler(request, form_data, user) + sources.extend(flags.get("sources", [])) + except Exception as e: + log.exception(e) + + # If context is not empty, insert it into the messages + if len(sources) > 0: + context_string = "" + for source_idx, source in enumerate(sources): + source_id = source.get("source", {}).get("name", "") + + if "document" in source: + for doc_idx, doc_context in enumerate(source["document"]): + metadata = source.get("metadata") + doc_source_id = None + + if metadata: + doc_source_id = metadata[doc_idx].get("source", source_id) + + if source_id: + context_string += f"{doc_source_id if doc_source_id is not None else source_id}{doc_context}\n" + else: + # If there is no source_id, then do not include the source_id tag + context_string += f"{doc_context}\n" + + context_string = context_string.strip() + prompt = get_last_user_message(form_data["messages"]) + + if prompt is None: + raise Exception("No user message found") + if ( + request.app.state.config.RELEVANCE_THRESHOLD == 0 + and context_string.strip() == "" + ): + log.debug( + f"With a 0 relevancy threshold for RAG, the context cannot be empty" + ) + + # Workaround for Ollama 2.0+ system prompt issue + # TODO: replace with add_or_update_system_message + if model["owned_by"] == "ollama": + form_data["messages"] = prepend_to_first_user_message_content( + rag_template( + request.app.state.config.RAG_TEMPLATE, context_string, prompt + ), + form_data["messages"], + ) + else: + form_data["messages"] = add_or_update_system_message( + rag_template( + request.app.state.config.RAG_TEMPLATE, context_string, prompt + ), + form_data["messages"], + ) + + # If there are citations, add them to the data_items + sources = [source for source in sources if source.get("source", {}).get("name", "")] + + if len(sources) > 0: + events.append({"sources": sources}) + + return form_data, events + + +async def process_chat_response(response, events): + if not isinstance(response, StreamingResponse): + return response + + content_type = response.headers["Content-Type"] + is_openai = "text/event-stream" in content_type + is_ollama = "application/x-ndjson" in content_type + + if not is_openai and not is_ollama: + return response + + async def stream_wrapper(original_generator, events): + def wrap_item(item): + return f"data: {item}\n\n" if is_openai else f"{item}\n" + + for event in events: + yield wrap_item(json.dumps(event)) + + async for data in original_generator: + yield data + + return StreamingResponse( + stream_wrapper(response.body_iterator, events), + headers=dict(response.headers), + ) diff --git a/backend/open_webui/utils/misc.py b/backend/open_webui/utils/misc.py index a5af492ba..aba696f60 100644 --- a/backend/open_webui/utils/misc.py +++ b/backend/open_webui/utils/misc.py @@ -106,7 +106,7 @@ def openai_chat_message_template(model: str): def openai_chat_chunk_message_template( - model: str, message: Optional[str] = None + model: str, message: Optional[str] = None, usage: Optional[dict] = None ) -> dict: template = openai_chat_message_template(model) template["object"] = "chat.completion.chunk" @@ -114,17 +114,23 @@ def openai_chat_chunk_message_template( template["choices"][0]["delta"] = {"content": message} else: template["choices"][0]["finish_reason"] = "stop" + + if usage: + template["usage"] = usage return template def openai_chat_completion_message_template( - model: str, message: Optional[str] = None + model: str, message: Optional[str] = None, usage: Optional[dict] = None ) -> dict: template = openai_chat_message_template(model) template["object"] = "chat.completion" if message is not None: template["choices"][0]["message"] = {"content": message, "role": "assistant"} template["choices"][0]["finish_reason"] = "stop" + + if usage: + template["usage"] = usage return template diff --git a/backend/open_webui/utils/models.py b/backend/open_webui/utils/models.py new file mode 100644 index 000000000..b9a4f07a3 --- /dev/null +++ b/backend/open_webui/utils/models.py @@ -0,0 +1,246 @@ +import time +import logging +import sys + +from aiocache import cached +from fastapi import Request + +from open_webui.routers import openai, ollama +from open_webui.functions import get_function_models + + +from open_webui.models.functions import Functions +from open_webui.models.models import Models + + +from open_webui.utils.plugin import load_function_module_by_id +from open_webui.utils.access_control import has_access + + +from open_webui.config import ( + DEFAULT_ARENA_MODEL, +) + +from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL + + +logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL) +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["MAIN"]) + + +async def get_all_base_models(request: Request): + function_models = [] + openai_models = [] + ollama_models = [] + + if request.app.state.config.ENABLE_OPENAI_API: + openai_models = await openai.get_all_models(request) + openai_models = openai_models["data"] + + if request.app.state.config.ENABLE_OLLAMA_API: + ollama_models = await ollama.get_all_models(request) + ollama_models = [ + { + "id": model["model"], + "name": model["name"], + "object": "model", + "created": int(time.time()), + "owned_by": "ollama", + "ollama": model, + } + for model in ollama_models["models"] + ] + + function_models = await get_function_models(request) + models = function_models + openai_models + ollama_models + + return models + + +@cached(ttl=3) +async def get_all_models(request): + models = await get_all_base_models(request) + + # If there are no models, return an empty list + if len(models) == 0: + return [] + + # Add arena models + if request.app.state.config.ENABLE_EVALUATION_ARENA_MODELS: + arena_models = [] + if len(request.app.state.config.EVALUATION_ARENA_MODELS) > 0: + arena_models = [ + { + "id": model["id"], + "name": model["name"], + "info": { + "meta": model["meta"], + }, + "object": "model", + "created": int(time.time()), + "owned_by": "arena", + "arena": True, + } + for model in request.app.state.config.EVALUATION_ARENA_MODELS + ] + else: + # Add default arena model + arena_models = [ + { + "id": DEFAULT_ARENA_MODEL["id"], + "name": DEFAULT_ARENA_MODEL["name"], + "info": { + "meta": DEFAULT_ARENA_MODEL["meta"], + }, + "object": "model", + "created": int(time.time()), + "owned_by": "arena", + "arena": True, + } + ] + models = models + arena_models + + global_action_ids = [ + function.id for function in Functions.get_global_action_functions() + ] + enabled_action_ids = [ + function.id + for function in Functions.get_functions_by_type("action", active_only=True) + ] + + custom_models = Models.get_all_models() + for custom_model in custom_models: + if custom_model.base_model_id is None: + for model in models: + if ( + custom_model.id == model["id"] + or custom_model.id == model["id"].split(":")[0] + ): + if custom_model.is_active: + model["name"] = custom_model.name + model["info"] = custom_model.model_dump() + + action_ids = [] + if "info" in model and "meta" in model["info"]: + action_ids.extend( + model["info"]["meta"].get("actionIds", []) + ) + + model["action_ids"] = action_ids + else: + models.remove(model) + + elif custom_model.is_active and ( + custom_model.id not in [model["id"] for model in models] + ): + owned_by = "openai" + pipe = None + action_ids = [] + + for model in models: + if ( + custom_model.base_model_id == model["id"] + or custom_model.base_model_id == model["id"].split(":")[0] + ): + owned_by = model["owned_by"] + if "pipe" in model: + pipe = model["pipe"] + break + + if custom_model.meta: + meta = custom_model.meta.model_dump() + if "actionIds" in meta: + action_ids.extend(meta["actionIds"]) + + models.append( + { + "id": f"{custom_model.id}", + "name": custom_model.name, + "object": "model", + "created": custom_model.created_at, + "owned_by": owned_by, + "info": custom_model.model_dump(), + "preset": True, + **({"pipe": pipe} if pipe is not None else {}), + "action_ids": action_ids, + } + ) + + # Process action_ids to get the actions + def get_action_items_from_module(function, module): + actions = [] + if hasattr(module, "actions"): + actions = module.actions + return [ + { + "id": f"{function.id}.{action['id']}", + "name": action.get("name", f"{function.name} ({action['id']})"), + "description": function.meta.description, + "icon_url": action.get( + "icon_url", function.meta.manifest.get("icon_url", None) + ), + } + for action in actions + ] + else: + return [ + { + "id": function.id, + "name": function.name, + "description": function.meta.description, + "icon_url": function.meta.manifest.get("icon_url", None), + } + ] + + def get_function_module_by_id(function_id): + if function_id in request.app.state.FUNCTIONS: + function_module = request.app.state.FUNCTIONS[function_id] + else: + function_module, _, _ = load_function_module_by_id(function_id) + request.app.state.FUNCTIONS[function_id] = function_module + + for model in models: + action_ids = [ + action_id + for action_id in list(set(model.pop("action_ids", []) + global_action_ids)) + if action_id in enabled_action_ids + ] + + model["actions"] = [] + for action_id in action_ids: + action_function = Functions.get_function_by_id(action_id) + if action_function is None: + raise Exception(f"Action not found: {action_id}") + + function_module = get_function_module_by_id(action_id) + model["actions"].extend( + get_action_items_from_module(action_function, function_module) + ) + log.debug(f"get_all_models() returned {len(models)} models") + + request.app.state.MODELS = {model["id"]: model for model in models} + return models + + +def check_model_access(user, model): + if model.get("arena"): + if not has_access( + user.id, + type="read", + access_control=model.get("info", {}) + .get("meta", {}) + .get("access_control", {}), + ): + raise Exception("Model not found") + else: + model_info = Models.get_model_by_id(model.get("id")) + if not model_info: + raise Exception("Model not found") + elif not ( + user.id == model_info.user_id + or has_access( + user.id, type="read", access_control=model_info.access_control + ) + ): + raise Exception("Model not found") diff --git a/backend/open_webui/utils/oauth.py b/backend/open_webui/utils/oauth.py index 722b1ea73..f0ab7a345 100644 --- a/backend/open_webui/utils/oauth.py +++ b/backend/open_webui/utils/oauth.py @@ -12,8 +12,8 @@ from fastapi import ( ) from starlette.responses import RedirectResponse -from open_webui.apps.webui.models.auths import Auths -from open_webui.apps.webui.models.users import Users +from open_webui.models.auths import Auths +from open_webui.models.users import Users from open_webui.config import ( DEFAULT_USER_ROLE, ENABLE_OAUTH_SIGNUP, @@ -26,6 +26,7 @@ from open_webui.config import ( OAUTH_USERNAME_CLAIM, OAUTH_ALLOWED_ROLES, OAUTH_ADMIN_ROLES, + OAUTH_ALLOWED_DOMAINS, WEBHOOK_URL, JWT_EXPIRES_IN, AppConfig, @@ -33,7 +34,7 @@ from open_webui.config import ( from open_webui.constants import ERROR_MESSAGES from open_webui.env import WEBUI_SESSION_COOKIE_SAME_SITE, WEBUI_SESSION_COOKIE_SECURE from open_webui.utils.misc import parse_duration -from open_webui.utils.utils import get_password_hash, create_token +from open_webui.utils.auth import get_password_hash, create_token from open_webui.utils.webhook import post_webhook log = logging.getLogger(__name__) @@ -49,6 +50,7 @@ auth_manager_config.OAUTH_PICTURE_CLAIM = OAUTH_PICTURE_CLAIM auth_manager_config.OAUTH_USERNAME_CLAIM = OAUTH_USERNAME_CLAIM auth_manager_config.OAUTH_ALLOWED_ROLES = OAUTH_ALLOWED_ROLES auth_manager_config.OAUTH_ADMIN_ROLES = OAUTH_ADMIN_ROLES +auth_manager_config.OAUTH_ALLOWED_DOMAINS = OAUTH_ALLOWED_DOMAINS auth_manager_config.WEBHOOK_URL = WEBHOOK_URL auth_manager_config.JWT_EXPIRES_IN = JWT_EXPIRES_IN @@ -156,6 +158,14 @@ class OAuthManager: if not email: log.warning(f"OAuth callback failed, email is missing: {user_data}") raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) + if ( + "*" not in auth_manager_config.OAUTH_ALLOWED_DOMAINS + and email.split("@")[-1] not in auth_manager_config.OAUTH_ALLOWED_DOMAINS + ): + log.warning( + f"OAuth callback failed, e-mail domain is not in the list of allowed domains: {user_data}" + ) + raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) # Check if the user exists user = Users.get_user_by_oauth_sub(provider_sub) @@ -253,9 +263,18 @@ class OAuthManager: secure=WEBUI_SESSION_COOKIE_SECURE, ) + if ENABLE_OAUTH_SIGNUP.value: + oauth_id_token = token.get("id_token") + response.set_cookie( + key="oauth_id_token", + value=oauth_id_token, + httponly=True, + samesite=WEBUI_SESSION_COOKIE_SAME_SITE, + secure=WEBUI_SESSION_COOKIE_SECURE, + ) # Redirect back to the frontend with the JWT token redirect_url = f"{request.base_url}auth#token={jwt_token}" - return RedirectResponse(url=redirect_url) + return RedirectResponse(url=redirect_url, headers=response.headers) oauth_manager = OAuthManager() diff --git a/backend/open_webui/utils/pdf_generator.py b/backend/open_webui/utils/pdf_generator.py index 6c3cf55ce..bbaf42dbb 100644 --- a/backend/open_webui/utils/pdf_generator.py +++ b/backend/open_webui/utils/pdf_generator.py @@ -9,7 +9,7 @@ import site from fpdf import FPDF from open_webui.env import STATIC_DIR, FONTS_DIR -from open_webui.apps.webui.models.chats import ChatTitleMessagesForm +from open_webui.models.chats import ChatTitleMessagesForm class PDFGenerator: @@ -51,21 +51,25 @@ class PDFGenerator: # extends pymdownx extension to convert markdown to html. # - https://facelessuser.github.io/pymdown-extensions/usage_notes/ - html_content = markdown(content, extensions=["pymdownx.extra"]) + # html_content = markdown(content, extensions=["pymdownx.extra"]) html_message = f""" -
- {date_str} -
-

- {role.title()} - {model} -

-
-
- {html_content} -
-
+
+
+

+ {role.title()} + {model} +

+
{date_str}
+
+
+
+ +
+ {content} +
+
+
""" return html_message @@ -74,18 +78,15 @@ class PDFGenerator: return f""" - - + -
-
-

{self.form_data.title}

-
-
- {self.messages_html} -
+
+
+

{self.form_data.title}

+ {self.messages_html}
+
""" @@ -114,9 +115,12 @@ class PDFGenerator: pdf.add_font("NotoSansKR", "", f"{FONTS_DIR}/NotoSansKR-Regular.ttf") pdf.add_font("NotoSansJP", "", f"{FONTS_DIR}/NotoSansJP-Regular.ttf") pdf.add_font("NotoSansSC", "", f"{FONTS_DIR}/NotoSansSC-Regular.ttf") + pdf.add_font("Twemoji", "", f"{FONTS_DIR}/Twemoji.ttf") pdf.set_font("NotoSans", size=12) - pdf.set_fallback_fonts(["NotoSansKR", "NotoSansJP", "NotoSansSC"]) + pdf.set_fallback_fonts( + ["NotoSansKR", "NotoSansJP", "NotoSansSC", "Twemoji"] + ) pdf.set_auto_page_break(auto=True, margin=15) diff --git a/backend/open_webui/apps/webui/utils.py b/backend/open_webui/utils/plugin.py similarity index 89% rename from backend/open_webui/apps/webui/utils.py rename to backend/open_webui/utils/plugin.py index 51d379656..17b86cea1 100644 --- a/backend/open_webui/apps/webui/utils.py +++ b/backend/open_webui/utils/plugin.py @@ -5,9 +5,14 @@ import sys from importlib import util import types import tempfile +import logging -from open_webui.apps.webui.models.functions import Functions -from open_webui.apps.webui.models.tools import Tools +from open_webui.env import SRC_LOG_LEVELS +from open_webui.models.functions import Functions +from open_webui.models.tools import Tools + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["MAIN"]) def extract_frontmatter(content): @@ -63,7 +68,7 @@ def replace_imports(content): return content -def load_toolkit_module_by_id(toolkit_id, content=None): +def load_tools_module_by_id(toolkit_id, content=None): if content is None: tool = Tools.get_tool_by_id(toolkit_id) @@ -95,7 +100,7 @@ def load_toolkit_module_by_id(toolkit_id, content=None): # Executing the modified content in the created module's namespace exec(content, module.__dict__) frontmatter = extract_frontmatter(content) - print(f"Loaded module: {module.__name__}") + log.info(f"Loaded module: {module.__name__}") # Create and return the object if the class 'Tools' is found in the module if hasattr(module, "Tools"): @@ -103,7 +108,7 @@ def load_toolkit_module_by_id(toolkit_id, content=None): else: raise Exception("No Tools class found in the module") except Exception as e: - print(f"Error loading module: {toolkit_id}: {e}") + log.error(f"Error loading module: {toolkit_id}: {e}") del sys.modules[module_name] # Clean up raise e finally: @@ -139,7 +144,7 @@ def load_function_module_by_id(function_id, content=None): # Execute the modified content in the created module's namespace exec(content, module.__dict__) frontmatter = extract_frontmatter(content) - print(f"Loaded module: {module.__name__}") + log.info(f"Loaded module: {module.__name__}") # Create appropriate object based on available class type in the module if hasattr(module, "Pipe"): @@ -151,7 +156,7 @@ def load_function_module_by_id(function_id, content=None): else: raise Exception("No Function class found in the module") except Exception as e: - print(f"Error loading module: {function_id}: {e}") + log.error(f"Error loading module: {function_id}: {e}") del sys.modules[module_name] # Cleanup by removing the module in case of error Functions.update_function_by_id(function_id, {"is_active": False}) @@ -164,7 +169,7 @@ def install_frontmatter_requirements(requirements): if requirements: req_list = [req.strip() for req in requirements.split(",")] for req in req_list: - print(f"Installing requirement: {req}") + log.info(f"Installing requirement: {req}") subprocess.check_call([sys.executable, "-m", "pip", "install", req]) else: - print("No requirements found in frontmatter.") + log.info("No requirements found in frontmatter.") diff --git a/backend/open_webui/utils/response.py b/backend/open_webui/utils/response.py index b8501e92c..891016e43 100644 --- a/backend/open_webui/utils/response.py +++ b/backend/open_webui/utils/response.py @@ -21,8 +21,63 @@ async def convert_streaming_response_ollama_to_openai(ollama_streaming_response) message_content = data.get("message", {}).get("content", "") done = data.get("done", False) + usage = None + if done: + usage = { + "response_token/s": ( + round( + ( + ( + data.get("eval_count", 0) + / ((data.get("eval_duration", 0) / 1_000_000_000)) + ) + * 100 + ), + 2, + ) + if data.get("eval_duration", 0) > 0 + else "N/A" + ), + "prompt_token/s": ( + round( + ( + ( + data.get("prompt_eval_count", 0) + / ( + ( + data.get("prompt_eval_duration", 0) + / 1_000_000_000 + ) + ) + ) + * 100 + ), + 2, + ) + if data.get("prompt_eval_duration", 0) > 0 + else "N/A" + ), + "total_duration": round( + ((data.get("total_duration", 0) / 1_000_000) * 100), 2 + ), + "load_duration": round( + ((data.get("load_duration", 0) / 1_000_000) * 100), 2 + ), + "prompt_eval_count": data.get("prompt_eval_count", 0), + "prompt_eval_duration": round( + ((data.get("prompt_eval_duration", 0) / 1_000_000) * 100), 2 + ), + "eval_count": data.get("eval_count", 0), + "eval_duration": round( + ((data.get("eval_duration", 0) / 1_000_000) * 100), 2 + ), + "approximate_total": ( + lambda s: f"{s // 3600}h{(s % 3600) // 60}m{s % 60}s" + )((data.get("total_duration", 0) or 0) // 1_000_000_000), + } + data = openai_chat_chunk_message_template( - model, message_content if not done else None + model, message_content if not done else None, usage ) line = f"data: {json.dumps(data)}\n\n" diff --git a/backend/open_webui/utils/schemas.py b/backend/open_webui/utils/schemas.py deleted file mode 100644 index 4d1d448cd..000000000 --- a/backend/open_webui/utils/schemas.py +++ /dev/null @@ -1,112 +0,0 @@ -from ast import literal_eval -from typing import Any, Literal, Optional, Type - -from pydantic import BaseModel, Field, create_model - - -def json_schema_to_model(tool_dict: dict[str, Any]) -> Type[BaseModel]: - """ - Converts a JSON schema to a Pydantic BaseModel class. - - Args: - json_schema: The JSON schema to convert. - - Returns: - A Pydantic BaseModel class. - """ - - # Extract the model name from the schema title. - model_name = tool_dict["name"] - schema = tool_dict["parameters"] - - # Extract the field definitions from the schema properties. - field_definitions = { - name: json_schema_to_pydantic_field(name, prop, schema.get("required", [])) - for name, prop in schema.get("properties", {}).items() - } - - # Create the BaseModel class using create_model(). - return create_model(model_name, **field_definitions) - - -def json_schema_to_pydantic_field( - name: str, json_schema: dict[str, Any], required: list[str] -) -> Any: - """ - Converts a JSON schema property to a Pydantic field definition. - - Args: - name: The field name. - json_schema: The JSON schema property. - - Returns: - A Pydantic field definition. - """ - - # Get the field type. - type_ = json_schema_to_pydantic_type(json_schema) - - # Get the field description. - description = json_schema.get("description") - - # Get the field examples. - examples = json_schema.get("examples") - - # Create a Field object with the type, description, and examples. - # The 'required' flag will be set later when creating the model. - return ( - type_, - Field( - description=description, - examples=examples, - default=... if name in required else None, - ), - ) - - -def json_schema_to_pydantic_type(json_schema: dict[str, Any]) -> Any: - """ - Converts a JSON schema type to a Pydantic type. - - Args: - json_schema: The JSON schema to convert. - - Returns: - A Pydantic type. - """ - - type_ = json_schema.get("type") - - if type_ == "string" or type_ == "str": - return str - elif type_ == "integer" or type_ == "int": - return int - elif type_ == "number" or type_ == "float": - return float - elif type_ == "boolean" or type_ == "bool": - return bool - elif type_ == "array" or type_ == "list": - items_schema = json_schema.get("items") - if items_schema: - item_type = json_schema_to_pydantic_type(items_schema) - return list[item_type] - else: - return list - elif type_ == "object": - # Handle nested models. - properties = json_schema.get("properties") - if properties: - nested_model = json_schema_to_model(json_schema) - return nested_model - else: - return dict - elif type_ == "null": - return Optional[Any] # Use Optional[Any] for nullable fields - elif type_ == "literal": - return Literal[literal_eval(json_schema.get("enum"))] - elif type_ == "optional": - inner_schema = json_schema.get("items", {"type": "string"}) - inner_type = json_schema_to_pydantic_type(inner_schema) - return Optional[inner_type] - else: - raise ValueError(f"Unsupported JSON schema type: {type_}") diff --git a/backend/open_webui/utils/security_headers.py b/backend/open_webui/utils/security_headers.py index a656b2935..fbcf7d697 100644 --- a/backend/open_webui/utils/security_headers.py +++ b/backend/open_webui/utils/security_headers.py @@ -20,12 +20,14 @@ def set_security_headers() -> Dict[str, str]: This function reads specific environment variables and uses their values to set corresponding security headers. The headers that can be set are: - cache-control + - permissions-policy - strict-transport-security - referrer-policy - x-content-type-options - x-download-options - x-frame-options - x-permitted-cross-domain-policies + - content-security-policy Each environment variable is associated with a specific setter function that constructs the header. If the environment variable is set, the @@ -38,11 +40,13 @@ def set_security_headers() -> Dict[str, str]: header_setters = { "CACHE_CONTROL": set_cache_control, "HSTS": set_hsts, + "PERMISSIONS_POLICY": set_permissions_policy, "REFERRER_POLICY": set_referrer, "XCONTENT_TYPE": set_xcontent_type, "XDOWNLOAD_OPTIONS": set_xdownload_options, "XFRAME_OPTIONS": set_xframe, "XPERMITTED_CROSS_DOMAIN_POLICIES": set_xpermitted_cross_domain_policies, + "CONTENT_SECURITY_POLICY": set_content_security_policy, } for env_var, setter in header_setters.items(): @@ -73,6 +77,15 @@ def set_xframe(value: str): return {"X-Frame-Options": value} +# Set Permissions-Policy response header +def set_permissions_policy(value: str): + pattern = r"^(?:(accelerometer|autoplay|camera|clipboard-read|clipboard-write|fullscreen|geolocation|gyroscope|magnetometer|microphone|midi|payment|picture-in-picture|sync-xhr|usb|xr-spatial-tracking)=\((self)?\),?)*$" + match = re.match(pattern, value, re.IGNORECASE) + if not match: + value = "none" + return {"Permissions-Policy": value} + + # Set Referrer-Policy response header def set_referrer(value: str): pattern = r"^(no-referrer|no-referrer-when-downgrade|origin|origin-when-cross-origin|same-origin|strict-origin|strict-origin-when-cross-origin|unsafe-url)$" @@ -113,3 +126,8 @@ def set_xpermitted_cross_domain_policies(value: str): if not match: value = "none" return {"X-Permitted-Cross-Domain-Policies": value} + + +# Set Content-Security-Policy response header +def set_content_security_policy(value: str): + return {"Content-Security-Policy": value} diff --git a/backend/open_webui/utils/task.py b/backend/open_webui/utils/task.py index 799cca11a..ebb7483ba 100644 --- a/backend/open_webui/utils/task.py +++ b/backend/open_webui/utils/task.py @@ -1,11 +1,36 @@ +import logging import math import re from datetime import datetime from typing import Optional +import uuid from open_webui.utils.misc import get_last_user_message, get_messages_content +from open_webui.env import SRC_LOG_LEVELS +from open_webui.config import DEFAULT_RAG_TEMPLATE + + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["RAG"]) + + +def get_task_model_id( + default_model_id: str, task_model: str, task_model_external: str, models +) -> str: + # Set the task model + task_model_id = default_model_id + # Check if the user has a custom task model and use that model + if models[task_model_id]["owned_by"] == "ollama": + if task_model and task_model in models: + task_model_id = task_model + else: + if task_model_external and task_model_external in models: + task_model_id = task_model_external + + return task_model_id + def prompt_template( template: str, user_name: Optional[str] = None, user_location: Optional[str] = None @@ -16,12 +41,14 @@ def prompt_template( # Format the date to YYYY-MM-DD formatted_date = current_date.strftime("%Y-%m-%d") formatted_time = current_date.strftime("%I:%M:%S %p") + formatted_weekday = current_date.strftime("%A") template = template.replace("{{CURRENT_DATE}}", formatted_date) template = template.replace("{{CURRENT_TIME}}", formatted_time) template = template.replace( "{{CURRENT_DATETIME}}", f"{formatted_date} {formatted_time}" ) + template = template.replace("{{CURRENT_WEEKDAY}}", formatted_weekday) if user_name: # Replace {{USER_NAME}} in the template with the user's name @@ -42,7 +69,9 @@ def prompt_template( def replace_prompt_variable(template: str, prompt: str) -> str: def replacement_function(match): - full_match = match.group(0) + full_match = match.group( + 0 + ).lower() # Normalize to lowercase for consistent handling start_length = match.group(1) end_length = match.group(2) middle_length = match.group(3) @@ -62,20 +91,23 @@ def replace_prompt_variable(template: str, prompt: str) -> str: return f"{start}...{end}" return "" - template = re.sub( - r"{{prompt}}|{{prompt:start:(\d+)}}|{{prompt:end:(\d+)}}|{{prompt:middletruncate:(\d+)}}", - replacement_function, - template, - ) + # Updated regex pattern to make it case-insensitive with the `(?i)` flag + pattern = r"(?i){{prompt}}|{{prompt:start:(\d+)}}|{{prompt:end:(\d+)}}|{{prompt:middletruncate:(\d+)}}" + template = re.sub(pattern, replacement_function, template) return template -def replace_messages_variable(template: str, messages: list[str]) -> str: +def replace_messages_variable( + template: str, messages: Optional[list[str]] = None +) -> str: def replacement_function(match): full_match = match.group(0) start_length = match.group(1) end_length = match.group(2) middle_length = match.group(3) + # If messages is None, handle it as an empty list + if messages is None: + return "" # Process messages based on the number of messages required if full_match == "{{MESSAGES}}": @@ -110,6 +142,44 @@ def replace_messages_variable(template: str, messages: list[str]) -> str: # {{prompt:middletruncate:8000}} +def rag_template(template: str, context: str, query: str): + if template.strip() == "": + template = DEFAULT_RAG_TEMPLATE + + if "[context]" not in template and "{{CONTEXT}}" not in template: + log.debug( + "WARNING: The RAG template does not contain the '[context]' or '{{CONTEXT}}' placeholder." + ) + + if "" in context and "" in context: + log.debug( + "WARNING: Potential prompt injection attack: the RAG " + "context contains '' and ''. This might be " + "nothing, or the user might be trying to hack something." + ) + + query_placeholders = [] + if "[query]" in context: + query_placeholder = "{{QUERY" + str(uuid.uuid4()) + "}}" + template = template.replace("[query]", query_placeholder) + query_placeholders.append(query_placeholder) + + if "{{QUERY}}" in context: + query_placeholder = "{{QUERY" + str(uuid.uuid4()) + "}}" + template = template.replace("{{QUERY}}", query_placeholder) + query_placeholders.append(query_placeholder) + + template = template.replace("[context]", context) + template = template.replace("{{CONTEXT}}", context) + template = template.replace("[query]", query) + template = template.replace("{{QUERY}}", query) + + for query_placeholder in query_placeholders: + template = template.replace(query_placeholder, query) + + return template + + def title_generation_template( template: str, messages: list[dict], user: Optional[dict] = None ) -> str: @@ -163,7 +233,29 @@ def emoji_generation_template( return template -def search_query_generation_template( +def autocomplete_generation_template( + template: str, + prompt: str, + messages: Optional[list[dict]] = None, + type: Optional[str] = None, + user: Optional[dict] = None, +) -> str: + template = template.replace("{{TYPE}}", type if type else "") + template = replace_prompt_variable(template, prompt) + template = replace_messages_variable(template, messages) + + template = prompt_template( + template, + **( + {"user_name": user.get("name"), "user_location": user.get("location")} + if user + else {} + ), + ) + return template + + +def query_generation_template( template: str, messages: list[dict], user: Optional[dict] = None ) -> str: prompt = get_last_user_message(messages) diff --git a/backend/open_webui/utils/tools.py b/backend/open_webui/utils/tools.py index 0b57eb35b..b6e13011d 100644 --- a/backend/open_webui/utils/tools.py +++ b/backend/open_webui/utils/tools.py @@ -1,11 +1,18 @@ import inspect import logging -from typing import Awaitable, Callable, get_type_hints +import re +from typing import Any, Awaitable, Callable, get_type_hints +from functools import update_wrapper, partial -from open_webui.apps.webui.models.tools import Tools -from open_webui.apps.webui.models.users import UserModel -from open_webui.apps.webui.utils import load_toolkit_module_by_id -from open_webui.utils.schemas import json_schema_to_model + +from fastapi import Request +from pydantic import BaseModel, Field, create_model +from langchain_core.utils.function_calling import convert_to_openai_function + + +from open_webui.models.tools import Tools +from open_webui.models.users import UserModel +from open_webui.utils.plugin import load_tools_module_by_id log = logging.getLogger(__name__) @@ -14,34 +21,34 @@ def apply_extra_params_to_tool_function( function: Callable, extra_params: dict ) -> Callable[..., Awaitable]: sig = inspect.signature(function) - extra_params = { - key: value for key, value in extra_params.items() if key in sig.parameters - } - is_coroutine = inspect.iscoroutinefunction(function) + extra_params = {k: v for k, v in extra_params.items() if k in sig.parameters} + partial_func = partial(function, **extra_params) + if inspect.iscoroutinefunction(function): + update_wrapper(partial_func, function) + return partial_func - async def new_function(**kwargs): - extra_kwargs = kwargs | extra_params - if is_coroutine: - return await function(**extra_kwargs) - return function(**extra_kwargs) + async def new_function(*args, **kwargs): + return partial_func(*args, **kwargs) + update_wrapper(new_function, function) return new_function # Mutation on extra_params def get_tools( - webui_app, tool_ids: list[str], user: UserModel, extra_params: dict + request: Request, tool_ids: list[str], user: UserModel, extra_params: dict ) -> dict[str, dict]: - tools = {} + tools_dict = {} + for tool_id in tool_ids: - toolkit = Tools.get_tool_by_id(tool_id) - if toolkit is None: + tools = Tools.get_tool_by_id(tool_id) + if tools is None: continue - module = webui_app.state.TOOLS.get(tool_id, None) + module = request.app.state.TOOLS.get(tool_id, None) if module is None: - module, _ = load_toolkit_module_by_id(tool_id) - webui_app.state.TOOLS[tool_id] = module + module, _ = load_tools_module_by_id(tool_id) + request.app.state.TOOLS[tool_id] = module extra_params["__id__"] = tool_id if hasattr(module, "valves") and hasattr(module, "Valves"): @@ -53,111 +60,143 @@ def get_tools( **Tools.get_user_valves_by_id_and_user_id(tool_id, user.id) ) - for spec in toolkit.specs: - # TODO: Fix hack for OpenAI API - for val in spec.get("parameters", {}).get("properties", {}).values(): - if val["type"] == "str": - val["type"] = "string" + for spec in tools.specs: + # Remove internal parameters + spec["parameters"]["properties"] = { + key: val + for key, val in spec["parameters"]["properties"].items() + if not key.startswith("__") + } + function_name = spec["name"] # convert to function that takes only model params and inserts custom params original_func = getattr(module, function_name) callable = apply_extra_params_to_tool_function(original_func, extra_params) - if hasattr(original_func, "__doc__"): - callable.__doc__ = original_func.__doc__ - # TODO: This needs to be a pydantic model tool_dict = { "toolkit_id": tool_id, "callable": callable, "spec": spec, - "pydantic_model": json_schema_to_model(spec), + "pydantic_model": function_to_pydantic_model(callable), "file_handler": hasattr(module, "file_handler") and module.file_handler, "citation": hasattr(module, "citation") and module.citation, } # TODO: if collision, prepend toolkit name - if function_name in tools: - log.warning(f"Tool {function_name} already exists in another toolkit!") - log.warning(f"Collision between {toolkit} and {tool_id}.") - log.warning(f"Discarding {toolkit}.{function_name}") + if function_name in tools_dict: + log.warning(f"Tool {function_name} already exists in another tools!") + log.warning(f"Collision between {tools} and {tool_id}.") + log.warning(f"Discarding {tools}.{function_name}") else: - tools[function_name] = tool_dict - return tools + tools_dict[function_name] = tool_dict + + return tools_dict -def doc_to_dict(docstring): - lines = docstring.split("\n") - description = lines[1].strip() - param_dict = {} +def parse_description(docstring: str | None) -> str: + """ + Parse a function's docstring to extract the description. + + Args: + docstring (str): The docstring to parse. + + Returns: + str: The description. + """ + + if not docstring: + return "" + + lines = [line.strip() for line in docstring.strip().split("\n")] + description_lines: list[str] = [] for line in lines: - if ":param" in line: - line = line.replace(":param", "").strip() - param, desc = line.split(":", 1) - param_dict[param.strip()] = desc.strip() - ret_dict = {"description": description, "params": param_dict} - return ret_dict + if re.match(r":param", line) or re.match(r":return", line): + break + + description_lines.append(line) + + return "\n".join(description_lines) -def get_tools_specs(tools) -> list[dict]: - function_list = [ - {"name": func, "function": getattr(tools, func)} - for func in dir(tools) - if callable(getattr(tools, func)) +def parse_docstring(docstring): + """ + Parse a function's docstring to extract parameter descriptions in reST format. + + Args: + docstring (str): The docstring to parse. + + Returns: + dict: A dictionary where keys are parameter names and values are descriptions. + """ + if not docstring: + return {} + + # Regex to match `:param name: description` format + param_pattern = re.compile(r":param (\w+):\s*(.+)") + param_descriptions = {} + + for line in docstring.splitlines(): + match = param_pattern.match(line.strip()) + if not match: + continue + param_name, param_description = match.groups() + if param_name.startswith("__"): + continue + param_descriptions[param_name] = param_description + + return param_descriptions + + +def function_to_pydantic_model(func: Callable) -> type[BaseModel]: + """ + Converts a Python function's type hints and docstring to a Pydantic model, + including support for nested types, default values, and descriptions. + + Args: + func: The function whose type hints and docstring should be converted. + model_name: The name of the generated Pydantic model. + + Returns: + A Pydantic model class. + """ + type_hints = get_type_hints(func) + signature = inspect.signature(func) + parameters = signature.parameters + + docstring = func.__doc__ + descriptions = parse_docstring(docstring) + + tool_description = parse_description(docstring) + + field_defs = {} + for name, param in parameters.items(): + type_hint = type_hints.get(name, Any) + default_value = param.default if param.default is not param.empty else ... + description = descriptions.get(name, None) + if not description: + field_defs[name] = type_hint, default_value + continue + field_defs[name] = type_hint, Field(default_value, description=description) + + model = create_model(func.__name__, **field_defs) + model.__doc__ = tool_description + + return model + + +def get_callable_attributes(tool: object) -> list[Callable]: + return [ + getattr(tool, func) + for func in dir(tool) + if callable(getattr(tool, func)) and not func.startswith("__") - and not inspect.isclass(getattr(tools, func)) + and not inspect.isclass(getattr(tool, func)) ] - specs = [] - for function_item in function_list: - function_name = function_item["name"] - function = function_item["function"] - function_doc = doc_to_dict(function.__doc__ or function_name) - specs.append( - { - "name": function_name, - # TODO: multi-line desc? - "description": function_doc.get("description", function_name), - "parameters": { - "type": "object", - "properties": { - param_name: { - "type": param_annotation.__name__.lower(), - **( - { - "enum": ( - str(param_annotation.__args__) - if hasattr(param_annotation, "__args__") - else None - ) - } - if hasattr(param_annotation, "__args__") - else {} - ), - "description": function_doc.get("params", {}).get( - param_name, param_name - ), - } - for param_name, param_annotation in get_type_hints( - function - ).items() - if param_name != "return" - and not ( - param_name.startswith("__") and param_name.endswith("__") - ) - }, - "required": [ - name - for name, param in inspect.signature( - function - ).parameters.items() - if param.default is param.empty - and not (name.startswith("__") and name.endswith("__")) - ], - }, - } - ) - - return specs +def get_tools_specs(tool_class: object) -> list[dict]: + function_list = get_callable_attributes(tool_class) + models = map(function_to_pydantic_model, function_list) + return [convert_to_openai_function(tool) for tool in models] diff --git a/backend/requirements.txt b/backend/requirements.txt index 561f291cb..79e898c6a 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -1,7 +1,7 @@ fastapi==0.111.0 uvicorn[standard]==0.30.6 pydantic==2.9.2 -python-multipart==0.0.9 +python-multipart==0.0.18 Flask==3.0.3 Flask-Cors==5.0.0 @@ -11,20 +11,23 @@ python-jose==3.3.0 passlib[bcrypt]==1.7.4 requests==2.32.3 -aiohttp==3.10.8 +aiohttp==3.11.8 async-timeout +aiocache +aiofiles sqlalchemy==2.0.32 -alembic==1.13.2 +alembic==1.14.0 peewee==3.17.6 peewee-migrate==1.12.2 psycopg2-binary==2.9.9 +pgvector==0.3.5 PyMySQL==1.1.1 bcrypt==4.2.0 pymongo redis -boto3==1.35.0 +boto3==1.35.53 argon2-cffi==23.1.0 APScheduler==3.10.4 @@ -35,23 +38,24 @@ anthropic google-generativeai==0.7.2 tiktoken -langchain==0.2.15 -langchain-community==0.2.12 +langchain==0.3.7 +langchain-community==0.3.7 langchain-chroma==0.1.4 fake-useragent==1.5.1 -chromadb==0.5.9 -pymilvus==2.4.7 +chromadb==0.5.15 +pymilvus==2.5.0 qdrant-client~=1.12.0 +opensearch-py==2.7.1 -sentence-transformers==3.2.0 +sentence-transformers==3.3.1 colbert-ai==0.2.21 einops==0.8.0 ftfy==6.2.3 pypdf==4.3.1 -xhtml2pdf==0.2.16 +fpdf2==2.7.9 pymdown-extensions==10.11.2 docx2txt==0.8 python-pptx==1.0.0 @@ -65,11 +69,11 @@ pyxlsb==1.0.10 xlrd==2.0.1 validators==0.33.0 psutil +sentencepiece +soundfile==0.12.1 opencv-python-headless==4.10.0.84 rapidocr-onnxruntime==1.3.24 - -fpdf2==2.7.9 rank-bm25==0.2.2 faster-whisper==1.0.3 @@ -79,12 +83,12 @@ authlib==1.3.2 black==24.8.0 langfuse==2.44.0 -youtube-transcript-api==0.6.2 +youtube-transcript-api==0.6.3 pytube==15.0.0 extract_msg pydub -duckduckgo-search~=6.2.13 +duckduckgo-search~=6.3.5 ## Tests docker~=7.1.0 @@ -92,3 +96,6 @@ pytest~=8.3.2 pytest-docker~=3.1.1 googleapis-common-protos==1.63.2 + +## LDAP +ldap3==2.9.1 diff --git a/package-lock.json b/package-lock.json index 6cf21ae1a..16542ed99 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,24 +1,33 @@ { "name": "open-webui", - "version": "0.3.35", + "version": "0.4.8", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "open-webui", - "version": "0.3.35", + "version": "0.4.8", "dependencies": { "@codemirror/lang-javascript": "^6.2.2", "@codemirror/lang-python": "^6.1.6", "@codemirror/language-data": "^6.5.1", "@codemirror/theme-one-dark": "^6.1.2", "@huggingface/transformers": "^3.0.0", + "@mediapipe/tasks-vision": "^0.10.17", "@pyscript/core": "^0.4.32", "@sveltejs/adapter-node": "^2.0.0", + "@tiptap/core": "^2.10.0", + "@tiptap/extension-code-block-lowlight": "^2.10.0", + "@tiptap/extension-highlight": "^2.10.0", + "@tiptap/extension-placeholder": "^2.10.0", + "@tiptap/extension-typography": "^2.10.0", + "@tiptap/pm": "^2.10.0", + "@tiptap/starter-kit": "^2.10.0", "@xyflow/svelte": "^0.1.19", "async": "^3.2.5", "bits-ui": "^0.19.7", "codemirror": "^6.0.1", + "codemirror-lang-hcl": "^0.0.0-beta.2", "crc-32": "^1.2.2", "dayjs": "^1.11.10", "dompurify": "^3.1.6", @@ -72,6 +81,7 @@ "postcss": "^8.4.31", "prettier": "^3.3.3", "prettier-plugin-svelte": "^3.2.6", + "sass-embedded": "^1.81.0", "svelte": "^4.2.18", "svelte-check": "^3.8.5", "svelte-confetti": "^1.3.2", @@ -135,6 +145,13 @@ "resolved": "https://registry.npmjs.org/@braintree/sanitize-url/-/sanitize-url-6.0.4.tgz", "integrity": "sha512-s3jaWicZd0pkP0jf5ysyHUI/RE7MHos6qlToFcGWXVp+ykHOy77OUMrfbgJ9it2C5bow7OIQwYYaHjk9XlBQ2A==" }, + "node_modules/@bufbuild/protobuf": { + "version": "2.2.2", + "resolved": "https://registry.npmjs.org/@bufbuild/protobuf/-/protobuf-2.2.2.tgz", + "integrity": "sha512-UNtPCbrwrenpmrXuRwn9jYpPoweNXj8X5sMvYgsqYyaH8jQ6LfUJSk3dJLnBK+6sfYPrF4iAIo5sd5HQ+tg75A==", + "devOptional": true, + "license": "(Apache-2.0 AND BSD-3-Clause)" + }, "node_modules/@codemirror/autocomplete": { "version": "6.16.2", "resolved": "https://registry.npmjs.org/@codemirror/autocomplete/-/autocomplete-6.16.2.tgz", @@ -1749,6 +1766,11 @@ "@lezer/lr": "^1.4.0" } }, + "node_modules/@mediapipe/tasks-vision": { + "version": "0.10.17", + "resolved": "https://registry.npmjs.org/@mediapipe/tasks-vision/-/tasks-vision-0.10.17.tgz", + "integrity": "sha512-CZWV/q6TTe8ta61cZXjfnnHsfWIdFhms03M9T7Cnd5y2mdpylJM0rF1qRq+wsQVRMLz1OYPVEBU9ph2Bx8cxrg==" + }, "node_modules/@melt-ui/svelte": { "version": "0.76.0", "resolved": "https://registry.npmjs.org/@melt-ui/svelte/-/svelte-0.76.0.tgz", @@ -1815,9 +1837,10 @@ } }, "node_modules/@polka/url": { - "version": "1.0.0-next.25", - "resolved": "https://registry.npmjs.org/@polka/url/-/url-1.0.0-next.25.tgz", - "integrity": "sha512-j7P6Rgr3mmtdkeDGTe0E/aYyWEWVtc5yFXtHCRHs28/jptDEWfaVOc5T7cblqy1XKPPfCxJc/8DwQ5YgLOZOVQ==" + "version": "1.0.0-next.28", + "resolved": "https://registry.npmjs.org/@polka/url/-/url-1.0.0-next.28.tgz", + "integrity": "sha512-8LduaNlMZGwdZ6qWrKlfa+2M4gahzFkprZiAt2TF8uS0qQgBizKXpXURqvTJ4WtmupWxaLqjRb2UCTe72mu+Aw==", + "license": "MIT" }, "node_modules/@popperjs/core": { "version": "2.11.8", @@ -1895,6 +1918,12 @@ "type-checked-collections": "^0.1.7" } }, + "node_modules/@remirror/core-constants": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/@remirror/core-constants/-/core-constants-3.0.0.tgz", + "integrity": "sha512-42aWfPrimMfDKDi4YegyS7x+/0tlzaqwPQCULLanv3DMIlu96KTJR0fM5isWX2UViOqlGnX6YFgqWepcX+XMNg==", + "license": "MIT" + }, "node_modules/@rollup/plugin-commonjs": { "version": "25.0.7", "resolved": "https://registry.npmjs.org/@rollup/plugin-commonjs/-/plugin-commonjs-25.0.7.tgz", @@ -2221,31 +2250,33 @@ } }, "node_modules/@sveltejs/adapter-static": { - "version": "3.0.2", - "resolved": "https://registry.npmjs.org/@sveltejs/adapter-static/-/adapter-static-3.0.2.tgz", - "integrity": "sha512-/EBFydZDwfwFfFEuF1vzUseBoRziwKP7AoHAwv+Ot3M084sE/HTVBHf9mCmXfdM9ijprY5YEugZjleflncX5fQ==", + "version": "3.0.6", + "resolved": "https://registry.npmjs.org/@sveltejs/adapter-static/-/adapter-static-3.0.6.tgz", + "integrity": "sha512-MGJcesnJWj7FxDcB/GbrdYD3q24Uk0PIL4QIX149ku+hlJuj//nxUbb0HxUTpjkecWfHjVveSUnUaQWnPRXlpg==", "dev": true, + "license": "MIT", "peerDependencies": { "@sveltejs/kit": "^2.0.0" } }, "node_modules/@sveltejs/kit": { - "version": "2.6.2", - "resolved": "https://registry.npmjs.org/@sveltejs/kit/-/kit-2.6.2.tgz", - "integrity": "sha512-ruogrSPXjckn5poUiZU8VYNCSPHq66SFR1AATvOikQxtP6LNI4niAZVX/AWZRe/EPDG3oY2DNJ9c5z7u0t2NAQ==", + "version": "2.9.0", + "resolved": "https://registry.npmjs.org/@sveltejs/kit/-/kit-2.9.0.tgz", + "integrity": "sha512-W3E7ed3ChB6kPqRs2H7tcHp+Z7oiTFC6m+lLyAQQuyXeqw6LdNuuwEUla+5VM0OGgqQD+cYD6+7Xq80vVm17Vg==", "hasInstallScript": true, + "license": "MIT", "dependencies": { "@types/cookie": "^0.6.0", - "cookie": "^0.7.0", + "cookie": "^0.6.0", "devalue": "^5.1.0", - "esm-env": "^1.0.0", + "esm-env": "^1.2.1", "import-meta-resolve": "^4.1.0", "kleur": "^4.1.5", "magic-string": "^0.30.5", "mrmime": "^2.0.0", "sade": "^1.8.1", "set-cookie-parser": "^2.6.0", - "sirv": "^2.0.4", + "sirv": "^3.0.0", "tiny-glob": "^0.2.9" }, "bin": { @@ -2255,9 +2286,9 @@ "node": ">=18.13" }, "peerDependencies": { - "@sveltejs/vite-plugin-svelte": "^3.0.0 || ^4.0.0-next.1", + "@sveltejs/vite-plugin-svelte": "^3.0.0 || ^4.0.0-next.1 || ^5.0.0", "svelte": "^4.0.0 || ^5.0.0-next.0", - "vite": "^5.0.3" + "vite": "^5.0.3 || ^6.0.0" } }, "node_modules/@sveltejs/vite-plugin-svelte": { @@ -2320,6 +2351,391 @@ "tailwindcss": ">=3.0.0 || insiders" } }, + "node_modules/@tiptap/core": { + "version": "2.10.0", + "resolved": "https://registry.npmjs.org/@tiptap/core/-/core-2.10.0.tgz", + "integrity": "sha512-58nAjPxLRFcXepdDqQRC1mhrw6E8Sanqr6bbO4Tz0+FWgDJMZvHG+dOK5wHaDVNSgK2iJDz08ETvQayfOOgDvg==", + "license": "MIT", + "funding": { + "type": "github", + "url": "https://github.com/sponsors/ueberdosis" + }, + "peerDependencies": { + "@tiptap/pm": "^2.7.0" + } + }, + "node_modules/@tiptap/extension-blockquote": { + "version": "2.10.0", + "resolved": "https://registry.npmjs.org/@tiptap/extension-blockquote/-/extension-blockquote-2.10.0.tgz", + "integrity": "sha512-6Xmfo2lpfIRcbfkLD/NGX4YgQqfgAbu6XaZQZf5oGtHLPTrz4D7Mw20GgNBHzae2XwUCwLMt6zXOkBgU/LnlZg==", + "license": "MIT", + "funding": { + "type": "github", + "url": "https://github.com/sponsors/ueberdosis" + }, + "peerDependencies": { + "@tiptap/core": "^2.7.0" + } + }, + "node_modules/@tiptap/extension-bold": { + "version": "2.10.0", + "resolved": "https://registry.npmjs.org/@tiptap/extension-bold/-/extension-bold-2.10.0.tgz", + "integrity": "sha512-1wL8UI1Aii0u2cbDEvwyqsZb2pgBt8HLJdsIax/ELoF2tKCD5821nElqTGLBBg4pUGPa0ru9ZemuL8GdXZp3Qg==", + "license": "MIT", + "funding": { + "type": "github", + "url": "https://github.com/sponsors/ueberdosis" + }, + "peerDependencies": { + "@tiptap/core": "^2.7.0" + } + }, + "node_modules/@tiptap/extension-bullet-list": { + "version": "2.10.0", + "resolved": "https://registry.npmjs.org/@tiptap/extension-bullet-list/-/extension-bullet-list-2.10.0.tgz", + "integrity": "sha512-Cl+DGu6D3SgF/hlKUDNet3gaZFy6cPEonOOkHwzXoybDXXdddFbaTvt9MLkBRUR3ldksXuVRP2/LwZsK5WyxJQ==", + "license": "MIT", + "funding": { + "type": "github", + "url": "https://github.com/sponsors/ueberdosis" + }, + "peerDependencies": { + "@tiptap/core": "^2.7.0" + } + }, + "node_modules/@tiptap/extension-code": { + "version": "2.10.0", + "resolved": "https://registry.npmjs.org/@tiptap/extension-code/-/extension-code-2.10.0.tgz", + "integrity": "sha512-8JznKG1Jmv8gJezZGPoka8oRmfrcAAnMEOeMpKXjwMrIbQ6QynTZpqMGGVL1kfkZlLV84PYm+CGjGgjSsT4iZw==", + "license": "MIT", + "funding": { + "type": "github", + "url": "https://github.com/sponsors/ueberdosis" + }, + "peerDependencies": { + "@tiptap/core": "^2.7.0" + } + }, + "node_modules/@tiptap/extension-code-block": { + "version": "2.10.0", + "resolved": "https://registry.npmjs.org/@tiptap/extension-code-block/-/extension-code-block-2.10.0.tgz", + "integrity": "sha512-QH+LP7L1s1EJlrDFnfgOP0q+Siqt0Zbkx4ICMcUGvEsycl53Ti8P0DRW7fAjRISdTCItuWJYvtmiYY7O3rYb+Q==", + "license": "MIT", + "funding": { + "type": "github", + "url": "https://github.com/sponsors/ueberdosis" + }, + "peerDependencies": { + "@tiptap/core": "^2.7.0", + "@tiptap/pm": "^2.7.0" + } + }, + "node_modules/@tiptap/extension-code-block-lowlight": { + "version": "2.10.0", + "resolved": "https://registry.npmjs.org/@tiptap/extension-code-block-lowlight/-/extension-code-block-lowlight-2.10.0.tgz", + "integrity": "sha512-dAv03XIHT5h+sdFmJzvx2FfpfFOOK9SBKHflRUdqTa8eA+0VZNAcPRjvJWVEWqts1fKZDJj774mO28NlhFzk9Q==", + "license": "MIT", + "funding": { + "type": "github", + "url": "https://github.com/sponsors/ueberdosis" + }, + "peerDependencies": { + "@tiptap/core": "^2.7.0", + "@tiptap/extension-code-block": "^2.7.0", + "@tiptap/pm": "^2.7.0", + "highlight.js": "^11", + "lowlight": "^2 || ^3" + } + }, + "node_modules/@tiptap/extension-document": { + "version": "2.10.0", + "resolved": "https://registry.npmjs.org/@tiptap/extension-document/-/extension-document-2.10.0.tgz", + "integrity": "sha512-vseMW3EKiQAPgdbN48Y8F0nRqWhhrAo9DLacAfP7tu0x3uv44uotNjDBtAgp5QmJmqQVyrEdkLSZaU5vFzduhQ==", + "license": "MIT", + "funding": { + "type": "github", + "url": "https://github.com/sponsors/ueberdosis" + }, + "peerDependencies": { + "@tiptap/core": "^2.7.0" + } + }, + "node_modules/@tiptap/extension-dropcursor": { + "version": "2.10.0", + "resolved": "https://registry.npmjs.org/@tiptap/extension-dropcursor/-/extension-dropcursor-2.10.0.tgz", + "integrity": "sha512-tifxp/a3NxTjLAuYBx9XAwVo4MSDoY/mQ8E18QtuXj0vuieCFxd8Bkyre0otubIAAQePXLTVGQoxPrKmMAa+Jg==", + "license": "MIT", + "funding": { + "type": "github", + "url": "https://github.com/sponsors/ueberdosis" + }, + "peerDependencies": { + "@tiptap/core": "^2.7.0", + "@tiptap/pm": "^2.7.0" + } + }, + "node_modules/@tiptap/extension-gapcursor": { + "version": "2.10.0", + "resolved": "https://registry.npmjs.org/@tiptap/extension-gapcursor/-/extension-gapcursor-2.10.0.tgz", + "integrity": "sha512-GViEnSnEBE74k7SYdXrQ4aXlKmWkrd9awdj/TgDSORgpZ4Dfyqtn+ENIWWby4NhL+BPM9P5hGCjkQXZsi6JKOw==", + "license": "MIT", + "funding": { + "type": "github", + "url": "https://github.com/sponsors/ueberdosis" + }, + "peerDependencies": { + "@tiptap/core": "^2.7.0", + "@tiptap/pm": "^2.7.0" + } + }, + "node_modules/@tiptap/extension-hard-break": { + "version": "2.10.0", + "resolved": "https://registry.npmjs.org/@tiptap/extension-hard-break/-/extension-hard-break-2.10.0.tgz", + "integrity": "sha512-NL/xPYUhhvQyCnOO5Yn+BlBOMLC1ru32nw7ox12TShGmaeKBrnV0DhzBRkyJU0MqCS26oWjieNPxfu0lR3oMSA==", + "license": "MIT", + "funding": { + "type": "github", + "url": "https://github.com/sponsors/ueberdosis" + }, + "peerDependencies": { + "@tiptap/core": "^2.7.0" + } + }, + "node_modules/@tiptap/extension-heading": { + "version": "2.10.0", + "resolved": "https://registry.npmjs.org/@tiptap/extension-heading/-/extension-heading-2.10.0.tgz", + "integrity": "sha512-x2Uj5wrAHFaUdlChwLoQVmWtzZCuNyJpBRA19kA4idWL5z+6cIrUWepvwVBxA8ou6ictbzWW15o+blKtW7DlqA==", + "license": "MIT", + "funding": { + "type": "github", + "url": "https://github.com/sponsors/ueberdosis" + }, + "peerDependencies": { + "@tiptap/core": "^2.7.0" + } + }, + "node_modules/@tiptap/extension-highlight": { + "version": "2.10.0", + "resolved": "https://registry.npmjs.org/@tiptap/extension-highlight/-/extension-highlight-2.10.0.tgz", + "integrity": "sha512-HU8UuKU7ljlzNn7jg29pM8QtIX7QvePcBjcWAt6K3qVwF1cbBNguIjKRY2rmoonU2nu8I6GknQNgV847kZifCQ==", + "license": "MIT", + "funding": { + "type": "github", + "url": "https://github.com/sponsors/ueberdosis" + }, + "peerDependencies": { + "@tiptap/core": "^2.7.0" + } + }, + "node_modules/@tiptap/extension-history": { + "version": "2.10.0", + "resolved": "https://registry.npmjs.org/@tiptap/extension-history/-/extension-history-2.10.0.tgz", + "integrity": "sha512-5aYOmxqaCnw7e7wmWqFZmkpYCxxDjEzFbgVI6WknqNwqeOizR4+YJf3aAt/lTbksLJe47XF+NBX51gOm/ZBCiw==", + "license": "MIT", + "funding": { + "type": "github", + "url": "https://github.com/sponsors/ueberdosis" + }, + "peerDependencies": { + "@tiptap/core": "^2.7.0", + "@tiptap/pm": "^2.7.0" + } + }, + "node_modules/@tiptap/extension-horizontal-rule": { + "version": "2.10.0", + "resolved": "https://registry.npmjs.org/@tiptap/extension-horizontal-rule/-/extension-horizontal-rule-2.10.0.tgz", + "integrity": "sha512-el1SzI/x/h4HW8UltxJlyMSrRsO55ypKPLQHJC9h7F6kTTR31fJUzQa3AeTFrZvXS0kNHIFRpAMstw+N0L5TYg==", + "license": "MIT", + "funding": { + "type": "github", + "url": "https://github.com/sponsors/ueberdosis" + }, + "peerDependencies": { + "@tiptap/core": "^2.7.0", + "@tiptap/pm": "^2.7.0" + } + }, + "node_modules/@tiptap/extension-italic": { + "version": "2.10.0", + "resolved": "https://registry.npmjs.org/@tiptap/extension-italic/-/extension-italic-2.10.0.tgz", + "integrity": "sha512-MqPYbHAEeO8QBvZRIkF4J2OTf/uiUPzUiXGLJ50w1ozfMBIw1txMvfR3g2cpwfvZlcOgYTgy7M0Oq00nQz5eXg==", + "license": "MIT", + "funding": { + "type": "github", + "url": "https://github.com/sponsors/ueberdosis" + }, + "peerDependencies": { + "@tiptap/core": "^2.7.0" + } + }, + "node_modules/@tiptap/extension-list-item": { + "version": "2.10.0", + "resolved": "https://registry.npmjs.org/@tiptap/extension-list-item/-/extension-list-item-2.10.0.tgz", + "integrity": "sha512-BxC6NNHd2xcC+mk5hpYWURUdj/mRz6TGFwH5CsyrUXPxApx0+V+EPHaAgdpu8dr+jtTEzjXF62V6e2JmOAPimg==", + "license": "MIT", + "funding": { + "type": "github", + "url": "https://github.com/sponsors/ueberdosis" + }, + "peerDependencies": { + "@tiptap/core": "^2.7.0" + } + }, + "node_modules/@tiptap/extension-ordered-list": { + "version": "2.10.0", + "resolved": "https://registry.npmjs.org/@tiptap/extension-ordered-list/-/extension-ordered-list-2.10.0.tgz", + "integrity": "sha512-jsK+mvzs7HmxQuQOU3HgIga+v7zUbQlmSP4/danusqUihJ+lc1n0frDCIkVvJrnSB3FChvNgT6ZEA14HOhdJzg==", + "license": "MIT", + "funding": { + "type": "github", + "url": "https://github.com/sponsors/ueberdosis" + }, + "peerDependencies": { + "@tiptap/core": "^2.7.0" + } + }, + "node_modules/@tiptap/extension-paragraph": { + "version": "2.10.0", + "resolved": "https://registry.npmjs.org/@tiptap/extension-paragraph/-/extension-paragraph-2.10.0.tgz", + "integrity": "sha512-4LUkVaJYjNdNZ7QOX6TRcA+m7oCtyrLGk49G22wl7XcPBkQPILP1mCUCU4f41bhjfhCgK5PPWP63kMtD+cEACg==", + "license": "MIT", + "funding": { + "type": "github", + "url": "https://github.com/sponsors/ueberdosis" + }, + "peerDependencies": { + "@tiptap/core": "^2.7.0" + } + }, + "node_modules/@tiptap/extension-placeholder": { + "version": "2.10.0", + "resolved": "https://registry.npmjs.org/@tiptap/extension-placeholder/-/extension-placeholder-2.10.0.tgz", + "integrity": "sha512-1o6azk2plgYAFgMrV3prnBb1NZjl2V1T3wwnH4n3/h9z9lJ0v5BBAk9r+TRYSrcdXknwwHAWFYnQe6dc9buG2g==", + "license": "MIT", + "funding": { + "type": "github", + "url": "https://github.com/sponsors/ueberdosis" + }, + "peerDependencies": { + "@tiptap/core": "^2.7.0", + "@tiptap/pm": "^2.7.0" + } + }, + "node_modules/@tiptap/extension-strike": { + "version": "2.10.0", + "resolved": "https://registry.npmjs.org/@tiptap/extension-strike/-/extension-strike-2.10.0.tgz", + "integrity": "sha512-SxApLJMQkxnmPGR3lwaskvLK61yI+Bu9hGZGdwMZqNh6o3LoDOxDaXjHD5joeMYQiqQrBE9zg46506MsXtrU7Q==", + "license": "MIT", + "funding": { + "type": "github", + "url": "https://github.com/sponsors/ueberdosis" + }, + "peerDependencies": { + "@tiptap/core": "^2.7.0" + } + }, + "node_modules/@tiptap/extension-text": { + "version": "2.10.0", + "resolved": "https://registry.npmjs.org/@tiptap/extension-text/-/extension-text-2.10.0.tgz", + "integrity": "sha512-SSnNncADS1KucdEcJlF6WGCs5+1pAhPrD68vlw34oj3NDT3Zh05KiyXsCV3Nw4wpHOnbWahV+z3uT2SnR+xgoQ==", + "license": "MIT", + "funding": { + "type": "github", + "url": "https://github.com/sponsors/ueberdosis" + }, + "peerDependencies": { + "@tiptap/core": "^2.7.0" + } + }, + "node_modules/@tiptap/extension-text-style": { + "version": "2.10.0", + "resolved": "https://registry.npmjs.org/@tiptap/extension-text-style/-/extension-text-style-2.10.0.tgz", + "integrity": "sha512-VZtH1dp64wg1UcFtUPpRQK+kOm4JHBIv+WXuKX7EnpIEKjHKnyfV94BBVmaqY5UE4n3kbkkmIRB2Cmix/10AMg==", + "license": "MIT", + "funding": { + "type": "github", + "url": "https://github.com/sponsors/ueberdosis" + }, + "peerDependencies": { + "@tiptap/core": "^2.7.0" + } + }, + "node_modules/@tiptap/extension-typography": { + "version": "2.10.0", + "resolved": "https://registry.npmjs.org/@tiptap/extension-typography/-/extension-typography-2.10.0.tgz", + "integrity": "sha512-03IOfJm4bk2hZ4SsSfxgBOVzcDxMRBlFD7ZY12H2EGNf1TKxj/0ANWhAH54FtquuOMoY5aWg5LZf0lk++8UDAw==", + "license": "MIT", + "funding": { + "type": "github", + "url": "https://github.com/sponsors/ueberdosis" + }, + "peerDependencies": { + "@tiptap/core": "^2.7.0" + } + }, + "node_modules/@tiptap/pm": { + "version": "2.10.0", + "resolved": "https://registry.npmjs.org/@tiptap/pm/-/pm-2.10.0.tgz", + "integrity": "sha512-ohshlWf4MlW6D3rQkNQnhmiQ2w4pwRoQcJmTPt8UJoIDGkeKmZh494fQp4Aeh80XuGd81SsCv//1HJeyaeHJYQ==", + "license": "MIT", + "dependencies": { + "prosemirror-changeset": "^2.2.1", + "prosemirror-collab": "^1.3.1", + "prosemirror-commands": "^1.6.2", + "prosemirror-dropcursor": "^1.8.1", + "prosemirror-gapcursor": "^1.3.2", + "prosemirror-history": "^1.4.1", + "prosemirror-inputrules": "^1.4.0", + "prosemirror-keymap": "^1.2.2", + "prosemirror-markdown": "^1.13.1", + "prosemirror-menu": "^1.2.4", + "prosemirror-model": "^1.23.0", + "prosemirror-schema-basic": "^1.2.3", + "prosemirror-schema-list": "^1.4.1", + "prosemirror-state": "^1.4.3", + "prosemirror-tables": "^1.6.1", + "prosemirror-trailing-node": "^3.0.0", + "prosemirror-transform": "^1.10.2", + "prosemirror-view": "^1.36.0" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/ueberdosis" + } + }, + "node_modules/@tiptap/starter-kit": { + "version": "2.10.0", + "resolved": "https://registry.npmjs.org/@tiptap/starter-kit/-/starter-kit-2.10.0.tgz", + "integrity": "sha512-hMIM9a6HjYZo25EzhZHlKEIR7CFi0grRSOltEyggiyBuQqKFkI7iwCpZVVtviDV1FwV0EPANpIAxPS7aBRgFdg==", + "license": "MIT", + "dependencies": { + "@tiptap/core": "^2.10.0", + "@tiptap/extension-blockquote": "^2.10.0", + "@tiptap/extension-bold": "^2.10.0", + "@tiptap/extension-bullet-list": "^2.10.0", + "@tiptap/extension-code": "^2.10.0", + "@tiptap/extension-code-block": "^2.10.0", + "@tiptap/extension-document": "^2.10.0", + "@tiptap/extension-dropcursor": "^2.10.0", + "@tiptap/extension-gapcursor": "^2.10.0", + "@tiptap/extension-hard-break": "^2.10.0", + "@tiptap/extension-heading": "^2.10.0", + "@tiptap/extension-history": "^2.10.0", + "@tiptap/extension-horizontal-rule": "^2.10.0", + "@tiptap/extension-italic": "^2.10.0", + "@tiptap/extension-list-item": "^2.10.0", + "@tiptap/extension-ordered-list": "^2.10.0", + "@tiptap/extension-paragraph": "^2.10.0", + "@tiptap/extension-strike": "^2.10.0", + "@tiptap/extension-text": "^2.10.0", + "@tiptap/extension-text-style": "^2.10.0", + "@tiptap/pm": "^2.10.0" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/ueberdosis" + } + }, "node_modules/@types/cookie": { "version": "0.6.0", "resolved": "https://registry.npmjs.org/@types/cookie/-/cookie-0.6.0.tgz", @@ -2399,6 +2815,16 @@ "resolved": "https://registry.npmjs.org/@types/estree/-/estree-1.0.5.tgz", "integrity": "sha512-/kYRxGDLWzHOB7q+wtSUQlFrtcdUccpfy+X+9iMBpHK8QLLhx2wIPYuS5DYtR9Wa/YlZAbIovy7qVdB1Aq6Lyw==" }, + "node_modules/@types/hast": { + "version": "3.0.4", + "resolved": "https://registry.npmjs.org/@types/hast/-/hast-3.0.4.tgz", + "integrity": "sha512-WPs+bbQw5aCj+x6laNGWLH3wviHtoCv/P3+otBhbOhJgG8qtpdAMlTCxLtsTWA7LH1Oh/bFCHsBn0TPS5m30EQ==", + "license": "MIT", + "peer": true, + "dependencies": { + "@types/unist": "*" + } + }, "node_modules/@types/json-schema": { "version": "7.0.15", "resolved": "https://registry.npmjs.org/@types/json-schema/-/json-schema-7.0.15.tgz", @@ -3419,6 +3845,13 @@ "ieee754": "^1.2.1" } }, + "node_modules/buffer-builder": { + "version": "0.2.0", + "resolved": "https://registry.npmjs.org/buffer-builder/-/buffer-builder-0.2.0.tgz", + "integrity": "sha512-7VPMEPuYznPSoR21NE1zvd2Xna6c/CloiZCfcMXR1Jny6PjX0N4Nsa38zcBFo/FMK+BlA+FLKbJCQ0i2yxp+Xg==", + "devOptional": true, + "license": "MIT/X11" + }, "node_modules/buffer-crc32": { "version": "0.2.13", "resolved": "https://registry.npmjs.org/buffer-crc32/-/buffer-crc32-0.2.13.tgz", @@ -3835,6 +4268,17 @@ "@codemirror/view": "^6.0.0" } }, + "node_modules/codemirror-lang-hcl": { + "version": "0.0.0-beta.2", + "resolved": "https://registry.npmjs.org/codemirror-lang-hcl/-/codemirror-lang-hcl-0.0.0-beta.2.tgz", + "integrity": "sha512-R3ew7Z2EYTdHTMXsWKBW9zxnLoLPYO+CrAa3dPZjXLrIR96Q3GR4cwJKF7zkSsujsnWgwRQZonyWpXYXfhQYuQ==", + "license": "MIT", + "dependencies": { + "@codemirror/language": "^6.0.0", + "@lezer/highlight": "^1.0.0", + "@lezer/lr": "^1.0.0" + } + }, "node_modules/coincident": { "version": "1.2.3", "resolved": "https://registry.npmjs.org/coincident/-/coincident-1.2.3.tgz", @@ -3892,6 +4336,13 @@ "integrity": "sha512-IfEDxwoWIjkeXL1eXcDiow4UbKjhLdq6/EuSVR9GMN7KVH3r9gQ83e73hsz1Nd1T3ijd5xv1wcWRYO+D6kCI2w==", "dev": true }, + "node_modules/colorjs.io": { + "version": "0.5.2", + "resolved": "https://registry.npmjs.org/colorjs.io/-/colorjs.io-0.5.2.tgz", + "integrity": "sha512-twmVoizEW7ylZSN32OgKdXRmo1qg+wT5/6C3xu5b9QsWzSFAhHLn2xd8ro0diCsKfCj1RdaTP/nrcW+vAoQPIw==", + "devOptional": true, + "license": "MIT" + }, "node_modules/colors": { "version": "1.4.0", "resolved": "https://registry.npmjs.org/colors/-/colors-1.4.0.tgz", @@ -3955,9 +4406,10 @@ "dev": true }, "node_modules/cookie": { - "version": "0.7.1", - "resolved": "https://registry.npmjs.org/cookie/-/cookie-0.7.1.tgz", - "integrity": "sha512-6DnInpx7SJ2AK3+CTUE/ZM0vWTUboZCegxhC2xiIydHR9jNuTAASBrfEpHhiGOZw/nX51bHt6YQl8jsGo4y/0w==", + "version": "0.6.0", + "resolved": "https://registry.npmjs.org/cookie/-/cookie-0.6.0.tgz", + "integrity": "sha512-U71cyTamuh1CRNCfpGY6to28lxvNwPG4Guz/EVjgf3Jmzv0vlDp1atT9eS5dDjMYHucpHbWns6Lwf3BKz6svdw==", + "license": "MIT", "engines": { "node": ">= 0.6" } @@ -3993,9 +4445,10 @@ "integrity": "sha512-VQ2MBenTq1fWZUH9DJNGti7kKv6EeAuYr3cLwxUWhIu1baTaXh4Ib5W2CqHVqib4/MqbYGJqiL3Zb8GJZr3l4g==" }, "node_modules/cross-spawn": { - "version": "7.0.3", - "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.3.tgz", - "integrity": "sha512-iRDPJKUPVEND7dHPO8rkbOnPpyDygcDFtWjpeWNCgy8WP2rXcxXL8TskReQl6OrB2G7+UJrags1q15Fudc7G6w==", + "version": "7.0.6", + "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.6.tgz", + "integrity": "sha512-uV2QOWP2nWzsy2aMp8aRibhi9dlzF5Hgh5SHaB9OiTGEyDTiJJyx0uy51QXdyWbtAHNua4XJzUKca3OzKUd3vA==", + "license": "MIT", "dependencies": { "path-key": "^3.1.0", "shebang-command": "^2.0.0", @@ -4753,6 +5206,20 @@ "resolved": "https://registry.npmjs.org/devalue/-/devalue-5.1.1.tgz", "integrity": "sha512-maua5KUiapvEwiEAe+XnlZ3Rh0GD+qI1J/nb9vrJc3muPXvcF/8gXYTWF76+5DAqHyDUtOIImEuo0YKE9mshVw==" }, + "node_modules/devlop": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/devlop/-/devlop-1.1.0.tgz", + "integrity": "sha512-RWmIqhcFf1lRYBvNmr7qTNuyCt/7/ns2jbpp1+PalgE/rDQcBT0fioSMUpJ93irlUhC5hrg4cYqe6U+0ImW0rA==", + "license": "MIT", + "peer": true, + "dependencies": { + "dequal": "^2.0.0" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/wooorm" + } + }, "node_modules/didyoumean": { "version": "1.2.2", "resolved": "https://registry.npmjs.org/didyoumean/-/didyoumean-1.2.2.tgz", @@ -5040,7 +5507,6 @@ "version": "4.0.0", "resolved": "https://registry.npmjs.org/escape-string-regexp/-/escape-string-regexp-4.0.0.tgz", "integrity": "sha512-TtpcNJ3XAzx3Gq8sWRzJaVajRs0uVxA2YAkdb1jm2YkPz4G6egUFAyA3n5vtEIZefPk5Wa4UXbKuS5fKkJWdgA==", - "dev": true, "engines": { "node": ">=10" }, @@ -5240,9 +5706,10 @@ } }, "node_modules/esm-env": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/esm-env/-/esm-env-1.0.0.tgz", - "integrity": "sha512-Cf6VksWPsTuW01vU9Mk/3vRue91Zevka5SjyNf3nEpokFRuqt/KjUQoGAwq9qMmhpLTHmXzSIrFRw8zxWzmFBA==" + "version": "1.2.1", + "resolved": "https://registry.npmjs.org/esm-env/-/esm-env-1.2.1.tgz", + "integrity": "sha512-U9JedYYjCnadUlXk7e1Kr+aENQhtUaoaV9+gZm1T8LC/YBAPJx3NSPIAurFOC0U5vrdSevnUJS2/wUVxGwPhng==", + "license": "MIT" }, "node_modules/espree": { "version": "9.6.1", @@ -5971,7 +6438,7 @@ "version": "4.0.0", "resolved": "https://registry.npmjs.org/has-flag/-/has-flag-4.0.0.tgz", "integrity": "sha512-EykJT/Q1KjTWctppgIAgfSO0tKVuZUjhgMr17kqTumMl6Afv3EISleU7qZUzoXDFTAHTDC4NOoG/ZxU3EvlMPQ==", - "dev": true, + "devOptional": true, "engines": { "node": ">=8" } @@ -6234,6 +6701,13 @@ "node": ">= 4" } }, + "node_modules/immutable": { + "version": "5.0.3", + "resolved": "https://registry.npmjs.org/immutable/-/immutable-5.0.3.tgz", + "integrity": "sha512-P8IdPQHq3lA1xVeBRi5VPqUm5HDgKnx0Ru51wZz5mjxHr5n3RWhjIpOFU7ybkUxfB+5IToy+OLaHYDBIWsv+uw==", + "devOptional": true, + "license": "MIT" + }, "node_modules/import-fresh": { "version": "3.3.0", "resolved": "https://registry.npmjs.org/import-fresh/-/import-fresh-3.3.0.tgz", @@ -6975,6 +7449,22 @@ "get-func-name": "^2.0.1" } }, + "node_modules/lowlight": { + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/lowlight/-/lowlight-3.1.0.tgz", + "integrity": "sha512-CEbNVoSikAxwDMDPjXlqlFYiZLkDJHwyGu/MfOsJnF3d7f3tds5J3z8s/l9TMXhzfsJCCJEAsD78842mwmg0PQ==", + "license": "MIT", + "peer": true, + "dependencies": { + "@types/hast": "^3.0.0", + "devlop": "^1.0.0", + "highlight.js": "~11.9.0" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/wooorm" + } + }, "node_modules/magic-string": { "version": "0.30.11", "resolved": "https://registry.npmjs.org/magic-string/-/magic-string-0.30.11.tgz", @@ -7755,6 +8245,7 @@ "version": "2.0.0", "resolved": "https://registry.npmjs.org/mrmime/-/mrmime-2.0.0.tgz", "integrity": "sha512-eu38+hdgojoyq63s+yTpN4XMBdt5l8HhMhc4VKLO9KM5caLIBvUm4thi7fFaxyTmCKeNnXZ5pAlBwCUnhA09uw==", + "license": "MIT", "engines": { "node": ">=10" } @@ -8597,14 +9088,33 @@ "node": "10.* || >= 12.*" } }, + "node_modules/prosemirror-changeset": { + "version": "2.2.1", + "resolved": "https://registry.npmjs.org/prosemirror-changeset/-/prosemirror-changeset-2.2.1.tgz", + "integrity": "sha512-J7msc6wbxB4ekDFj+n9gTW/jav/p53kdlivvuppHsrZXCaQdVgRghoZbSS3kwrRyAstRVQ4/+u5k7YfLgkkQvQ==", + "license": "MIT", + "dependencies": { + "prosemirror-transform": "^1.0.0" + } + }, + "node_modules/prosemirror-collab": { + "version": "1.3.1", + "resolved": "https://registry.npmjs.org/prosemirror-collab/-/prosemirror-collab-1.3.1.tgz", + "integrity": "sha512-4SnynYR9TTYaQVXd/ieUvsVV4PDMBzrq2xPUWutHivDuOshZXqQ5rGbZM84HEaXKbLdItse7weMGOUdDVcLKEQ==", + "license": "MIT", + "dependencies": { + "prosemirror-state": "^1.0.0" + } + }, "node_modules/prosemirror-commands": { - "version": "1.6.0", - "resolved": "https://registry.npmjs.org/prosemirror-commands/-/prosemirror-commands-1.6.0.tgz", - "integrity": "sha512-xn1U/g36OqXn2tn5nGmvnnimAj/g1pUx2ypJJIe8WkVX83WyJVC5LTARaxZa2AtQRwntu9Jc5zXs9gL9svp/mg==", + "version": "1.6.2", + "resolved": "https://registry.npmjs.org/prosemirror-commands/-/prosemirror-commands-1.6.2.tgz", + "integrity": "sha512-0nDHH++qcf/BuPLYvmqZTUUsPJUCPBUXt0J1ErTcDIS369CTp773itzLGIgIXG4LJXOlwYCr44+Mh4ii6MP1QA==", + "license": "MIT", "dependencies": { "prosemirror-model": "^1.0.0", "prosemirror-state": "^1.0.0", - "prosemirror-transform": "^1.0.0" + "prosemirror-transform": "^1.10.2" } }, "node_modules/prosemirror-dropcursor": { @@ -8730,18 +9240,48 @@ "prosemirror-view": "^1.27.0" } }, + "node_modules/prosemirror-tables": { + "version": "1.6.1", + "resolved": "https://registry.npmjs.org/prosemirror-tables/-/prosemirror-tables-1.6.1.tgz", + "integrity": "sha512-p8WRJNA96jaNQjhJolmbxTzd6M4huRE5xQ8OxjvMhQUP0Nzpo4zz6TztEiwk6aoqGBhz9lxRWR1yRZLlpQN98w==", + "license": "MIT", + "dependencies": { + "prosemirror-keymap": "^1.1.2", + "prosemirror-model": "^1.8.1", + "prosemirror-state": "^1.3.1", + "prosemirror-transform": "^1.2.1", + "prosemirror-view": "^1.13.3" + } + }, + "node_modules/prosemirror-trailing-node": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/prosemirror-trailing-node/-/prosemirror-trailing-node-3.0.0.tgz", + "integrity": "sha512-xiun5/3q0w5eRnGYfNlW1uU9W6x5MoFKWwq/0TIRgt09lv7Hcser2QYV8t4muXbEr+Fwo0geYn79Xs4GKywrRQ==", + "license": "MIT", + "dependencies": { + "@remirror/core-constants": "3.0.0", + "escape-string-regexp": "^4.0.0" + }, + "peerDependencies": { + "prosemirror-model": "^1.22.1", + "prosemirror-state": "^1.4.2", + "prosemirror-view": "^1.33.8" + } + }, "node_modules/prosemirror-transform": { - "version": "1.10.0", - "resolved": "https://registry.npmjs.org/prosemirror-transform/-/prosemirror-transform-1.10.0.tgz", - "integrity": "sha512-9UOgFSgN6Gj2ekQH5CTDJ8Rp/fnKR2IkYfGdzzp5zQMFsS4zDllLVx/+jGcX86YlACpG7UR5fwAXiWzxqWtBTg==", + "version": "1.10.2", + "resolved": "https://registry.npmjs.org/prosemirror-transform/-/prosemirror-transform-1.10.2.tgz", + "integrity": "sha512-2iUq0wv2iRoJO/zj5mv8uDUriOHWzXRnOTVgCzSXnktS/2iQRa3UUQwVlkBlYZFtygw6Nh1+X4mGqoYBINn5KQ==", + "license": "MIT", "dependencies": { "prosemirror-model": "^1.21.0" } }, "node_modules/prosemirror-view": { - "version": "1.34.3", - "resolved": "https://registry.npmjs.org/prosemirror-view/-/prosemirror-view-1.34.3.tgz", - "integrity": "sha512-mKZ54PrX19sSaQye+sef+YjBbNu2voNwLS1ivb6aD2IRmxRGW64HU9B644+7OfJStGLyxvOreKqEgfvXa91WIA==", + "version": "1.36.0", + "resolved": "https://registry.npmjs.org/prosemirror-view/-/prosemirror-view-1.36.0.tgz", + "integrity": "sha512-U0GQd5yFvV5qUtT41X1zCQfbw14vkbbKwLlQXhdylEmgpYVHkefXYcC4HHwWOfZa3x6Y8wxDLUBv7dxN5XQ3nA==", + "license": "MIT", "dependencies": { "prosemirror-model": "^1.20.0", "prosemirror-state": "^1.0.0", @@ -9229,7 +9769,7 @@ "version": "7.8.1", "resolved": "https://registry.npmjs.org/rxjs/-/rxjs-7.8.1.tgz", "integrity": "sha512-AA3TVj+0A2iuIoQkWEK/tqFjBq2j+6PO6Y0zJcvzLAFhEFIO3HL0vls9hWLncZbAAbK0mar7oZ4V079I/qPMxg==", - "dev": true, + "devOptional": true, "dependencies": { "tslib": "^2.1.0" } @@ -9322,6 +9862,387 @@ "rimraf": "bin.js" } }, + "node_modules/sass-embedded": { + "version": "1.81.0", + "resolved": "https://registry.npmjs.org/sass-embedded/-/sass-embedded-1.81.0.tgz", + "integrity": "sha512-uZQ2Faxb1oWBHpeSSzjxnhClbMb3QadN0ql0ZFNuqWOLUxwaVhrMlMhPq6TDPbbfDUjihuwrMCuy695Bgna5RA==", + "devOptional": true, + "license": "MIT", + "dependencies": { + "@bufbuild/protobuf": "^2.0.0", + "buffer-builder": "^0.2.0", + "colorjs.io": "^0.5.0", + "immutable": "^5.0.2", + "rxjs": "^7.4.0", + "supports-color": "^8.1.1", + "sync-child-process": "^1.0.2", + "varint": "^6.0.0" + }, + "bin": { + "sass": "dist/bin/sass.js" + }, + "engines": { + "node": ">=16.0.0" + }, + "optionalDependencies": { + "sass-embedded-android-arm": "1.81.0", + "sass-embedded-android-arm64": "1.81.0", + "sass-embedded-android-ia32": "1.81.0", + "sass-embedded-android-riscv64": "1.81.0", + "sass-embedded-android-x64": "1.81.0", + "sass-embedded-darwin-arm64": "1.81.0", + "sass-embedded-darwin-x64": "1.81.0", + "sass-embedded-linux-arm": "1.81.0", + "sass-embedded-linux-arm64": "1.81.0", + "sass-embedded-linux-ia32": "1.81.0", + "sass-embedded-linux-musl-arm": "1.81.0", + "sass-embedded-linux-musl-arm64": "1.81.0", + "sass-embedded-linux-musl-ia32": "1.81.0", + "sass-embedded-linux-musl-riscv64": "1.81.0", + "sass-embedded-linux-musl-x64": "1.81.0", + "sass-embedded-linux-riscv64": "1.81.0", + "sass-embedded-linux-x64": "1.81.0", + "sass-embedded-win32-arm64": "1.81.0", + "sass-embedded-win32-ia32": "1.81.0", + "sass-embedded-win32-x64": "1.81.0" + } + }, + "node_modules/sass-embedded-android-arm": { + "version": "1.81.0", + "resolved": "https://registry.npmjs.org/sass-embedded-android-arm/-/sass-embedded-android-arm-1.81.0.tgz", + "integrity": "sha512-NWEmIuaIEsGFNsIRa+5JpIpPJyZ32H15E85CNZqEIhhwWlk9UNw7vlOCmTH8MtabtnACwC/2NG8VyNa3nxKzUQ==", + "cpu": [ + "arm" + ], + "license": "MIT", + "optional": true, + "os": [ + "android" + ], + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/sass-embedded-android-arm64": { + "version": "1.81.0", + "resolved": "https://registry.npmjs.org/sass-embedded-android-arm64/-/sass-embedded-android-arm64-1.81.0.tgz", + "integrity": "sha512-I36P77/PKAHx6sqOmexO2iEY5kpsmQ1VxcgITZSOxPMQhdB6m4t3bTabfDuWQQmCrqqiNFtLQHeytB65bUqwiw==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "android" + ], + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/sass-embedded-android-ia32": { + "version": "1.81.0", + "resolved": "https://registry.npmjs.org/sass-embedded-android-ia32/-/sass-embedded-android-ia32-1.81.0.tgz", + "integrity": "sha512-k8V1usXw30w1GVxvrteG1RzgYJzYQ9PfL2aeOqGdroBN7zYTD9VGJXTGcxA4IeeRxmRd7szVW2mKXXS472fh8g==", + "cpu": [ + "ia32" + ], + "license": "MIT", + "optional": true, + "os": [ + "android" + ], + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/sass-embedded-android-riscv64": { + "version": "1.81.0", + "resolved": "https://registry.npmjs.org/sass-embedded-android-riscv64/-/sass-embedded-android-riscv64-1.81.0.tgz", + "integrity": "sha512-RXlanyLXEpN/DEehXgLuKPsqT//GYlsGFxKXgRiCc8hIPAueFLQXKJmLWlL3BEtHgmFdbsStIu4aZCcb1hOFlQ==", + "cpu": [ + "riscv64" + ], + "license": "MIT", + "optional": true, + "os": [ + "android" + ], + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/sass-embedded-android-x64": { + "version": "1.81.0", + "resolved": "https://registry.npmjs.org/sass-embedded-android-x64/-/sass-embedded-android-x64-1.81.0.tgz", + "integrity": "sha512-RQG0FxGQ1DERNyUDED8+BDVaLIjI+BNg8lVcyqlLZUrWY6NhzjwYEeiN/DNZmMmHtqDucAPNDcsdVUNQqsBy2A==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "android" + ], + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/sass-embedded-darwin-arm64": { + "version": "1.81.0", + "resolved": "https://registry.npmjs.org/sass-embedded-darwin-arm64/-/sass-embedded-darwin-arm64-1.81.0.tgz", + "integrity": "sha512-gLKbsfII9Ppua76N41ODFnKGutla9qv0OGAas8gxe0jYBeAQFi/1iKQYdNtQtKi4mA9n5TQTqz+HHCKszZCoyA==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "darwin" + ], + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/sass-embedded-darwin-x64": { + "version": "1.81.0", + "resolved": "https://registry.npmjs.org/sass-embedded-darwin-x64/-/sass-embedded-darwin-x64-1.81.0.tgz", + "integrity": "sha512-7uMOlT9hD2KUJCbTN2XcfghDxt/rc50ujjfSjSHjX1SYj7mGplkINUXvVbbvvaV2wt6t9vkGkCo5qNbeBhfwBg==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "darwin" + ], + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/sass-embedded-linux-arm": { + "version": "1.81.0", + "resolved": "https://registry.npmjs.org/sass-embedded-linux-arm/-/sass-embedded-linux-arm-1.81.0.tgz", + "integrity": "sha512-REqR9qM4RchCE3cKqzRy9Q4zigIV82SbSpCi/O4O3oK3pg2I1z7vkb3TiJsivusG/li7aqKZGmYOtAXjruGQDA==", + "cpu": [ + "arm" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/sass-embedded-linux-arm64": { + "version": "1.81.0", + "resolved": "https://registry.npmjs.org/sass-embedded-linux-arm64/-/sass-embedded-linux-arm64-1.81.0.tgz", + "integrity": "sha512-jy4bvhdUmqbyw1jv1f3Uxl+MF8EU/Y/GDx4w6XPJm4Ds+mwH/TwnyAwsxxoBhWfnBnW8q2ADy039DlS5p+9csQ==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/sass-embedded-linux-ia32": { + "version": "1.81.0", + "resolved": "https://registry.npmjs.org/sass-embedded-linux-ia32/-/sass-embedded-linux-ia32-1.81.0.tgz", + "integrity": "sha512-ga/Jk4q5Bn1aC+iHJteDZuLSKnmBUiS3dEg1fnl/Z7GaHIChceKDJOw0zNaILRXI0qT2E1at9MwzoRaRA5Nn/g==", + "cpu": [ + "ia32" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/sass-embedded-linux-musl-arm": { + "version": "1.81.0", + "resolved": "https://registry.npmjs.org/sass-embedded-linux-musl-arm/-/sass-embedded-linux-musl-arm-1.81.0.tgz", + "integrity": "sha512-oWVUvQ4d5Kx1Md75YXZl5z1WBjc+uOhfRRqzkJ3nWc8tjszxJN+y/5EOJavhsNI3/2yoTt6eMXRTqDD9b0tWSQ==", + "cpu": [ + "arm" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/sass-embedded-linux-musl-arm64": { + "version": "1.81.0", + "resolved": "https://registry.npmjs.org/sass-embedded-linux-musl-arm64/-/sass-embedded-linux-musl-arm64-1.81.0.tgz", + "integrity": "sha512-hpntWf5kjkoxncA1Vh8vhsUOquZ8AROZKx0rQh7ZjSRs4JrYZASz1cfevPKaEM3wIim/nYa6TJqm0VqWsrERlA==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/sass-embedded-linux-musl-ia32": { + "version": "1.81.0", + "resolved": "https://registry.npmjs.org/sass-embedded-linux-musl-ia32/-/sass-embedded-linux-musl-ia32-1.81.0.tgz", + "integrity": "sha512-UEXUYkBuqTSwg5JNWiNlfMZ1Jx6SJkaEdx+fsL3Tk099L8cKSoJWH2EPz4ZJjNbyIMymrSdVfymheTeZ8u24xA==", + "cpu": [ + "ia32" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/sass-embedded-linux-musl-riscv64": { + "version": "1.81.0", + "resolved": "https://registry.npmjs.org/sass-embedded-linux-musl-riscv64/-/sass-embedded-linux-musl-riscv64-1.81.0.tgz", + "integrity": "sha512-1D7OznytbIhx2XDHWi1nuQ8d/uCVR7FGGzELgaU//T8A9DapVTUgPKvB70AF1k4GzChR9IXU/WvFZs2hDTbaJg==", + "cpu": [ + "riscv64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/sass-embedded-linux-musl-x64": { + "version": "1.81.0", + "resolved": "https://registry.npmjs.org/sass-embedded-linux-musl-x64/-/sass-embedded-linux-musl-x64-1.81.0.tgz", + "integrity": "sha512-ia6VCTeVDQtBSMktXRFza1AZCt8/6aUoujot6Ugf4KmdytQqPJIHxkHaGftm5xwi9WdrMGYS7zgolToPijR11A==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/sass-embedded-linux-riscv64": { + "version": "1.81.0", + "resolved": "https://registry.npmjs.org/sass-embedded-linux-riscv64/-/sass-embedded-linux-riscv64-1.81.0.tgz", + "integrity": "sha512-KbxSsqu4tT1XbhZfJV/5NfW0VtJIGlD58RjqJqJBi8Rnjrx29/upBsuwoDWtsPV/LhoGwwU1XkSa9Q1ifCz4fQ==", + "cpu": [ + "riscv64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/sass-embedded-linux-x64": { + "version": "1.81.0", + "resolved": "https://registry.npmjs.org/sass-embedded-linux-x64/-/sass-embedded-linux-x64-1.81.0.tgz", + "integrity": "sha512-AMDeVY2T9WAnSFkuQcsOn5c29GRs/TuqnCiblKeXfxCSKym5uKdBl/N7GnTV6OjzoxiJBbkYKdVIaS5By7Gj4g==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/sass-embedded-win32-arm64": { + "version": "1.81.0", + "resolved": "https://registry.npmjs.org/sass-embedded-win32-arm64/-/sass-embedded-win32-arm64-1.81.0.tgz", + "integrity": "sha512-YOmBRYnygwWUmCoH14QbMRHjcvCJufeJBAp0m61tOJXIQh64ziwV4mjdqjS/Rx3zhTT4T+nulDUw4d3kLiMncA==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "win32" + ], + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/sass-embedded-win32-ia32": { + "version": "1.81.0", + "resolved": "https://registry.npmjs.org/sass-embedded-win32-ia32/-/sass-embedded-win32-ia32-1.81.0.tgz", + "integrity": "sha512-HFfr/C+uLJGGTENdnssuNTmXI/xnIasUuEHEKqI+2J0FHCWT5cpz3PGAOHymPyJcZVYGUG/7gIxIx/d7t0LFYw==", + "cpu": [ + "ia32" + ], + "license": "MIT", + "optional": true, + "os": [ + "win32" + ], + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/sass-embedded-win32-x64": { + "version": "1.81.0", + "resolved": "https://registry.npmjs.org/sass-embedded-win32-x64/-/sass-embedded-win32-x64-1.81.0.tgz", + "integrity": "sha512-wxj52jDcIAwWcXb7ShZ7vQYKcVUkJ+04YM9l46jDY+qwHzliGuorAUyujLyKTE9heGD3gShJ3wPPC1lXzq6v9A==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "win32" + ], + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/sass-embedded/node_modules/supports-color": { + "version": "8.1.1", + "resolved": "https://registry.npmjs.org/supports-color/-/supports-color-8.1.1.tgz", + "integrity": "sha512-MpUEN2OodtUzxvKQl72cUF7RQ5EiHsGvSsVG0ia9c5RbWGL2CI4C7EpPS8UTBIplnlzZiNuV56w+FuNxy3ty2Q==", + "devOptional": true, + "license": "MIT", + "dependencies": { + "has-flag": "^4.0.0" + }, + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/chalk/supports-color?sponsor=1" + } + }, "node_modules/semver": { "version": "7.6.3", "resolved": "https://registry.npmjs.org/semver/-/semver-7.6.3.tgz", @@ -9456,16 +10377,17 @@ } }, "node_modules/sirv": { - "version": "2.0.4", - "resolved": "https://registry.npmjs.org/sirv/-/sirv-2.0.4.tgz", - "integrity": "sha512-94Bdh3cC2PKrbgSOUqTiGPWVZeSiXfKOVZNJniWoqrWrRkB1CJzBU3NEbiTsPcYy1lDsANA/THzS+9WBiy5nfQ==", + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/sirv/-/sirv-3.0.0.tgz", + "integrity": "sha512-BPwJGUeDaDCHihkORDchNyyTvWFhcusy1XMmhEVTQTwGeybFbp8YEmB+njbPnth1FibULBSBVwCQni25XlCUDg==", + "license": "MIT", "dependencies": { "@polka/url": "^1.0.0-next.24", "mrmime": "^2.0.0", "totalist": "^3.0.0" }, "engines": { - "node": ">= 10" + "node": ">=18" } }, "node_modules/slash": { @@ -10037,6 +10959,29 @@ "integrity": "sha512-0K91MEXFpBUaywiwSSkmKjnGcasG/rVBXFLJz5DrgGabpYD6N+3yZrfD6uUIfpuTu65DZLHi7N8CizHc07BPZA==", "dev": true }, + "node_modules/sync-child-process": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/sync-child-process/-/sync-child-process-1.0.2.tgz", + "integrity": "sha512-8lD+t2KrrScJ/7KXCSyfhT3/hRq78rC0wBFqNJXv3mZyn6hW2ypM05JmlSvtqRbeq6jqA94oHbxAr2vYsJ8vDA==", + "devOptional": true, + "license": "MIT", + "dependencies": { + "sync-message-port": "^1.0.0" + }, + "engines": { + "node": ">=16.0.0" + } + }, + "node_modules/sync-message-port": { + "version": "1.1.3", + "resolved": "https://registry.npmjs.org/sync-message-port/-/sync-message-port-1.1.3.tgz", + "integrity": "sha512-GTt8rSKje5FilG+wEdfCkOcLL7LWqpMlr2c3LRuKt/YXxcJ52aGSbGBAdI4L3aaqfrBt6y711El53ItyH1NWzg==", + "devOptional": true, + "license": "MIT", + "engines": { + "node": ">=16.0.0" + } + }, "node_modules/tabbable": { "version": "6.2.0", "resolved": "https://registry.npmjs.org/tabbable/-/tabbable-6.2.0.tgz", @@ -10334,6 +11279,7 @@ "version": "3.0.1", "resolved": "https://registry.npmjs.org/totalist/-/totalist-3.0.1.tgz", "integrity": "sha512-sf4i37nQ2LBx4m3wB74y+ubopq6W/dIzXg0FDGjsYnZHVa1Da8FH853wlL2gtUhg+xJXjfk3kUZS3BRoQeoQBQ==", + "license": "MIT", "engines": { "node": ">=6" } @@ -10409,6 +11355,7 @@ "version": "7.2.0", "resolved": "https://registry.npmjs.org/turndown/-/turndown-7.2.0.tgz", "integrity": "sha512-eCZGBN4nNNqM9Owkv9HAtWRYfLA4h909E/WGAWWBpmB275ehNhZyk87/Tpvjbp0jjNl9XwCsbe6bm6CqFsgD+A==", + "license": "MIT", "dependencies": { "@mixmark-io/domino": "^2.2.0" } @@ -10622,6 +11569,13 @@ "node": ">= 10.13.0" } }, + "node_modules/varint": { + "version": "6.0.0", + "resolved": "https://registry.npmjs.org/varint/-/varint-6.0.0.tgz", + "integrity": "sha512-cXEIW6cfr15lFv563k4GuVuW/fiwjknytD37jIOLSdSWuOI6WnO/oKwmP2FQTU2l01LP8/M5TSAJpzUaGe3uWg==", + "devOptional": true, + "license": "MIT" + }, "node_modules/verror": { "version": "1.10.0", "resolved": "https://registry.npmjs.org/verror/-/verror-1.10.0.tgz", diff --git a/package.json b/package.json index 319402307..3b5911791 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "open-webui", - "version": "0.3.35", + "version": "0.4.8", "private": true, "scripts": { "dev": "npm run pyodide:fetch && vite dev --host", @@ -37,6 +37,7 @@ "postcss": "^8.4.31", "prettier": "^3.3.3", "prettier-plugin-svelte": "^3.2.6", + "sass-embedded": "^1.81.0", "svelte": "^4.2.18", "svelte-check": "^3.8.5", "svelte-confetti": "^1.3.2", @@ -49,12 +50,21 @@ "type": "module", "dependencies": { "@codemirror/lang-javascript": "^6.2.2", + "codemirror-lang-hcl": "^0.0.0-beta.2", "@codemirror/lang-python": "^6.1.6", "@codemirror/language-data": "^6.5.1", "@codemirror/theme-one-dark": "^6.1.2", "@huggingface/transformers": "^3.0.0", + "@mediapipe/tasks-vision": "^0.10.17", "@pyscript/core": "^0.4.32", "@sveltejs/adapter-node": "^2.0.0", + "@tiptap/core": "^2.10.0", + "@tiptap/extension-code-block-lowlight": "^2.10.0", + "@tiptap/extension-highlight": "^2.10.0", + "@tiptap/extension-placeholder": "^2.10.0", + "@tiptap/extension-typography": "^2.10.0", + "@tiptap/pm": "^2.10.0", + "@tiptap/starter-kit": "^2.10.0", "@xyflow/svelte": "^0.1.19", "async": "^3.2.5", "bits-ui": "^0.19.7", diff --git a/pyproject.toml b/pyproject.toml index b6248063d..de14a9fa1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ dependencies = [ "fastapi==0.111.0", "uvicorn[standard]==0.30.6", "pydantic==2.9.2", - "python-multipart==0.0.9", + "python-multipart==0.0.18", "Flask==3.0.3", "Flask-Cors==5.0.0", @@ -19,20 +19,23 @@ dependencies = [ "passlib[bcrypt]==1.7.4", "requests==2.32.3", - "aiohttp==3.10.8", + "aiohttp==3.11.8", "async-timeout", + "aiocache", + "aiofiles", "sqlalchemy==2.0.32", - "alembic==1.13.2", + "alembic==1.14.0", "peewee==3.17.6", "peewee-migrate==1.12.2", "psycopg2-binary==2.9.9", + "pgvector==0.3.5", "PyMySQL==1.1.1", "bcrypt==4.2.0", "pymongo", "redis", - "boto3==1.35.0", + "boto3==1.35.53", "argon2-cffi==23.1.0", "APScheduler==3.10.4", @@ -42,21 +45,23 @@ dependencies = [ "google-generativeai==0.7.2", "tiktoken", - "langchain==0.2.15", - "langchain-community==0.2.12", + "langchain==0.3.7", + "langchain-community==0.3.7", "langchain-chroma==0.1.4", "fake-useragent==1.5.1", - "chromadb==0.5.9", - "pymilvus==2.4.7", + "chromadb==0.5.15", + "pymilvus==2.5.0", + "qdrant-client~=1.12.0", + "opensearch-py==2.7.1", - "sentence-transformers==3.2.0", + "sentence-transformers==3.3.1", "colbert-ai==0.2.21", "einops==0.8.0", "ftfy==6.2.3", "pypdf==4.3.1", - "xhtml2pdf==0.2.16", + "fpdf2==2.7.9", "pymdown-extensions==10.11.2", "docx2txt==0.8", "python-pptx==1.0.0", @@ -70,11 +75,11 @@ dependencies = [ "xlrd==2.0.1", "validators==0.33.0", "psutil", + "sentencepiece", + "soundfile==0.12.1", "opencv-python-headless==4.10.0.84", "rapidocr-onnxruntime==1.3.24", - - "fpdf2==2.7.9", "rank-bm25==0.2.2", "faster-whisper==1.0.3", @@ -84,27 +89,30 @@ dependencies = [ "black==24.8.0", "langfuse==2.44.0", - "youtube-transcript-api==0.6.2", + "youtube-transcript-api==0.6.3", "pytube==15.0.0", "extract_msg", "pydub", - "duckduckgo-search~=6.2.13", + "duckduckgo-search~=6.3.5", "docker~=7.1.0", "pytest~=8.3.2", "pytest-docker~=3.1.1", - "googleapis-common-protos==1.63.2" + "googleapis-common-protos==1.63.2", + + "ldap3==2.9.1" ] readme = "README.md" -requires-python = ">= 3.11, < 3.12.0a1" +requires-python = ">= 3.11, < 3.13.0a1" dynamic = ["version"] classifiers = [ "Development Status :: 4 - Beta", "License :: OSI Approved :: MIT License", "Programming Language :: Python :: 3", "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", "Topic :: Communications :: Chat", "Topic :: Multimedia", ] diff --git a/src/app.css b/src/app.css index ca5249bbd..cf0afea4f 100644 --- a/src/app.css +++ b/src/app.css @@ -16,6 +16,12 @@ font-display: swap; } +@font-face { + font-family: 'InstrumentSerif'; + src: url('/assets/fonts/InstrumentSerif-Regular.ttf'); + font-display: swap; +} + html { word-break: break-word; } @@ -26,6 +32,10 @@ code { width: auto; } +.font-secondary { + font-family: 'InstrumentSerif', sans-serif; +} + math { margin-top: 1rem; } @@ -35,15 +45,15 @@ math { } .input-prose { - @apply prose dark:prose-invert prose-p:my-0 prose-img:my-1 prose-headings:my-1 prose-pre:my-0 prose-table:my-0 prose-blockquote:my-0 prose-ul:-my-0 prose-ol:-my-0 prose-li:-my-0 whitespace-pre-line; + @apply prose dark:prose-invert prose-headings:font-semibold prose-hr:my-4 prose-hr:border-gray-100 prose-hr:dark:border-gray-800 prose-p:my-0 prose-img:my-1 prose-headings:my-1 prose-pre:my-0 prose-table:my-0 prose-blockquote:my-0 prose-ul:-my-0 prose-ol:-my-0 prose-li:-my-0 whitespace-pre-line; } .input-prose-sm { - @apply prose dark:prose-invert prose-p:my-0 prose-img:my-1 prose-headings:my-1 prose-pre:my-0 prose-table:my-0 prose-blockquote:my-0 prose-ul:-my-0 prose-ol:-my-0 prose-li:-my-0 whitespace-pre-line text-sm; + @apply prose dark:prose-invert prose-headings:font-semibold prose-hr:my-4 prose-hr:border-gray-100 prose-hr:dark:border-gray-800 prose-p:my-0 prose-img:my-1 prose-headings:my-1 prose-pre:my-0 prose-table:my-0 prose-blockquote:my-0 prose-ul:-my-0 prose-ol:-my-0 prose-li:-my-0 whitespace-pre-line text-sm; } .markdown-prose { - @apply prose dark:prose-invert prose-p:my-0 prose-img:my-1 prose-headings:my-1 prose-pre:my-0 prose-table:my-0 prose-blockquote:my-0 prose-ul:-my-0 prose-ol:-my-0 prose-li:-my-0 whitespace-pre-line; + @apply prose dark:prose-invert prose-headings:font-semibold prose-hr:my-4 prose-p:my-0 prose-img:my-1 prose-headings:my-1 prose-pre:my-0 prose-table:my-0 prose-blockquote:my-0 prose-ul:-my-0 prose-ol:-my-0 prose-li:-my-0 whitespace-pre-line; } .markdown a { @@ -189,19 +199,103 @@ input[type='number'] { } .ProseMirror { - @apply h-full min-h-fit max-h-full whitespace-pre-wrap; + @apply h-full min-h-fit max-h-full whitespace-pre-wrap; } .ProseMirror:focus { outline: none; } -.placeholder::after { +.ProseMirror p.is-editor-empty:first-child::before { content: attr(data-placeholder); - cursor: text; + float: left; + color: #adb5bd; pointer-events: none; - float: left; - - @apply absolute inset-0 z-0 text-gray-500; + @apply line-clamp-1 absolute; +} + +.ai-autocompletion::after { + color: #a0a0a0; + + content: attr(data-suggestion); + pointer-events: none; +} + +.tiptap > pre > code { + border-radius: 0.4rem; + font-size: 0.85rem; + padding: 0.25em 0.3em; + + @apply dark:bg-gray-800 bg-gray-100; +} + +.tiptap > pre { + border-radius: 0.5rem; + font-family: 'JetBrainsMono', monospace; + margin: 1.5rem 0; + padding: 0.75rem 1rem; + + @apply dark:bg-gray-800 bg-gray-100; +} + +.tiptap p code { + color: #eb5757; + border-width: 0px; + padding: 3px 8px; + font-size: 0.8em; + font-weight: 600; + @apply rounded-md dark:bg-gray-800 bg-gray-100 mx-0.5; +} + +/* Code styling */ +.hljs-comment, +.hljs-quote { + color: #616161; +} + +.hljs-variable, +.hljs-template-variable, +.hljs-attribute, +.hljs-tag, +.hljs-regexp, +.hljs-link, +.hljs-name, +.hljs-selector-id, +.hljs-selector-class { + color: #f98181; +} + +.hljs-number, +.hljs-meta, +.hljs-built_in, +.hljs-builtin-name, +.hljs-literal, +.hljs-type, +.hljs-params { + color: #fbbc88; +} + +.hljs-string, +.hljs-symbol, +.hljs-bullet { + color: #b9f18d; +} + +.hljs-title, +.hljs-section { + color: #faf594; +} + +.hljs-keyword, +.hljs-selector-tag { + color: #70cff8; +} + +.hljs-emphasis { + font-style: italic; +} + +.hljs-strong { + font-weight: 700; } diff --git a/src/app.html b/src/app.html index f6e46c9cf..537e28dbe 100644 --- a/src/app.html +++ b/src/app.html @@ -2,9 +2,12 @@ - - - + + + + + + { return res; }; +export const ldapUserSignIn = async (user: string, password: string) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/auths/ldap`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json' + }, + credentials: 'include', + body: JSON.stringify({ + user: user, + password: password + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + + error = err.detail; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const getLdapConfig = async (token: string = '') => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/auths/admin/config/ldap`, { + method: 'GET', + headers: { + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + error = err.detail; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const updateLdapConfig = async (token: string = '', enable_ldap: boolean) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/auths/admin/config/ldap`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + }, + body: JSON.stringify({ + enable_ldap: enable_ldap + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + error = err.detail; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const getLdapServer = async (token: string = '') => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/auths/admin/config/ldap/server`, { + method: 'GET', + headers: { + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + error = err.detail; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const updateLdapServer = async (token: string = '', body: object) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/auths/admin/config/ldap/server`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + }, + body: JSON.stringify(body) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + error = err.detail; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + export const userSignIn = async (email: string, password: string) => { let error = null; diff --git a/src/lib/apis/configs/index.ts b/src/lib/apis/configs/index.ts index 0c4de6ad6..e9faf346b 100644 --- a/src/lib/apis/configs/index.ts +++ b/src/lib/apis/configs/index.ts @@ -58,17 +58,44 @@ export const exportConfig = async (token: string) => { return res; }; -export const setDefaultModels = async (token: string, models: string) => { +export const getModelsConfig = async (token: string) => { let error = null; - const res = await fetch(`${WEBUI_API_BASE_URL}/configs/default/models`, { + const res = await fetch(`${WEBUI_API_BASE_URL}/configs/models`, { + method: 'GET', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + error = err.detail; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const setModelsConfig = async (token: string, config: object) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/configs/models`, { method: 'POST', headers: { 'Content-Type': 'application/json', Authorization: `Bearer ${token}` }, body: JSON.stringify({ - models: models + ...config }) }) .then(async (res) => { @@ -91,7 +118,7 @@ export const setDefaultModels = async (token: string, models: string) => { export const setDefaultPromptSuggestions = async (token: string, promptSuggestions: string) => { let error = null; - const res = await fetch(`${WEBUI_API_BASE_URL}/configs/default/suggestions`, { + const res = await fetch(`${WEBUI_API_BASE_URL}/configs/suggestions`, { method: 'POST', headers: { 'Content-Type': 'application/json', diff --git a/src/lib/apis/groups/index.ts b/src/lib/apis/groups/index.ts new file mode 100644 index 000000000..b7d4f8ef9 --- /dev/null +++ b/src/lib/apis/groups/index.ts @@ -0,0 +1,162 @@ +import { WEBUI_API_BASE_URL } from '$lib/constants'; + +export const createNewGroup = async (token: string, group: object) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/groups/create`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + ...group + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + error = err.detail; + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const getGroups = async (token: string = '') => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/groups/`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const getGroupById = async (token: string, id: string) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/groups/id/${id}`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const updateGroupById = async (token: string, id: string, group: object) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/groups/id/${id}/update`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + ...group + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const deleteGroupById = async (token: string, id: string) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/groups/id/${id}/delete`, { + method: 'DELETE', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; diff --git a/src/lib/apis/index.ts b/src/lib/apis/index.ts index 40d0e0392..d06fbf3d7 100644 --- a/src/lib/apis/index.ts +++ b/src/lib/apis/index.ts @@ -1,9 +1,8 @@ import { WEBUI_API_BASE_URL, WEBUI_BASE_URL } from '$lib/constants'; -export const getModels = async (token: string = '') => { +export const getModels = async (token: string = '', base: boolean = false) => { let error = null; - - const res = await fetch(`${WEBUI_BASE_URL}/api/models`, { + const res = await fetch(`${WEBUI_BASE_URL}/api/models${base ? '/base' : ''}`, { method: 'GET', headers: { Accept: 'application/json', @@ -16,8 +15,8 @@ export const getModels = async (token: string = '') => { return res.json(); }) .catch((err) => { - console.log(err); error = err; + console.log(err); return null; }); @@ -26,42 +25,6 @@ export const getModels = async (token: string = '') => { } let models = res?.data ?? []; - - models = models - .filter((models) => models) - // Sort the models - .sort((a, b) => { - // Check if models have position property - const aHasPosition = a.info?.meta?.position !== undefined; - const bHasPosition = b.info?.meta?.position !== undefined; - - // If both a and b have the position property - if (aHasPosition && bHasPosition) { - return a.info.meta.position - b.info.meta.position; - } - - // If only a has the position property, it should come first - if (aHasPosition) return -1; - - // If only b has the position property, it should come first - if (bHasPosition) return 1; - - // Compare case-insensitively by name for models without position property - const lowerA = a.name.toLowerCase(); - const lowerB = b.name.toLowerCase(); - - if (lowerA < lowerB) return -1; - if (lowerA > lowerB) return 1; - - // If same case-insensitively, sort by original strings, - // lowercase will come before uppercase due to ASCII values - if (a.name < b.name) return -1; - if (a.name > b.name) return 1; - - return 0; // They are equal - }); - - console.log(models); return models; }; @@ -147,7 +110,7 @@ export const chatAction = async (token: string, action_id: string, body: ChatAct export const getTaskConfig = async (token: string = '') => { let error = null; - const res = await fetch(`${WEBUI_BASE_URL}/api/task/config`, { + const res = await fetch(`${WEBUI_BASE_URL}/api/v1/tasks/config`, { method: 'GET', headers: { Accept: 'application/json', @@ -175,7 +138,7 @@ export const getTaskConfig = async (token: string = '') => { export const updateTaskConfig = async (token: string, config: object) => { let error = null; - const res = await fetch(`${WEBUI_BASE_URL}/api/task/config/update`, { + const res = await fetch(`${WEBUI_BASE_URL}/api/v1/tasks/config/update`, { method: 'POST', headers: { Accept: 'application/json', @@ -213,7 +176,7 @@ export const generateTitle = async ( ) => { let error = null; - const res = await fetch(`${WEBUI_BASE_URL}/api/task/title/completions`, { + const res = await fetch(`${WEBUI_BASE_URL}/api/v1/tasks/title/completions`, { method: 'POST', headers: { Accept: 'application/json', @@ -253,7 +216,7 @@ export const generateTags = async ( ) => { let error = null; - const res = await fetch(`${WEBUI_BASE_URL}/api/task/tags/completions`, { + const res = await fetch(`${WEBUI_BASE_URL}/api/v1/tasks/tags/completions`, { method: 'POST', headers: { Accept: 'application/json', @@ -325,7 +288,7 @@ export const generateEmoji = async ( ) => { let error = null; - const res = await fetch(`${WEBUI_BASE_URL}/api/task/emoji/completions`, { + const res = await fetch(`${WEBUI_BASE_URL}/api/v1/tasks/emoji/completions`, { method: 'POST', headers: { Accept: 'application/json', @@ -365,15 +328,16 @@ export const generateEmoji = async ( return null; }; -export const generateSearchQuery = async ( +export const generateQueries = async ( token: string = '', model: string, messages: object[], - prompt: string + prompt: string, + type?: string = 'web_search' ) => { let error = null; - const res = await fetch(`${WEBUI_BASE_URL}/api/task/query/completions`, { + const res = await fetch(`${WEBUI_BASE_URL}/api/v1/tasks/queries/completions`, { method: 'POST', headers: { Accept: 'application/json', @@ -383,7 +347,8 @@ export const generateSearchQuery = async ( body: JSON.stringify({ model: model, messages: messages, - prompt: prompt + prompt: prompt, + type: type }) }) .then(async (res) => { @@ -402,7 +367,105 @@ export const generateSearchQuery = async ( throw error; } - return res?.choices[0]?.message?.content.replace(/["']/g, '') ?? prompt; + // Step 1: Safely extract the response string + const response = res?.choices[0]?.message?.content ?? ''; + + try { + const jsonStartIndex = response.indexOf('{'); + const jsonEndIndex = response.lastIndexOf('}'); + + if (jsonStartIndex !== -1 && jsonEndIndex !== -1) { + const jsonResponse = response.substring(jsonStartIndex, jsonEndIndex + 1); + + // Step 5: Parse the JSON block + const parsed = JSON.parse(jsonResponse); + + // Step 6: If there's a "queries" key, return the queries array; otherwise, return an empty array + if (parsed && parsed.queries) { + return Array.isArray(parsed.queries) ? parsed.queries : []; + } else { + return []; + } + } + + // If no valid JSON block found, return response as is + return [response]; + } catch (e) { + // Catch and safely return empty array on any parsing errors + console.error('Failed to parse response: ', e); + return [response]; + } +}; + +export const generateAutoCompletion = async ( + token: string = '', + model: string, + prompt: string, + messages?: object[], + type: string = 'search query' +) => { + const controller = new AbortController(); + let error = null; + + const res = await fetch(`${WEBUI_BASE_URL}/api/v1/tasks/auto/completions`, { + signal: controller.signal, + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + model: model, + prompt: prompt, + ...(messages && { messages: messages }), + type: type, + stream: false + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + if ('detail' in err) { + error = err.detail; + } + return null; + }); + + if (error) { + throw error; + } + + const response = res?.choices[0]?.message?.content ?? ''; + + try { + const jsonStartIndex = response.indexOf('{'); + const jsonEndIndex = response.lastIndexOf('}'); + + if (jsonStartIndex !== -1 && jsonEndIndex !== -1) { + const jsonResponse = response.substring(jsonStartIndex, jsonEndIndex + 1); + + // Step 5: Parse the JSON block + const parsed = JSON.parse(jsonResponse); + + // Step 6: If there's a "queries" key, return the queries array; otherwise, return an empty array + if (parsed && parsed.text) { + return parsed.text; + } else { + return ''; + } + } + + // If no valid JSON block found, return response as is + return response; + } catch (e) { + // Catch and safely return empty array on any parsing errors + console.error('Failed to parse response: ', e); + return response; + } }; export const generateMoACompletion = async ( @@ -414,7 +477,7 @@ export const generateMoACompletion = async ( const controller = new AbortController(); let error = null; - const res = await fetch(`${WEBUI_BASE_URL}/api/task/moa/completions`, { + const res = await fetch(`${WEBUI_BASE_URL}/api/v1/tasks/moa/completions`, { signal: controller.signal, method: 'POST', headers: { @@ -444,7 +507,7 @@ export const generateMoACompletion = async ( export const getPipelinesList = async (token: string = '') => { let error = null; - const res = await fetch(`${WEBUI_BASE_URL}/api/pipelines/list`, { + const res = await fetch(`${WEBUI_BASE_URL}/api/v1/pipelines/list`, { method: 'GET', headers: { Accept: 'application/json', @@ -478,7 +541,7 @@ export const uploadPipeline = async (token: string, file: File, urlIdx: string) formData.append('file', file); formData.append('urlIdx', urlIdx); - const res = await fetch(`${WEBUI_BASE_URL}/api/pipelines/upload`, { + const res = await fetch(`${WEBUI_BASE_URL}/api/v1/pipelines/upload`, { method: 'POST', headers: { ...(token && { authorization: `Bearer ${token}` }) @@ -510,7 +573,7 @@ export const uploadPipeline = async (token: string, file: File, urlIdx: string) export const downloadPipeline = async (token: string, url: string, urlIdx: string) => { let error = null; - const res = await fetch(`${WEBUI_BASE_URL}/api/pipelines/add`, { + const res = await fetch(`${WEBUI_BASE_URL}/api/v1/pipelines/add`, { method: 'POST', headers: { Accept: 'application/json', @@ -546,7 +609,7 @@ export const downloadPipeline = async (token: string, url: string, urlIdx: strin export const deletePipeline = async (token: string, id: string, urlIdx: string) => { let error = null; - const res = await fetch(`${WEBUI_BASE_URL}/api/pipelines/delete`, { + const res = await fetch(`${WEBUI_BASE_URL}/api/v1/pipelines/delete`, { method: 'DELETE', headers: { Accept: 'application/json', @@ -587,7 +650,7 @@ export const getPipelines = async (token: string, urlIdx?: string) => { searchParams.append('urlIdx', urlIdx); } - const res = await fetch(`${WEBUI_BASE_URL}/api/pipelines?${searchParams.toString()}`, { + const res = await fetch(`${WEBUI_BASE_URL}/api/v1/pipelines?${searchParams.toString()}`, { method: 'GET', headers: { Accept: 'application/json', @@ -622,7 +685,7 @@ export const getPipelineValves = async (token: string, pipeline_id: string, urlI } const res = await fetch( - `${WEBUI_BASE_URL}/api/pipelines/${pipeline_id}/valves?${searchParams.toString()}`, + `${WEBUI_BASE_URL}/api/v1/pipelines/${pipeline_id}/valves?${searchParams.toString()}`, { method: 'GET', headers: { @@ -658,7 +721,7 @@ export const getPipelineValvesSpec = async (token: string, pipeline_id: string, } const res = await fetch( - `${WEBUI_BASE_URL}/api/pipelines/${pipeline_id}/valves/spec?${searchParams.toString()}`, + `${WEBUI_BASE_URL}/api/v1/pipelines/${pipeline_id}/valves/spec?${searchParams.toString()}`, { method: 'GET', headers: { @@ -699,7 +762,7 @@ export const updatePipelineValves = async ( } const res = await fetch( - `${WEBUI_BASE_URL}/api/pipelines/${pipeline_id}/valves/update?${searchParams.toString()}`, + `${WEBUI_BASE_URL}/api/v1/pipelines/${pipeline_id}/valves/update?${searchParams.toString()}`, { method: 'POST', headers: { diff --git a/src/lib/apis/knowledge/index.ts b/src/lib/apis/knowledge/index.ts index 842866899..c5fad1323 100644 --- a/src/lib/apis/knowledge/index.ts +++ b/src/lib/apis/knowledge/index.ts @@ -1,6 +1,11 @@ import { WEBUI_API_BASE_URL } from '$lib/constants'; -export const createNewKnowledge = async (token: string, name: string, description: string) => { +export const createNewKnowledge = async ( + token: string, + name: string, + description: string, + accessControl: null | object +) => { let error = null; const res = await fetch(`${WEBUI_API_BASE_URL}/knowledge/create`, { @@ -12,7 +17,8 @@ export const createNewKnowledge = async (token: string, name: string, descriptio }, body: JSON.stringify({ name: name, - description: description + description: description, + access_control: accessControl }) }) .then(async (res) => { @@ -32,7 +38,7 @@ export const createNewKnowledge = async (token: string, name: string, descriptio return res; }; -export const getKnowledgeItems = async (token: string = '') => { +export const getKnowledgeBases = async (token: string = '') => { let error = null; const res = await fetch(`${WEBUI_API_BASE_URL}/knowledge/`, { @@ -63,6 +69,37 @@ export const getKnowledgeItems = async (token: string = '') => { return res; }; +export const getKnowledgeBaseList = async (token: string = '') => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/knowledge/list`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + export const getKnowledgeById = async (token: string, id: string) => { let error = null; @@ -99,6 +136,7 @@ type KnowledgeUpdateForm = { name?: string; description?: string; data?: object; + access_control?: null | object; }; export const updateKnowledgeById = async (token: string, id: string, form: KnowledgeUpdateForm) => { @@ -114,7 +152,8 @@ export const updateKnowledgeById = async (token: string, id: string, form: Knowl body: JSON.stringify({ name: form?.name ? form.name : undefined, description: form?.description ? form.description : undefined, - data: form?.data ? form.data : undefined + data: form?.data ? form.data : undefined, + access_control: form.access_control }) }) .then(async (res) => { diff --git a/src/lib/apis/models/index.ts b/src/lib/apis/models/index.ts index 9faa358d3..5880874bb 100644 --- a/src/lib/apis/models/index.ts +++ b/src/lib/apis/models/index.ts @@ -1,38 +1,9 @@ import { WEBUI_API_BASE_URL } from '$lib/constants'; -export const addNewModel = async (token: string, model: object) => { +export const getModels = async (token: string = '') => { let error = null; - const res = await fetch(`${WEBUI_API_BASE_URL}/models/add`, { - method: 'POST', - headers: { - Accept: 'application/json', - 'Content-Type': 'application/json', - authorization: `Bearer ${token}` - }, - body: JSON.stringify(model) - }) - .then(async (res) => { - if (!res.ok) throw await res.json(); - return res.json(); - }) - .catch((err) => { - error = err.detail; - console.log(err); - return null; - }); - - if (error) { - throw error; - } - - return res; -}; - -export const getModelInfos = async (token: string = '') => { - let error = null; - - const res = await fetch(`${WEBUI_API_BASE_URL}/models`, { + const res = await fetch(`${WEBUI_API_BASE_URL}/models/`, { method: 'GET', headers: { Accept: 'application/json', @@ -60,13 +31,73 @@ export const getModelInfos = async (token: string = '') => { return res; }; +export const getBaseModels = async (token: string = '') => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/models/base`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err; + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const createNewModel = async (token: string, model: object) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/models/create`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + }, + body: JSON.stringify(model) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + error = err.detail; + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + export const getModelById = async (token: string, id: string) => { let error = null; const searchParams = new URLSearchParams(); searchParams.append('id', id); - const res = await fetch(`${WEBUI_API_BASE_URL}/models?${searchParams.toString()}`, { + const res = await fetch(`${WEBUI_API_BASE_URL}/models/model?${searchParams.toString()}`, { method: 'GET', headers: { Accept: 'application/json', @@ -95,13 +126,48 @@ export const getModelById = async (token: string, id: string) => { return res; }; +export const toggleModelById = async (token: string, id: string) => { + let error = null; + + const searchParams = new URLSearchParams(); + searchParams.append('id', id); + + const res = await fetch(`${WEBUI_API_BASE_URL}/models/model/toggle?${searchParams.toString()}`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err; + + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + export const updateModelById = async (token: string, id: string, model: object) => { let error = null; const searchParams = new URLSearchParams(); searchParams.append('id', id); - const res = await fetch(`${WEBUI_API_BASE_URL}/models/update?${searchParams.toString()}`, { + const res = await fetch(`${WEBUI_API_BASE_URL}/models/model/update?${searchParams.toString()}`, { method: 'POST', headers: { Accept: 'application/json', @@ -137,7 +203,39 @@ export const deleteModelById = async (token: string, id: string) => { const searchParams = new URLSearchParams(); searchParams.append('id', id); - const res = await fetch(`${WEBUI_API_BASE_URL}/models/delete?${searchParams.toString()}`, { + const res = await fetch(`${WEBUI_API_BASE_URL}/models/model/delete?${searchParams.toString()}`, { + method: 'DELETE', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err; + + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const deleteAllModels = async (token: string) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/models/delete/all`, { method: 'DELETE', headers: { Accept: 'application/json', diff --git a/src/lib/apis/ollama/index.ts b/src/lib/apis/ollama/index.ts index d4e994312..16eed9f21 100644 --- a/src/lib/apis/ollama/index.ts +++ b/src/lib/apis/ollama/index.ts @@ -1,5 +1,40 @@ import { OLLAMA_API_BASE_URL } from '$lib/constants'; +export const verifyOllamaConnection = async ( + token: string = '', + url: string = '', + key: string = '' +) => { + let error = null; + + const res = await fetch(`${OLLAMA_API_BASE_URL}/verify`, { + method: 'POST', + headers: { + Accept: 'application/json', + Authorization: `Bearer ${token}`, + 'Content-Type': 'application/json' + }, + body: JSON.stringify({ + url, + key + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + error = `Ollama: ${err?.error?.message ?? 'Network Problem'}`; + return []; + }); + + if (error) { + throw error; + } + + return res; +}; + export const getOllamaConfig = async (token: string = '') => { let error = null; @@ -32,7 +67,13 @@ export const getOllamaConfig = async (token: string = '') => { return res; }; -export const updateOllamaConfig = async (token: string = '', enable_ollama_api: boolean) => { +type OllamaConfig = { + ENABLE_OLLAMA_API: boolean; + OLLAMA_BASE_URLS: string[]; + OLLAMA_API_CONFIGS: object; +}; + +export const updateOllamaConfig = async (token: string = '', config: OllamaConfig) => { let error = null; const res = await fetch(`${OLLAMA_API_BASE_URL}/config/update`, { @@ -43,7 +84,7 @@ export const updateOllamaConfig = async (token: string = '', enable_ollama_api: ...(token && { authorization: `Bearer ${token}` }) }, body: JSON.stringify({ - enable_ollama_api: enable_ollama_api + ...config }) }) .then(async (res) => { @@ -166,10 +207,10 @@ export const getOllamaVersion = async (token: string, urlIdx?: number) => { return res?.version ?? false; }; -export const getOllamaModels = async (token: string = '') => { +export const getOllamaModels = async (token: string = '', urlIdx: null | number = null) => { let error = null; - const res = await fetch(`${OLLAMA_API_BASE_URL}/api/tags`, { + const res = await fetch(`${OLLAMA_API_BASE_URL}/api/tags${urlIdx !== null ? `/${urlIdx}` : ''}`, { method: 'GET', headers: { Accept: 'application/json', diff --git a/src/lib/apis/openai/index.ts b/src/lib/apis/openai/index.ts index 2bb11d12a..1988dc0c3 100644 --- a/src/lib/apis/openai/index.ts +++ b/src/lib/apis/openai/index.ts @@ -32,7 +32,14 @@ export const getOpenAIConfig = async (token: string = '') => { return res; }; -export const updateOpenAIConfig = async (token: string = '', enable_openai_api: boolean) => { +type OpenAIConfig = { + ENABLE_OPENAI_API: boolean; + OPENAI_API_BASE_URLS: string[]; + OPENAI_API_KEYS: string[]; + OPENAI_API_CONFIGS: object; +}; + +export const updateOpenAIConfig = async (token: string = '', config: OpenAIConfig) => { let error = null; const res = await fetch(`${OPENAI_API_BASE_URL}/config/update`, { @@ -43,7 +50,7 @@ export const updateOpenAIConfig = async (token: string = '', enable_openai_api: ...(token && { authorization: `Bearer ${token}` }) }, body: JSON.stringify({ - enable_openai_api: enable_openai_api + ...config }) }) .then(async (res) => { @@ -231,41 +238,39 @@ export const getOpenAIModels = async (token: string, urlIdx?: number) => { return res; }; -export const getOpenAIModelsDirect = async ( - base_url: string = 'https://api.openai.com/v1', - api_key: string = '' +export const verifyOpenAIConnection = async ( + token: string = '', + url: string = 'https://api.openai.com/v1', + key: string = '' ) => { let error = null; - const res = await fetch(`${base_url}/models`, { - method: 'GET', + const res = await fetch(`${OPENAI_API_BASE_URL}/verify`, { + method: 'POST', headers: { - 'Content-Type': 'application/json', - Authorization: `Bearer ${api_key}` - } + Accept: 'application/json', + Authorization: `Bearer ${token}`, + 'Content-Type': 'application/json' + }, + body: JSON.stringify({ + url, + key + }) }) .then(async (res) => { if (!res.ok) throw await res.json(); return res.json(); }) .catch((err) => { - console.log(err); error = `OpenAI: ${err?.error?.message ?? 'Network Problem'}`; - return null; + return []; }); if (error) { throw error; } - const models = Array.isArray(res) ? res : (res?.data ?? null); - - return models - .map((model) => ({ id: model.id, name: model.name ?? model.id, external: true })) - .filter((model) => (base_url.includes('openai') ? model.name.includes('gpt') : true)) - .sort((a, b) => { - return a.name.localeCompare(b.name); - }); + return res; }; export const generateOpenAIChatCompletion = async ( diff --git a/src/lib/apis/prompts/index.ts b/src/lib/apis/prompts/index.ts index ca9c7d543..0d796e5cd 100644 --- a/src/lib/apis/prompts/index.ts +++ b/src/lib/apis/prompts/index.ts @@ -1,11 +1,13 @@ import { WEBUI_API_BASE_URL } from '$lib/constants'; -export const createNewPrompt = async ( - token: string, - command: string, - title: string, - content: string -) => { +type PromptItem = { + command: string; + title: string; + content: string; + access_control: null | object; +}; + +export const createNewPrompt = async (token: string, prompt: PromptItem) => { let error = null; const res = await fetch(`${WEBUI_API_BASE_URL}/prompts/create`, { @@ -16,9 +18,8 @@ export const createNewPrompt = async ( authorization: `Bearer ${token}` }, body: JSON.stringify({ - command: `/${command}`, - title: title, - content: content + ...prompt, + command: `/${prompt.command}` }) }) .then(async (res) => { @@ -69,6 +70,37 @@ export const getPrompts = async (token: string = '') => { return res; }; +export const getPromptList = async (token: string = '') => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/prompts/list`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + export const getPromptByCommand = async (token: string, command: string) => { let error = null; @@ -101,15 +133,10 @@ export const getPromptByCommand = async (token: string, command: string) => { return res; }; -export const updatePromptByCommand = async ( - token: string, - command: string, - title: string, - content: string -) => { +export const updatePromptByCommand = async (token: string, prompt: PromptItem) => { let error = null; - const res = await fetch(`${WEBUI_API_BASE_URL}/prompts/command/${command}/update`, { + const res = await fetch(`${WEBUI_API_BASE_URL}/prompts/command/${prompt.command}/update`, { method: 'POST', headers: { Accept: 'application/json', @@ -117,9 +144,8 @@ export const updatePromptByCommand = async ( authorization: `Bearer ${token}` }, body: JSON.stringify({ - command: `/${command}`, - title: title, - content: content + ...prompt, + command: `/${prompt.command}` }) }) .then(async (res) => { diff --git a/src/lib/apis/retrieval/index.ts b/src/lib/apis/retrieval/index.ts index 6c6b18b9f..21ae792fa 100644 --- a/src/lib/apis/retrieval/index.ts +++ b/src/lib/apis/retrieval/index.ts @@ -40,6 +40,7 @@ type ContentExtractConfigForm = { type YoutubeConfigForm = { language: string[]; translation?: string | null; + proxy_url: string; }; type RAGConfigForm = { diff --git a/src/lib/apis/streaming/index.ts b/src/lib/apis/streaming/index.ts index a8249abe0..5617ce36c 100644 --- a/src/lib/apis/streaming/index.ts +++ b/src/lib/apis/streaming/index.ts @@ -5,7 +5,7 @@ type TextStreamUpdate = { done: boolean; value: string; // eslint-disable-next-line @typescript-eslint/no-explicit-any - citations?: any; + sources?: any; // eslint-disable-next-line @typescript-eslint/no-explicit-any selectedModelId?: any; error?: any; @@ -67,8 +67,8 @@ async function* openAIStreamToIterator( break; } - if (parsedData.citations) { - yield { done: false, value: '', citations: parsedData.citations }; + if (parsedData.sources) { + yield { done: false, value: '', sources: parsedData.sources }; continue; } @@ -77,10 +77,14 @@ async function* openAIStreamToIterator( continue; } + if (parsedData.usage) { + yield { done: false, value: '', usage: parsedData.usage }; + continue; + } + yield { done: false, value: parsedData.choices?.[0]?.delta?.content ?? '', - usage: parsedData.usage }; } catch (e) { console.error('Error extracting delta from SSE event:', e); @@ -98,10 +102,26 @@ async function* streamLargeDeltasAsRandomChunks( yield textStreamUpdate; return; } - if (textStreamUpdate.citations) { + + if (textStreamUpdate.error) { yield textStreamUpdate; continue; } + if (textStreamUpdate.sources) { + yield textStreamUpdate; + continue; + } + if (textStreamUpdate.selectedModelId) { + yield textStreamUpdate; + continue; + } + if (textStreamUpdate.usage) { + yield textStreamUpdate; + continue; + } + + + let content = textStreamUpdate.value; if (content.length < 5) { yield { done: false, value: content }; diff --git a/src/lib/apis/tools/index.ts b/src/lib/apis/tools/index.ts index 28e8dde86..d1dc11c16 100644 --- a/src/lib/apis/tools/index.ts +++ b/src/lib/apis/tools/index.ts @@ -62,6 +62,37 @@ export const getTools = async (token: string = '') => { return res; }; +export const getToolList = async (token: string = '') => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/tools/list`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + export const exportTools = async (token: string = '') => { let error = null; diff --git a/src/lib/apis/users/index.ts b/src/lib/apis/users/index.ts index 0b22b7171..b0efe39d2 100644 --- a/src/lib/apis/users/index.ts +++ b/src/lib/apis/users/index.ts @@ -1,10 +1,10 @@ import { WEBUI_API_BASE_URL } from '$lib/constants'; import { getUserPosition } from '$lib/utils'; -export const getUserPermissions = async (token: string) => { +export const getUserGroups = async (token: string) => { let error = null; - const res = await fetch(`${WEBUI_API_BASE_URL}/users/permissions/user`, { + const res = await fetch(`${WEBUI_API_BASE_URL}/users/groups`, { method: 'GET', headers: { 'Content-Type': 'application/json', @@ -28,10 +28,37 @@ export const getUserPermissions = async (token: string) => { return res; }; -export const updateUserPermissions = async (token: string, permissions: object) => { +export const getUserDefaultPermissions = async (token: string) => { let error = null; - const res = await fetch(`${WEBUI_API_BASE_URL}/users/permissions/user`, { + const res = await fetch(`${WEBUI_API_BASE_URL}/users/default/permissions`, { + method: 'GET', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + error = err.detail; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const updateUserDefaultPermissions = async (token: string, permissions: object) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/users/default/permissions`, { method: 'POST', headers: { 'Content-Type': 'application/json', diff --git a/src/lib/components/ChangelogModal.svelte b/src/lib/components/ChangelogModal.svelte index 9002b9100..b395ddcbd 100644 --- a/src/lib/components/ChangelogModal.svelte +++ b/src/lib/components/ChangelogModal.svelte @@ -22,7 +22,7 @@ }); - +
@@ -59,7 +59,7 @@
-
+
{#if changelog} {#each Object.keys(changelog) as version} @@ -111,7 +111,7 @@ await updateUserSettings(localStorage.token, { ui: $settings }); show = false; }} - class=" px-4 py-2 bg-emerald-700 hover:bg-emerald-800 text-gray-100 transition rounded-lg" + class="px-3.5 py-1.5 text-sm font-medium bg-black hover:bg-gray-900 text-white dark:bg-white dark:text-black dark:hover:bg-gray-100 transition rounded-full" > {$i18n.t("Okay, Let's Go!")} diff --git a/src/lib/components/OnBoarding.svelte b/src/lib/components/OnBoarding.svelte new file mode 100644 index 000000000..2add98a6c --- /dev/null +++ b/src/lib/components/OnBoarding.svelte @@ -0,0 +1,78 @@ + + +{#if show} +
+
+
+
+ logo +
+
+
+ + + +
+ +
+ +
+
+
+ + +
{$i18n.t(`wherever you are`)}
+ + +
+
+ +
{$i18n.t(`Get started`)}
+
+
+ + + + + +{/if} diff --git a/src/lib/components/admin/Evaluations.svelte b/src/lib/components/admin/Evaluations.svelte index 4632238f0..a5532ae2f 100644 --- a/src/lib/components/admin/Evaluations.svelte +++ b/src/lib/components/admin/Evaluations.svelte @@ -1,677 +1,100 @@ - {#if loaded} -
-
-
- {$i18n.t('Leaderboard')} -
- -
- - {rankedModels.length} +
+
- -
- -
-
- -
- { - loadEmbeddingModel(); - }} - /> -
-
-
-
- -
- {#if loadingLeaderboard} -
-
- -
-
- {/if} - {#if (rankedModels ?? []).length === 0} -
- {$i18n.t('No models found')} -
- {:else} - - - - - - - - - - - - {#each rankedModels as model, modelIdx (model.id)} - - - - - - - - - - {/each} - -
- {$i18n.t('RK')} - - {$i18n.t('Model')} - - {$i18n.t('Rating')} - - {$i18n.t('Won')} - - {$i18n.t('Lost')} -
-
- {model?.rating !== '-' ? modelIdx + 1 : '-'} -
-
-
-
- {model.name} -
- -
- {model.name} -
-
-
- {model.rating} - -
- {#if model.stats.won === '-'} - - - {:else} - - {model.stats.won} - {/if} -
-
-
- {#if model.stats.lost === '-'} - - - {:else} - - {model.stats.lost} - {/if} -
-
- {/if} -
- -
-
-
- ⓘ {$i18n.t( - 'The evaluation leaderboard is based on the Elo rating system and is updated in real-time.' - )} -
- {$i18n.t( - 'The leaderboard is currently in beta, and we may adjust the rating calculations as we refine the algorithm.' - )} -
-
- -
- -
-
- {$i18n.t('Feedback History')} - -
- - {feedbacks.length} -
- -
-
- - - -
-
-
+ + +
+
{$i18n.t('Leaderboard')}
+ -
- {#if (feedbacks ?? []).length === 0} -
- {$i18n.t('No feedbacks found')} -
- {:else} - { + selectedTab = 'feedbacks'; + }} > - - - - - - - - - - - - - - - {#each paginatedFeedbacks as feedback (feedback.id)} - - - - - - - - - - - {/each} - -
- {$i18n.t('User')} - - {$i18n.t('Models')} - - {$i18n.t('Result')} - - {$i18n.t('Updated At')} -
-
- -
- {feedback?.user?.name} -
-
-
-
-
-
- {#if feedback.data?.sibling_model_ids} -
- {feedback.data?.model_id} -
- - -
- {#if feedback.data.sibling_model_ids.length > 2} - - {feedback.data.sibling_model_ids.slice(0, 2).join(', ')}, {$i18n.t( - 'and {{COUNT}} more', - { COUNT: feedback.data.sibling_model_ids.length - 2 } - )} - {:else} - {feedback.data.sibling_model_ids.join(', ')} - {/if} -
-
- {:else} -
- {feedback.data?.model_id} -
- {/if} -
-
-
-
- {#if feedback.data.rating.toString() === '1'} - - {:else if feedback.data.rating.toString() === '0'} - - {:else if feedback.data.rating.toString() === '-1'} - - {/if} -
-
- {dayjs(feedback.updated_at * 1000).fromNow()} - - { - deleteFeedbackHandler(feedback.id); - }} - > - - -
- {/if} -
- - {#if feedbacks.length > 0} -
-
- {$i18n.t('Help us create the best community leaderboard by sharing your feedback history!')} -
- -
- - - -
+ + +
+
{$i18n.t('Feedbacks')}
+
- {/if} - {#if feedbacks.length > 10} - - {/if} - -
+
+ {#if selectedTab === 'leaderboard'} + + {:else if selectedTab === 'feedbacks'} + + {/if} +
+
{/if} diff --git a/src/lib/components/admin/Evaluations/Feedbacks.svelte b/src/lib/components/admin/Evaluations/Feedbacks.svelte new file mode 100644 index 000000000..e43081302 --- /dev/null +++ b/src/lib/components/admin/Evaluations/Feedbacks.svelte @@ -0,0 +1,283 @@ + + +
+
+ {$i18n.t('Feedback History')} + +
+ + {feedbacks.length} +
+ +
+
+ + + +
+
+
+ +
+ {#if (feedbacks ?? []).length === 0} +
+ {$i18n.t('No feedbacks found')} +
+ {:else} + + + + + + + + + + + + + + + + {#each paginatedFeedbacks as feedback (feedback.id)} + + + + + + + + + + + {/each} + +
+ {$i18n.t('User')} + + {$i18n.t('Models')} + + {$i18n.t('Result')} + + {$i18n.t('Updated At')} +
+
+ +
+ {feedback?.user?.name} +
+
+
+
+
+
+ {#if feedback.data?.sibling_model_ids} +
+ {feedback.data?.model_id} +
+ + +
+ {#if feedback.data.sibling_model_ids.length > 2} + + {feedback.data.sibling_model_ids.slice(0, 2).join(', ')}, {$i18n.t( + 'and {{COUNT}} more', + { COUNT: feedback.data.sibling_model_ids.length - 2 } + )} + {:else} + {feedback.data.sibling_model_ids.join(', ')} + {/if} +
+
+ {:else} +
+ {feedback.data?.model_id} +
+ {/if} +
+
+
+
+ {#if feedback.data.rating.toString() === '1'} + + {:else if feedback.data.rating.toString() === '0'} + + {:else if feedback.data.rating.toString() === '-1'} + + {/if} +
+
+ {dayjs(feedback.updated_at * 1000).fromNow()} + + { + deleteFeedbackHandler(feedback.id); + }} + > + + +
+ {/if} +
+ +{#if feedbacks.length > 0} +
+
+ {$i18n.t('Help us create the best community leaderboard by sharing your feedback history!')} +
+ +
+ + + +
+
+{/if} + +{#if feedbacks.length > 10} + +{/if} diff --git a/src/lib/components/admin/Evaluations/Leaderboard.svelte b/src/lib/components/admin/Evaluations/Leaderboard.svelte new file mode 100644 index 000000000..59f6df916 --- /dev/null +++ b/src/lib/components/admin/Evaluations/Leaderboard.svelte @@ -0,0 +1,410 @@ + + +
+
+
+ {$i18n.t('Leaderboard')} +
+ +
+ + {rankedModels.length} +
+ +
+ +
+
+ +
+ { + loadEmbeddingModel(); + }} + /> +
+
+
+
+ +
+ {#if loadingLeaderboard} +
+
+ +
+
+ {/if} + {#if (rankedModels ?? []).length === 0} +
+ {$i18n.t('No models found')} +
+ {:else} + + + + + + + + + + + + {#each rankedModels as model, modelIdx (model.id)} + + + + + + + + + + {/each} + +
+ {$i18n.t('RK')} + + {$i18n.t('Model')} + + {$i18n.t('Rating')} + + {$i18n.t('Won')} + + {$i18n.t('Lost')} +
+
+ {model?.rating !== '-' ? modelIdx + 1 : '-'} +
+
+
+
+ {model.name} +
+ +
+ {model.name} +
+
+
+ {model.rating} + +
+ {#if model.stats.won === '-'} + - + {:else} + + {model.stats.won} + {/if} +
+
+
+ {#if model.stats.lost === '-'} + - + {:else} + + {model.stats.lost} + {/if} +
+
+ {/if} +
+ +
+
+
+ ⓘ {$i18n.t( + 'The evaluation leaderboard is based on the Elo rating system and is updated in real-time.' + )} +
+ {$i18n.t( + 'The leaderboard is currently in beta, and we may adjust the rating calculations as we refine the algorithm.' + )} +
+
diff --git a/src/lib/components/workspace/Functions.svelte b/src/lib/components/admin/Functions.svelte similarity index 84% rename from src/lib/components/workspace/Functions.svelte rename to src/lib/components/admin/Functions.svelte index c488c57c2..03da04ea7 100644 --- a/src/lib/components/workspace/Functions.svelte +++ b/src/lib/components/admin/Functions.svelte @@ -5,7 +5,6 @@ import { WEBUI_NAME, config, functions, models } from '$lib/stores'; import { onMount, getContext, tick } from 'svelte'; - import { createNewPrompt, deletePromptByCommand, getPrompts } from '$lib/apis/prompts'; import { goto } from '$app/navigation'; import { @@ -25,11 +24,14 @@ import FunctionMenu from './Functions/FunctionMenu.svelte'; import EllipsisHorizontal from '../icons/EllipsisHorizontal.svelte'; import Switch from '../common/Switch.svelte'; - import ValvesModal from './common/ValvesModal.svelte'; - import ManifestModal from './common/ManifestModal.svelte'; + import ValvesModal from '../workspace/common/ValvesModal.svelte'; + import ManifestModal from '../workspace/common/ManifestModal.svelte'; import Heart from '../icons/Heart.svelte'; import DeleteConfirmDialog from '$lib/components/common/ConfirmDialog.svelte'; import GarbageBin from '../icons/GarbageBin.svelte'; + import Search from '../icons/Search.svelte'; + import Plus from '../icons/Plus.svelte'; + import ChevronRight from '../icons/ChevronRight.svelte'; const i18n = getContext('i18n'); @@ -48,12 +50,14 @@ let showDeleteConfirm = false; let filteredItems = []; - $: filteredItems = $functions.filter( - (f) => - query === '' || - f.name.toLowerCase().includes(query.toLowerCase()) || - f.id.toLowerCase().includes(query.toLowerCase()) - ); + $: filteredItems = $functions + .filter( + (f) => + query === '' || + f.name.toLowerCase().includes(query.toLowerCase()) || + f.id.toLowerCase().includes(query.toLowerCase()) + ) + .sort((a, b) => a.type.localeCompare(b.type) || a.name.localeCompare(b.name)); const shareHandler = async (func) => { const item = await getFunctionById(localStorage.token, func.id).catch((error) => { @@ -94,7 +98,7 @@ id: `${_function.id}_clone`, name: `${_function.name} (Clone)` }); - goto('/workspace/functions/create'); + goto('/admin/functions/create'); } }; @@ -182,68 +186,46 @@ -
-
-
- - - -
- -
- - -
- -
+
-
+
{$i18n.t('Functions')}
- {filteredItems.length}{filteredItems.length} +
+
+ +
+
+
+ +
+ +
+ +
-
- {#each filteredItems as func} +
+ {#each filteredItems as func (func.id)}
@@ -340,7 +322,7 @@ { - goto(`/workspace/functions/edit?id=${encodeURIComponent(func.id)}`); + goto(`/admin/functions/edit?id=${encodeURIComponent(func.id)}`); }} shareHandler={() => { shareHandler(func); @@ -470,40 +452,27 @@ {#if $config?.features.enable_community_sharing}
- {/if} diff --git a/src/lib/components/workspace/Functions/FunctionEditor.svelte b/src/lib/components/admin/Functions/FunctionEditor.svelte similarity index 91% rename from src/lib/components/workspace/Functions/FunctionEditor.svelte rename to src/lib/components/admin/Functions/FunctionEditor.svelte index ec65bae5d..187110be0 100644 --- a/src/lib/components/workspace/Functions/FunctionEditor.svelte +++ b/src/lib/components/admin/Functions/FunctionEditor.svelte @@ -305,7 +305,7 @@ class Pipe:
- + + +
@@ -329,31 +331,37 @@ class Pipe:
-
+
{#if edit}
{id}
{:else} - + + + {/if} - + + +
diff --git a/src/lib/components/workspace/Functions/FunctionMenu.svelte b/src/lib/components/admin/Functions/FunctionMenu.svelte similarity index 100% rename from src/lib/components/workspace/Functions/FunctionMenu.svelte rename to src/lib/components/admin/Functions/FunctionMenu.svelte diff --git a/src/lib/components/admin/Settings.svelte b/src/lib/components/admin/Settings.svelte index 7e1e489b1..f0886ea5c 100644 --- a/src/lib/components/admin/Settings.svelte +++ b/src/lib/components/admin/Settings.svelte @@ -2,11 +2,11 @@ import { getContext, tick, onMount } from 'svelte'; import { toast } from 'svelte-sonner'; + import { config } from '$lib/stores'; + import { getBackendConfig } from '$lib/apis'; import Database from './Settings/Database.svelte'; import General from './Settings/General.svelte'; - import Users from './Settings/Users.svelte'; - import Pipelines from './Settings/Pipelines.svelte'; import Audio from './Settings/Audio.svelte'; import Images from './Settings/Images.svelte'; @@ -15,8 +15,7 @@ import Connections from './Settings/Connections.svelte'; import Documents from './Settings/Documents.svelte'; import WebSearch from './Settings/WebSearch.svelte'; - import { config } from '$lib/stores'; - import { getBackendConfig } from '$lib/apis'; + import ChartBar from '../icons/ChartBar.svelte'; import DocumentChartBar from '../icons/DocumentChartBar.svelte'; import Evaluations from './Settings/Evaluations.svelte'; @@ -39,16 +38,16 @@ }); -
+
- -
-
+
{#if selectedTab === 'general'} { @@ -361,12 +336,6 @@ await config.set(await getBackendConfig()); }} /> - {:else if selectedTab === 'users'} - { - toast.success($i18n.t('Settings saved successfully!')); - }} - /> {:else if selectedTab === 'connections'} { diff --git a/src/lib/components/admin/Settings/Audio.svelte b/src/lib/components/admin/Settings/Audio.svelte index ae827e6ab..a7a030027 100644 --- a/src/lib/components/admin/Settings/Audio.svelte +++ b/src/lib/components/admin/Settings/Audio.svelte @@ -181,7 +181,7 @@
+ @@ -333,7 +334,7 @@
+ {:else if TTS_ENGINE === 'transformers'} +
+
{$i18n.t('TTS Model')}
+
+
+ + + + +
+
+
+ {$i18n.t(`Open WebUI uses SpeechT5 and CMU Arctic speaker embeddings.`)} + + To learn more about SpeechT5, + + + {$i18n.t(`click here`, { + name: 'SpeechT5' + })}. + + To see the available CMU Arctic speaker embeddings, + + {$i18n.t(`click here`)}. + +
+
{:else if TTS_ENGINE === 'openai'}
diff --git a/src/lib/components/admin/Settings/Connections.svelte b/src/lib/components/admin/Settings/Connections.svelte index 292d22993..ddc19bb8f 100644 --- a/src/lib/components/admin/Settings/Connections.svelte +++ b/src/lib/components/admin/Settings/Connections.svelte @@ -1,31 +1,23 @@ + + + +
{ updateOpenAIHandler(); - updateOllamaUrlsHandler(); + updateOllamaHandler(); dispatch('save'); }} > -
+
{#if ENABLE_OPENAI_API !== null && ENABLE_OLLAMA_API !== null} -
+
{$i18n.t('OpenAI API')}
-
- { - updateOpenAIConfig(localStorage.token, ENABLE_OPENAI_API); - }} - /> +
+
+ { + updateOpenAIHandler(); + }} + /> +
{#if ENABLE_OPENAI_API} -
- {#each OPENAI_API_BASE_URLS as url, idx} -
-
- +
- {#if pipelineUrls[url]} -
- - - - - - - -
- {/if} -
+
+
+
{$i18n.t('Manage OpenAI API Connections')}
- + + +
+ +
+ {#each OPENAI_API_BASE_URLS as url, idx} + { + updateOpenAIHandler(); + }} + onDelete={() => { + OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS.filter( + (url, urlIdx) => idx !== urlIdx + ); + OPENAI_API_KEYS = OPENAI_API_KEYS.filter((key, keyIdx) => idx !== keyIdx); + }} /> -
- {#if idx === 0} - - {:else} - - {/if} -
- -
- - - -
-
-
- {$i18n.t('WebUI will make requests to')} - '{url}/models' -
- {/each} + {/each} +
{/if}
-
+
-
-
+
+
{$i18n.t('Ollama API')}
{ - updateOllamaConfig(localStorage.token, ENABLE_OLLAMA_API); - - if (OLLAMA_BASE_URLS.length === 0) { - OLLAMA_BASE_URLS = ['']; - } + updateOllamaHandler(); }} />
+ {#if ENABLE_OLLAMA_API} -
-
- {#each OLLAMA_BASE_URLS as url, idx} -
- +
-
- {#if idx === 0} - - {:else} - - {/if} -
+
+
+
{$i18n.t('Manage Ollama API Connections')}
-
- - - -
-
- {/each} + + +
-
-
- {$i18n.t('Trouble accessing Ollama?')} - - {$i18n.t('Click here for help.')} - +
+
+ {#each OLLAMA_BASE_URLS as url, idx} + { + updateOllamaHandler(); + }} + onDelete={() => { + OLLAMA_BASE_URLS = OLLAMA_BASE_URLS.filter((url, urlIdx) => idx !== urlIdx); + }} + /> + {/each} +
+
+ +
+ {$i18n.t('Trouble accessing Ollama?')} + + {$i18n.t('Click here for help.')} + +
{/if}
diff --git a/src/lib/components/admin/Settings/Connections/AddConnectionModal.svelte b/src/lib/components/admin/Settings/Connections/AddConnectionModal.svelte new file mode 100644 index 000000000..3f24dc6d7 --- /dev/null +++ b/src/lib/components/admin/Settings/Connections/AddConnectionModal.svelte @@ -0,0 +1,365 @@ + + + +
+
+
+ {#if edit} + {$i18n.t('Edit Connection')} + {:else} + {$i18n.t('Add Connection')} + {/if} +
+ +
+ +
+
+ { + e.preventDefault(); + submitHandler(); + }} + > +
+
+
+
{$i18n.t('URL')}
+ +
+ +
+
+ + + + + +
+ + + +
+
+ +
+
+
{$i18n.t('Key')}
+ +
+ +
+
+ +
+
{$i18n.t('Prefix ID')}
+ +
+ + + +
+
+
+ +
+ +
+
+
{$i18n.t('Model IDs')}
+
+ + {#if modelIds.length > 0} +
+ {#each modelIds as modelId, modelIdx} +
+
+ {modelId} +
+
+ +
+
+ {/each} +
+ {:else} +
+ {#if ollama} + {$i18n.t('Leave empty to include all models from "{{URL}}/api/tags" endpoint', { + URL: url + })} + {:else} + {$i18n.t('Leave empty to include all models from "{{URL}}/models" endpoint', { + URL: url + })} + {/if} +
+ {/if} +
+ +
+ +
+ + +
+ +
+
+
+ +
+ {#if edit} + + {/if} + + +
+ +
+
+
+
diff --git a/src/lib/components/admin/Settings/Connections/ManageOllamaModal.svelte b/src/lib/components/admin/Settings/Connections/ManageOllamaModal.svelte new file mode 100644 index 000000000..220214ed1 --- /dev/null +++ b/src/lib/components/admin/Settings/Connections/ManageOllamaModal.svelte @@ -0,0 +1,1054 @@ + + + { + deleteModelHandler(); + }} +/> + + +
+
+
+
+ {$i18n.t('Manage Ollama')} +
+ +
+ + + +
+
+ +
+ +
+ {#if !loading} +
+
+
+ {#if updateModelId} +
+ Updating "{updateModelId}" {updateProgress ? `(${updateProgress}%)` : ''} +
+ {/if} + +
+
+ {$i18n.t('Pull a model from Ollama.com')} +
+
+
+ +
+ +
+ +
+ {$i18n.t('To access the available model names for downloading,')} + {$i18n.t('click here.')} +
+ + {#if Object.keys($MODEL_DOWNLOAD_POOL).length > 0} + {#each Object.keys($MODEL_DOWNLOAD_POOL) as model} + {#if 'pullProgress' in $MODEL_DOWNLOAD_POOL[model]} +
+
{model}
+
+
+
+
+ {$MODEL_DOWNLOAD_POOL[model].pullProgress ?? 0}% +
+
+ + + + +
+ {#if 'digest' in $MODEL_DOWNLOAD_POOL[model]} +
+ {$MODEL_DOWNLOAD_POOL[model].digest} +
+ {/if} +
+
+ {/if} + {/each} + {/if} +
+ +
+
{$i18n.t('Delete a model')}
+
+
+ +
+ +
+
+ +
+
{$i18n.t('Create a model')}
+
+
+ + +