fix merge conflicts

This commit is contained in:
Taylor Wilsdon 2024-12-16 15:29:49 -05:00
commit 7c9f04bd59
383 changed files with 45034 additions and 23866 deletions

View File

@ -16,4 +16,5 @@ _old
uploads
.ipynb_checkpoints
**/*.db
_test
_test
backend/data/*

View File

@ -28,6 +28,8 @@ jobs:
steps:
- name: Checkout repository
uses: actions/checkout@v4
with:
lfs: true
- name: Remove git history
run: rm -rf .git
@ -52,7 +54,9 @@ jobs:
- name: Set up Git and push to Space
run: |
git init --initial-branch=main
git lfs install
git lfs track "*.ttf"
git lfs track "*.jpg"
rm demo.gif
git add .
git commit -m "GitHub deploy: ${{ github.sha }}"

View File

@ -5,10 +5,177 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
### Added
- **🌐 Enhanced Translations**: Added Slovak language, improved Czech language.
## [0.4.8] - 2024-12-07
### Added
- **🔓 Bypass Model Access Control**: Introduced the 'BYPASS_MODEL_ACCESS_CONTROL' environment variable. Easily bypass model access controls for user roles when access control isn't required, simplifying workflows for trusted environments.
- **📝 Markdown in Banners**: Now supports markdown for banners, enabling richer, more visually engaging announcements.
- **🌐 Internationalization Updates**: Enhanced translations across multiple languages, further improving accessibility and global user experience.
- **🎨 Styling Enhancements**: General UI style refinements for a cleaner and more polished interface.
- **📋 Rich Text Reliability**: Improved the reliability and stability of rich text input across chats for smoother interactions.
### Fixed
- **💡 Tailwind Build Issue**: Resolved a breaking bug caused by Tailwind, ensuring smoother builds and overall system reliability.
- **📚 Knowledge Collection Query Fix**: Addressed API endpoint issues with querying knowledge collections, ensuring accurate and reliable information retrieval.
## [0.4.7] - 2024-12-01
### Added
- **✨ Prompt Input Auto-Completion**: Type a prompt and let AI intelligently suggest and complete your inputs. Simply press 'Tab' or swipe right on mobile to confirm. Available only with Rich Text Input (default setting). Disable via Admin Settings for full control.
- **🌍 Improved Translations**: Enhanced localization for multiple languages, ensuring a more polished and accessible experience for international users.
### Fixed
- **🛠️ Tools Export Issue**: Resolved a critical issue where exporting tools wasnt functioning, restoring seamless export capabilities.
- **🔗 Model ID Registration**: Fixed an issue where model IDs werent registering correctly in the model editor, ensuring reliable model setup and tracking.
- **🖋️ Textarea Auto-Expansion**: Corrected a bug where textareas didnt expand automatically on certain browsers, improving usability for multi-line inputs.
- **🔧 Ollama Embed Endpoint**: Addressed the /ollama/embed endpoint malfunction, ensuring consistent performance and functionality.
### Changed
- **🎨 Knowledge Base Styling**: Refined knowledge base visuals for a cleaner, more modern look, laying the groundwork for further enhancements in upcoming releases.
## [0.4.6] - 2024-11-26
### Added
- **🌍 Enhanced Translations**: Various language translations improved to make the WebUI more accessible and user-friendly worldwide.
### Fixed
- **✏️ Textarea Shifting Bug**: Resolved the issue where the textarea shifted unexpectedly, ensuring a smoother typing experience.
- **⚙️ Model Configuration Modal**: Fixed the issue where the models configuration modal introduced in 0.4.5 wasnt working for some users.
- **🔍 Legacy Query Support**: Restored functionality for custom query generation in RAG when using legacy prompts, ensuring both default and custom templates now work seamlessly.
- **⚡ Improved General Reliability**: Various minor fixes improve platform stability and ensure a smoother overall experience across workflows.
## [0.4.5] - 2024-11-26
### Added
- **🎨 Model Order/Defaults Reintroduced**: Brought back the ability to set model order and default models, now configurable via Admin Settings > Models > Configure (Gear Icon).
### Fixed
- **🔍 Query Generation Issue**: Resolved an error in web search query generation, enhancing search accuracy and ensuring smoother search workflows.
- **📏 Textarea Auto Height Bug**: Fixed a layout issue where textarea input height was shifting unpredictably, particularly when editing system prompts.
- **🔑 Ollama Authentication**: Corrected an issue with Ollamas authorization headers, guaranteeing reliable authentication across all endpoints.
- **⚙️ Missing Min_P Save**: Resolved an issue where the 'min_p' parameter was not being saved in configurations.
- **🛠️ Tools Description**: Fixed a key issue that omitted tool descriptions in tools payload.
## [0.4.4] - 2024-11-22
### Added
- **🌐 Translation Updates**: Refreshed Catalan, Brazilian Portuguese, German, and Ukrainian translations, further enhancing the platform's accessibility and improving the experience for international users.
### Fixed
- **📱 Mobile Controls Visibility**: Resolved an issue where the controls button was not displaying on the new chats page for mobile users, ensuring smoother navigation and functionality on smaller screens.
- **📷 LDAP Profile Image Issue**: Fixed an LDAP integration bug related to profile images, ensuring seamless authentication and a reliable login experience for users.
- **⏳ RAG Query Generation Issue**: Addressed a significant problem where RAG query generation occurred unnecessarily without attached files, drastically improving speed and reducing delays during chat completions.
### Changed
- **⚙️ Legacy Event Emitter Support**: Reintroduced compatibility with legacy "citation" types for event emitters in tools and functions, providing smoother workflows and broader tool support for users.
## [0.4.3] - 2024-11-21
### Added
- **📚 Inline Citations for RAG Results**: Get seamless inline citations for Retrieval-Augmented Generation (RAG) responses using the default RAG prompt. Note: This feature only supports newly uploaded files, improving traceability and providing source clarity.
- **🎨 Better Rich Text Input Support**: Enjoy smoother and more reliable rich text formatting for chats, enhancing communication quality.
- **⚡ Faster Model Retrieval**: Implemented caching optimizations for faster model loading, providing a noticeable speed boost across workflows. Further improvements are on the way!
### Fixed
- **🔗 Pipelines Feature Restored**: Resolved a critical issue that previously prevented Pipelines from functioning, ensuring seamless workflows.
- **✏️ Missing Suffix Field in Ollama Form**: Added the missing "suffix" field to the Ollama generate form, enhancing customization options.
### Changed
- **🗂️ Renamed "Citations" to "Sources"**: Improved clarity and consistency by renaming the "citations" field to "sources" in messages.
## [0.4.2] - 2024-11-20
### Fixed
- **📁 Knowledge Files Visibility Issue**: Resolved the bug preventing individual files in knowledge collections from displaying when referenced with '#'.
- **🔗 OpenAI Endpoint Prefix**: Fixed the issue where certain OpenAI connections that deviate from the official API spec werent working correctly with prefixes.
- **⚔️ Arena Model Access Control**: Corrected an issue where arena model access control settings were not being saved.
- **🔧 Usage Capability Selector**: Fixed the broken usage capabilities selector in the model editor.
## [0.4.1] - 2024-11-19
### Added
- **📊 Enhanced Feedback System**: Introduced a detailed 1-10 rating scale for feedback alongside thumbs up/down, preparing for more precise model fine-tuning and improving feedback quality.
- ** Tool Descriptions on Hover**: Easily access tool descriptions by hovering over the message input, providing a smoother workflow with more context when utilizing tools.
### Fixed
- **🗑️ Graceful Handling of Deleted Users**: Resolved an issue where deleted users caused workspace items (models, knowledge, prompts, tools) to fail, ensuring reliable workspace loading.
- **🔑 API Key Creation**: Fixed an issue preventing users from creating new API keys, restoring secure and seamless API management.
- **🔗 HTTPS Proxy Fix**: Corrected HTTPS proxy issues affecting the '/api/v1/models/' endpoint, ensuring smoother, uninterrupted model management.
## [0.4.0] - 2024-11-19
### Added
- **👥 User Groups**: You can now create and manage user groups, making user organization seamless.
- **🔐 Group-Based Access Control**: Set granular access to models, knowledge, prompts, and tools based on user groups, allowing for more controlled and secure environments.
- **🛠️ Group-Based User Permissions**: Easily manage workspace permissions. Grant users the ability to upload files, delete, edit, or create temporary chats, as well as define their ability to create models, knowledge, prompts, and tools.
- **🔑 LDAP Support**: Newly introduced LDAP authentication adds robust security and scalability to user management.
- **🌐 Enhanced OpenAI-Compatible Connections**: Added prefix ID support to avoid model ID clashes, with explicit model ID support for APIs lacking '/models' endpoint support, ensuring smooth operation with custom setups.
- **🔐 Ollama API Key Support**: Now manage credentials for Ollama when set behind proxies, including the option to utilize prefix ID for proper distinction across multiple Ollama instances.
- **🔄 Connection Enable/Disable Toggle**: Easily enable or disable individual OpenAI and Ollama connections as needed.
- **🎨 Redesigned Model Workspace**: Freshly redesigned to improve usability for managing models across users and groups.
- **🎨 Redesigned Prompt Workspace**: A fresh UI to conveniently organize and manage prompts.
- **🧩 Sorted Functions Workspace**: Functions are now automatically categorized by type (Action, Filter, Pipe), streamlining management.
- **💻 Redesigned Collaborative Workspace**: Enhanced support for multiple users contributing to models, knowledge, prompts, or tools, improving collaboration.
- **🔧 Auto-Selected Tools in Model Editor**: Tools enabled through the model editor are now automatically selected, whereas previously it only gave users the option to enable the tool, reducing manual steps and enhancing efficiency.
- **🔔 Web Search & Tools Indicator**: A clear indication now shows when web search or tools are active, reducing confusion.
- **🔑 Toggle API Key Auth**: Tighten security by easily enabling or disabling API key authentication option for Open WebUI.
- **🗂️ Agentic Retrieval**: Improve RAG accuracy via smart pre-processing of chat history to determine the best queries before retrieval.
- **📁 Large Text as File Option**: Optionally convert large pasted text into a file upload, keeping the chat interface cleaner.
- **🗂️ Toggle Citations for Models**: Ability to disable citations has been introduced in the model editor.
- **🔍 User Settings Search**: Quickly search for settings fields, improving ease of use and navigation.
- **🗣️ Experimental SpeechT5 TTS**: Local SpeechT5 support added for improved text-to-speech capabilities.
- **🔄 Unified Reset for Models**: A one-click option has been introduced to reset and remove all models from the Admin Settings.
- **🛠️ Initial Setup Wizard**: The setup process now explicitly informs users that they are creating an admin account during the first-time setup, ensuring clarity. Previously, users encountered the login page right away without this distinction.
- **🌐 Enhanced Translations**: Several language translations, including Ukrainian, Norwegian, and Brazilian Portuguese, were refined for better localization.
### Fixed
- **🎥 YouTube Video Attachments**: Fixed issues preventing proper loading and attachment of YouTube videos as files.
- **🔄 Shared Chat Update**: Corrected issues where shared chats were not updating, improving collaboration consistency.
- **🔍 DuckDuckGo Rate Limit Fix**: Addressed issues with DuckDuckGo search integration, enhancing search stability and performance when operating within rate limits.
- **🧾 Citations Relevance Fix**: Adjusted the relevance percentage calculation for citations, so that Open WebUI properly reflect the accuracy of a retrieved document in RAG, ensuring users get clearer insights into sources.
- **🔑 Jina Search API Key Requirement**: Added the option to input an API key for Jina Search, ensuring smooth functionality as keys are now mandatory.
### Changed
- **🛠️ Functions Moved to Admin Panel**: As Functions operate as advanced plugins, they are now accessible from the Admin Panel instead of the workspace.
- **🛠️ Manage Ollama Connections**: The "Models" section in Admin Settings has been relocated to Admin Settings > "Connections" > Ollama Connections. You can now manage Ollama instances via a dedicated "Manage Ollama" modal from "Connections", streamlining the setup and configuration of Ollama models.
- **📊 Base Models in Admin Settings**: Admins can now find all base models, both connections or functions, in the "Models" Admin setting. Global model accessibility can be enabled or disabled here. Models are private by default, requiring explicit permission assignment for user access.
- **📌 Sticky Model Selection for New Chats**: The model chosen from a previous chat now persists when creating a new chat. If you click "New Chat" again from the new chat page, it will revert to your default model.
- **🎨 Design Refactoring**: Overall design refinements across the platform have been made, providing a more cohesive and polished user experience.
### Removed
- **📂 Model List Reordering**: Temporarily removed and will be reintroduced in upcoming user group settings improvements.
- **⚙️ Default Model Setting**: Removed the ability to set a default model for users, will be reintroduced with user group settings in the future.
## [0.3.35] - 2024-10-26
### Added
- **🌐 Translation Update**: Added translation labels in the SearchInput and CreateCollection components and updated Brazilian Portuguese translation (pt-BR)
- **📁 Robust File Handling**: Enhanced file input handling for chat. If the content extraction fails or is empty, users will now receive a clear warning, preventing silent failures and ensuring you always know what's happening with your uploads.
- **🌍 New Language Support**: Introduced Hungarian translations and updated French translations, expanding the platform's language accessibility for a more global user base.

View File

@ -2,76 +2,98 @@
## Our Pledge
We as members, contributors, and leaders pledge to make participation in our
community a harassment-free experience for everyone, regardless of age, body
size, visible or invisible disability, ethnicity, sex characteristics, gender
identity and expression, level of experience, education, socio-economic status,
nationality, personal appearance, race, religion, or sexual identity
and orientation.
As members, contributors, and leaders of this community, we pledge to make participation in our open-source project a harassment-free experience for everyone, regardless of age, body size, visible or invisible disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, religion, or sexual identity and orientation.
We pledge to act and interact in ways that contribute to an open, welcoming, diverse, inclusive, and healthy community.
We are committed to creating and maintaining an open, respectful, and professional environment where positive contributions and meaningful discussions can flourish. By participating in this project, you agree to uphold these values and align your behavior to the standards outlined in this Code of Conduct.
## Why These Standards Are Important
Open-source projects rely on a community of volunteers dedicating their time, expertise, and effort toward a shared goal. These projects are inherently collaborative but also fragile, as the success of the project depends on the goodwill, energy, and productivity of those involved.
Maintaining a positive and respectful environment is essential to safeguarding the integrity of this project and protecting contributors' efforts. Behavior that disrupts this atmosphere—whether through hostility, entitlement, or unprofessional conduct—can severely harm the morale and productivity of the community. **Strict enforcement of these standards ensures a safe and supportive space for meaningful collaboration.**
This is a community where **respect and professionalism are mandatory.** Violations of these standards will result in **zero tolerance** and immediate enforcement to prevent disruption and ensure the well-being of all participants.
## Our Standards
Examples of behavior that contribute to a positive environment for our community include:
Examples of behavior that contribute to a positive and professional community include:
- Demonstrating empathy and kindness toward other people
- Being respectful of differing opinions, viewpoints, and experiences
- Giving and gracefully accepting constructive feedback
- Accepting responsibility and apologizing to those affected by our mistakes, and learning from the experience
- Focusing on what is best not just for us as individuals, but for the overall community
- **Respecting others.** Be considerate, listen actively, and engage with empathy toward others' viewpoints and experiences.
- **Constructive feedback.** Provide actionable, thoughtful, and respectful feedback that helps improve the project and encourages collaboration. Avoid unproductive negativity or hypercriticism.
- **Recognizing volunteer contributions.** Appreciate that contributors dedicate their free time and resources selflessly. Approach them with gratitude and patience.
- **Focusing on shared goals.** Collaborate in ways that prioritize the health, success, and sustainability of the community over individual agendas.
Examples of unacceptable behavior include:
- The use of sexualized language or imagery, and sexual attention or advances of any kind
- Trolling, insulting or derogatory comments, and personal or political attacks
- Public or private harassment
- Publishing others' private information, such as a physical or email address, without their explicit permission
- **Spamming of any kind**
- Aggressive sales tactics targeting our community members are strictly prohibited. You can mention your product if it's relevant to the discussion, but under no circumstances should you push it forcefully
- Other conduct which could reasonably be considered inappropriate in a professional setting
- The use of discriminatory, demeaning, or sexualized language or behavior.
- Personal attacks, derogatory comments, trolling, or inflammatory political or ideological arguments.
- Harassment, intimidation, or any behavior intended to create a hostile, uncomfortable, or unsafe environment.
- Publishing others' private information (e.g., physical or email addresses) without explicit permission.
- **Entitlement, demand, or aggression toward contributors.** Volunteers are under no obligation to provide immediate or personalized support. Rude or dismissive behavior will not be tolerated.
- **Unproductive or destructive behavior.** This includes venting frustration as hostility ("tantrums"), hypercriticism, attention-seeking negativity, or anything that distracts from the project's goals.
- **Spamming and promotional exploitation.** Sharing irrelevant product promotions or self-promotion in the community is not allowed unless it directly contributes value to the discussion.
### Feedback and Community Engagement
- **Constructive feedback is encouraged, but hostile or entitled behavior will result in immediate action.** If you disagree with elements of the project, we encourage you to offer meaningful improvements or fork the project if necessary. Healthy discussions and technical disagreements are welcome only when handled with professionalism.
- **Respect contributors' time and efforts.** No one is entitled to personalized or on-demand assistance. This is a community built on collaboration and shared effort; demanding or demeaning behavior undermines that trust and will not be allowed.
### Zero Tolerance: No Warnings, Immediate Action
This community operates under a **zero-tolerance policy.** Any behavior deemed unacceptable under this Code of Conduct will result in **immediate enforcement, without prior warning.**
We employ this approach to ensure that unproductive or disruptive behavior does not escalate further or cause unnecessary harm to other contributors. The standards are clear, and violations of any kind—whether mild or severe—will be addressed decisively to protect the community.
## Enforcement Responsibilities
Community leaders are responsible for clarifying and enforcing our standards of acceptable behavior and will take appropriate and fair corrective action in response to any behavior that they deem inappropriate, threatening, offensive, or harmful.
Community leaders are responsible for upholding and enforcing these standards. They are empowered to take **immediate and appropriate action** to address any behaviors they deem unacceptable under this Code of Conduct. These actions are taken with the goal of protecting the community and preserving its safe, positive, and productive environment.
## Scope
This Code of Conduct applies within all community spaces and also applies when an individual is officially representing the community in public spaces. Examples of representing our community include using an official e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event.
This Code of Conduct applies to all community spaces, including forums, repositories, social media accounts, and in-person events. It also applies when an individual represents the community in public settings, such as conferences or official communications.
## Enforcement
Additionally, any behavior outside of these defined spaces that negatively impacts the community or its members may fall within the scope of this Code of Conduct.
Instances of abusive, harassing, spamming, or otherwise unacceptable behavior may be reported to the community leaders responsible for enforcement at hello@openwebui.com. All complaints will be reviewed and investigated promptly and fairly.
## Reporting Violations
All community leaders are obligated to respect the privacy and security of the reporter of any incident.
Instances of unacceptable behavior can be reported to the leadership team at **hello@openwebui.com**. Reports will be handled promptly, confidentially, and with consideration for the safety and well-being of the reporter.
All community leaders are required to uphold confidentiality and impartiality when addressing reports of violations.
## Enforcement Guidelines
Community leaders will follow these Community Impact Guidelines in determining the consequences for any action they deem in violation of this Code of Conduct:
### Ban
### 1. Temporary Ban
**Community Impact**: Community leaders will issue a ban to any participant whose behavior is deemed unacceptable according to this Code of Conduct. Bans are enforced immediately and without prior notice.
**Community Impact**: Any violation of community standards, including but not limited to inappropriate language, unprofessional behavior, harassment, or spamming.
A ban may be temporary or permanent, depending on the severity of the violation. This includes—but is not limited to—behavior such as:
**Consequence**: A temporary ban from any sort of interaction or public communication with the community for a specified period of time. No public or private interaction with the people involved, including unsolicited interaction with those enforcing the Code of Conduct, is allowed during this period. Violating these terms may lead to a permanent ban.
- Harassment or abusive behavior toward contributors.
- Persistent negativity or hostility that disrupts the collaborative environment.
- Disrespectful, demanding, or aggressive interactions with others.
- Attempts to cause harm or sabotage the community.
### 2. Permanent Ban
**Consequence**: A banned individual is immediately removed from access to all community spaces, communication channels, and events. Community leaders reserve the right to enforce either a time-limited suspension or a permanent ban based on the specific circumstances of the violation.
**Community Impact**: Repeated or severe violations of community standards, including sustained inappropriate behavior, harassment of an individual, or aggression toward or disparagement of classes of individuals.
This approach ensures that disruptive behaviors are addressed swiftly and decisively in order to maintain the integrity and productivity of the community.
**Consequence**: A permanent ban from any sort of public interaction within the community.
## Why Zero Tolerance Is Necessary
Open-source projects thrive on collaboration, goodwill, and mutual respect. Toxic behaviors—such as entitlement, hostility, or persistent negativity—threaten not just individual contributors but the health of the project as a whole. Allowing such behaviors to persist robs contributors of their time, energy, and enthusiasm for the work they do.
By enforcing a zero-tolerance policy, we ensure that the community remains a safe, welcoming space for all participants. These measures are not about harshness—they are about protecting contributors and fostering a productive environment where innovation can thrive.
Our expectations are clear, and our enforcement reflects our commitment to this project's long-term success.
## Attribution
This Code of Conduct is adapted from the [Contributor Covenant][homepage],
version 2.0, available at
This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 2.0, available at
https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
Community Impact Guidelines were inspired by [Mozilla's code of conduct
enforcement ladder](https://github.com/mozilla/diversity).
Community Impact Guidelines were inspired by [Mozilla's code of conduct enforcement ladder](https://github.com/mozilla/diversity).
[homepage]: https://www.contributor-covenant.org
For answers to common questions about this code of conduct, see the FAQ at
https://www.contributor-covenant.org/faq. Translations are available at
For answers to common questions about this code of conduct, see the FAQ at
https://www.contributor-covenant.org/faq. Translations are available at
https://www.contributor-covenant.org/translations.

View File

@ -21,7 +21,7 @@ Open WebUI is an [extensible](https://github.com/open-webui/pipelines), feature-
- 🤝 **Ollama/OpenAI API Integration**: Effortlessly integrate OpenAI-compatible APIs for versatile conversations alongside Ollama models. Customize the OpenAI API URL to link with **LMStudio, GroqCloud, Mistral, OpenRouter, and more**.
- 🧩 **Pipelines, Open WebUI Plugin Support**: Seamlessly integrate custom logic and Python libraries into Open WebUI using [Pipelines Plugin Framework](https://github.com/open-webui/pipelines). Launch your Pipelines instance, set the OpenAI URL to the Pipelines URL, and explore endless possibilities. [Examples](https://github.com/open-webui/pipelines/tree/main/examples) include **Function Calling**, User **Rate Limiting** to control access, **Usage Monitoring** with tools like Langfuse, **Live Translation with LibreTranslate** for multilingual support, **Toxic Message Filtering** and much more.
- 🛡️ **Granular Permissions and User Groups**: By allowing administrators to create detailed user roles and permissions, we ensure a secure user environment. This granularity not only enhances security but also allows for customized user experiences, fostering a sense of ownership and responsibility amongst users.
- 📱 **Responsive Design**: Enjoy a seamless experience across Desktop PC, Laptop, and Mobile devices.
@ -37,7 +37,7 @@ Open WebUI is an [extensible](https://github.com/open-webui/pipelines), feature-
- 📚 **Local RAG Integration**: Dive into the future of chat interactions with groundbreaking Retrieval Augmented Generation (RAG) support. This feature seamlessly integrates document interactions into your chat experience. You can load documents directly into the chat or add files to your document library, effortlessly accessing them using the `#` command before a query.
- 🔍 **Web Search for RAG**: Perform web searches using providers like `SearXNG`, `Google PSE`, `Brave Search`, `serpstack`, `serper`, `Serply`, `DuckDuckGo`, `TavilySearch` and `SearchApi` and inject the results directly into your chat experience.
- 🔍 **Web Search for RAG**: Perform web searches using providers like `SearXNG`, `Google PSE`, `Brave Search`, `serpstack`, `serper`, `Serply`, `DuckDuckGo`, `TavilySearch`, `SearchApi` and `Bing` and inject the results directly into your chat experience.
- 🌐 **Web Browsing Capability**: Seamlessly integrate websites into your chat experience using the `#` command followed by a URL. This feature allows you to incorporate web content directly into your conversations, enhancing the richness and depth of your interactions.
@ -49,6 +49,8 @@ Open WebUI is an [extensible](https://github.com/open-webui/pipelines), feature-
- 🌐🌍 **Multilingual Support**: Experience Open WebUI in your preferred language with our internationalization (i18n) support. Join us in expanding our supported languages! We're actively seeking contributors!
- 🧩 **Pipelines, Open WebUI Plugin Support**: Seamlessly integrate custom logic and Python libraries into Open WebUI using [Pipelines Plugin Framework](https://github.com/open-webui/pipelines). Launch your Pipelines instance, set the OpenAI URL to the Pipelines URL, and explore endless possibilities. [Examples](https://github.com/open-webui/pipelines/tree/main/examples) include **Function Calling**, User **Rate Limiting** to control access, **Usage Monitoring** with tools like Langfuse, **Live Translation with LibreTranslate** for multilingual support, **Toxic Message Filtering** and much more.
- 🌟 **Continuous Updates**: We are committed to improving Open WebUI with regular updates, fixes, and new features.
Want to learn more about Open WebUI's features? Check out our [Open WebUI documentation](https://docs.openwebui.com/features) for a comprehensive overview!

View File

@ -1,641 +0,0 @@
import hashlib
import json
import logging
import os
import uuid
from functools import lru_cache
from pathlib import Path
from pydub import AudioSegment
from pydub.silence import split_on_silence
import requests
from open_webui.config import (
AUDIO_STT_ENGINE,
AUDIO_STT_MODEL,
AUDIO_STT_OPENAI_API_BASE_URL,
AUDIO_STT_OPENAI_API_KEY,
AUDIO_TTS_API_KEY,
AUDIO_TTS_ENGINE,
AUDIO_TTS_MODEL,
AUDIO_TTS_OPENAI_API_BASE_URL,
AUDIO_TTS_OPENAI_API_KEY,
AUDIO_TTS_SPLIT_ON,
AUDIO_TTS_VOICE,
AUDIO_TTS_AZURE_SPEECH_REGION,
AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT,
CACHE_DIR,
CORS_ALLOW_ORIGIN,
WHISPER_MODEL,
WHISPER_MODEL_AUTO_UPDATE,
WHISPER_MODEL_DIR,
AppConfig,
)
from open_webui.constants import ERROR_MESSAGES
from open_webui.env import ENV, SRC_LOG_LEVELS, DEVICE_TYPE
from fastapi import Depends, FastAPI, File, HTTPException, Request, UploadFile, status
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse
from pydantic import BaseModel
from open_webui.utils.utils import get_admin_user, get_verified_user
# Constants
MAX_FILE_SIZE_MB = 25
MAX_FILE_SIZE = MAX_FILE_SIZE_MB * 1024 * 1024 # Convert MB to bytes
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["AUDIO"])
app = FastAPI(docs_url="/docs" if ENV == "dev" else None, openapi_url="/openapi.json" if ENV == "dev" else None, redoc_url=None)
app.add_middleware(
CORSMiddleware,
allow_origins=CORS_ALLOW_ORIGIN,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.state.config = AppConfig()
app.state.config.STT_OPENAI_API_BASE_URL = AUDIO_STT_OPENAI_API_BASE_URL
app.state.config.STT_OPENAI_API_KEY = AUDIO_STT_OPENAI_API_KEY
app.state.config.STT_ENGINE = AUDIO_STT_ENGINE
app.state.config.STT_MODEL = AUDIO_STT_MODEL
app.state.config.WHISPER_MODEL = WHISPER_MODEL
app.state.faster_whisper_model = None
app.state.config.TTS_OPENAI_API_BASE_URL = AUDIO_TTS_OPENAI_API_BASE_URL
app.state.config.TTS_OPENAI_API_KEY = AUDIO_TTS_OPENAI_API_KEY
app.state.config.TTS_ENGINE = AUDIO_TTS_ENGINE
app.state.config.TTS_MODEL = AUDIO_TTS_MODEL
app.state.config.TTS_VOICE = AUDIO_TTS_VOICE
app.state.config.TTS_API_KEY = AUDIO_TTS_API_KEY
app.state.config.TTS_SPLIT_ON = AUDIO_TTS_SPLIT_ON
app.state.config.TTS_AZURE_SPEECH_REGION = AUDIO_TTS_AZURE_SPEECH_REGION
app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT = AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT
# setting device type for whisper model
whisper_device_type = DEVICE_TYPE if DEVICE_TYPE and DEVICE_TYPE == "cuda" else "cpu"
log.info(f"whisper_device_type: {whisper_device_type}")
SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/")
SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
def set_faster_whisper_model(model: str, auto_update: bool = False):
if model and app.state.config.STT_ENGINE == "":
from faster_whisper import WhisperModel
faster_whisper_kwargs = {
"model_size_or_path": model,
"device": whisper_device_type,
"compute_type": "int8",
"download_root": WHISPER_MODEL_DIR,
"local_files_only": not auto_update,
}
try:
app.state.faster_whisper_model = WhisperModel(**faster_whisper_kwargs)
except Exception:
log.warning(
"WhisperModel initialization failed, attempting download with local_files_only=False"
)
faster_whisper_kwargs["local_files_only"] = False
app.state.faster_whisper_model = WhisperModel(**faster_whisper_kwargs)
else:
app.state.faster_whisper_model = None
class TTSConfigForm(BaseModel):
OPENAI_API_BASE_URL: str
OPENAI_API_KEY: str
API_KEY: str
ENGINE: str
MODEL: str
VOICE: str
SPLIT_ON: str
AZURE_SPEECH_REGION: str
AZURE_SPEECH_OUTPUT_FORMAT: str
class STTConfigForm(BaseModel):
OPENAI_API_BASE_URL: str
OPENAI_API_KEY: str
ENGINE: str
MODEL: str
WHISPER_MODEL: str
class AudioConfigUpdateForm(BaseModel):
tts: TTSConfigForm
stt: STTConfigForm
from pydub import AudioSegment
from pydub.utils import mediainfo
def is_mp4_audio(file_path):
"""Check if the given file is an MP4 audio file."""
if not os.path.isfile(file_path):
print(f"File not found: {file_path}")
return False
info = mediainfo(file_path)
if (
info.get("codec_name") == "aac"
and info.get("codec_type") == "audio"
and info.get("codec_tag_string") == "mp4a"
):
return True
return False
def convert_mp4_to_wav(file_path, output_path):
"""Convert MP4 audio file to WAV format."""
audio = AudioSegment.from_file(file_path, format="mp4")
audio.export(output_path, format="wav")
print(f"Converted {file_path} to {output_path}")
@app.get("/config")
async def get_audio_config(user=Depends(get_admin_user)):
return {
"tts": {
"OPENAI_API_BASE_URL": app.state.config.TTS_OPENAI_API_BASE_URL,
"OPENAI_API_KEY": app.state.config.TTS_OPENAI_API_KEY,
"API_KEY": app.state.config.TTS_API_KEY,
"ENGINE": app.state.config.TTS_ENGINE,
"MODEL": app.state.config.TTS_MODEL,
"VOICE": app.state.config.TTS_VOICE,
"SPLIT_ON": app.state.config.TTS_SPLIT_ON,
"AZURE_SPEECH_REGION": app.state.config.TTS_AZURE_SPEECH_REGION,
"AZURE_SPEECH_OUTPUT_FORMAT": app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT,
},
"stt": {
"OPENAI_API_BASE_URL": app.state.config.STT_OPENAI_API_BASE_URL,
"OPENAI_API_KEY": app.state.config.STT_OPENAI_API_KEY,
"ENGINE": app.state.config.STT_ENGINE,
"MODEL": app.state.config.STT_MODEL,
"WHISPER_MODEL": app.state.config.WHISPER_MODEL,
},
}
@app.post("/config/update")
async def update_audio_config(
form_data: AudioConfigUpdateForm, user=Depends(get_admin_user)
):
app.state.config.TTS_OPENAI_API_BASE_URL = form_data.tts.OPENAI_API_BASE_URL
app.state.config.TTS_OPENAI_API_KEY = form_data.tts.OPENAI_API_KEY
app.state.config.TTS_API_KEY = form_data.tts.API_KEY
app.state.config.TTS_ENGINE = form_data.tts.ENGINE
app.state.config.TTS_MODEL = form_data.tts.MODEL
app.state.config.TTS_VOICE = form_data.tts.VOICE
app.state.config.TTS_SPLIT_ON = form_data.tts.SPLIT_ON
app.state.config.TTS_AZURE_SPEECH_REGION = form_data.tts.AZURE_SPEECH_REGION
app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT = (
form_data.tts.AZURE_SPEECH_OUTPUT_FORMAT
)
app.state.config.STT_OPENAI_API_BASE_URL = form_data.stt.OPENAI_API_BASE_URL
app.state.config.STT_OPENAI_API_KEY = form_data.stt.OPENAI_API_KEY
app.state.config.STT_ENGINE = form_data.stt.ENGINE
app.state.config.STT_MODEL = form_data.stt.MODEL
app.state.config.WHISPER_MODEL = form_data.stt.WHISPER_MODEL
set_faster_whisper_model(form_data.stt.WHISPER_MODEL, WHISPER_MODEL_AUTO_UPDATE)
return {
"tts": {
"OPENAI_API_BASE_URL": app.state.config.TTS_OPENAI_API_BASE_URL,
"OPENAI_API_KEY": app.state.config.TTS_OPENAI_API_KEY,
"API_KEY": app.state.config.TTS_API_KEY,
"ENGINE": app.state.config.TTS_ENGINE,
"MODEL": app.state.config.TTS_MODEL,
"VOICE": app.state.config.TTS_VOICE,
"SPLIT_ON": app.state.config.TTS_SPLIT_ON,
"AZURE_SPEECH_REGION": app.state.config.TTS_AZURE_SPEECH_REGION,
"AZURE_SPEECH_OUTPUT_FORMAT": app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT,
},
"stt": {
"OPENAI_API_BASE_URL": app.state.config.STT_OPENAI_API_BASE_URL,
"OPENAI_API_KEY": app.state.config.STT_OPENAI_API_KEY,
"ENGINE": app.state.config.STT_ENGINE,
"MODEL": app.state.config.STT_MODEL,
"WHISPER_MODEL": app.state.config.WHISPER_MODEL,
},
}
@app.post("/speech")
async def speech(request: Request, user=Depends(get_verified_user)):
body = await request.body()
name = hashlib.sha256(body).hexdigest()
file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3")
file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json")
# Check if the file already exists in the cache
if file_path.is_file():
return FileResponse(file_path)
if app.state.config.TTS_ENGINE == "openai":
headers = {}
headers["Authorization"] = f"Bearer {app.state.config.TTS_OPENAI_API_KEY}"
headers["Content-Type"] = "application/json"
try:
body = body.decode("utf-8")
body = json.loads(body)
body["model"] = app.state.config.TTS_MODEL
body = json.dumps(body).encode("utf-8")
except Exception:
pass
r = None
try:
r = requests.post(
url=f"{app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech",
data=body,
headers=headers,
stream=True,
)
r.raise_for_status()
# Save the streaming content to a file
with open(file_path, "wb") as f:
for chunk in r.iter_content(chunk_size=8192):
f.write(chunk)
with open(file_body_path, "w") as f:
json.dump(json.loads(body.decode("utf-8")), f)
# Return the saved file
return FileResponse(file_path)
except Exception as e:
log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"External: {res['error']['message']}"
except Exception:
error_detail = f"External: {e}"
raise HTTPException(
status_code=r.status_code if r != None else 500,
detail=error_detail,
)
elif app.state.config.TTS_ENGINE == "elevenlabs":
payload = None
try:
payload = json.loads(body.decode("utf-8"))
except Exception as e:
log.exception(e)
raise HTTPException(status_code=400, detail="Invalid JSON payload")
voice_id = payload.get("voice", "")
if voice_id not in get_available_voices():
raise HTTPException(
status_code=400,
detail="Invalid voice id",
)
url = f"https://api.elevenlabs.io/v1/text-to-speech/{voice_id}"
headers = {
"Accept": "audio/mpeg",
"Content-Type": "application/json",
"xi-api-key": app.state.config.TTS_API_KEY,
}
data = {
"text": payload["input"],
"model_id": app.state.config.TTS_MODEL,
"voice_settings": {"stability": 0.5, "similarity_boost": 0.5},
}
try:
r = requests.post(url, json=data, headers=headers)
r.raise_for_status()
# Save the streaming content to a file
with open(file_path, "wb") as f:
for chunk in r.iter_content(chunk_size=8192):
f.write(chunk)
with open(file_body_path, "w") as f:
json.dump(json.loads(body.decode("utf-8")), f)
# Return the saved file
return FileResponse(file_path)
except Exception as e:
log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"External: {res['error']['message']}"
except Exception:
error_detail = f"External: {e}"
raise HTTPException(
status_code=r.status_code if r != None else 500,
detail=error_detail,
)
elif app.state.config.TTS_ENGINE == "azure":
payload = None
try:
payload = json.loads(body.decode("utf-8"))
except Exception as e:
log.exception(e)
raise HTTPException(status_code=400, detail="Invalid JSON payload")
region = app.state.config.TTS_AZURE_SPEECH_REGION
language = app.state.config.TTS_VOICE
locale = "-".join(app.state.config.TTS_VOICE.split("-")[:1])
output_format = app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT
url = f"https://{region}.tts.speech.microsoft.com/cognitiveservices/v1"
headers = {
"Ocp-Apim-Subscription-Key": app.state.config.TTS_API_KEY,
"Content-Type": "application/ssml+xml",
"X-Microsoft-OutputFormat": output_format,
}
data = f"""<speak version="1.0" xmlns="http://www.w3.org/2001/10/synthesis" xml:lang="{locale}">
<voice name="{language}">{payload["input"]}</voice>
</speak>"""
response = requests.post(url, headers=headers, data=data)
if response.status_code == 200:
with open(file_path, "wb") as f:
f.write(response.content)
return FileResponse(file_path)
else:
log.error(f"Error synthesizing speech - {response.reason}")
raise HTTPException(
status_code=500, detail=f"Error synthesizing speech - {response.reason}"
)
def transcribe(file_path):
print("transcribe", file_path)
filename = os.path.basename(file_path)
file_dir = os.path.dirname(file_path)
id = filename.split(".")[0]
if app.state.config.STT_ENGINE == "":
if app.state.faster_whisper_model is None:
set_faster_whisper_model(app.state.config.WHISPER_MODEL)
model = app.state.faster_whisper_model
segments, info = model.transcribe(file_path, beam_size=5)
log.info(
"Detected language '%s' with probability %f"
% (info.language, info.language_probability)
)
transcript = "".join([segment.text for segment in list(segments)])
data = {"text": transcript.strip()}
# save the transcript to a json file
transcript_file = f"{file_dir}/{id}.json"
with open(transcript_file, "w") as f:
json.dump(data, f)
log.debug(data)
return data
elif app.state.config.STT_ENGINE == "openai":
if is_mp4_audio(file_path):
print("is_mp4_audio")
os.rename(file_path, file_path.replace(".wav", ".mp4"))
# Convert MP4 audio file to WAV format
convert_mp4_to_wav(file_path.replace(".wav", ".mp4"), file_path)
headers = {"Authorization": f"Bearer {app.state.config.STT_OPENAI_API_KEY}"}
files = {"file": (filename, open(file_path, "rb"))}
data = {"model": app.state.config.STT_MODEL}
log.debug(files, data)
r = None
try:
r = requests.post(
url=f"{app.state.config.STT_OPENAI_API_BASE_URL}/audio/transcriptions",
headers=headers,
files=files,
data=data,
)
r.raise_for_status()
data = r.json()
# save the transcript to a json file
transcript_file = f"{file_dir}/{id}.json"
with open(transcript_file, "w") as f:
json.dump(data, f)
print(data)
return data
except Exception as e:
log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"External: {res['error']['message']}"
except Exception:
error_detail = f"External: {e}"
raise Exception(error_detail)
@app.post("/transcriptions")
def transcription(
file: UploadFile = File(...),
user=Depends(get_verified_user),
):
log.info(f"file.content_type: {file.content_type}")
if file.content_type not in ["audio/mpeg", "audio/wav", "audio/ogg", "audio/x-m4a"]:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.FILE_NOT_SUPPORTED,
)
try:
ext = file.filename.split(".")[-1]
id = uuid.uuid4()
filename = f"{id}.{ext}"
contents = file.file.read()
file_dir = f"{CACHE_DIR}/audio/transcriptions"
os.makedirs(file_dir, exist_ok=True)
file_path = f"{file_dir}/{filename}"
with open(file_path, "wb") as f:
f.write(contents)
try:
if os.path.getsize(file_path) > MAX_FILE_SIZE: # file is bigger than 25MB
log.debug(f"File size is larger than {MAX_FILE_SIZE_MB}MB")
audio = AudioSegment.from_file(file_path)
audio = audio.set_frame_rate(16000).set_channels(1) # Compress audio
compressed_path = f"{file_dir}/{id}_compressed.opus"
audio.export(compressed_path, format="opus", bitrate="32k")
log.debug(f"Compressed audio to {compressed_path}")
file_path = compressed_path
if (
os.path.getsize(file_path) > MAX_FILE_SIZE
): # Still larger than 25MB after compression
log.debug(
f"Compressed file size is still larger than {MAX_FILE_SIZE_MB}MB: {os.path.getsize(file_path)}"
)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.FILE_TOO_LARGE(
size=f"{MAX_FILE_SIZE_MB}MB"
),
)
data = transcribe(file_path)
else:
data = transcribe(file_path)
file_path = file_path.split("/")[-1]
return {**data, "filename": file_path}
except Exception as e:
log.exception(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
except Exception as e:
log.exception(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
def get_available_models() -> list[dict]:
if app.state.config.TTS_ENGINE == "openai":
return [{"id": "tts-1"}, {"id": "tts-1-hd"}]
elif app.state.config.TTS_ENGINE == "elevenlabs":
headers = {
"xi-api-key": app.state.config.TTS_API_KEY,
"Content-Type": "application/json",
}
try:
response = requests.get(
"https://api.elevenlabs.io/v1/models", headers=headers, timeout=5
)
response.raise_for_status()
models = response.json()
return [
{"name": model["name"], "id": model["model_id"]} for model in models
]
except requests.RequestException as e:
log.error(f"Error fetching voices: {str(e)}")
return []
@app.get("/models")
async def get_models(user=Depends(get_verified_user)):
return {"models": get_available_models()}
def get_available_voices() -> dict:
"""Returns {voice_id: voice_name} dict"""
ret = {}
if app.state.config.TTS_ENGINE == "openai":
ret = {
"alloy": "alloy",
"echo": "echo",
"fable": "fable",
"onyx": "onyx",
"nova": "nova",
"shimmer": "shimmer",
}
elif app.state.config.TTS_ENGINE == "elevenlabs":
try:
ret = get_elevenlabs_voices()
except Exception:
# Avoided @lru_cache with exception
pass
elif app.state.config.TTS_ENGINE == "azure":
try:
region = app.state.config.TTS_AZURE_SPEECH_REGION
url = f"https://{region}.tts.speech.microsoft.com/cognitiveservices/voices/list"
headers = {"Ocp-Apim-Subscription-Key": app.state.config.TTS_API_KEY}
response = requests.get(url, headers=headers)
response.raise_for_status()
voices = response.json()
for voice in voices:
ret[voice["ShortName"]] = (
f"{voice['DisplayName']} ({voice['ShortName']})"
)
except requests.RequestException as e:
log.error(f"Error fetching voices: {str(e)}")
return ret
@lru_cache
def get_elevenlabs_voices() -> dict:
"""
Note, set the following in your .env file to use Elevenlabs:
AUDIO_TTS_ENGINE=elevenlabs
AUDIO_TTS_API_KEY=sk_... # Your Elevenlabs API key
AUDIO_TTS_VOICE=EXAVITQu4vr4xnSDxMaL # From https://api.elevenlabs.io/v1/voices
AUDIO_TTS_MODEL=eleven_multilingual_v2
"""
headers = {
"xi-api-key": app.state.config.TTS_API_KEY,
"Content-Type": "application/json",
}
try:
# TODO: Add retries
response = requests.get("https://api.elevenlabs.io/v1/voices", headers=headers)
response.raise_for_status()
voices_data = response.json()
voices = {}
for voice in voices_data.get("voices", []):
voices[voice["voice_id"]] = voice["name"]
except requests.RequestException as e:
# Avoid @lru_cache with exception
log.error(f"Error fetching voices: {str(e)}")
raise RuntimeError(f"Error fetching voices: {str(e)}")
return voices
@app.get("/voices")
async def get_voices(user=Depends(get_verified_user)):
return {"voices": [{"id": k, "name": v} for k, v in get_available_voices().items()]}

File diff suppressed because it is too large Load Diff

View File

@ -1,557 +0,0 @@
import asyncio
import hashlib
import json
import logging
from pathlib import Path
from typing import Literal, Optional, overload
import aiohttp
import requests
from open_webui.apps.webui.models.models import Models
from open_webui.config import (
CACHE_DIR,
CORS_ALLOW_ORIGIN,
ENABLE_MODEL_FILTER,
ENABLE_OPENAI_API,
MODEL_FILTER_LIST,
OPENAI_API_BASE_URLS,
OPENAI_API_KEYS,
AppConfig,
)
from open_webui.env import (
AIOHTTP_CLIENT_TIMEOUT,
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST,
)
from open_webui.constants import ERROR_MESSAGES
from open_webui.env import ENV, SRC_LOG_LEVELS
from fastapi import Depends, FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, StreamingResponse
from pydantic import BaseModel
from starlette.background import BackgroundTask
from open_webui.utils.payload import (
apply_model_params_to_body_openai,
apply_model_system_prompt_to_body,
)
from open_webui.utils.utils import get_admin_user, get_verified_user
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["OPENAI"])
app = FastAPI(docs_url="/docs" if ENV == "dev" else None, openapi_url="/openapi.json" if ENV == "dev" else None, redoc_url=None)
app.add_middleware(
CORSMiddleware,
allow_origins=CORS_ALLOW_ORIGIN,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.state.config = AppConfig()
app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
app.state.config.ENABLE_OPENAI_API = ENABLE_OPENAI_API
app.state.config.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS
app.state.config.OPENAI_API_KEYS = OPENAI_API_KEYS
app.state.MODELS = {}
@app.middleware("http")
async def check_url(request: Request, call_next):
if len(app.state.MODELS) == 0:
await get_all_models()
response = await call_next(request)
return response
@app.get("/config")
async def get_config(user=Depends(get_admin_user)):
return {"ENABLE_OPENAI_API": app.state.config.ENABLE_OPENAI_API}
class OpenAIConfigForm(BaseModel):
enable_openai_api: Optional[bool] = None
@app.post("/config/update")
async def update_config(form_data: OpenAIConfigForm, user=Depends(get_admin_user)):
app.state.config.ENABLE_OPENAI_API = form_data.enable_openai_api
return {"ENABLE_OPENAI_API": app.state.config.ENABLE_OPENAI_API}
class UrlsUpdateForm(BaseModel):
urls: list[str]
class KeysUpdateForm(BaseModel):
keys: list[str]
@app.get("/urls")
async def get_openai_urls(user=Depends(get_admin_user)):
return {"OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS}
@app.post("/urls/update")
async def update_openai_urls(form_data: UrlsUpdateForm, user=Depends(get_admin_user)):
await get_all_models()
app.state.config.OPENAI_API_BASE_URLS = form_data.urls
return {"OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS}
@app.get("/keys")
async def get_openai_keys(user=Depends(get_admin_user)):
return {"OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS}
@app.post("/keys/update")
async def update_openai_key(form_data: KeysUpdateForm, user=Depends(get_admin_user)):
app.state.config.OPENAI_API_KEYS = form_data.keys
return {"OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS}
@app.post("/audio/speech")
async def speech(request: Request, user=Depends(get_verified_user)):
idx = None
try:
idx = app.state.config.OPENAI_API_BASE_URLS.index("https://api.openai.com/v1")
body = await request.body()
name = hashlib.sha256(body).hexdigest()
SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/")
SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3")
file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json")
# Check if the file already exists in the cache
if file_path.is_file():
return FileResponse(file_path)
headers = {}
headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEYS[idx]}"
headers["Content-Type"] = "application/json"
if "openrouter.ai" in app.state.config.OPENAI_API_BASE_URLS[idx]:
headers["HTTP-Referer"] = "https://openwebui.com/"
headers["X-Title"] = "Open WebUI"
r = None
try:
r = requests.post(
url=f"{app.state.config.OPENAI_API_BASE_URLS[idx]}/audio/speech",
data=body,
headers=headers,
stream=True,
)
r.raise_for_status()
# Save the streaming content to a file
with open(file_path, "wb") as f:
for chunk in r.iter_content(chunk_size=8192):
f.write(chunk)
with open(file_body_path, "w") as f:
json.dump(json.loads(body.decode("utf-8")), f)
# Return the saved file
return FileResponse(file_path)
except Exception as e:
log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"External: {res['error']}"
except Exception:
error_detail = f"External: {e}"
raise HTTPException(
status_code=r.status_code if r else 500, detail=error_detail
)
except ValueError:
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.OPENAI_NOT_FOUND)
async def fetch_url(url, key):
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
try:
headers = {"Authorization": f"Bearer {key}"}
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
async with session.get(url, headers=headers) as response:
return await response.json()
except Exception as e:
# Handle connection error here
log.error(f"Connection error: {e}")
return None
async def cleanup_response(
response: Optional[aiohttp.ClientResponse],
session: Optional[aiohttp.ClientSession],
):
if response:
response.close()
if session:
await session.close()
def merge_models_lists(model_lists):
log.debug(f"merge_models_lists {model_lists}")
merged_list = []
for idx, models in enumerate(model_lists):
if models is not None and "error" not in models:
merged_list.extend(
[
{
**model,
"name": model.get("name", model["id"]),
"owned_by": "openai",
"openai": model,
"urlIdx": idx,
}
for model in models
if "api.openai.com"
not in app.state.config.OPENAI_API_BASE_URLS[idx]
or not any(
name in model["id"]
for name in [
"babbage",
"dall-e",
"davinci",
"embedding",
"tts",
"whisper",
]
)
]
)
return merged_list
def is_openai_api_disabled():
return not app.state.config.ENABLE_OPENAI_API
async def get_all_models_raw() -> list:
if is_openai_api_disabled():
return []
# Check if API KEYS length is same than API URLS length
num_urls = len(app.state.config.OPENAI_API_BASE_URLS)
num_keys = len(app.state.config.OPENAI_API_KEYS)
if num_keys != num_urls:
# if there are more keys than urls, remove the extra keys
if num_keys > num_urls:
new_keys = app.state.config.OPENAI_API_KEYS[:num_urls]
app.state.config.OPENAI_API_KEYS = new_keys
# if there are more urls than keys, add empty keys
else:
app.state.config.OPENAI_API_KEYS += [""] * (num_urls - num_keys)
tasks = [
fetch_url(f"{url}/models", app.state.config.OPENAI_API_KEYS[idx])
for idx, url in enumerate(app.state.config.OPENAI_API_BASE_URLS)
]
responses = await asyncio.gather(*tasks)
log.debug(f"get_all_models:responses() {responses}")
return responses
@overload
async def get_all_models(raw: Literal[True]) -> list: ...
@overload
async def get_all_models(raw: Literal[False] = False) -> dict[str, list]: ...
async def get_all_models(raw=False) -> dict[str, list] | list:
log.info("get_all_models()")
if is_openai_api_disabled():
return [] if raw else {"data": []}
responses = await get_all_models_raw()
if raw:
return responses
def extract_data(response):
if response and "data" in response:
return response["data"]
if isinstance(response, list):
return response
return None
models = {"data": merge_models_lists(map(extract_data, responses))}
log.debug(f"models: {models}")
app.state.MODELS = {model["id"]: model for model in models["data"]}
return models
@app.get("/models")
@app.get("/models/{url_idx}")
async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_user)):
if url_idx is None:
models = await get_all_models()
if app.state.config.ENABLE_MODEL_FILTER:
if user.role == "user":
models["data"] = list(
filter(
lambda model: model["id"] in app.state.config.MODEL_FILTER_LIST,
models["data"],
)
)
return models
return models
else:
url = app.state.config.OPENAI_API_BASE_URLS[url_idx]
key = app.state.config.OPENAI_API_KEYS[url_idx]
headers = {}
headers["Authorization"] = f"Bearer {key}"
headers["Content-Type"] = "application/json"
r = None
try:
r = requests.request(method="GET", url=f"{url}/models", headers=headers)
r.raise_for_status()
response_data = r.json()
if "api.openai.com" in url:
# Filter the response data
response_data["data"] = [
model
for model in response_data["data"]
if not any(
name in model["id"]
for name in [
"babbage",
"dall-e",
"davinci",
"embedding",
"tts",
"whisper",
]
)
]
return response_data
except Exception as e:
log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"External: {res['error']}"
except Exception:
error_detail = f"External: {e}"
raise HTTPException(
status_code=r.status_code if r else 500,
detail=error_detail,
)
@app.post("/chat/completions")
@app.post("/chat/completions/{url_idx}")
async def generate_chat_completion(
form_data: dict,
url_idx: Optional[int] = None,
user=Depends(get_verified_user),
):
idx = 0
payload = {**form_data}
if "metadata" in payload:
del payload["metadata"]
model_id = form_data.get("model")
model_info = Models.get_model_by_id(model_id)
if model_info:
if model_info.base_model_id:
payload["model"] = model_info.base_model_id
params = model_info.params.model_dump()
payload = apply_model_params_to_body_openai(params, payload)
payload = apply_model_system_prompt_to_body(params, payload, user)
model = app.state.MODELS[payload.get("model")]
idx = model["urlIdx"]
if "pipeline" in model and model.get("pipeline"):
payload["user"] = {
"name": user.name,
"id": user.id,
"email": user.email,
"role": user.role,
}
url = app.state.config.OPENAI_API_BASE_URLS[idx]
key = app.state.config.OPENAI_API_KEYS[idx]
is_o1 = payload["model"].lower().startswith("o1-")
# Change max_completion_tokens to max_tokens (Backward compatible)
if "api.openai.com" not in url and not is_o1:
if "max_completion_tokens" in payload:
# Remove "max_completion_tokens" from the payload
payload["max_tokens"] = payload["max_completion_tokens"]
del payload["max_completion_tokens"]
else:
if is_o1 and "max_tokens" in payload:
payload["max_completion_tokens"] = payload["max_tokens"]
del payload["max_tokens"]
if "max_tokens" in payload and "max_completion_tokens" in payload:
del payload["max_tokens"]
# Fix: O1 does not support the "system" parameter, Modify "system" to "user"
if is_o1 and payload["messages"][0]["role"] == "system":
payload["messages"][0]["role"] = "user"
# Convert the modified body back to JSON
payload = json.dumps(payload)
log.debug(payload)
headers = {}
headers["Authorization"] = f"Bearer {key}"
headers["Content-Type"] = "application/json"
if "openrouter.ai" in app.state.config.OPENAI_API_BASE_URLS[idx]:
headers["HTTP-Referer"] = "https://openwebui.com/"
headers["X-Title"] = "Open WebUI"
r = None
session = None
streaming = False
response = None
try:
session = aiohttp.ClientSession(
trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
)
r = await session.request(
method="POST",
url=f"{url}/chat/completions",
data=payload,
headers=headers,
)
# Check if response is SSE
if "text/event-stream" in r.headers.get("Content-Type", ""):
streaming = True
return StreamingResponse(
r.content,
status_code=r.status,
headers=dict(r.headers),
background=BackgroundTask(
cleanup_response, response=r, session=session
),
)
else:
try:
response = await r.json()
except Exception as e:
log.error(e)
response = await r.text()
r.raise_for_status()
return response
except Exception as e:
log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
if isinstance(response, dict):
if "error" in response:
error_detail = f"{response['error']['message'] if 'message' in response['error'] else response['error']}"
elif isinstance(response, str):
error_detail = response
raise HTTPException(status_code=r.status if r else 500, detail=error_detail)
finally:
if not streaming and session:
if r:
r.close()
await session.close()
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
idx = 0
body = await request.body()
url = app.state.config.OPENAI_API_BASE_URLS[idx]
key = app.state.config.OPENAI_API_KEYS[idx]
target_url = f"{url}/{path}"
headers = {}
headers["Authorization"] = f"Bearer {key}"
headers["Content-Type"] = "application/json"
r = None
session = None
streaming = False
try:
session = aiohttp.ClientSession(trust_env=True)
r = await session.request(
method=request.method,
url=target_url,
data=body,
headers=headers,
)
r.raise_for_status()
# Check if response is SSE
if "text/event-stream" in r.headers.get("Content-Type", ""):
streaming = True
return StreamingResponse(
r.content,
status_code=r.status,
headers=dict(r.headers),
background=BackgroundTask(
cleanup_response, response=r, session=session
),
)
else:
response_data = await r.json()
return response_data
except Exception as e:
log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = await r.json()
print(res)
if "error" in res:
error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
except Exception:
error_detail = f"External: {e}"
raise HTTPException(status_code=r.status if r else 500, detail=error_detail)
finally:
if not streaming and session:
if r:
r.close()
await session.close()

File diff suppressed because it is too large Load Diff

View File

@ -1,14 +0,0 @@
from open_webui.config import VECTOR_DB
if VECTOR_DB == "milvus":
from open_webui.apps.retrieval.vector.dbs.milvus import MilvusClient
VECTOR_DB_CLIENT = MilvusClient()
elif VECTOR_DB == "qdrant":
from open_webui.apps.retrieval.vector.dbs.qdrant import QdrantClient
VECTOR_DB_CLIENT = QdrantClient()
else:
from open_webui.apps.retrieval.vector.dbs.chroma import ChromaClient
VECTOR_DB_CLIENT = ChromaClient()

View File

@ -1,458 +0,0 @@
import inspect
import json
import logging
import time
from typing import AsyncGenerator, Generator, Iterator
from open_webui.apps.socket.main import get_event_call, get_event_emitter
from open_webui.apps.webui.models.functions import Functions
from open_webui.apps.webui.models.models import Models
from open_webui.apps.webui.routers import (
auths,
chats,
folders,
configs,
files,
functions,
memories,
models,
knowledge,
prompts,
evaluations,
tools,
users,
utils,
)
from open_webui.apps.webui.utils import load_function_module_by_id
from open_webui.config import (
ADMIN_EMAIL,
CORS_ALLOW_ORIGIN,
DEFAULT_MODELS,
DEFAULT_PROMPT_SUGGESTIONS,
DEFAULT_USER_ROLE,
ENABLE_COMMUNITY_SHARING,
ENABLE_LOGIN_FORM,
ENABLE_MESSAGE_RATING,
ENABLE_SIGNUP,
ENABLE_EVALUATION_ARENA_MODELS,
EVALUATION_ARENA_MODELS,
DEFAULT_ARENA_MODEL,
JWT_EXPIRES_IN,
ENABLE_OAUTH_ROLE_MANAGEMENT,
OAUTH_ROLES_CLAIM,
OAUTH_EMAIL_CLAIM,
OAUTH_PICTURE_CLAIM,
OAUTH_USERNAME_CLAIM,
OAUTH_ALLOWED_ROLES,
OAUTH_ADMIN_ROLES,
SHOW_ADMIN_DETAILS,
USER_PERMISSIONS,
WEBHOOK_URL,
WEBUI_AUTH,
WEBUI_BANNERS,
AppConfig,
)
from open_webui.env import (
ENV,
WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
WEBUI_AUTH_TRUSTED_NAME_HEADER,
)
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from open_webui.utils.misc import (
openai_chat_chunk_message_template,
openai_chat_completion_message_template,
)
from open_webui.utils.payload import (
apply_model_params_to_body_openai,
apply_model_system_prompt_to_body,
)
from open_webui.utils.tools import get_tools
app = FastAPI(docs_url="/docs" if ENV == "dev" else None, openapi_url="/openapi.json" if ENV == "dev" else None, redoc_url=None)
log = logging.getLogger(__name__)
app.state.config = AppConfig()
app.state.config.ENABLE_SIGNUP = ENABLE_SIGNUP
app.state.config.ENABLE_LOGIN_FORM = ENABLE_LOGIN_FORM
app.state.config.JWT_EXPIRES_IN = JWT_EXPIRES_IN
app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER
app.state.AUTH_TRUSTED_NAME_HEADER = WEBUI_AUTH_TRUSTED_NAME_HEADER
app.state.config.SHOW_ADMIN_DETAILS = SHOW_ADMIN_DETAILS
app.state.config.ADMIN_EMAIL = ADMIN_EMAIL
app.state.config.DEFAULT_MODELS = DEFAULT_MODELS
app.state.config.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS
app.state.config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
app.state.config.USER_PERMISSIONS = USER_PERMISSIONS
app.state.config.WEBHOOK_URL = WEBHOOK_URL
app.state.config.BANNERS = WEBUI_BANNERS
app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING
app.state.config.ENABLE_MESSAGE_RATING = ENABLE_MESSAGE_RATING
app.state.config.ENABLE_EVALUATION_ARENA_MODELS = ENABLE_EVALUATION_ARENA_MODELS
app.state.config.EVALUATION_ARENA_MODELS = EVALUATION_ARENA_MODELS
app.state.config.OAUTH_USERNAME_CLAIM = OAUTH_USERNAME_CLAIM
app.state.config.OAUTH_PICTURE_CLAIM = OAUTH_PICTURE_CLAIM
app.state.config.OAUTH_EMAIL_CLAIM = OAUTH_EMAIL_CLAIM
app.state.config.ENABLE_OAUTH_ROLE_MANAGEMENT = ENABLE_OAUTH_ROLE_MANAGEMENT
app.state.config.OAUTH_ROLES_CLAIM = OAUTH_ROLES_CLAIM
app.state.config.OAUTH_ALLOWED_ROLES = OAUTH_ALLOWED_ROLES
app.state.config.OAUTH_ADMIN_ROLES = OAUTH_ADMIN_ROLES
app.state.MODELS = {}
app.state.TOOLS = {}
app.state.FUNCTIONS = {}
app.add_middleware(
CORSMiddleware,
allow_origins=CORS_ALLOW_ORIGIN,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.include_router(configs.router, prefix="/configs", tags=["configs"])
app.include_router(auths.router, prefix="/auths", tags=["auths"])
app.include_router(users.router, prefix="/users", tags=["users"])
app.include_router(chats.router, prefix="/chats", tags=["chats"])
app.include_router(models.router, prefix="/models", tags=["models"])
app.include_router(knowledge.router, prefix="/knowledge", tags=["knowledge"])
app.include_router(prompts.router, prefix="/prompts", tags=["prompts"])
app.include_router(tools.router, prefix="/tools", tags=["tools"])
app.include_router(functions.router, prefix="/functions", tags=["functions"])
app.include_router(memories.router, prefix="/memories", tags=["memories"])
app.include_router(evaluations.router, prefix="/evaluations", tags=["evaluations"])
app.include_router(folders.router, prefix="/folders", tags=["folders"])
app.include_router(files.router, prefix="/files", tags=["files"])
app.include_router(utils.router, prefix="/utils", tags=["utils"])
@app.get("/")
async def get_status():
return {
"status": True,
"auth": WEBUI_AUTH,
"default_models": app.state.config.DEFAULT_MODELS,
"default_prompt_suggestions": app.state.config.DEFAULT_PROMPT_SUGGESTIONS,
}
async def get_all_models():
models = []
pipe_models = await get_pipe_models()
models = models + pipe_models
if app.state.config.ENABLE_EVALUATION_ARENA_MODELS:
arena_models = []
if len(app.state.config.EVALUATION_ARENA_MODELS) > 0:
arena_models = [
{
"id": model["id"],
"name": model["name"],
"info": {
"meta": model["meta"],
},
"object": "model",
"created": int(time.time()),
"owned_by": "arena",
"arena": True,
}
for model in app.state.config.EVALUATION_ARENA_MODELS
]
else:
# Add default arena model
arena_models = [
{
"id": DEFAULT_ARENA_MODEL["id"],
"name": DEFAULT_ARENA_MODEL["name"],
"info": {
"meta": DEFAULT_ARENA_MODEL["meta"],
},
"object": "model",
"created": int(time.time()),
"owned_by": "arena",
"arena": True,
}
]
models = models + arena_models
return models
def get_function_module(pipe_id: str):
# Check if function is already loaded
if pipe_id not in app.state.FUNCTIONS:
function_module, _, _ = load_function_module_by_id(pipe_id)
app.state.FUNCTIONS[pipe_id] = function_module
else:
function_module = app.state.FUNCTIONS[pipe_id]
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
valves = Functions.get_function_valves_by_id(pipe_id)
function_module.valves = function_module.Valves(**(valves if valves else {}))
return function_module
async def get_pipe_models():
pipes = Functions.get_functions_by_type("pipe", active_only=True)
pipe_models = []
for pipe in pipes:
function_module = get_function_module(pipe.id)
# Check if function is a manifold
if hasattr(function_module, "pipes"):
sub_pipes = []
# Check if pipes is a function or a list
try:
if callable(function_module.pipes):
sub_pipes = function_module.pipes()
else:
sub_pipes = function_module.pipes
except Exception as e:
log.exception(e)
sub_pipes = []
print(sub_pipes)
for p in sub_pipes:
sub_pipe_id = f'{pipe.id}.{p["id"]}'
sub_pipe_name = p["name"]
if hasattr(function_module, "name"):
sub_pipe_name = f"{function_module.name}{sub_pipe_name}"
pipe_flag = {"type": pipe.type}
pipe_models.append(
{
"id": sub_pipe_id,
"name": sub_pipe_name,
"object": "model",
"created": pipe.created_at,
"owned_by": "openai",
"pipe": pipe_flag,
}
)
else:
pipe_flag = {"type": "pipe"}
pipe_models.append(
{
"id": pipe.id,
"name": pipe.name,
"object": "model",
"created": pipe.created_at,
"owned_by": "openai",
"pipe": pipe_flag,
}
)
return pipe_models
async def execute_pipe(pipe, params):
if inspect.iscoroutinefunction(pipe):
return await pipe(**params)
else:
return pipe(**params)
async def get_message_content(res: str | Generator | AsyncGenerator) -> str:
if isinstance(res, str):
return res
if isinstance(res, Generator):
return "".join(map(str, res))
if isinstance(res, AsyncGenerator):
return "".join([str(stream) async for stream in res])
def process_line(form_data: dict, line):
if isinstance(line, BaseModel):
line = line.model_dump_json()
line = f"data: {line}"
if isinstance(line, dict):
line = f"data: {json.dumps(line)}"
try:
line = line.decode("utf-8")
except Exception:
pass
if line.startswith("data:"):
return f"{line}\n\n"
else:
line = openai_chat_chunk_message_template(form_data["model"], line)
return f"data: {json.dumps(line)}\n\n"
def get_pipe_id(form_data: dict) -> str:
pipe_id = form_data["model"]
if "." in pipe_id:
pipe_id, _ = pipe_id.split(".", 1)
print(pipe_id)
return pipe_id
def get_function_params(function_module, form_data, user, extra_params=None):
if extra_params is None:
extra_params = {}
pipe_id = get_pipe_id(form_data)
# Get the signature of the function
sig = inspect.signature(function_module.pipe)
params = {"body": form_data} | {
k: v for k, v in extra_params.items() if k in sig.parameters
}
if "__user__" in params and hasattr(function_module, "UserValves"):
user_valves = Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id)
try:
params["__user__"]["valves"] = function_module.UserValves(**user_valves)
except Exception as e:
log.exception(e)
params["__user__"]["valves"] = function_module.UserValves()
return params
async def generate_function_chat_completion(form_data, user):
model_id = form_data.get("model")
model_info = Models.get_model_by_id(model_id)
metadata = form_data.pop("metadata", {})
files = metadata.get("files", [])
tool_ids = metadata.get("tool_ids", [])
# Check if tool_ids is None
if tool_ids is None:
tool_ids = []
__event_emitter__ = None
__event_call__ = None
__task__ = None
__task_body__ = None
if metadata:
if all(k in metadata for k in ("session_id", "chat_id", "message_id")):
__event_emitter__ = get_event_emitter(metadata)
__event_call__ = get_event_call(metadata)
__task__ = metadata.get("task", None)
__task_body__ = metadata.get("task_body", None)
extra_params = {
"__event_emitter__": __event_emitter__,
"__event_call__": __event_call__,
"__task__": __task__,
"__task_body__": __task_body__,
"__files__": files,
"__user__": {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
},
}
extra_params["__tools__"] = get_tools(
app,
tool_ids,
user,
{
**extra_params,
"__model__": app.state.MODELS[form_data["model"]],
"__messages__": form_data["messages"],
"__files__": files,
},
)
if model_info:
if model_info.base_model_id:
form_data["model"] = model_info.base_model_id
params = model_info.params.model_dump()
form_data = apply_model_params_to_body_openai(params, form_data)
form_data = apply_model_system_prompt_to_body(params, form_data, user)
pipe_id = get_pipe_id(form_data)
function_module = get_function_module(pipe_id)
pipe = function_module.pipe
params = get_function_params(function_module, form_data, user, extra_params)
if form_data.get("stream", False):
async def stream_content():
try:
res = await execute_pipe(pipe, params)
# Directly return if the response is a StreamingResponse
if isinstance(res, StreamingResponse):
async for data in res.body_iterator:
yield data
return
if isinstance(res, dict):
yield f"data: {json.dumps(res)}\n\n"
return
except Exception as e:
print(f"Error: {e}")
yield f"data: {json.dumps({'error': {'detail':str(e)}})}\n\n"
return
if isinstance(res, str):
message = openai_chat_chunk_message_template(form_data["model"], res)
yield f"data: {json.dumps(message)}\n\n"
if isinstance(res, Iterator):
for line in res:
yield process_line(form_data, line)
if isinstance(res, AsyncGenerator):
async for line in res:
yield process_line(form_data, line)
if isinstance(res, str) or isinstance(res, Generator):
finish_message = openai_chat_chunk_message_template(
form_data["model"], ""
)
finish_message["choices"][0]["finish_reason"] = "stop"
yield f"data: {json.dumps(finish_message)}\n\n"
yield "data: [DONE]"
return StreamingResponse(stream_content(), media_type="text/event-stream")
else:
try:
res = await execute_pipe(pipe, params)
except Exception as e:
print(f"Error: {e}")
return {"error": {"detail": str(e)}}
if isinstance(res, StreamingResponse) or isinstance(res, dict):
return res
if isinstance(res, BaseModel):
return res.model_dump()
message = await get_message_content(res)
return openai_chat_completion_message_template(form_data["model"], message)

View File

@ -1,157 +0,0 @@
import json
import logging
import time
from typing import Optional
from open_webui.apps.webui.internal.db import Base, get_db
from open_webui.env import SRC_LOG_LEVELS
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, String, Text
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
# Documents DB Schema
####################
class Document(Base):
__tablename__ = "document"
collection_name = Column(String, primary_key=True)
name = Column(String, unique=True)
title = Column(Text)
filename = Column(Text)
content = Column(Text, nullable=True)
user_id = Column(String)
timestamp = Column(BigInteger)
class DocumentModel(BaseModel):
model_config = ConfigDict(from_attributes=True)
collection_name: str
name: str
title: str
filename: str
content: Optional[str] = None
user_id: str
timestamp: int # timestamp in epoch
####################
# Forms
####################
class DocumentResponse(BaseModel):
collection_name: str
name: str
title: str
filename: str
content: Optional[dict] = None
user_id: str
timestamp: int # timestamp in epoch
class DocumentUpdateForm(BaseModel):
name: str
title: str
class DocumentForm(DocumentUpdateForm):
collection_name: str
filename: str
content: Optional[str] = None
class DocumentsTable:
def insert_new_doc(
self, user_id: str, form_data: DocumentForm
) -> Optional[DocumentModel]:
with get_db() as db:
document = DocumentModel(
**{
**form_data.model_dump(),
"user_id": user_id,
"timestamp": int(time.time()),
}
)
try:
result = Document(**document.model_dump())
db.add(result)
db.commit()
db.refresh(result)
if result:
return DocumentModel.model_validate(result)
else:
return None
except Exception:
return None
def get_doc_by_name(self, name: str) -> Optional[DocumentModel]:
try:
with get_db() as db:
document = db.query(Document).filter_by(name=name).first()
return DocumentModel.model_validate(document) if document else None
except Exception:
return None
def get_docs(self) -> list[DocumentModel]:
with get_db() as db:
return [
DocumentModel.model_validate(doc) for doc in db.query(Document).all()
]
def update_doc_by_name(
self, name: str, form_data: DocumentUpdateForm
) -> Optional[DocumentModel]:
try:
with get_db() as db:
db.query(Document).filter_by(name=name).update(
{
"title": form_data.title,
"name": form_data.name,
"timestamp": int(time.time()),
}
)
db.commit()
return self.get_doc_by_name(form_data.name)
except Exception as e:
log.exception(e)
return None
def update_doc_content_by_name(
self, name: str, updated: dict
) -> Optional[DocumentModel]:
try:
doc = self.get_doc_by_name(name)
doc_content = json.loads(doc.content if doc.content else "{}")
doc_content = {**doc_content, **updated}
with get_db() as db:
db.query(Document).filter_by(name=name).update(
{
"content": json.dumps(doc_content),
"timestamp": int(time.time()),
}
)
db.commit()
return self.get_doc_by_name(name)
except Exception as e:
log.exception(e)
return None
def delete_doc_by_name(self, name: str) -> bool:
try:
with get_db() as db:
db.query(Document).filter_by(name=name).delete()
db.commit()
return True
except Exception:
return False
Documents = DocumentsTable()

View File

@ -1,155 +0,0 @@
import json
from typing import Optional
from open_webui.apps.webui.models.documents import (
DocumentForm,
DocumentResponse,
Documents,
DocumentUpdateForm,
)
from open_webui.constants import ERROR_MESSAGES
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel
from open_webui.utils.utils import get_admin_user, get_verified_user
router = APIRouter()
############################
# GetDocuments
############################
@router.get("/", response_model=list[DocumentResponse])
async def get_documents(user=Depends(get_verified_user)):
docs = [
DocumentResponse(
**{
**doc.model_dump(),
"content": json.loads(doc.content if doc.content else "{}"),
}
)
for doc in Documents.get_docs()
]
return docs
############################
# CreateNewDoc
############################
@router.post("/create", response_model=Optional[DocumentResponse])
async def create_new_doc(form_data: DocumentForm, user=Depends(get_admin_user)):
doc = Documents.get_doc_by_name(form_data.name)
if doc is None:
doc = Documents.insert_new_doc(user.id, form_data)
if doc:
return DocumentResponse(
**{
**doc.model_dump(),
"content": json.loads(doc.content if doc.content else "{}"),
}
)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.FILE_EXISTS,
)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.NAME_TAG_TAKEN,
)
############################
# GetDocByName
############################
@router.get("/doc", response_model=Optional[DocumentResponse])
async def get_doc_by_name(name: str, user=Depends(get_verified_user)):
doc = Documents.get_doc_by_name(name)
if doc:
return DocumentResponse(
**{
**doc.model_dump(),
"content": json.loads(doc.content if doc.content else "{}"),
}
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
############################
# TagDocByName
############################
class TagItem(BaseModel):
name: str
class TagDocumentForm(BaseModel):
name: str
tags: list[dict]
@router.post("/doc/tags", response_model=Optional[DocumentResponse])
async def tag_doc_by_name(form_data: TagDocumentForm, user=Depends(get_verified_user)):
doc = Documents.update_doc_content_by_name(form_data.name, {"tags": form_data.tags})
if doc:
return DocumentResponse(
**{
**doc.model_dump(),
"content": json.loads(doc.content if doc.content else "{}"),
}
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
############################
# UpdateDocByName
############################
@router.post("/doc/update", response_model=Optional[DocumentResponse])
async def update_doc_by_name(
name: str,
form_data: DocumentUpdateForm,
user=Depends(get_admin_user),
):
doc = Documents.update_doc_by_name(name, form_data)
if doc:
return DocumentResponse(
**{
**doc.model_dump(),
"content": json.loads(doc.content if doc.content else "{}"),
}
)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.NAME_TAG_TAKEN,
)
############################
# DeleteDocByName
############################
@router.delete("/doc/delete", response_model=bool)
async def delete_doc_by_name(name: str, user=Depends(get_admin_user)):
result = Documents.delete_doc_by_name(name)
return result

View File

@ -1,104 +0,0 @@
from typing import Optional
from open_webui.apps.webui.models.models import (
ModelForm,
ModelModel,
ModelResponse,
Models,
)
from open_webui.constants import ERROR_MESSAGES
from fastapi import APIRouter, Depends, HTTPException, Request, status
from open_webui.utils.utils import get_admin_user, get_verified_user
router = APIRouter()
###########################
# getModels
###########################
@router.get("/", response_model=list[ModelResponse])
async def get_models(id: Optional[str] = None, user=Depends(get_verified_user)):
if id:
model = Models.get_model_by_id(id)
if model:
return [model]
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
else:
return Models.get_all_models()
############################
# AddNewModel
############################
@router.post("/add", response_model=Optional[ModelModel])
async def add_new_model(
request: Request,
form_data: ModelForm,
user=Depends(get_admin_user),
):
if form_data.id in request.app.state.MODELS:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.MODEL_ID_TAKEN,
)
else:
model = Models.insert_new_model(form_data, user.id)
if model:
return model
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.DEFAULT(),
)
############################
# UpdateModelById
############################
@router.post("/update", response_model=Optional[ModelModel])
async def update_model_by_id(
request: Request,
id: str,
form_data: ModelForm,
user=Depends(get_admin_user),
):
model = Models.get_model_by_id(id)
if model:
model = Models.update_model_by_id(id, form_data)
return model
else:
if form_data.id in request.app.state.MODELS:
model = Models.insert_new_model(form_data, user.id)
if model:
return model
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.DEFAULT(),
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.DEFAULT(),
)
############################
# DeleteModelById
############################
@router.delete("/delete", response_model=bool)
async def delete_model_by_id(id: str, user=Depends(get_admin_user)):
result = Models.delete_model_by_id(id)
return result

View File

@ -1,90 +0,0 @@
from typing import Optional
from open_webui.apps.webui.models.prompts import PromptForm, PromptModel, Prompts
from open_webui.constants import ERROR_MESSAGES
from fastapi import APIRouter, Depends, HTTPException, status
from open_webui.utils.utils import get_admin_user, get_verified_user
router = APIRouter()
############################
# GetPrompts
############################
@router.get("/", response_model=list[PromptModel])
async def get_prompts(user=Depends(get_verified_user)):
return Prompts.get_prompts()
############################
# CreateNewPrompt
############################
@router.post("/create", response_model=Optional[PromptModel])
async def create_new_prompt(form_data: PromptForm, user=Depends(get_admin_user)):
prompt = Prompts.get_prompt_by_command(form_data.command)
if prompt is None:
prompt = Prompts.insert_new_prompt(user.id, form_data)
if prompt:
return prompt
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(),
)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.COMMAND_TAKEN,
)
############################
# GetPromptByCommand
############################
@router.get("/command/{command}", response_model=Optional[PromptModel])
async def get_prompt_by_command(command: str, user=Depends(get_verified_user)):
prompt = Prompts.get_prompt_by_command(f"/{command}")
if prompt:
return prompt
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
############################
# UpdatePromptByCommand
############################
@router.post("/command/{command}/update", response_model=Optional[PromptModel])
async def update_prompt_by_command(
command: str,
form_data: PromptForm,
user=Depends(get_admin_user),
):
prompt = Prompts.update_prompt_by_command(f"/{command}", form_data)
if prompt:
return prompt
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
############################
# DeletePromptByCommand
############################
@router.delete("/command/{command}/delete", response_model=bool)
async def delete_prompt_by_command(command: str, user=Depends(get_admin_user)):
result = Prompts.delete_prompt_by_command(f"/{command}")
return result

View File

@ -10,7 +10,7 @@ from urllib.parse import urlparse
import chromadb
import requests
import yaml
from open_webui.apps.webui.internal.db import Base, get_db
from open_webui.internal.db import Base, get_db
from open_webui.env import (
OPEN_WEBUI_DIR,
DATA_DIR,
@ -20,6 +20,8 @@ from open_webui.env import (
WEBUI_FAVICON_URL,
WEBUI_NAME,
log,
DATABASE_URL,
OFFLINE_MODE
)
from pydantic import BaseModel
from sqlalchemy import JSON, Column, DateTime, Integer, func
@ -264,6 +266,13 @@ class AppConfig:
# WEBUI_AUTH (Required for security)
####################################
ENABLE_API_KEY = PersistentConfig(
"ENABLE_API_KEY",
"auth.api_key.enable",
os.environ.get("ENABLE_API_KEY", "True").lower() == "true",
)
JWT_EXPIRES_IN = PersistentConfig(
"JWT_EXPIRES_IN", "auth.jwt_expiry", os.environ.get("JWT_EXPIRES_IN", "-1")
)
@ -433,6 +442,15 @@ OAUTH_ADMIN_ROLES = PersistentConfig(
[role.strip() for role in os.environ.get("OAUTH_ADMIN_ROLES", "admin").split(",")],
)
OAUTH_ALLOWED_DOMAINS = PersistentConfig(
"OAUTH_ALLOWED_DOMAINS",
"oauth.allowed_domains",
[
domain.strip()
for domain in os.environ.get("OAUTH_ALLOWED_DOMAINS", "*").split(",")
],
)
def load_oauth_providers():
OAUTH_PROVIDERS.clear()
@ -587,6 +605,12 @@ OLLAMA_API_BASE_URL = os.environ.get(
)
OLLAMA_BASE_URL = os.environ.get("OLLAMA_BASE_URL", "")
if OLLAMA_BASE_URL:
# Remove trailing slash
OLLAMA_BASE_URL = (
OLLAMA_BASE_URL[:-1] if OLLAMA_BASE_URL.endswith("/") else OLLAMA_BASE_URL
)
K8S_FLAG = os.environ.get("K8S_FLAG", "")
USE_OLLAMA_DOCKER = os.environ.get("USE_OLLAMA_DOCKER", "false")
@ -618,6 +642,12 @@ OLLAMA_BASE_URLS = PersistentConfig(
"OLLAMA_BASE_URLS", "ollama.base_urls", OLLAMA_BASE_URLS
)
OLLAMA_API_CONFIGS = PersistentConfig(
"OLLAMA_API_CONFIGS",
"ollama.api_configs",
{},
)
####################################
# OPENAI_API
####################################
@ -658,15 +688,20 @@ OPENAI_API_BASE_URLS = PersistentConfig(
"OPENAI_API_BASE_URLS", "openai.api_base_urls", OPENAI_API_BASE_URLS
)
OPENAI_API_KEY = ""
OPENAI_API_CONFIGS = PersistentConfig(
"OPENAI_API_CONFIGS",
"openai.api_configs",
{},
)
# Get the actual OpenAI API key based on the base URL
OPENAI_API_KEY = ""
try:
OPENAI_API_KEY = OPENAI_API_KEYS.value[
OPENAI_API_BASE_URLS.value.index("https://api.openai.com/v1")
]
except Exception:
pass
OPENAI_API_BASE_URL = "https://api.openai.com/v1"
####################################
@ -689,6 +724,7 @@ ENABLE_LOGIN_FORM = PersistentConfig(
os.environ.get("ENABLE_LOGIN_FORM", "True").lower() == "true",
)
DEFAULT_LOCALE = PersistentConfig(
"DEFAULT_LOCALE",
"ui.default_locale",
@ -733,18 +769,47 @@ DEFAULT_PROMPT_SUGGESTIONS = PersistentConfig(
],
)
MODEL_ORDER_LIST = PersistentConfig(
"MODEL_ORDER_LIST",
"ui.model_order_list",
[],
)
DEFAULT_USER_ROLE = PersistentConfig(
"DEFAULT_USER_ROLE",
"ui.default_user_role",
os.getenv("DEFAULT_USER_ROLE", "pending"),
)
USER_PERMISSIONS_CHAT_DELETION = (
os.environ.get("USER_PERMISSIONS_CHAT_DELETION", "True").lower() == "true"
USER_PERMISSIONS_WORKSPACE_MODELS_ACCESS = (
os.environ.get("USER_PERMISSIONS_WORKSPACE_MODELS_ACCESS", "False").lower()
== "true"
)
USER_PERMISSIONS_CHAT_EDITING = (
os.environ.get("USER_PERMISSIONS_CHAT_EDITING", "True").lower() == "true"
USER_PERMISSIONS_WORKSPACE_KNOWLEDGE_ACCESS = (
os.environ.get("USER_PERMISSIONS_WORKSPACE_KNOWLEDGE_ACCESS", "False").lower()
== "true"
)
USER_PERMISSIONS_WORKSPACE_PROMPTS_ACCESS = (
os.environ.get("USER_PERMISSIONS_WORKSPACE_PROMPTS_ACCESS", "False").lower()
== "true"
)
USER_PERMISSIONS_WORKSPACE_TOOLS_ACCESS = (
os.environ.get("USER_PERMISSIONS_WORKSPACE_TOOLS_ACCESS", "False").lower() == "true"
)
USER_PERMISSIONS_CHAT_FILE_UPLOAD = (
os.environ.get("USER_PERMISSIONS_CHAT_FILE_UPLOAD", "True").lower() == "true"
)
USER_PERMISSIONS_CHAT_DELETE = (
os.environ.get("USER_PERMISSIONS_CHAT_DELETE", "True").lower() == "true"
)
USER_PERMISSIONS_CHAT_EDIT = (
os.environ.get("USER_PERMISSIONS_CHAT_EDIT", "True").lower() == "true"
)
USER_PERMISSIONS_CHAT_TEMPORARY = (
@ -753,13 +818,20 @@ USER_PERMISSIONS_CHAT_TEMPORARY = (
USER_PERMISSIONS = PersistentConfig(
"USER_PERMISSIONS",
"ui.user_permissions",
"user.permissions",
{
"workspace": {
"models": USER_PERMISSIONS_WORKSPACE_MODELS_ACCESS,
"knowledge": USER_PERMISSIONS_WORKSPACE_KNOWLEDGE_ACCESS,
"prompts": USER_PERMISSIONS_WORKSPACE_PROMPTS_ACCESS,
"tools": USER_PERMISSIONS_WORKSPACE_TOOLS_ACCESS,
},
"chat": {
"deletion": USER_PERMISSIONS_CHAT_DELETION,
"editing": USER_PERMISSIONS_CHAT_EDITING,
"file_upload": USER_PERMISSIONS_CHAT_FILE_UPLOAD,
"delete": USER_PERMISSIONS_CHAT_DELETE,
"edit": USER_PERMISSIONS_CHAT_EDIT,
"temporary": USER_PERMISSIONS_CHAT_TEMPORARY,
}
},
},
)
@ -785,18 +857,6 @@ DEFAULT_ARENA_MODEL = {
},
}
ENABLE_MODEL_FILTER = PersistentConfig(
"ENABLE_MODEL_FILTER",
"model_filter.enable",
os.environ.get("ENABLE_MODEL_FILTER", "False").lower() == "true",
)
MODEL_FILTER_LIST = os.environ.get("MODEL_FILTER_LIST", "")
MODEL_FILTER_LIST = PersistentConfig(
"MODEL_FILTER_LIST",
"model_filter.list",
[model.strip() for model in MODEL_FILTER_LIST.split(";")],
)
WEBHOOK_URL = PersistentConfig(
"WEBHOOK_URL", "webhook_url", os.environ.get("WEBHOOK_URL", "")
)
@ -910,25 +970,155 @@ TITLE_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
os.environ.get("TITLE_GENERATION_PROMPT_TEMPLATE", ""),
)
DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE = """Create a concise, 3-5 word title with an emoji as a title for the chat history, in the given language. Suitable Emojis for the summary can be used to enhance understanding but avoid quotation marks or special formatting. RESPOND ONLY WITH THE TITLE TEXT.
Examples of titles:
📉 Stock Market Trends
🍪 Perfect Chocolate Chip Recipe
Evolution of Music Streaming
Remote Work Productivity Tips
Artificial Intelligence in Healthcare
🎮 Video Game Development Insights
<chat_history>
{{MESSAGES:END:2}}
</chat_history>"""
TAGS_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
"TAGS_GENERATION_PROMPT_TEMPLATE",
"task.tags.prompt_template",
os.environ.get("TAGS_GENERATION_PROMPT_TEMPLATE", ""),
)
ENABLE_SEARCH_QUERY = PersistentConfig(
"ENABLE_SEARCH_QUERY",
"task.search.enable",
os.environ.get("ENABLE_SEARCH_QUERY", "True").lower() == "true",
DEFAULT_TAGS_GENERATION_PROMPT_TEMPLATE = """### Task:
Generate 1-3 broad tags categorizing the main themes of the chat history, along with 1-3 more specific subtopic tags.
### Guidelines:
- Start with high-level domains (e.g. Science, Technology, Philosophy, Arts, Politics, Business, Health, Sports, Entertainment, Education)
- Consider including relevant subfields/subdomains if they are strongly represented throughout the conversation
- If content is too short (less than 3 messages) or too diverse, use only ["General"]
- Use the chat's primary language; default to English if multilingual
- Prioritize accuracy over specificity
### Output:
JSON format: { "tags": ["tag1", "tag2", "tag3"] }
### Chat History:
<chat_history>
{{MESSAGES:END:6}}
</chat_history>"""
ENABLE_TAGS_GENERATION = PersistentConfig(
"ENABLE_TAGS_GENERATION",
"task.tags.enable",
os.environ.get("ENABLE_TAGS_GENERATION", "True").lower() == "true",
)
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
"SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE",
"task.search.prompt_template",
os.environ.get("SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE", ""),
ENABLE_SEARCH_QUERY_GENERATION = PersistentConfig(
"ENABLE_SEARCH_QUERY_GENERATION",
"task.query.search.enable",
os.environ.get("ENABLE_SEARCH_QUERY_GENERATION", "True").lower() == "true",
)
ENABLE_RETRIEVAL_QUERY_GENERATION = PersistentConfig(
"ENABLE_RETRIEVAL_QUERY_GENERATION",
"task.query.retrieval.enable",
os.environ.get("ENABLE_RETRIEVAL_QUERY_GENERATION", "True").lower() == "true",
)
QUERY_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
"QUERY_GENERATION_PROMPT_TEMPLATE",
"task.query.prompt_template",
os.environ.get("QUERY_GENERATION_PROMPT_TEMPLATE", ""),
)
DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE = """### Task:
Analyze the chat history to determine the necessity of generating search queries, in the given language. By default, **prioritize generating 1-3 broad and relevant search queries** unless it is absolutely certain that no additional information is required. The aim is to retrieve comprehensive, updated, and valuable information even with minimal uncertainty. If no search is unequivocally needed, return an empty list.
### Guidelines:
- Respond **EXCLUSIVELY** with a JSON object. Any form of extra commentary, explanation, or additional text is strictly prohibited.
- When generating search queries, respond in the format: { "queries": ["query1", "query2"] }, ensuring each query is distinct, concise, and relevant to the topic.
- If and only if it is entirely certain that no useful results can be retrieved by a search, return: { "queries": [] }.
- Err on the side of suggesting search queries if there is **any chance** they might provide useful or updated information.
- Be concise and focused on composing high-quality search queries, avoiding unnecessary elaboration, commentary, or assumptions.
- Today's date is: {{CURRENT_DATE}}.
- Always prioritize providing actionable and broad queries that maximize informational coverage.
### Output:
Strictly return in JSON format:
{
"queries": ["query1", "query2"]
}
### Chat History:
<chat_history>
{{MESSAGES:END:6}}
</chat_history>
"""
ENABLE_AUTOCOMPLETE_GENERATION = PersistentConfig(
"ENABLE_AUTOCOMPLETE_GENERATION",
"task.autocomplete.enable",
os.environ.get("ENABLE_AUTOCOMPLETE_GENERATION", "True").lower() == "true",
)
AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH = PersistentConfig(
"AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH",
"task.autocomplete.input_max_length",
int(os.environ.get("AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH", "-1")),
)
AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
"AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE",
"task.autocomplete.prompt_template",
os.environ.get("AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE", ""),
)
DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE = """### Task:
You are an autocompletion system. Continue the text in `<text>` based on the **completion type** in `<type>` and the given language.
### **Instructions**:
1. Analyze `<text>` for context and meaning.
2. Use `<type>` to guide your output:
- **General**: Provide a natural, concise continuation.
- **Search Query**: Complete as if generating a realistic search query.
3. Start as if you are directly continuing `<text>`. Do **not** repeat, paraphrase, or respond as a model. Simply complete the text.
4. Ensure the continuation:
- Flows naturally from `<text>`.
- Avoids repetition, overexplaining, or unrelated ideas.
5. If unsure, return: `{ "text": "" }`.
### **Output Rules**:
- Respond only in JSON format: `{ "text": "<your_completion>" }`.
### **Examples**:
#### Example 1:
Input:
<type>General</type>
<text>The sun was setting over the horizon, painting the sky</text>
Output:
{ "text": "with vibrant shades of orange and pink." }
#### Example 2:
Input:
<type>Search Query</type>
<text>Top-rated restaurants in</text>
Output:
{ "text": "New York City for Italian cuisine." }
---
### Context:
<chat_history>
{{MESSAGES:END:6}}
</chat_history>
<type>{{TYPE}}</type>
<text>{{PROMPT}}</text>
#### Output:
"""
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = PersistentConfig(
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE",
@ -937,6 +1127,19 @@ TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = PersistentConfig(
)
DEFAULT_TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = """Available Tools: {{TOOLS}}\nReturn an empty string if no tools match the query. If a function tool matches, construct and return a JSON object in the format {\"name\": \"functionName\", \"parameters\": {\"requiredFunctionParamKey\": \"requiredFunctionParamValue\"}} using the appropriate tool and its parameters. Only return the object and limit the response to the JSON object without additional text."""
DEFAULT_EMOJI_GENERATION_PROMPT_TEMPLATE = """Your task is to reflect the speaker's likely facial expression through a fitting emoji. Interpret emotions from the message and reflect their facial expression using fitting, diverse emojis (e.g., 😊, 😢, 😡, 😱).
Message: ```{{prompt}}```"""
DEFAULT_MOA_GENERATION_PROMPT_TEMPLATE = """You have been provided with a set of responses from various models to the latest user query: "{{prompt}}"
Your task is to synthesize these responses into a single, high-quality response. It is crucial to critically evaluate the information provided in these responses, recognizing that some of it may be biased or incorrect. Your response should not simply replicate the given answers but should offer a refined, accurate, and comprehensive reply to the instruction. Ensure your response is well-structured, coherent, and adheres to the highest standards of accuracy and reliability.
Responses from models: {{responses}}"""
####################################
# Vector Database
####################################
@ -949,6 +1152,8 @@ CHROMA_TENANT = os.environ.get("CHROMA_TENANT", chromadb.DEFAULT_TENANT)
CHROMA_DATABASE = os.environ.get("CHROMA_DATABASE", chromadb.DEFAULT_DATABASE)
CHROMA_HTTP_HOST = os.environ.get("CHROMA_HTTP_HOST", "")
CHROMA_HTTP_PORT = int(os.environ.get("CHROMA_HTTP_PORT", "8000"))
CHROMA_CLIENT_AUTH_PROVIDER = os.environ.get("CHROMA_CLIENT_AUTH_PROVIDER", "")
CHROMA_CLIENT_AUTH_CREDENTIALS = os.environ.get("CHROMA_CLIENT_AUTH_CREDENTIALS", "")
# Comma-separated list of header=value pairs
CHROMA_HTTP_HEADERS = os.environ.get("CHROMA_HTTP_HEADERS", "")
if CHROMA_HTTP_HEADERS:
@ -966,6 +1171,21 @@ MILVUS_URI = os.environ.get("MILVUS_URI", f"{DATA_DIR}/vector_db/milvus.db")
# Qdrant
QDRANT_URI = os.environ.get("QDRANT_URI", None)
QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY", None)
# OpenSearch
OPENSEARCH_URI = os.environ.get("OPENSEARCH_URI", "https://localhost:9200")
OPENSEARCH_SSL = os.environ.get("OPENSEARCH_SSL", True)
OPENSEARCH_CERT_VERIFY = os.environ.get("OPENSEARCH_CERT_VERIFY", False)
OPENSEARCH_USERNAME = os.environ.get("OPENSEARCH_USERNAME", None)
OPENSEARCH_PASSWORD = os.environ.get("OPENSEARCH_PASSWORD", None)
# Pgvector
PGVECTOR_DB_URL = os.environ.get("PGVECTOR_DB_URL", DATABASE_URL)
if VECTOR_DB == "pgvector" and not PGVECTOR_DB_URL.startswith("postgres"):
raise ValueError(
"Pgvector requires setting PGVECTOR_DB_URL or using Postgres with vector extension as the primary database."
)
####################################
# Information Retrieval (RAG)
@ -1045,11 +1265,11 @@ RAG_EMBEDDING_MODEL = PersistentConfig(
log.info(f"Embedding model set: {RAG_EMBEDDING_MODEL.value}")
RAG_EMBEDDING_MODEL_AUTO_UPDATE = (
os.environ.get("RAG_EMBEDDING_MODEL_AUTO_UPDATE", "").lower() == "true"
not OFFLINE_MODE and os.environ.get("RAG_EMBEDDING_MODEL_AUTO_UPDATE", "True").lower() == "true"
)
RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE = (
os.environ.get("RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true"
os.environ.get("RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE", "True").lower() == "true"
)
RAG_EMBEDDING_BATCH_SIZE = PersistentConfig(
@ -1070,11 +1290,11 @@ if RAG_RERANKING_MODEL.value != "":
log.info(f"Reranking model set: {RAG_RERANKING_MODEL.value}")
RAG_RERANKING_MODEL_AUTO_UPDATE = (
os.environ.get("RAG_RERANKING_MODEL_AUTO_UPDATE", "").lower() == "true"
not OFFLINE_MODE and os.environ.get("RAG_RERANKING_MODEL_AUTO_UPDATE", "True").lower() == "true"
)
RAG_RERANKING_MODEL_TRUST_REMOTE_CODE = (
os.environ.get("RAG_RERANKING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true"
os.environ.get("RAG_RERANKING_MODEL_TRUST_REMOTE_CODE", "True").lower() == "true"
)
@ -1102,21 +1322,32 @@ CHUNK_OVERLAP = PersistentConfig(
int(os.environ.get("CHUNK_OVERLAP", "100")),
)
DEFAULT_RAG_TEMPLATE = """You are given a user query, some textual context and rules, all inside xml tags. You have to answer the query based on the context while respecting the rules.
DEFAULT_RAG_TEMPLATE = """### Task:
Respond to the user query using the provided context, incorporating inline citations in the format [source_id] **only when the <source_id> tag is explicitly provided** in the context.
### Guidelines:
- If you don't know the answer, clearly state that.
- If uncertain, ask the user for clarification.
- Respond in the same language as the user's query.
- If the context is unreadable or of poor quality, inform the user and provide the best possible answer.
- If the answer isn't present in the context but you possess the knowledge, explain this to the user and provide the answer using your own understanding.
- **Only include inline citations using [source_id] when a <source_id> tag is explicitly provided in the context.**
- Do not cite if the <source_id> tag is not provided in the context.
- Do not use XML tags in your response.
- Ensure citations are concise and directly related to the information provided.
### Example of Citation:
If the user asks about a specific topic and the information is found in "whitepaper.pdf" with a provided <source_id>, the response should include the citation like so:
* "According to the study, the proposed method increases efficiency by 20% [whitepaper.pdf]."
If no <source_id> is present, the response should omit the citation.
### Output:
Provide a clear and direct response to the user's query, including inline citations in the format [source_id] only when the <source_id> tag is present in the context.
<context>
{{CONTEXT}}
</context>
<rules>
- If you don't know, just say so.
- If you are not sure, ask for clarification.
- Answer in the same language as the user query.
- If the context appears unreadable or of poor quality, tell the user then answer as best as you can.
- If the answer is not in the context but you think you know the answer, explain that to the user then answer with your own knowledge.
- Answer directly and without using xml tags.
</rules>
<user_query>
{{QUERY}}
</user_query>
@ -1139,6 +1370,19 @@ RAG_OPENAI_API_KEY = PersistentConfig(
os.getenv("RAG_OPENAI_API_KEY", OPENAI_API_KEY),
)
RAG_OLLAMA_BASE_URL = PersistentConfig(
"RAG_OLLAMA_BASE_URL",
"rag.ollama.url",
os.getenv("RAG_OLLAMA_BASE_URL", OLLAMA_BASE_URL),
)
RAG_OLLAMA_API_KEY = PersistentConfig(
"RAG_OLLAMA_API_KEY",
"rag.ollama.key",
os.getenv("RAG_OLLAMA_API_KEY", ""),
)
ENABLE_RAG_LOCAL_WEB_FETCH = (
os.getenv("ENABLE_RAG_LOCAL_WEB_FETCH", "False").lower() == "true"
)
@ -1149,6 +1393,12 @@ YOUTUBE_LOADER_LANGUAGE = PersistentConfig(
os.getenv("YOUTUBE_LOADER_LANGUAGE", "en").split(","),
)
YOUTUBE_LOADER_PROXY_URL = PersistentConfig(
"YOUTUBE_LOADER_PROXY_URL",
"rag.youtube_loader_proxy_url",
os.getenv("YOUTUBE_LOADER_PROXY_URL", ""),
)
ENABLE_RAG_WEB_SEARCH = PersistentConfig(
"ENABLE_RAG_WEB_SEARCH",
@ -1198,6 +1448,18 @@ BRAVE_SEARCH_API_KEY = PersistentConfig(
os.getenv("BRAVE_SEARCH_API_KEY", ""),
)
KAGI_SEARCH_API_KEY = PersistentConfig(
"KAGI_SEARCH_API_KEY",
"rag.web.search.kagi_search_api_key",
os.getenv("KAGI_SEARCH_API_KEY", ""),
)
MOJEEK_SEARCH_API_KEY = PersistentConfig(
"MOJEEK_SEARCH_API_KEY",
"rag.web.search.mojeek_search_api_key",
os.getenv("MOJEEK_SEARCH_API_KEY", ""),
)
SERPSTACK_API_KEY = PersistentConfig(
"SERPSTACK_API_KEY",
"rag.web.search.serpstack_api_key",
@ -1228,6 +1490,12 @@ TAVILY_API_KEY = PersistentConfig(
os.getenv("TAVILY_API_KEY", ""),
)
JINA_API_KEY = PersistentConfig(
"JINA_API_KEY",
"rag.web.search.jina_api_key",
os.getenv("JINA_API_KEY", ""),
)
SEARCHAPI_API_KEY = PersistentConfig(
"SEARCHAPI_API_KEY",
"rag.web.search.searchapi_api_key",
@ -1240,6 +1508,21 @@ SEARCHAPI_ENGINE = PersistentConfig(
os.getenv("SEARCHAPI_ENGINE", ""),
)
BING_SEARCH_V7_ENDPOINT = PersistentConfig(
"BING_SEARCH_V7_ENDPOINT",
"rag.web.search.bing_search_v7_endpoint",
os.environ.get(
"BING_SEARCH_V7_ENDPOINT", "https://api.bing.microsoft.com/v7.0/search"
),
)
BING_SEARCH_V7_SUBSCRIPTION_KEY = PersistentConfig(
"BING_SEARCH_V7_SUBSCRIPTION_KEY",
"rag.web.search.bing_search_v7_subscription_key",
os.environ.get("BING_SEARCH_V7_SUBSCRIPTION_KEY", ""),
)
RAG_WEB_SEARCH_RESULT_COUNT = PersistentConfig(
"RAG_WEB_SEARCH_RESULT_COUNT",
"rag.web.search.result_count",
@ -1291,7 +1574,7 @@ AUTOMATIC1111_CFG_SCALE = PersistentConfig(
AUTOMATIC1111_SAMPLER = PersistentConfig(
"AUTOMATIC1111_SAMPLERE",
"AUTOMATIC1111_SAMPLER",
"image_generation.automatic1111.sampler",
(
os.environ.get("AUTOMATIC1111_SAMPLER")
@ -1477,7 +1760,7 @@ WHISPER_MODEL = PersistentConfig(
WHISPER_MODEL_DIR = os.getenv("WHISPER_MODEL_DIR", f"{CACHE_DIR}/whisper/models")
WHISPER_MODEL_AUTO_UPDATE = (
os.environ.get("WHISPER_MODEL_AUTO_UPDATE", "").lower() == "true"
not OFFLINE_MODE and os.environ.get("WHISPER_MODEL_AUTO_UPDATE", "").lower() == "true"
)
@ -1560,3 +1843,74 @@ AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT = PersistentConfig(
"AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT", "audio-24khz-160kbitrate-mono-mp3"
),
)
####################################
# LDAP
####################################
ENABLE_LDAP = PersistentConfig(
"ENABLE_LDAP",
"ldap.enable",
os.environ.get("ENABLE_LDAP", "false").lower() == "true",
)
LDAP_SERVER_LABEL = PersistentConfig(
"LDAP_SERVER_LABEL",
"ldap.server.label",
os.environ.get("LDAP_SERVER_LABEL", "LDAP Server"),
)
LDAP_SERVER_HOST = PersistentConfig(
"LDAP_SERVER_HOST",
"ldap.server.host",
os.environ.get("LDAP_SERVER_HOST", "localhost"),
)
LDAP_SERVER_PORT = PersistentConfig(
"LDAP_SERVER_PORT",
"ldap.server.port",
int(os.environ.get("LDAP_SERVER_PORT", "389")),
)
LDAP_ATTRIBUTE_FOR_USERNAME = PersistentConfig(
"LDAP_ATTRIBUTE_FOR_USERNAME",
"ldap.server.attribute_for_username",
os.environ.get("LDAP_ATTRIBUTE_FOR_USERNAME", "uid"),
)
LDAP_APP_DN = PersistentConfig(
"LDAP_APP_DN", "ldap.server.app_dn", os.environ.get("LDAP_APP_DN", "")
)
LDAP_APP_PASSWORD = PersistentConfig(
"LDAP_APP_PASSWORD",
"ldap.server.app_password",
os.environ.get("LDAP_APP_PASSWORD", ""),
)
LDAP_SEARCH_BASE = PersistentConfig(
"LDAP_SEARCH_BASE", "ldap.server.users_dn", os.environ.get("LDAP_SEARCH_BASE", "")
)
LDAP_SEARCH_FILTERS = PersistentConfig(
"LDAP_SEARCH_FILTER",
"ldap.server.search_filter",
os.environ.get("LDAP_SEARCH_FILTER", ""),
)
LDAP_USE_TLS = PersistentConfig(
"LDAP_USE_TLS",
"ldap.server.use_tls",
os.environ.get("LDAP_USE_TLS", "True").lower() == "true",
)
LDAP_CA_CERT_FILE = PersistentConfig(
"LDAP_CA_CERT_FILE",
"ldap.server.ca_cert_file",
os.environ.get("LDAP_CA_CERT_FILE", ""),
)
LDAP_CIPHERS = PersistentConfig(
"LDAP_CIPHERS", "ldap.server.ciphers", os.environ.get("LDAP_CIPHERS", "ALL")
)

View File

@ -62,6 +62,7 @@ class ERROR_MESSAGES(str, Enum):
NOT_FOUND = "We could not find what you're looking for :/"
USER_NOT_FOUND = "We could not find what you're looking for :/"
API_KEY_NOT_FOUND = "Oops! It looks like there's a hiccup. The API key is missing. Please make sure to provide a valid API key to access this feature."
API_KEY_NOT_ALLOWED = "Use of API key is not enabled in the environment."
MALICIOUS = "Unusual activities detected, please try again in a few minutes."
@ -75,6 +76,7 @@ class ERROR_MESSAGES(str, Enum):
OPENAI_NOT_FOUND = lambda name="": "OpenAI API was not found"
OLLAMA_NOT_FOUND = "WebUI could not connect to Ollama"
CREATE_API_KEY_ERROR = "Oops! Something went wrong while creating your API key. Please try again later. If the issue persists, contact support for assistance."
API_KEY_CREATION_NOT_ALLOWED = "API key creation is not allowed in the environment."
EMPTY_CONTENT = "The content provided is empty. Please ensure that there is text or data present before proceeding."
@ -111,5 +113,6 @@ class TASKS(str, Enum):
TAGS_GENERATION = "tags_generation"
EMOJI_GENERATION = "emoji_generation"
QUERY_GENERATION = "query_generation"
AUTOCOMPLETE_GENERATION = "autocomplete_generation"
FUNCTION_CALLING = "function_calling"
MOA_RESPONSE_GENERATION = "moa_response_generation"

View File

@ -195,6 +195,15 @@ CHANGELOG = changelog_json
SAFE_MODE = os.environ.get("SAFE_MODE", "false").lower() == "true"
####################################
# ENABLE_FORWARD_USER_INFO_HEADERS
####################################
ENABLE_FORWARD_USER_INFO_HEADERS = (
os.environ.get("ENABLE_FORWARD_USER_INFO_HEADERS", "False").lower() == "true"
)
####################################
# WEBUI_BUILD_HASH
####################################
@ -320,6 +329,9 @@ WEBUI_AUTH_TRUSTED_EMAIL_HEADER = os.environ.get(
)
WEBUI_AUTH_TRUSTED_NAME_HEADER = os.environ.get("WEBUI_AUTH_TRUSTED_NAME_HEADER", None)
BYPASS_MODEL_ACCESS_CONTROL = (
os.environ.get("BYPASS_MODEL_ACCESS_CONTROL", "False").lower() == "true"
)
####################################
# WEBUI_SECRET_KEY
@ -364,7 +376,7 @@ else:
AIOHTTP_CLIENT_TIMEOUT = 300
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST = os.environ.get(
"AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST", "3"
"AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST", ""
)
if AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST == "":
@ -375,7 +387,7 @@ else:
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST
)
except Exception:
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST = 3
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST = 5
####################################
# OFFLINE_MODE

View File

@ -0,0 +1,316 @@
import logging
import sys
import inspect
import json
from pydantic import BaseModel
from typing import AsyncGenerator, Generator, Iterator
from fastapi import (
Depends,
FastAPI,
File,
Form,
HTTPException,
Request,
UploadFile,
status,
)
from starlette.responses import Response, StreamingResponse
from open_webui.socket.main import (
get_event_call,
get_event_emitter,
)
from open_webui.models.functions import Functions
from open_webui.models.models import Models
from open_webui.utils.plugin import load_function_module_by_id
from open_webui.utils.tools import get_tools
from open_webui.utils.access_control import has_access
from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL
from open_webui.utils.misc import (
add_or_update_system_message,
get_last_user_message,
prepend_to_first_user_message_content,
openai_chat_chunk_message_template,
openai_chat_completion_message_template,
)
from open_webui.utils.payload import (
apply_model_params_to_body_openai,
apply_model_system_prompt_to_body,
)
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MAIN"])
def get_function_module_by_id(request: Request, pipe_id: str):
# Check if function is already loaded
if pipe_id not in request.app.state.FUNCTIONS:
function_module, _, _ = load_function_module_by_id(pipe_id)
request.app.state.FUNCTIONS[pipe_id] = function_module
else:
function_module = request.app.state.FUNCTIONS[pipe_id]
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
valves = Functions.get_function_valves_by_id(pipe_id)
function_module.valves = function_module.Valves(**(valves if valves else {}))
return function_module
async def get_function_models(request):
pipes = Functions.get_functions_by_type("pipe", active_only=True)
pipe_models = []
for pipe in pipes:
function_module = get_function_module_by_id(request, pipe.id)
# Check if function is a manifold
if hasattr(function_module, "pipes"):
sub_pipes = []
# Check if pipes is a function or a list
try:
if callable(function_module.pipes):
sub_pipes = function_module.pipes()
else:
sub_pipes = function_module.pipes
except Exception as e:
log.exception(e)
sub_pipes = []
log.debug(
f"get_function_models: function '{pipe.id}' is a manifold of {sub_pipes}"
)
for p in sub_pipes:
sub_pipe_id = f'{pipe.id}.{p["id"]}'
sub_pipe_name = p["name"]
if hasattr(function_module, "name"):
sub_pipe_name = f"{function_module.name}{sub_pipe_name}"
pipe_flag = {"type": pipe.type}
pipe_models.append(
{
"id": sub_pipe_id,
"name": sub_pipe_name,
"object": "model",
"created": pipe.created_at,
"owned_by": "openai",
"pipe": pipe_flag,
}
)
else:
pipe_flag = {"type": "pipe"}
log.debug(
f"get_function_models: function '{pipe.id}' is a single pipe {{ 'id': {pipe.id}, 'name': {pipe.name} }}"
)
pipe_models.append(
{
"id": pipe.id,
"name": pipe.name,
"object": "model",
"created": pipe.created_at,
"owned_by": "openai",
"pipe": pipe_flag,
}
)
return pipe_models
async def generate_function_chat_completion(
request, form_data, user, models: dict = {}
):
async def execute_pipe(pipe, params):
if inspect.iscoroutinefunction(pipe):
return await pipe(**params)
else:
return pipe(**params)
async def get_message_content(res: str | Generator | AsyncGenerator) -> str:
if isinstance(res, str):
return res
if isinstance(res, Generator):
return "".join(map(str, res))
if isinstance(res, AsyncGenerator):
return "".join([str(stream) async for stream in res])
def process_line(form_data: dict, line):
if isinstance(line, BaseModel):
line = line.model_dump_json()
line = f"data: {line}"
if isinstance(line, dict):
line = f"data: {json.dumps(line)}"
try:
line = line.decode("utf-8")
except Exception:
pass
if line.startswith("data:"):
return f"{line}\n\n"
else:
line = openai_chat_chunk_message_template(form_data["model"], line)
return f"data: {json.dumps(line)}\n\n"
def get_pipe_id(form_data: dict) -> str:
pipe_id = form_data["model"]
if "." in pipe_id:
pipe_id, _ = pipe_id.split(".", 1)
return pipe_id
def get_function_params(function_module, form_data, user, extra_params=None):
if extra_params is None:
extra_params = {}
pipe_id = get_pipe_id(form_data)
# Get the signature of the function
sig = inspect.signature(function_module.pipe)
params = {"body": form_data} | {
k: v for k, v in extra_params.items() if k in sig.parameters
}
if "__user__" in params and hasattr(function_module, "UserValves"):
user_valves = Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id)
try:
params["__user__"]["valves"] = function_module.UserValves(**user_valves)
except Exception as e:
log.exception(e)
params["__user__"]["valves"] = function_module.UserValves()
return params
model_id = form_data.get("model")
model_info = Models.get_model_by_id(model_id)
metadata = form_data.pop("metadata", {})
files = metadata.get("files", [])
tool_ids = metadata.get("tool_ids", [])
# Check if tool_ids is None
if tool_ids is None:
tool_ids = []
__event_emitter__ = None
__event_call__ = None
__task__ = None
__task_body__ = None
if metadata:
if all(k in metadata for k in ("session_id", "chat_id", "message_id")):
__event_emitter__ = get_event_emitter(metadata)
__event_call__ = get_event_call(metadata)
__task__ = metadata.get("task", None)
__task_body__ = metadata.get("task_body", None)
extra_params = {
"__event_emitter__": __event_emitter__,
"__event_call__": __event_call__,
"__task__": __task__,
"__task_body__": __task_body__,
"__files__": files,
"__user__": {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
},
"__metadata__": metadata,
"__request__": request,
}
extra_params["__tools__"] = get_tools(
request,
tool_ids,
user,
{
**extra_params,
"__model__": models.get(form_data["model"], None),
"__messages__": form_data["messages"],
"__files__": files,
},
)
if model_info:
if model_info.base_model_id:
form_data["model"] = model_info.base_model_id
params = model_info.params.model_dump()
form_data = apply_model_params_to_body_openai(params, form_data)
form_data = apply_model_system_prompt_to_body(params, form_data, user)
pipe_id = get_pipe_id(form_data)
function_module = get_function_module_by_id(request, pipe_id)
pipe = function_module.pipe
params = get_function_params(function_module, form_data, user, extra_params)
if form_data.get("stream", False):
async def stream_content():
try:
res = await execute_pipe(pipe, params)
# Directly return if the response is a StreamingResponse
if isinstance(res, StreamingResponse):
async for data in res.body_iterator:
yield data
return
if isinstance(res, dict):
yield f"data: {json.dumps(res)}\n\n"
return
except Exception as e:
log.error(f"Error: {e}")
yield f"data: {json.dumps({'error': {'detail':str(e)}})}\n\n"
return
if isinstance(res, str):
message = openai_chat_chunk_message_template(form_data["model"], res)
yield f"data: {json.dumps(message)}\n\n"
if isinstance(res, Iterator):
for line in res:
yield process_line(form_data, line)
if isinstance(res, AsyncGenerator):
async for line in res:
yield process_line(form_data, line)
if isinstance(res, str) or isinstance(res, Generator):
finish_message = openai_chat_chunk_message_template(
form_data["model"], ""
)
finish_message["choices"][0]["finish_reason"] = "stop"
yield f"data: {json.dumps(finish_message)}\n\n"
yield "data: [DONE]"
return StreamingResponse(stream_content(), media_type="text/event-stream")
else:
try:
res = await execute_pipe(pipe, params)
except Exception as e:
log.error(f"Error: {e}")
return {"error": {"detail": str(e)}}
if isinstance(res, StreamingResponse) or isinstance(res, dict):
return res
if isinstance(res, BaseModel):
return res.model_dump()
message = await get_message_content(res)
return openai_chat_completion_message_template(form_data["model"], message)

View File

@ -3,7 +3,7 @@ import logging
from contextlib import contextmanager
from typing import Any, Optional
from open_webui.apps.webui.internal.wrappers import register_connection
from open_webui.internal.wrappers import register_connection
from open_webui.env import (
OPEN_WEBUI_DIR,
DATABASE_URL,

File diff suppressed because it is too large Load Diff

View File

@ -1,7 +1,7 @@
from logging.config import fileConfig
from alembic import context
from open_webui.apps.webui.models.auths import Auth
from open_webui.models.auths import Auth
from open_webui.env import DATABASE_URL
from sqlalchemy import engine_from_config, pool

View File

@ -9,7 +9,7 @@ from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
import open_webui.apps.webui.internal.db
import open_webui.internal.db
${imports if imports else ""}
# revision identifiers, used by Alembic.

View File

@ -11,8 +11,8 @@ from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
import open_webui.apps.webui.internal.db
from open_webui.apps.webui.internal.db import JSONField
import open_webui.internal.db
from open_webui.internal.db import JSONField
from open_webui.migrations.util import get_existing_tables
# revision identifiers, used by Alembic.

View File

@ -0,0 +1,85 @@
"""Add group table
Revision ID: 922e7a387820
Revises: 4ace53fd72c8
Create Date: 2024-11-14 03:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
revision = "922e7a387820"
down_revision = "4ace53fd72c8"
branch_labels = None
depends_on = None
def upgrade():
op.create_table(
"group",
sa.Column("id", sa.Text(), nullable=False, primary_key=True, unique=True),
sa.Column("user_id", sa.Text(), nullable=True),
sa.Column("name", sa.Text(), nullable=True),
sa.Column("description", sa.Text(), nullable=True),
sa.Column("data", sa.JSON(), nullable=True),
sa.Column("meta", sa.JSON(), nullable=True),
sa.Column("permissions", sa.JSON(), nullable=True),
sa.Column("user_ids", sa.JSON(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.Column("updated_at", sa.BigInteger(), nullable=True),
)
# Add 'access_control' column to 'model' table
op.add_column(
"model",
sa.Column("access_control", sa.JSON(), nullable=True),
)
# Add 'is_active' column to 'model' table
op.add_column(
"model",
sa.Column(
"is_active",
sa.Boolean(),
nullable=False,
server_default=sa.sql.expression.true(),
),
)
# Add 'access_control' column to 'knowledge' table
op.add_column(
"knowledge",
sa.Column("access_control", sa.JSON(), nullable=True),
)
# Add 'access_control' column to 'prompt' table
op.add_column(
"prompt",
sa.Column("access_control", sa.JSON(), nullable=True),
)
# Add 'access_control' column to 'tools' table
op.add_column(
"tool",
sa.Column("access_control", sa.JSON(), nullable=True),
)
def downgrade():
op.drop_table("group")
# Drop 'access_control' column from 'model' table
op.drop_column("model", "access_control")
# Drop 'is_active' column from 'model' table
op.drop_column("model", "is_active")
# Drop 'access_control' column from 'knowledge' table
op.drop_column("knowledge", "access_control")
# Drop 'access_control' column from 'prompt' table
op.drop_column("prompt", "access_control")
# Drop 'access_control' column from 'tools' table
op.drop_column("tool", "access_control")

View File

@ -2,12 +2,12 @@ import logging
import uuid
from typing import Optional
from open_webui.apps.webui.internal.db import Base, get_db
from open_webui.apps.webui.models.users import UserModel, Users
from open_webui.internal.db import Base, get_db
from open_webui.models.users import UserModel, Users
from open_webui.env import SRC_LOG_LEVELS
from pydantic import BaseModel
from sqlalchemy import Boolean, Column, String, Text
from open_webui.utils.utils import verify_password
from open_webui.utils.auth import verify_password
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
@ -64,6 +64,11 @@ class SigninForm(BaseModel):
password: str
class LdapForm(BaseModel):
user: str
password: str
class ProfileImageUrlForm(BaseModel):
profile_image_url: str

View File

@ -3,8 +3,8 @@ import time
import uuid
from typing import Optional
from open_webui.apps.webui.internal.db import Base, get_db
from open_webui.apps.webui.models.tags import TagModel, Tag, Tags
from open_webui.internal.db import Base, get_db
from open_webui.models.tags import TagModel, Tag, Tags
from pydantic import BaseModel, ConfigDict
@ -203,15 +203,22 @@ class ChatTable:
def update_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
try:
with get_db() as db:
print("update_shared_chat_by_id")
chat = db.get(Chat, chat_id)
print(chat)
chat.title = chat.title
chat.chat = chat.chat
db.commit()
db.refresh(chat)
shared_chat = (
db.query(Chat).filter_by(user_id=f"shared-{chat_id}").first()
)
return self.get_chat_by_id(chat.share_id)
if shared_chat is None:
return self.insert_shared_chat_by_chat_id(chat_id)
shared_chat.title = chat.title
shared_chat.chat = chat.chat
shared_chat.updated_at = int(time.time())
db.commit()
db.refresh(shared_chat)
return ChatModel.model_validate(shared_chat)
except Exception:
return None

View File

@ -3,8 +3,8 @@ import time
import uuid
from typing import Optional
from open_webui.apps.webui.internal.db import Base, get_db
from open_webui.apps.webui.models.chats import Chats
from open_webui.internal.db import Base, get_db
from open_webui.models.chats import Chats
from open_webui.env import SRC_LOG_LEVELS
from pydantic import BaseModel, ConfigDict

View File

@ -2,7 +2,7 @@ import logging
import time
from typing import Optional
from open_webui.apps.webui.internal.db import Base, JSONField, get_db
from open_webui.internal.db import Base, JSONField, get_db
from open_webui.env import SRC_LOG_LEVELS
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, String, Text, JSON

View File

@ -3,8 +3,8 @@ import time
import uuid
from typing import Optional
from open_webui.apps.webui.internal.db import Base, get_db
from open_webui.apps.webui.models.chats import Chats
from open_webui.internal.db import Base, get_db
from open_webui.models.chats import Chats
from open_webui.env import SRC_LOG_LEVELS
from pydantic import BaseModel, ConfigDict

View File

@ -2,8 +2,8 @@ import logging
import time
from typing import Optional
from open_webui.apps.webui.internal.db import Base, JSONField, get_db
from open_webui.apps.webui.models.users import Users
from open_webui.internal.db import Base, JSONField, get_db
from open_webui.models.users import Users
from open_webui.env import SRC_LOG_LEVELS
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Boolean, Column, String, Text

View File

@ -0,0 +1,186 @@
import json
import logging
import time
from typing import Optional
import uuid
from open_webui.internal.db import Base, get_db
from open_webui.env import SRC_LOG_LEVELS
from open_webui.models.files import FileMetadataResponse
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, String, Text, JSON, func
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
# UserGroup DB Schema
####################
class Group(Base):
__tablename__ = "group"
id = Column(Text, unique=True, primary_key=True)
user_id = Column(Text)
name = Column(Text)
description = Column(Text)
data = Column(JSON, nullable=True)
meta = Column(JSON, nullable=True)
permissions = Column(JSON, nullable=True)
user_ids = Column(JSON, nullable=True)
created_at = Column(BigInteger)
updated_at = Column(BigInteger)
class GroupModel(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: str
user_id: str
name: str
description: str
data: Optional[dict] = None
meta: Optional[dict] = None
permissions: Optional[dict] = None
user_ids: list[str] = []
created_at: int # timestamp in epoch
updated_at: int # timestamp in epoch
####################
# Forms
####################
class GroupResponse(BaseModel):
id: str
user_id: str
name: str
description: str
permissions: Optional[dict] = None
data: Optional[dict] = None
meta: Optional[dict] = None
user_ids: list[str] = []
created_at: int # timestamp in epoch
updated_at: int # timestamp in epoch
class GroupForm(BaseModel):
name: str
description: str
class GroupUpdateForm(GroupForm):
permissions: Optional[dict] = None
user_ids: Optional[list[str]] = None
admin_ids: Optional[list[str]] = None
class GroupTable:
def insert_new_group(
self, user_id: str, form_data: GroupForm
) -> Optional[GroupModel]:
with get_db() as db:
group = GroupModel(
**{
**form_data.model_dump(),
"id": str(uuid.uuid4()),
"user_id": user_id,
"created_at": int(time.time()),
"updated_at": int(time.time()),
}
)
try:
result = Group(**group.model_dump())
db.add(result)
db.commit()
db.refresh(result)
if result:
return GroupModel.model_validate(result)
else:
return None
except Exception:
return None
def get_groups(self) -> list[GroupModel]:
with get_db() as db:
return [
GroupModel.model_validate(group)
for group in db.query(Group).order_by(Group.updated_at.desc()).all()
]
def get_groups_by_member_id(self, user_id: str) -> list[GroupModel]:
with get_db() as db:
return [
GroupModel.model_validate(group)
for group in db.query(Group)
.filter(
func.json_array_length(Group.user_ids) > 0
) # Ensure array exists
.filter(
Group.user_ids.cast(String).like(f'%"{user_id}"%')
) # String-based check
.order_by(Group.updated_at.desc())
.all()
]
def get_group_by_id(self, id: str) -> Optional[GroupModel]:
try:
with get_db() as db:
group = db.query(Group).filter_by(id=id).first()
return GroupModel.model_validate(group) if group else None
except Exception:
return None
def update_group_by_id(
self, id: str, form_data: GroupUpdateForm, overwrite: bool = False
) -> Optional[GroupModel]:
try:
with get_db() as db:
db.query(Group).filter_by(id=id).update(
{
**form_data.model_dump(exclude_none=True),
"updated_at": int(time.time()),
}
)
db.commit()
return self.get_group_by_id(id=id)
except Exception as e:
log.exception(e)
return None
def delete_group_by_id(self, id: str) -> bool:
try:
with get_db() as db:
db.query(Group).filter_by(id=id).delete()
db.commit()
return True
except Exception:
return False
def delete_all_groups(self) -> bool:
with get_db() as db:
try:
db.query(Group).delete()
db.commit()
return True
except Exception:
return False
Groups = GroupTable()

View File

@ -4,15 +4,17 @@ import time
from typing import Optional
import uuid
from open_webui.apps.webui.internal.db import Base, get_db
from open_webui.internal.db import Base, get_db
from open_webui.env import SRC_LOG_LEVELS
from open_webui.apps.webui.models.files import FileMetadataResponse
from open_webui.models.files import FileMetadataResponse
from open_webui.models.users import Users, UserResponse
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, String, Text, JSON
from open_webui.utils.access_control import has_access
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
@ -34,6 +36,23 @@ class Knowledge(Base):
data = Column(JSON, nullable=True)
meta = Column(JSON, nullable=True)
access_control = Column(JSON, nullable=True) # Controls data access levels.
# Defines access control rules for this entry.
# - `None`: Public access, available to all users with the "user" role.
# - `{}`: Private access, restricted exclusively to the owner.
# - Custom permissions: Specific access control for reading and writing;
# Can specify group or user-level restrictions:
# {
# "read": {
# "group_ids": ["group_id1", "group_id2"],
# "user_ids": ["user_id1", "user_id2"]
# },
# "write": {
# "group_ids": ["group_id1", "group_id2"],
# "user_ids": ["user_id1", "user_id2"]
# }
# }
created_at = Column(BigInteger)
updated_at = Column(BigInteger)
@ -50,6 +69,8 @@ class KnowledgeModel(BaseModel):
data: Optional[dict] = None
meta: Optional[dict] = None
access_control: Optional[dict] = None
created_at: int # timestamp in epoch
updated_at: int # timestamp in epoch
@ -59,15 +80,15 @@ class KnowledgeModel(BaseModel):
####################
class KnowledgeResponse(BaseModel):
id: str
name: str
description: str
data: Optional[dict] = None
meta: Optional[dict] = None
created_at: int # timestamp in epoch
updated_at: int # timestamp in epoch
class KnowledgeUserModel(KnowledgeModel):
user: Optional[UserResponse] = None
class KnowledgeResponse(KnowledgeModel):
files: Optional[list[FileMetadataResponse | dict]] = None
class KnowledgeUserResponse(KnowledgeUserModel):
files: Optional[list[FileMetadataResponse | dict]] = None
@ -75,12 +96,7 @@ class KnowledgeForm(BaseModel):
name: str
description: str
data: Optional[dict] = None
class KnowledgeUpdateForm(BaseModel):
name: Optional[str] = None
description: Optional[str] = None
data: Optional[dict] = None
access_control: Optional[dict] = None
class KnowledgeTable:
@ -110,14 +126,33 @@ class KnowledgeTable:
except Exception:
return None
def get_knowledge_items(self) -> list[KnowledgeModel]:
def get_knowledge_bases(self) -> list[KnowledgeUserModel]:
with get_db() as db:
return [
KnowledgeModel.model_validate(knowledge)
for knowledge in db.query(Knowledge)
.order_by(Knowledge.updated_at.desc())
.all()
]
knowledge_bases = []
for knowledge in (
db.query(Knowledge).order_by(Knowledge.updated_at.desc()).all()
):
user = Users.get_user_by_id(knowledge.user_id)
knowledge_bases.append(
KnowledgeUserModel.model_validate(
{
**KnowledgeModel.model_validate(knowledge).model_dump(),
"user": user.model_dump() if user else None,
}
)
)
return knowledge_bases
def get_knowledge_bases_by_user_id(
self, user_id: str, permission: str = "write"
) -> list[KnowledgeUserModel]:
knowledge_bases = self.get_knowledge_bases()
return [
knowledge_base
for knowledge_base in knowledge_bases
if knowledge_base.user_id == user_id
or has_access(user_id, permission, knowledge_base.access_control)
]
def get_knowledge_by_id(self, id: str) -> Optional[KnowledgeModel]:
try:
@ -128,14 +163,32 @@ class KnowledgeTable:
return None
def update_knowledge_by_id(
self, id: str, form_data: KnowledgeUpdateForm, overwrite: bool = False
self, id: str, form_data: KnowledgeForm, overwrite: bool = False
) -> Optional[KnowledgeModel]:
try:
with get_db() as db:
knowledge = self.get_knowledge_by_id(id=id)
db.query(Knowledge).filter_by(id=id).update(
{
**form_data.model_dump(exclude_none=True),
**form_data.model_dump(),
"updated_at": int(time.time()),
}
)
db.commit()
return self.get_knowledge_by_id(id=id)
except Exception as e:
log.exception(e)
return None
def update_knowledge_data_by_id(
self, id: str, data: dict
) -> Optional[KnowledgeModel]:
try:
with get_db() as db:
knowledge = self.get_knowledge_by_id(id=id)
db.query(Knowledge).filter_by(id=id).update(
{
"data": data,
"updated_at": int(time.time()),
}
)

View File

@ -2,7 +2,7 @@ import time
import uuid
from typing import Optional
from open_webui.apps.webui.internal.db import Base, get_db
from open_webui.internal.db import Base, get_db
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, String, Text

View File

@ -2,10 +2,21 @@ import logging
import time
from typing import Optional
from open_webui.apps.webui.internal.db import Base, JSONField, get_db
from open_webui.internal.db import Base, JSONField, get_db
from open_webui.env import SRC_LOG_LEVELS
from open_webui.models.users import Users, UserResponse
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, Text
from sqlalchemy import or_, and_, func
from sqlalchemy.dialects import postgresql, sqlite
from sqlalchemy import BigInteger, Column, Text, JSON, Boolean
from open_webui.utils.access_control import has_access
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
@ -67,6 +78,25 @@ class Model(Base):
Holds a JSON encoded blob of metadata, see `ModelMeta`.
"""
access_control = Column(JSON, nullable=True) # Controls data access levels.
# Defines access control rules for this entry.
# - `None`: Public access, available to all users with the "user" role.
# - `{}`: Private access, restricted exclusively to the owner.
# - Custom permissions: Specific access control for reading and writing;
# Can specify group or user-level restrictions:
# {
# "read": {
# "group_ids": ["group_id1", "group_id2"],
# "user_ids": ["user_id1", "user_id2"]
# },
# "write": {
# "group_ids": ["group_id1", "group_id2"],
# "user_ids": ["user_id1", "user_id2"]
# }
# }
is_active = Column(Boolean, default=True)
updated_at = Column(BigInteger)
created_at = Column(BigInteger)
@ -80,6 +110,9 @@ class ModelModel(BaseModel):
params: ModelParams
meta: ModelMeta
access_control: Optional[dict] = None
is_active: bool
updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch
@ -91,12 +124,12 @@ class ModelModel(BaseModel):
####################
class ModelResponse(BaseModel):
id: str
name: str
meta: ModelMeta
updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch
class ModelUserResponse(ModelModel):
user: Optional[UserResponse] = None
class ModelResponse(ModelModel):
pass
class ModelForm(BaseModel):
@ -105,6 +138,8 @@ class ModelForm(BaseModel):
name: str
meta: ModelMeta
params: ModelParams
access_control: Optional[dict] = None
is_active: bool = True
class ModelsTable:
@ -138,6 +173,39 @@ class ModelsTable:
with get_db() as db:
return [ModelModel.model_validate(model) for model in db.query(Model).all()]
def get_models(self) -> list[ModelUserResponse]:
with get_db() as db:
models = []
for model in db.query(Model).filter(Model.base_model_id != None).all():
user = Users.get_user_by_id(model.user_id)
models.append(
ModelUserResponse.model_validate(
{
**ModelModel.model_validate(model).model_dump(),
"user": user.model_dump() if user else None,
}
)
)
return models
def get_base_models(self) -> list[ModelModel]:
with get_db() as db:
return [
ModelModel.model_validate(model)
for model in db.query(Model).filter(Model.base_model_id == None).all()
]
def get_models_by_user_id(
self, user_id: str, permission: str = "write"
) -> list[ModelUserResponse]:
models = self.get_models()
return [
model
for model in models
if model.user_id == user_id
or has_access(user_id, permission, model.access_control)
]
def get_model_by_id(self, id: str) -> Optional[ModelModel]:
try:
with get_db() as db:
@ -146,6 +214,23 @@ class ModelsTable:
except Exception:
return None
def toggle_model_by_id(self, id: str) -> Optional[ModelModel]:
with get_db() as db:
try:
is_active = db.query(Model).filter_by(id=id).first().is_active
db.query(Model).filter_by(id=id).update(
{
"is_active": not is_active,
"updated_at": int(time.time()),
}
)
db.commit()
return self.get_model_by_id(id)
except Exception:
return None
def update_model_by_id(self, id: str, model: ModelForm) -> Optional[ModelModel]:
try:
with get_db() as db:
@ -153,7 +238,7 @@ class ModelsTable:
result = (
db.query(Model)
.filter_by(id=id)
.update(model.model_dump(exclude={"id"}, exclude_none=True))
.update(model.model_dump(exclude={"id"}))
)
db.commit()
@ -175,5 +260,15 @@ class ModelsTable:
except Exception:
return False
def delete_all_models(self) -> bool:
try:
with get_db() as db:
db.query(Model).delete()
db.commit()
return True
except Exception:
return False
Models = ModelsTable()

View File

@ -1,9 +1,13 @@
import time
from typing import Optional
from open_webui.apps.webui.internal.db import Base, get_db
from open_webui.internal.db import Base, get_db
from open_webui.models.users import Users, UserResponse
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, String, Text
from sqlalchemy import BigInteger, Column, String, Text, JSON
from open_webui.utils.access_control import has_access
####################
# Prompts DB Schema
@ -19,6 +23,23 @@ class Prompt(Base):
content = Column(Text)
timestamp = Column(BigInteger)
access_control = Column(JSON, nullable=True) # Controls data access levels.
# Defines access control rules for this entry.
# - `None`: Public access, available to all users with the "user" role.
# - `{}`: Private access, restricted exclusively to the owner.
# - Custom permissions: Specific access control for reading and writing;
# Can specify group or user-level restrictions:
# {
# "read": {
# "group_ids": ["group_id1", "group_id2"],
# "user_ids": ["user_id1", "user_id2"]
# },
# "write": {
# "group_ids": ["group_id1", "group_id2"],
# "user_ids": ["user_id1", "user_id2"]
# }
# }
class PromptModel(BaseModel):
command: str
@ -27,6 +48,7 @@ class PromptModel(BaseModel):
content: str
timestamp: int # timestamp in epoch
access_control: Optional[dict] = None
model_config = ConfigDict(from_attributes=True)
@ -35,10 +57,15 @@ class PromptModel(BaseModel):
####################
class PromptUserResponse(PromptModel):
user: Optional[UserResponse] = None
class PromptForm(BaseModel):
command: str
title: str
content: str
access_control: Optional[dict] = None
class PromptsTable:
@ -48,16 +75,14 @@ class PromptsTable:
prompt = PromptModel(
**{
"user_id": user_id,
"command": form_data.command,
"title": form_data.title,
"content": form_data.content,
**form_data.model_dump(),
"timestamp": int(time.time()),
}
)
try:
with get_db() as db:
result = Prompt(**prompt.dict())
result = Prompt(**prompt.model_dump())
db.add(result)
db.commit()
db.refresh(result)
@ -76,11 +101,34 @@ class PromptsTable:
except Exception:
return None
def get_prompts(self) -> list[PromptModel]:
def get_prompts(self) -> list[PromptUserResponse]:
with get_db() as db:
return [
PromptModel.model_validate(prompt) for prompt in db.query(Prompt).all()
]
prompts = []
for prompt in db.query(Prompt).order_by(Prompt.timestamp.desc()).all():
user = Users.get_user_by_id(prompt.user_id)
prompts.append(
PromptUserResponse.model_validate(
{
**PromptModel.model_validate(prompt).model_dump(),
"user": user.model_dump() if user else None,
}
)
)
return prompts
def get_prompts_by_user_id(
self, user_id: str, permission: str = "write"
) -> list[PromptUserResponse]:
prompts = self.get_prompts()
return [
prompt
for prompt in prompts
if prompt.user_id == user_id
or has_access(user_id, permission, prompt.access_control)
]
def update_prompt_by_command(
self, command: str, form_data: PromptForm
@ -90,6 +138,7 @@ class PromptsTable:
prompt = db.query(Prompt).filter_by(command=command).first()
prompt.title = form_data.title
prompt.content = form_data.content
prompt.access_control = form_data.access_control
prompt.timestamp = int(time.time())
db.commit()
return PromptModel.model_validate(prompt)

View File

@ -3,7 +3,7 @@ import time
import uuid
from typing import Optional
from open_webui.apps.webui.internal.db import Base, get_db
from open_webui.internal.db import Base, get_db
from open_webui.env import SRC_LOG_LEVELS

View File

@ -2,11 +2,14 @@ import logging
import time
from typing import Optional
from open_webui.apps.webui.internal.db import Base, JSONField, get_db
from open_webui.apps.webui.models.users import Users
from open_webui.internal.db import Base, JSONField, get_db
from open_webui.models.users import Users, UserResponse
from open_webui.env import SRC_LOG_LEVELS
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, String, Text
from sqlalchemy import BigInteger, Column, String, Text, JSON
from open_webui.utils.access_control import has_access
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
@ -26,6 +29,24 @@ class Tool(Base):
specs = Column(JSONField)
meta = Column(JSONField)
valves = Column(JSONField)
access_control = Column(JSON, nullable=True) # Controls data access levels.
# Defines access control rules for this entry.
# - `None`: Public access, available to all users with the "user" role.
# - `{}`: Private access, restricted exclusively to the owner.
# - Custom permissions: Specific access control for reading and writing;
# Can specify group or user-level restrictions:
# {
# "read": {
# "group_ids": ["group_id1", "group_id2"],
# "user_ids": ["user_id1", "user_id2"]
# },
# "write": {
# "group_ids": ["group_id1", "group_id2"],
# "user_ids": ["user_id1", "user_id2"]
# }
# }
updated_at = Column(BigInteger)
created_at = Column(BigInteger)
@ -42,6 +63,8 @@ class ToolModel(BaseModel):
content: str
specs: list[dict]
meta: ToolMeta
access_control: Optional[dict] = None
updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch
@ -53,20 +76,30 @@ class ToolModel(BaseModel):
####################
class ToolUserModel(ToolModel):
user: Optional[UserResponse] = None
class ToolResponse(BaseModel):
id: str
user_id: str
name: str
meta: ToolMeta
access_control: Optional[dict] = None
updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch
class ToolUserResponse(ToolResponse):
user: Optional[UserResponse] = None
class ToolForm(BaseModel):
id: str
name: str
content: str
meta: ToolMeta
access_control: Optional[dict] = None
class ToolValves(BaseModel):
@ -109,9 +142,32 @@ class ToolsTable:
except Exception:
return None
def get_tools(self) -> list[ToolModel]:
def get_tools(self) -> list[ToolUserModel]:
with get_db() as db:
return [ToolModel.model_validate(tool) for tool in db.query(Tool).all()]
tools = []
for tool in db.query(Tool).order_by(Tool.updated_at.desc()).all():
user = Users.get_user_by_id(tool.user_id)
tools.append(
ToolUserModel.model_validate(
{
**ToolModel.model_validate(tool).model_dump(),
"user": user.model_dump() if user else None,
}
)
)
return tools
def get_tools_by_user_id(
self, user_id: str, permission: str = "write"
) -> list[ToolUserModel]:
tools = self.get_tools()
return [
tool
for tool in tools
if tool.user_id == user_id
or has_access(user_id, permission, tool.access_control)
]
def get_tool_valves_by_id(self, id: str) -> Optional[dict]:
try:

View File

@ -1,8 +1,8 @@
import time
from typing import Optional
from open_webui.apps.webui.internal.db import Base, JSONField, get_db
from open_webui.apps.webui.models.chats import Chats
from open_webui.internal.db import Base, JSONField, get_db
from open_webui.models.chats import Chats
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, String, Text
@ -62,6 +62,14 @@ class UserModel(BaseModel):
####################
class UserResponse(BaseModel):
id: str
name: str
email: str
role: str
profile_image_url: str
class UserRoleUpdateForm(BaseModel):
id: str
role: str

View File

@ -1,6 +1,7 @@
import requests
import logging
import ftfy
import sys
from langchain_community.document_loaders import (
BSHTMLLoader,
@ -18,8 +19,9 @@ from langchain_community.document_loaders import (
YoutubeLoader,
)
from langchain_core.documents import Document
from open_webui.env import SRC_LOG_LEVELS
from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
@ -106,7 +108,7 @@ class TikaLoader:
if "Content-Type" in raw_metadata:
headers["Content-Type"] = raw_metadata["Content-Type"]
log.info("Tika extracted text: %s", text)
log.debug("Tika extracted text: %s", text)
return [Document(page_content=text, metadata=headers)]
else:
@ -159,7 +161,7 @@ class Loader:
elif file_ext in ["htm", "html"]:
loader = BSHTMLLoader(file_path, open_encoding="unicode_escape")
elif file_ext == "md":
loader = UnstructuredMarkdownLoader(file_path)
loader = TextLoader(file_path, autodetect_encoding=True)
elif file_content_type == "application/epub+zip":
loader = UnstructuredEPubLoader(file_path)
elif (

View File

@ -0,0 +1,117 @@
import logging
from typing import Any, Dict, Generator, List, Optional, Sequence, Union
from urllib.parse import parse_qs, urlparse
from langchain_core.documents import Document
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
ALLOWED_SCHEMES = {"http", "https"}
ALLOWED_NETLOCS = {
"youtu.be",
"m.youtube.com",
"youtube.com",
"www.youtube.com",
"www.youtube-nocookie.com",
"vid.plus",
}
def _parse_video_id(url: str) -> Optional[str]:
"""Parse a YouTube URL and return the video ID if valid, otherwise None."""
parsed_url = urlparse(url)
if parsed_url.scheme not in ALLOWED_SCHEMES:
return None
if parsed_url.netloc not in ALLOWED_NETLOCS:
return None
path = parsed_url.path
if path.endswith("/watch"):
query = parsed_url.query
parsed_query = parse_qs(query)
if "v" in parsed_query:
ids = parsed_query["v"]
video_id = ids if isinstance(ids, str) else ids[0]
else:
return None
else:
path = parsed_url.path.lstrip("/")
video_id = path.split("/")[-1]
if len(video_id) != 11: # Video IDs are 11 characters long
return None
return video_id
class YoutubeLoader:
"""Load `YouTube` video transcripts."""
def __init__(
self,
video_id: str,
language: Union[str, Sequence[str]] = "en",
proxy_url: Optional[str] = None,
):
"""Initialize with YouTube video ID."""
_video_id = _parse_video_id(video_id)
self.video_id = _video_id if _video_id is not None else video_id
self._metadata = {"source": video_id}
self.language = language
self.proxy_url = proxy_url
if isinstance(language, str):
self.language = [language]
else:
self.language = language
def load(self) -> List[Document]:
"""Load YouTube transcripts into `Document` objects."""
try:
from youtube_transcript_api import (
NoTranscriptFound,
TranscriptsDisabled,
YouTubeTranscriptApi,
)
except ImportError:
raise ImportError(
'Could not import "youtube_transcript_api" Python package. '
"Please install it with `pip install youtube-transcript-api`."
)
if self.proxy_url:
youtube_proxies = {
"http": self.proxy_url,
"https": self.proxy_url,
}
# Don't log complete URL because it might contain secrets
log.debug(f"Using proxy URL: {self.proxy_url[:14]}...")
else:
youtube_proxies = None
try:
transcript_list = YouTubeTranscriptApi.list_transcripts(
self.video_id, proxies=youtube_proxies
)
except Exception as e:
log.exception("Loading YouTube transcript failed")
return []
try:
transcript = transcript_list.find_transcript(self.language)
except NoTranscriptFound:
transcript = transcript_list.find_transcript(["en"])
transcript_pieces: List[Dict[str, Any]] = transcript.fetch()
transcript = " ".join(
map(
lambda transcript_piece: transcript_piece["text"].strip(" "),
transcript_pieces,
)
)
return [Document(page_content=transcript, metadata=self._metadata)]

View File

@ -3,6 +3,7 @@ import os
import uuid
from typing import Optional, Union
import asyncio
import requests
from huggingface_hub import snapshot_download
@ -10,17 +11,10 @@ from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriev
from langchain_community.retrievers import BM25Retriever
from langchain_core.documents import Document
from open_webui.apps.ollama.main import (
GenerateEmbedForm,
generate_ollama_batch_embeddings,
)
from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT
from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT
from open_webui.utils.misc import get_last_user_message
from open_webui.env import SRC_LOG_LEVELS
from open_webui.config import DEFAULT_RAG_TEMPLATE
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
@ -76,7 +70,7 @@ def query_doc(
limit=k,
)
log.info(f"query_doc:result {result}")
log.info(f"query_doc:result {result.ids} {result.metadatas}")
return result
except Exception as e:
print(e)
@ -127,7 +121,10 @@ def query_doc_with_hybrid_search(
"metadatas": [[d.metadata for d in result]],
}
log.info(f"query_doc_with_hybrid_search:result {result}")
log.info(
"query_doc_with_hybrid_search:result "
+ f'{result["metadatas"]} {result["distances"]}'
)
return result
except Exception as e:
raise e
@ -178,35 +175,34 @@ def merge_and_sort_query_results(
def query_collection(
collection_names: list[str],
query: str,
queries: list[str],
embedding_function,
k: int,
) -> dict:
results = []
query_embedding = embedding_function(query)
for collection_name in collection_names:
if collection_name:
try:
result = query_doc(
collection_name=collection_name,
k=k,
query_embedding=query_embedding,
)
if result is not None:
results.append(result.model_dump())
except Exception as e:
log.exception(f"Error when querying the collection: {e}")
else:
pass
for query in queries:
query_embedding = embedding_function(query)
for collection_name in collection_names:
if collection_name:
try:
result = query_doc(
collection_name=collection_name,
k=k,
query_embedding=query_embedding,
)
if result is not None:
results.append(result.model_dump())
except Exception as e:
log.exception(f"Error when querying the collection: {e}")
else:
pass
return merge_and_sort_query_results(results, k=k)
def query_collection_with_hybrid_search(
collection_names: list[str],
query: str,
queries: list[str],
embedding_function,
k: int,
reranking_function,
@ -216,15 +212,16 @@ def query_collection_with_hybrid_search(
error = False
for collection_name in collection_names:
try:
result = query_doc_with_hybrid_search(
collection_name=collection_name,
query=query,
embedding_function=embedding_function,
k=k,
reranking_function=reranking_function,
r=r,
)
results.append(result)
for query in queries:
result = query_doc_with_hybrid_search(
collection_name=collection_name,
query=query,
embedding_function=embedding_function,
k=k,
reranking_function=reranking_function,
r=r,
)
results.append(result)
except Exception as e:
log.exception(
"Error when querying the collection with " f"hybrid_search: {e}"
@ -239,50 +236,12 @@ def query_collection_with_hybrid_search(
return merge_and_sort_query_results(results, k=k, reverse=True)
def rag_template(template: str, context: str, query: str):
if template == "":
template = DEFAULT_RAG_TEMPLATE
if "[context]" not in template and "{{CONTEXT}}" not in template:
log.debug(
"WARNING: The RAG template does not contain the '[context]' or '{{CONTEXT}}' placeholder."
)
if "<context>" in context and "</context>" in context:
log.debug(
"WARNING: Potential prompt injection attack: the RAG "
"context contains '<context>' and '</context>'. This might be "
"nothing, or the user might be trying to hack something."
)
query_placeholders = []
if "[query]" in context:
query_placeholder = "{{QUERY" + str(uuid.uuid4()) + "}}"
template = template.replace("[query]", query_placeholder)
query_placeholders.append(query_placeholder)
if "{{QUERY}}" in context:
query_placeholder = "{{QUERY" + str(uuid.uuid4()) + "}}"
template = template.replace("{{QUERY}}", query_placeholder)
query_placeholders.append(query_placeholder)
template = template.replace("[context]", context)
template = template.replace("{{CONTEXT}}", context)
template = template.replace("[query]", query)
template = template.replace("{{QUERY}}", query)
for query_placeholder in query_placeholders:
template = template.replace(query_placeholder, query)
return template
def get_embedding_function(
embedding_engine,
embedding_model,
embedding_function,
openai_key,
openai_url,
url,
key,
embedding_batch_size,
):
if embedding_engine == "":
@ -292,8 +251,8 @@ def get_embedding_function(
engine=embedding_engine,
model=embedding_model,
text=query,
key=openai_key if embedding_engine == "openai" else "",
url=openai_url if embedding_engine == "openai" else "",
url=url,
key=key,
)
def generate_multiple(query, func):
@ -308,17 +267,16 @@ def get_embedding_function(
return lambda query: generate_multiple(query, func)
def get_rag_context(
def get_sources_from_files(
files,
messages,
queries,
embedding_function,
k,
reranking_function,
r,
hybrid_search,
):
log.debug(f"files: {files} {messages} {embedding_function} {reranking_function}")
query = get_last_user_message(messages)
log.debug(f"files: {files} {queries} {embedding_function} {reranking_function}")
extracted_collections = []
relevant_contexts = []
@ -360,7 +318,7 @@ def get_rag_context(
try:
context = query_collection_with_hybrid_search(
collection_names=collection_names,
query=query,
queries=queries,
embedding_function=embedding_function,
k=k,
reranking_function=reranking_function,
@ -375,7 +333,7 @@ def get_rag_context(
if (not hybrid_search) or (context is None):
context = query_collection(
collection_names=collection_names,
query=query,
queries=queries,
embedding_function=embedding_function,
k=k,
)
@ -389,43 +347,24 @@ def get_rag_context(
del file["data"]
relevant_contexts.append({**context, "file": file})
contexts = []
citations = []
sources = []
for context in relevant_contexts:
try:
if "documents" in context:
file_names = list(
set(
[
metadata["name"]
for metadata in context["metadatas"][0]
if metadata is not None and "name" in metadata
]
)
)
contexts.append(
((", ".join(file_names) + ":\n\n") if file_names else "")
+ "\n\n".join(
[text for text in context["documents"][0] if text is not None]
)
)
if "metadatas" in context:
citation = {
source = {
"source": context["file"],
"document": context["documents"][0],
"metadata": context["metadatas"][0],
}
if "distances" in context and context["distances"]:
citation["distances"] = context["distances"][0]
citations.append(citation)
source["distances"] = context["distances"][0]
sources.append(source)
except Exception as e:
log.exception(e)
print("contexts", contexts)
print("citations", citations)
return contexts, citations
return sources
def get_model_path(model: str, update_model: bool = False):
@ -467,7 +406,7 @@ def get_model_path(model: str, update_model: bool = False):
def generate_openai_batch_embeddings(
model: str, texts: list[str], key: str, url: str = "https://api.openai.com/v1"
model: str, texts: list[str], url: str = "https://api.openai.com/v1", key: str = ""
) -> Optional[list[list[float]]]:
try:
r = requests.post(
@ -489,29 +428,49 @@ def generate_openai_batch_embeddings(
return None
def generate_ollama_batch_embeddings(
model: str, texts: list[str], url: str, key: str = ""
) -> Optional[list[list[float]]]:
try:
r = requests.post(
f"{url}/api/embed",
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {key}",
},
json={"input": texts, "model": model},
)
r.raise_for_status()
data = r.json()
if "embeddings" in data:
return data["embeddings"]
else:
raise "Something went wrong :/"
except Exception as e:
print(e)
return None
def generate_embeddings(engine: str, model: str, text: Union[str, list[str]], **kwargs):
url = kwargs.get("url", "")
key = kwargs.get("key", "")
if engine == "ollama":
if isinstance(text, list):
embeddings = generate_ollama_batch_embeddings(
GenerateEmbedForm(**{"model": model, "input": text})
**{"model": model, "texts": text, "url": url, "key": key}
)
else:
embeddings = generate_ollama_batch_embeddings(
GenerateEmbedForm(**{"model": model, "input": [text]})
**{"model": model, "texts": [text], "url": url, "key": key}
)
return (
embeddings["embeddings"][0]
if isinstance(text, str)
else embeddings["embeddings"]
)
return embeddings[0] if isinstance(text, str) else embeddings
elif engine == "openai":
key = kwargs.get("key", "")
url = kwargs.get("url", "https://api.openai.com/v1")
if isinstance(text, list):
embeddings = generate_openai_batch_embeddings(model, text, key, url)
embeddings = generate_openai_batch_embeddings(model, text, url, key)
else:
embeddings = generate_openai_batch_embeddings(model, [text], key, url)
embeddings = generate_openai_batch_embeddings(model, [text], url, key)
return embeddings[0] if isinstance(text, str) else embeddings

View File

@ -0,0 +1,22 @@
from open_webui.config import VECTOR_DB
if VECTOR_DB == "milvus":
from open_webui.retrieval.vector.dbs.milvus import MilvusClient
VECTOR_DB_CLIENT = MilvusClient()
elif VECTOR_DB == "qdrant":
from open_webui.retrieval.vector.dbs.qdrant import QdrantClient
VECTOR_DB_CLIENT = QdrantClient()
elif VECTOR_DB == "opensearch":
from open_webui.retrieval.vector.dbs.opensearch import OpenSearchClient
VECTOR_DB_CLIENT = OpenSearchClient()
elif VECTOR_DB == "pgvector":
from open_webui.retrieval.vector.dbs.pgvector import PgvectorClient
VECTOR_DB_CLIENT = PgvectorClient()
else:
from open_webui.retrieval.vector.dbs.chroma import ChromaClient
VECTOR_DB_CLIENT = ChromaClient()

View File

@ -4,7 +4,7 @@ from chromadb.utils.batch_utils import create_batches
from typing import Optional
from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
from open_webui.config import (
CHROMA_DATA_PATH,
CHROMA_HTTP_HOST,
@ -13,11 +13,24 @@ from open_webui.config import (
CHROMA_HTTP_SSL,
CHROMA_TENANT,
CHROMA_DATABASE,
CHROMA_CLIENT_AUTH_PROVIDER,
CHROMA_CLIENT_AUTH_CREDENTIALS,
)
class ChromaClient:
def __init__(self):
settings_dict = {
"allow_reset": True,
"anonymized_telemetry": False,
}
if CHROMA_CLIENT_AUTH_PROVIDER is not None:
settings_dict["chroma_client_auth_provider"] = CHROMA_CLIENT_AUTH_PROVIDER
if CHROMA_CLIENT_AUTH_CREDENTIALS is not None:
settings_dict["chroma_client_auth_credentials"] = (
CHROMA_CLIENT_AUTH_CREDENTIALS
)
if CHROMA_HTTP_HOST != "":
self.client = chromadb.HttpClient(
host=CHROMA_HTTP_HOST,
@ -26,12 +39,12 @@ class ChromaClient:
ssl=CHROMA_HTTP_SSL,
tenant=CHROMA_TENANT,
database=CHROMA_DATABASE,
settings=Settings(allow_reset=True, anonymized_telemetry=False),
settings=Settings(**settings_dict),
)
else:
self.client = chromadb.PersistentClient(
path=CHROMA_DATA_PATH,
settings=Settings(allow_reset=True, anonymized_telemetry=False),
settings=Settings(**settings_dict),
tenant=CHROMA_TENANT,
database=CHROMA_DATABASE,
)

View File

@ -4,7 +4,7 @@ import json
from typing import Optional
from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
from open_webui.config import (
MILVUS_URI,
)

View File

@ -0,0 +1,178 @@
from opensearchpy import OpenSearch
from typing import Optional
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
from open_webui.config import (
OPENSEARCH_URI,
OPENSEARCH_SSL,
OPENSEARCH_CERT_VERIFY,
OPENSEARCH_USERNAME,
OPENSEARCH_PASSWORD,
)
class OpenSearchClient:
def __init__(self):
self.index_prefix = "open_webui"
self.client = OpenSearch(
hosts=[OPENSEARCH_URI],
use_ssl=OPENSEARCH_SSL,
verify_certs=OPENSEARCH_CERT_VERIFY,
http_auth=(OPENSEARCH_USERNAME, OPENSEARCH_PASSWORD),
)
def _result_to_get_result(self, result) -> GetResult:
ids = []
documents = []
metadatas = []
for hit in result["hits"]["hits"]:
ids.append(hit["_id"])
documents.append(hit["_source"].get("text"))
metadatas.append(hit["_source"].get("metadata"))
return GetResult(ids=ids, documents=documents, metadatas=metadatas)
def _result_to_search_result(self, result) -> SearchResult:
ids = []
distances = []
documents = []
metadatas = []
for hit in result["hits"]["hits"]:
ids.append(hit["_id"])
distances.append(hit["_score"])
documents.append(hit["_source"].get("text"))
metadatas.append(hit["_source"].get("metadata"))
return SearchResult(
ids=ids, distances=distances, documents=documents, metadatas=metadatas
)
def _create_index(self, index_name: str, dimension: int):
body = {
"mappings": {
"properties": {
"id": {"type": "keyword"},
"vector": {
"type": "dense_vector",
"dims": dimension, # Adjust based on your vector dimensions
"index": true,
"similarity": "faiss",
"method": {
"name": "hnsw",
"space_type": "ip", # Use inner product to approximate cosine similarity
"engine": "faiss",
"ef_construction": 128,
"m": 16,
},
},
"text": {"type": "text"},
"metadata": {"type": "object"},
}
}
}
self.client.indices.create(index=f"{self.index_prefix}_{index_name}", body=body)
def _create_batches(self, items: list[VectorItem], batch_size=100):
for i in range(0, len(items), batch_size):
yield items[i : i + batch_size]
def has_collection(self, index_name: str) -> bool:
# has_collection here means has index.
# We are simply adapting to the norms of the other DBs.
return self.client.indices.exists(index=f"{self.index_prefix}_{index_name}")
def delete_colleciton(self, index_name: str):
# delete_collection here means delete index.
# We are simply adapting to the norms of the other DBs.
self.client.indices.delete(index=f"{self.index_prefix}_{index_name}")
def search(
self, index_name: str, vectors: list[list[float]], limit: int
) -> Optional[SearchResult]:
query = {
"size": limit,
"_source": ["text", "metadata"],
"query": {
"script_score": {
"query": {"match_all": {}},
"script": {
"source": "cosineSimilarity(params.vector, 'vector') + 1.0",
"params": {
"vector": vectors[0]
}, # Assuming single query vector
},
}
},
}
result = self.client.search(
index=f"{self.index_prefix}_{index_name}", body=query
)
return self._result_to_search_result(result)
def get_or_create_index(self, index_name: str, dimension: int):
if not self.has_index(index_name):
self._create_index(index_name, dimension)
def get(self, index_name: str) -> Optional[GetResult]:
query = {"query": {"match_all": {}}, "_source": ["text", "metadata"]}
result = self.client.search(
index=f"{self.index_prefix}_{index_name}", body=query
)
return self._result_to_get_result(result)
def insert(self, index_name: str, items: list[VectorItem]):
if not self.has_index(index_name):
self._create_index(index_name, dimension=len(items[0]["vector"]))
for batch in self._create_batches(items):
actions = [
{
"index": {
"_id": item["id"],
"_source": {
"vector": item["vector"],
"text": item["text"],
"metadata": item["metadata"],
},
}
}
for item in batch
]
self.client.bulk(actions)
def upsert(self, index_name: str, items: list[VectorItem]):
if not self.has_index(index_name):
self._create_index(index_name, dimension=len(items[0]["vector"]))
for batch in self._create_batches(items):
actions = [
{
"index": {
"_id": item["id"],
"_source": {
"vector": item["vector"],
"text": item["text"],
"metadata": item["metadata"],
},
}
}
for item in batch
]
self.client.bulk(actions)
def delete(self, index_name: str, ids: list[str]):
actions = [
{"delete": {"_index": f"{self.index_prefix}_{index_name}", "_id": id}}
for id in ids
]
self.client.bulk(body=actions)
def reset(self):
indices = self.client.indices.get(index=f"{self.index_prefix}_*")
for index in indices:
self.client.indices.delete(index=index)

View File

@ -0,0 +1,354 @@
from typing import Optional, List, Dict, Any
from sqlalchemy import (
cast,
column,
create_engine,
Column,
Integer,
select,
text,
Text,
values,
)
from sqlalchemy.sql import true
from sqlalchemy.pool import NullPool
from sqlalchemy.orm import declarative_base, scoped_session, sessionmaker
from sqlalchemy.dialects.postgresql import JSONB, array
from pgvector.sqlalchemy import Vector
from sqlalchemy.ext.mutable import MutableDict
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
from open_webui.config import PGVECTOR_DB_URL
VECTOR_LENGTH = 1536
Base = declarative_base()
class DocumentChunk(Base):
__tablename__ = "document_chunk"
id = Column(Text, primary_key=True)
vector = Column(Vector(dim=VECTOR_LENGTH), nullable=True)
collection_name = Column(Text, nullable=False)
text = Column(Text, nullable=True)
vmetadata = Column(MutableDict.as_mutable(JSONB), nullable=True)
class PgvectorClient:
def __init__(self) -> None:
# if no pgvector uri, use the existing database connection
if not PGVECTOR_DB_URL:
from open_webui.internal.db import Session
self.session = Session
else:
engine = create_engine(
PGVECTOR_DB_URL, pool_pre_ping=True, poolclass=NullPool
)
SessionLocal = sessionmaker(
autocommit=False, autoflush=False, bind=engine, expire_on_commit=False
)
self.session = scoped_session(SessionLocal)
try:
# Ensure the pgvector extension is available
self.session.execute(text("CREATE EXTENSION IF NOT EXISTS vector;"))
# Create the tables if they do not exist
# Base.metadata.create_all requires a bind (engine or connection)
# Get the connection from the session
connection = self.session.connection()
Base.metadata.create_all(bind=connection)
# Create an index on the vector column if it doesn't exist
self.session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_document_chunk_vector "
"ON document_chunk USING ivfflat (vector vector_cosine_ops) WITH (lists = 100);"
)
)
self.session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_document_chunk_collection_name "
"ON document_chunk (collection_name);"
)
)
self.session.commit()
print("Initialization complete.")
except Exception as e:
self.session.rollback()
print(f"Error during initialization: {e}")
raise
def adjust_vector_length(self, vector: List[float]) -> List[float]:
# Adjust vector to have length VECTOR_LENGTH
current_length = len(vector)
if current_length < VECTOR_LENGTH:
# Pad the vector with zeros
vector += [0.0] * (VECTOR_LENGTH - current_length)
elif current_length > VECTOR_LENGTH:
raise Exception(
f"Vector length {current_length} not supported. Max length must be <= {VECTOR_LENGTH}"
)
return vector
def insert(self, collection_name: str, items: List[VectorItem]) -> None:
try:
new_items = []
for item in items:
vector = self.adjust_vector_length(item["vector"])
new_chunk = DocumentChunk(
id=item["id"],
vector=vector,
collection_name=collection_name,
text=item["text"],
vmetadata=item["metadata"],
)
new_items.append(new_chunk)
self.session.bulk_save_objects(new_items)
self.session.commit()
print(
f"Inserted {len(new_items)} items into collection '{collection_name}'."
)
except Exception as e:
self.session.rollback()
print(f"Error during insert: {e}")
raise
def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
try:
for item in items:
vector = self.adjust_vector_length(item["vector"])
existing = (
self.session.query(DocumentChunk)
.filter(DocumentChunk.id == item["id"])
.first()
)
if existing:
existing.vector = vector
existing.text = item["text"]
existing.vmetadata = item["metadata"]
existing.collection_name = (
collection_name # Update collection_name if necessary
)
else:
new_chunk = DocumentChunk(
id=item["id"],
vector=vector,
collection_name=collection_name,
text=item["text"],
vmetadata=item["metadata"],
)
self.session.add(new_chunk)
self.session.commit()
print(f"Upserted {len(items)} items into collection '{collection_name}'.")
except Exception as e:
self.session.rollback()
print(f"Error during upsert: {e}")
raise
def search(
self,
collection_name: str,
vectors: List[List[float]],
limit: Optional[int] = None,
) -> Optional[SearchResult]:
try:
if not vectors:
return None
# Adjust query vectors to VECTOR_LENGTH
vectors = [self.adjust_vector_length(vector) for vector in vectors]
num_queries = len(vectors)
def vector_expr(vector):
return cast(array(vector), Vector(VECTOR_LENGTH))
# Create the values for query vectors
qid_col = column("qid", Integer)
q_vector_col = column("q_vector", Vector(VECTOR_LENGTH))
query_vectors = (
values(qid_col, q_vector_col)
.data(
[(idx, vector_expr(vector)) for idx, vector in enumerate(vectors)]
)
.alias("query_vectors")
)
# Build the lateral subquery for each query vector
subq = (
select(
DocumentChunk.id,
DocumentChunk.text,
DocumentChunk.vmetadata,
(
DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)
).label("distance"),
)
.where(DocumentChunk.collection_name == collection_name)
.order_by(
(DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector))
)
)
if limit is not None:
subq = subq.limit(limit)
subq = subq.lateral("result")
# Build the main query by joining query_vectors and the lateral subquery
stmt = (
select(
query_vectors.c.qid,
subq.c.id,
subq.c.text,
subq.c.vmetadata,
subq.c.distance,
)
.select_from(query_vectors)
.join(subq, true())
.order_by(query_vectors.c.qid, subq.c.distance)
)
result_proxy = self.session.execute(stmt)
results = result_proxy.all()
ids = [[] for _ in range(num_queries)]
distances = [[] for _ in range(num_queries)]
documents = [[] for _ in range(num_queries)]
metadatas = [[] for _ in range(num_queries)]
if not results:
return SearchResult(
ids=ids,
distances=distances,
documents=documents,
metadatas=metadatas,
)
for row in results:
qid = int(row.qid)
ids[qid].append(row.id)
distances[qid].append(row.distance)
documents[qid].append(row.text)
metadatas[qid].append(row.vmetadata)
return SearchResult(
ids=ids, distances=distances, documents=documents, metadatas=metadatas
)
except Exception as e:
print(f"Error during search: {e}")
return None
def query(
self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None
) -> Optional[GetResult]:
try:
query = self.session.query(DocumentChunk).filter(
DocumentChunk.collection_name == collection_name
)
for key, value in filter.items():
query = query.filter(DocumentChunk.vmetadata[key].astext == str(value))
if limit is not None:
query = query.limit(limit)
results = query.all()
if not results:
return None
ids = [[result.id for result in results]]
documents = [[result.text for result in results]]
metadatas = [[result.vmetadata for result in results]]
return GetResult(
ids=ids,
documents=documents,
metadatas=metadatas,
)
except Exception as e:
print(f"Error during query: {e}")
return None
def get(
self, collection_name: str, limit: Optional[int] = None
) -> Optional[GetResult]:
try:
query = self.session.query(DocumentChunk).filter(
DocumentChunk.collection_name == collection_name
)
if limit is not None:
query = query.limit(limit)
results = query.all()
if not results:
return None
ids = [[result.id for result in results]]
documents = [[result.text for result in results]]
metadatas = [[result.vmetadata for result in results]]
return GetResult(ids=ids, documents=documents, metadatas=metadatas)
except Exception as e:
print(f"Error during get: {e}")
return None
def delete(
self,
collection_name: str,
ids: Optional[List[str]] = None,
filter: Optional[Dict[str, Any]] = None,
) -> None:
try:
query = self.session.query(DocumentChunk).filter(
DocumentChunk.collection_name == collection_name
)
if ids:
query = query.filter(DocumentChunk.id.in_(ids))
if filter:
for key, value in filter.items():
query = query.filter(
DocumentChunk.vmetadata[key].astext == str(value)
)
deleted = query.delete(synchronize_session=False)
self.session.commit()
print(f"Deleted {deleted} items from collection '{collection_name}'.")
except Exception as e:
self.session.rollback()
print(f"Error during delete: {e}")
raise
def reset(self) -> None:
try:
deleted = self.session.query(DocumentChunk).delete()
self.session.commit()
print(
f"Reset complete. Deleted {deleted} items from 'document_chunk' table."
)
except Exception as e:
self.session.rollback()
print(f"Error during reset: {e}")
raise
def close(self) -> None:
pass
def has_collection(self, collection_name: str) -> bool:
try:
exists = (
self.session.query(DocumentChunk)
.filter(DocumentChunk.collection_name == collection_name)
.first()
is not None
)
return exists
except Exception as e:
print(f"Error checking collection existence: {e}")
return False
def delete_collection(self, collection_name: str) -> None:
self.delete(collection_name)
print(f"Collection '{collection_name}' deleted.")

View File

@ -4,8 +4,8 @@ from qdrant_client import QdrantClient as Qclient
from qdrant_client.http.models import PointStruct
from qdrant_client.models import models
from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult
from open_webui.config import QDRANT_URI
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
from open_webui.config import QDRANT_URI, QDRANT_API_KEY
NO_LIMIT = 999999999
@ -14,7 +14,12 @@ class QdrantClient:
def __init__(self):
self.collection_prefix = "open-webui"
self.QDRANT_URI = QDRANT_URI
self.client = Qclient(url=self.QDRANT_URI) if self.QDRANT_URI else None
self.QDRANT_API_KEY = QDRANT_API_KEY
self.client = (
Qclient(url=self.QDRANT_URI, api_key=self.QDRANT_API_KEY)
if self.QDRANT_URI
else None
)
def _result_to_get_result(self, points) -> GetResult:
ids = []

View File

@ -0,0 +1,73 @@
import logging
import os
from pprint import pprint
from typing import Optional
import requests
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.env import SRC_LOG_LEVELS
import argparse
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
"""
Documentation: https://docs.microsoft.com/en-us/bing/search-apis/bing-web-search/overview
"""
def search_bing(
subscription_key: str,
endpoint: str,
locale: str,
query: str,
count: int,
filter_list: Optional[list[str]] = None,
) -> list[SearchResult]:
mkt = locale
params = {"q": query, "mkt": mkt, "answerCount": count}
headers = {"Ocp-Apim-Subscription-Key": subscription_key}
try:
response = requests.get(endpoint, headers=headers, params=params)
response.raise_for_status()
json_response = response.json()
results = json_response.get("webPages", {}).get("value", [])
if filter_list:
results = get_filtered_results(results, filter_list)
return [
SearchResult(
link=result["url"],
title=result.get("name"),
snippet=result.get("snippet"),
)
for result in results
]
except Exception as ex:
log.error(f"Error: {ex}")
raise ex
def main():
parser = argparse.ArgumentParser(description="Search Bing from the command line.")
parser.add_argument(
"query",
type=str,
default="Top 10 international news today",
help="The search query.",
)
parser.add_argument(
"--count", type=int, default=10, help="Number of search results to return."
)
parser.add_argument(
"--filter", nargs="*", help="List of filters to apply to the search results."
)
parser.add_argument(
"--locale",
type=str,
default="en-US",
help="The locale to use for the search, maps to market in api",
)
args = parser.parse_args()
results = search_bing(args.locale, args.query, args.count, args.filter)
pprint(results)

View File

@ -2,7 +2,7 @@ import logging
from typing import Optional
import requests
from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)

View File

@ -1,7 +1,7 @@
import logging
from typing import Optional
from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
from duckduckgo_search import DDGS
from open_webui.env import SRC_LOG_LEVELS

View File

@ -2,7 +2,7 @@ import logging
from typing import Optional
import requests
from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)

View File

@ -1,7 +1,7 @@
import logging
import requests
from open_webui.apps.retrieval.web.main import SearchResult
from open_webui.retrieval.web.main import SearchResult
from open_webui.env import SRC_LOG_LEVELS
from yarl import URL
@ -9,7 +9,7 @@ log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_jina(query: str, count: int) -> list[SearchResult]:
def search_jina(api_key: str, query: str, count: int) -> list[SearchResult]:
"""
Search using Jina's Search API and return the results as a list of SearchResult objects.
Args:
@ -20,9 +20,7 @@ def search_jina(query: str, count: int) -> list[SearchResult]:
list[SearchResult]: A list of search results
"""
jina_search_endpoint = "https://s.jina.ai/"
headers = {
"Accept": "application/json",
}
headers = {"Accept": "application/json", "Authorization": f"Bearer {api_key}"}
url = str(URL(jina_search_endpoint + query))
response = requests.get(url, headers=headers)
response.raise_for_status()

View File

@ -0,0 +1,48 @@
import logging
from typing import Optional
import requests
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_kagi(
api_key: str, query: str, count: int, filter_list: Optional[list[str]] = None
) -> list[SearchResult]:
"""Search using Kagi's Search API and return the results as a list of SearchResult objects.
The Search API will inherit the settings in your account, including results personalization and snippet length.
Args:
api_key (str): A Kagi Search API key
query (str): The query to search for
count (int): The number of results to return
"""
url = "https://kagi.com/api/v0/search"
headers = {
"Authorization": f"Bot {api_key}",
}
params = {"q": query, "limit": count}
response = requests.get(url, headers=headers, params=params)
response.raise_for_status()
json_response = response.json()
search_results = json_response.get("data", [])
results = [
SearchResult(
link=result["url"], title=result["title"], snippet=result.get("snippet")
)
for result in search_results
if result["t"] == 0
]
print(results)
if filter_list:
results = get_filtered_results(results, filter_list)
return results

View File

@ -0,0 +1,40 @@
import logging
from typing import Optional
import requests
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_mojeek(
api_key: str, query: str, count: int, filter_list: Optional[list[str]] = None
) -> list[SearchResult]:
"""Search using Mojeek's Search API and return the results as a list of SearchResult objects.
Args:
api_key (str): A Mojeek Search API key
query (str): The query to search for
"""
url = "https://api.mojeek.com/search"
headers = {
"Accept": "application/json",
}
params = {"q": query, "api_key": api_key, "fmt": "json", "t": count}
response = requests.get(url, headers=headers, params=params)
response.raise_for_status()
json_response = response.json()
results = json_response.get("response", {}).get("results", [])
print(results)
if filter_list:
results = get_filtered_results(results, filter_list)
return [
SearchResult(
link=result["url"], title=result.get("title"), snippet=result.get("desc")
)
for result in results
]

View File

@ -3,7 +3,7 @@ from typing import Optional
from urllib.parse import urlencode
import requests
from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)

View File

@ -2,7 +2,7 @@ import logging
from typing import Optional
import requests
from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)

View File

@ -3,7 +3,7 @@ import logging
from typing import Optional
import requests
from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)

View File

@ -3,7 +3,7 @@ from typing import Optional
from urllib.parse import urlencode
import requests
from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)

View File

@ -2,7 +2,7 @@ import logging
from typing import Optional
import requests
from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)

View File

@ -1,7 +1,7 @@
import logging
import requests
from open_webui.apps.retrieval.web.main import SearchResult
from open_webui.retrieval.web.main import SearchResult
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)

View File

@ -0,0 +1,58 @@
{
"_type": "SearchResponse",
"queryContext": {
"originalQuery": "Top 10 international results"
},
"webPages": {
"webSearchUrl": "https://www.bing.com/search?q=Top+10+international+results",
"totalEstimatedMatches": 687,
"value": [
{
"id": "https://api.bing.microsoft.com/api/v7/#WebPages.0",
"name": "2024 Mexican Grand Prix - F1 results and latest standings ... - PlanetF1",
"url": "https://www.planetf1.com/news/f1-results-2024-mexican-grand-prix-race-standings",
"datePublished": "2024-10-27T00:00:00.0000000",
"datePublishedFreshnessText": "1 day ago",
"isFamilyFriendly": true,
"displayUrl": "https://www.planetf1.com/news/f1-results-2024-mexican-grand-prix-race-standings",
"snippet": "Nico Hulkenberg and Pierre Gasly completed the top 10. A full report of the Mexican Grand Prix is available at the bottom of this article. F1 results 2024 Mexican Grand Prix",
"dateLastCrawled": "2024-10-28T07:15:00.0000000Z",
"cachedPageUrl": "https://cc.bingj.com/cache.aspx?q=Top+10+international+results&d=916492551782&mkt=en-US&setlang=en-US&w=zBsfaAPyF2tUrHFHr_vFFdUm8sng4g34",
"language": "en",
"isNavigational": false,
"noCache": false
},
{
"id": "https://api.bing.microsoft.com/api/v7/#WebPages.1",
"name": "F1 Results Today: HUGE Verstappen penalties cause major title change",
"url": "https://www.gpfans.com/en/f1-news/1033512/f1-results-today-mexican-grand-prix-huge-max-verstappen-penalties-cause-major-title-change/",
"datePublished": "2024-10-27T00:00:00.0000000",
"datePublishedFreshnessText": "1 day ago",
"isFamilyFriendly": true,
"displayUrl": "https://www.gpfans.com/en/f1-news/1033512/f1-results-today-mexican-grand-prix-huge-max...",
"snippet": "Elsewhere, Mercedes duo Lewis Hamilton and George Russell came home in P4 and P5 respectively. Meanwhile, the surprise package of the day were Haas, with both Kevin Magnussen and Nico Hulkenberg finishing inside the points.. READ MORE: RB star issues apology after red flag CRASH at Mexican GP Mexican Grand Prix 2024 results. 1. Carlos Sainz [Ferrari] 2. Lando Norris [McLaren] - +4.705",
"dateLastCrawled": "2024-10-28T06:06:00.0000000Z",
"cachedPageUrl": "https://cc.bingj.com/cache.aspx?q=Top+10+international+results&d=2840656522642&mkt=en-US&setlang=en-US&w=-Tbkwxnq52jZCvG7l3CtgcwT1vwAjIUD",
"language": "en",
"isNavigational": false,
"noCache": false
},
{
"id": "https://api.bing.microsoft.com/api/v7/#WebPages.2",
"name": "International Power Rankings: England flying, Kangaroos cruising, Fiji rise",
"url": "https://www.loverugbyleague.com/post/international-power-rankings-england-flying-kangaroos-cruising-fiji-rise",
"datePublished": "2024-10-28T00:00:00.0000000",
"datePublishedFreshnessText": "7 hours ago",
"isFamilyFriendly": true,
"displayUrl": "https://www.loverugbyleague.com/post/international-power-rankings-england-flying...",
"snippet": "LRL RECOMMENDS: England player ratings from first Test against Samoa as omnificent George Williams scores perfect 10. 2. Australia (Men) SAME. The Kangaroos remain 2nd in our Power Rankings after their 22-10 win against New Zealand in Christchurch on Sunday. As was the case in their win against Tonga last week, Mal Meningas side weren ...",
"dateLastCrawled": "2024-10-28T07:09:00.0000000Z",
"cachedPageUrl": "https://cc.bingj.com/cache.aspx?q=Top+10+international+results&d=1535008462672&mkt=en-US&setlang=en-US&w=82ujhH4Kp0iuhCS7wh1xLUFYUeetaVVm",
"language": "en",
"isNavigational": false,
"noCache": false
}
],
"someResultsRemoved": true
}
}

View File

@ -0,0 +1,703 @@
import hashlib
import json
import logging
import os
import uuid
from functools import lru_cache
from pathlib import Path
from pydub import AudioSegment
from pydub.silence import split_on_silence
import aiohttp
import aiofiles
import requests
from fastapi import (
Depends,
FastAPI,
File,
HTTPException,
Request,
UploadFile,
status,
APIRouter,
)
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse
from pydantic import BaseModel
from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.config import (
WHISPER_MODEL_AUTO_UPDATE,
WHISPER_MODEL_DIR,
CACHE_DIR,
)
from open_webui.constants import ERROR_MESSAGES
from open_webui.env import (
ENV,
SRC_LOG_LEVELS,
DEVICE_TYPE,
ENABLE_FORWARD_USER_INFO_HEADERS,
)
router = APIRouter()
# Constants
MAX_FILE_SIZE_MB = 25
MAX_FILE_SIZE = MAX_FILE_SIZE_MB * 1024 * 1024 # Convert MB to bytes
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["AUDIO"])
SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/")
SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
##########################################
#
# Utility functions
#
##########################################
from pydub import AudioSegment
from pydub.utils import mediainfo
def is_mp4_audio(file_path):
"""Check if the given file is an MP4 audio file."""
if not os.path.isfile(file_path):
print(f"File not found: {file_path}")
return False
info = mediainfo(file_path)
if (
info.get("codec_name") == "aac"
and info.get("codec_type") == "audio"
and info.get("codec_tag_string") == "mp4a"
):
return True
return False
def convert_mp4_to_wav(file_path, output_path):
"""Convert MP4 audio file to WAV format."""
audio = AudioSegment.from_file(file_path, format="mp4")
audio.export(output_path, format="wav")
print(f"Converted {file_path} to {output_path}")
def set_faster_whisper_model(model: str, auto_update: bool = False):
whisper_model = None
if model:
from faster_whisper import WhisperModel
faster_whisper_kwargs = {
"model_size_or_path": model,
"device": DEVICE_TYPE if DEVICE_TYPE and DEVICE_TYPE == "cuda" else "cpu",
"compute_type": "int8",
"download_root": WHISPER_MODEL_DIR,
"local_files_only": not auto_update,
}
try:
whisper_model = WhisperModel(**faster_whisper_kwargs)
except Exception:
log.warning(
"WhisperModel initialization failed, attempting download with local_files_only=False"
)
faster_whisper_kwargs["local_files_only"] = False
whisper_model = WhisperModel(**faster_whisper_kwargs)
return whisper_model
##########################################
#
# Audio API
#
##########################################
class TTSConfigForm(BaseModel):
OPENAI_API_BASE_URL: str
OPENAI_API_KEY: str
API_KEY: str
ENGINE: str
MODEL: str
VOICE: str
SPLIT_ON: str
AZURE_SPEECH_REGION: str
AZURE_SPEECH_OUTPUT_FORMAT: str
class STTConfigForm(BaseModel):
OPENAI_API_BASE_URL: str
OPENAI_API_KEY: str
ENGINE: str
MODEL: str
WHISPER_MODEL: str
class AudioConfigUpdateForm(BaseModel):
tts: TTSConfigForm
stt: STTConfigForm
@router.get("/config")
async def get_audio_config(request: Request, user=Depends(get_admin_user)):
return {
"tts": {
"OPENAI_API_BASE_URL": request.app.state.config.TTS_OPENAI_API_BASE_URL,
"OPENAI_API_KEY": request.app.state.config.TTS_OPENAI_API_KEY,
"API_KEY": request.app.state.config.TTS_API_KEY,
"ENGINE": request.app.state.config.TTS_ENGINE,
"MODEL": request.app.state.config.TTS_MODEL,
"VOICE": request.app.state.config.TTS_VOICE,
"SPLIT_ON": request.app.state.config.TTS_SPLIT_ON,
"AZURE_SPEECH_REGION": request.app.state.config.TTS_AZURE_SPEECH_REGION,
"AZURE_SPEECH_OUTPUT_FORMAT": request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT,
},
"stt": {
"OPENAI_API_BASE_URL": request.app.state.config.STT_OPENAI_API_BASE_URL,
"OPENAI_API_KEY": request.app.state.config.STT_OPENAI_API_KEY,
"ENGINE": request.app.state.config.STT_ENGINE,
"MODEL": request.app.state.config.STT_MODEL,
"WHISPER_MODEL": request.app.state.config.WHISPER_MODEL,
},
}
@router.post("/config/update")
async def update_audio_config(
request: Request, form_data: AudioConfigUpdateForm, user=Depends(get_admin_user)
):
request.app.state.config.TTS_OPENAI_API_BASE_URL = form_data.tts.OPENAI_API_BASE_URL
request.app.state.config.TTS_OPENAI_API_KEY = form_data.tts.OPENAI_API_KEY
request.app.state.config.TTS_API_KEY = form_data.tts.API_KEY
request.app.state.config.TTS_ENGINE = form_data.tts.ENGINE
request.app.state.config.TTS_MODEL = form_data.tts.MODEL
request.app.state.config.TTS_VOICE = form_data.tts.VOICE
request.app.state.config.TTS_SPLIT_ON = form_data.tts.SPLIT_ON
request.app.state.config.TTS_AZURE_SPEECH_REGION = form_data.tts.AZURE_SPEECH_REGION
request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT = (
form_data.tts.AZURE_SPEECH_OUTPUT_FORMAT
)
request.app.state.config.STT_OPENAI_API_BASE_URL = form_data.stt.OPENAI_API_BASE_URL
request.app.state.config.STT_OPENAI_API_KEY = form_data.stt.OPENAI_API_KEY
request.app.state.config.STT_ENGINE = form_data.stt.ENGINE
request.app.state.config.STT_MODEL = form_data.stt.MODEL
request.app.state.config.WHISPER_MODEL = form_data.stt.WHISPER_MODEL
if request.app.state.config.STT_ENGINE == "":
request.app.state.faster_whisper_model = set_faster_whisper_model(
form_data.stt.WHISPER_MODEL, WHISPER_MODEL_AUTO_UPDATE
)
return {
"tts": {
"OPENAI_API_BASE_URL": request.app.state.config.TTS_OPENAI_API_BASE_URL,
"OPENAI_API_KEY": request.app.state.config.TTS_OPENAI_API_KEY,
"API_KEY": request.app.state.config.TTS_API_KEY,
"ENGINE": request.app.state.config.TTS_ENGINE,
"MODEL": request.app.state.config.TTS_MODEL,
"VOICE": request.app.state.config.TTS_VOICE,
"SPLIT_ON": request.app.state.config.TTS_SPLIT_ON,
"AZURE_SPEECH_REGION": request.app.state.config.TTS_AZURE_SPEECH_REGION,
"AZURE_SPEECH_OUTPUT_FORMAT": request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT,
},
"stt": {
"OPENAI_API_BASE_URL": request.app.state.config.STT_OPENAI_API_BASE_URL,
"OPENAI_API_KEY": request.app.state.config.STT_OPENAI_API_KEY,
"ENGINE": request.app.state.config.STT_ENGINE,
"MODEL": request.app.state.config.STT_MODEL,
"WHISPER_MODEL": request.app.state.config.WHISPER_MODEL,
},
}
def load_speech_pipeline():
from transformers import pipeline
from datasets import load_dataset
if request.app.state.speech_synthesiser is None:
request.app.state.speech_synthesiser = pipeline(
"text-to-speech", "microsoft/speecht5_tts"
)
if request.app.state.speech_speaker_embeddings_dataset is None:
request.app.state.speech_speaker_embeddings_dataset = load_dataset(
"Matthijs/cmu-arctic-xvectors", split="validation"
)
@router.post("/speech")
async def speech(request: Request, user=Depends(get_verified_user)):
body = await request.body()
name = hashlib.sha256(body).hexdigest()
file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3")
file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json")
# Check if the file already exists in the cache
if file_path.is_file():
return FileResponse(file_path)
payload = None
try:
payload = json.loads(body.decode("utf-8"))
except Exception as e:
log.exception(e)
raise HTTPException(status_code=400, detail="Invalid JSON payload")
if request.app.state.config.TTS_ENGINE == "openai":
payload["model"] = request.app.state.config.TTS_MODEL
try:
async with aiohttp.ClientSession() as session:
async with session.post(
url=f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech",
data=payload,
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {request.app.state.config.TTS_OPENAI_API_KEY}",
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS
else {}
),
},
) as r:
r.raise_for_status()
async with aiofiles.open(file_path, "wb") as f:
await f.write(await r.read())
async with aiofiles.open(file_body_path, "w") as f:
await f.write(json.dumps(json.loads(body.decode("utf-8"))))
return FileResponse(file_path)
except Exception as e:
log.exception(e)
detail = None
try:
if r.status != 200:
res = await r.json()
if "error" in res:
detail = f"External: {res['error'].get('message', '')}"
except Exception:
detail = f"External: {e}"
raise HTTPException(
status_code=getattr(r, "status", 500),
detail=detail if detail else "Open WebUI: Server Connection Error",
)
elif request.app.state.config.TTS_ENGINE == "elevenlabs":
voice_id = payload.get("voice", "")
if voice_id not in get_available_voices():
raise HTTPException(
status_code=400,
detail="Invalid voice id",
)
try:
async with aiohttp.ClientSession() as session:
async with session.post(
f"https://api.elevenlabs.io/v1/text-to-speech/{voice_id}",
json={
"text": payload["input"],
"model_id": request.app.state.config.TTS_MODEL,
"voice_settings": {"stability": 0.5, "similarity_boost": 0.5},
},
headers={
"Accept": "audio/mpeg",
"Content-Type": "application/json",
"xi-api-key": request.app.state.config.TTS_API_KEY,
},
) as r:
r.raise_for_status()
async with aiofiles.open(file_path, "wb") as f:
await f.write(await r.read())
async with aiofiles.open(file_body_path, "w") as f:
await f.write(json.dumps(json.loads(body.decode("utf-8"))))
return FileResponse(file_path)
except Exception as e:
log.exception(e)
detail = None
try:
if r.status != 200:
res = await r.json()
if "error" in res:
detail = f"External: {res['error'].get('message', '')}"
except Exception:
detail = f"External: {e}"
raise HTTPException(
status_code=getattr(r, "status", 500),
detail=detail if detail else "Open WebUI: Server Connection Error",
)
elif request.app.state.config.TTS_ENGINE == "azure":
try:
payload = json.loads(body.decode("utf-8"))
except Exception as e:
log.exception(e)
raise HTTPException(status_code=400, detail="Invalid JSON payload")
region = request.app.state.config.TTS_AZURE_SPEECH_REGION
language = request.app.state.config.TTS_VOICE
locale = "-".join(request.app.state.config.TTS_VOICE.split("-")[:1])
output_format = request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT
try:
data = f"""<speak version="1.0" xmlns="http://www.w3.org/2001/10/synthesis" xml:lang="{locale}">
<voice name="{language}">{payload["input"]}</voice>
</speak>"""
async with aiohttp.ClientSession() as session:
async with session.post(
f"https://{region}.tts.speech.microsoft.com/cognitiveservices/v1",
headers={
"Ocp-Apim-Subscription-Key": request.app.state.config.TTS_API_KEY,
"Content-Type": "application/ssml+xml",
"X-Microsoft-OutputFormat": output_format,
},
data=data,
) as r:
r.raise_for_status()
async with aiofiles.open(file_path, "wb") as f:
await f.write(await r.read())
return FileResponse(file_path)
except Exception as e:
log.exception(e)
detail = None
try:
if r.status != 200:
res = await r.json()
if "error" in res:
detail = f"External: {res['error'].get('message', '')}"
except Exception:
detail = f"External: {e}"
raise HTTPException(
status_code=getattr(r, "status", 500),
detail=detail if detail else "Open WebUI: Server Connection Error",
)
elif request.app.state.config.TTS_ENGINE == "transformers":
payload = None
try:
payload = json.loads(body.decode("utf-8"))
except Exception as e:
log.exception(e)
raise HTTPException(status_code=400, detail="Invalid JSON payload")
import torch
import soundfile as sf
load_speech_pipeline()
embeddings_dataset = request.app.state.speech_speaker_embeddings_dataset
speaker_index = 6799
try:
speaker_index = embeddings_dataset["filename"].index(
request.app.state.config.TTS_MODEL
)
except Exception:
pass
speaker_embedding = torch.tensor(
embeddings_dataset[speaker_index]["xvector"]
).unsqueeze(0)
speech = request.app.state.speech_synthesiser(
payload["input"],
forward_params={"speaker_embeddings": speaker_embedding},
)
sf.write(file_path, speech["audio"], samplerate=speech["sampling_rate"])
with open(file_body_path, "w") as f:
json.dump(json.loads(body.decode("utf-8")), f)
return FileResponse(file_path)
def transcribe(request: Request, file_path):
print("transcribe", file_path)
filename = os.path.basename(file_path)
file_dir = os.path.dirname(file_path)
id = filename.split(".")[0]
if request.app.state.config.STT_ENGINE == "":
if request.app.state.faster_whisper_model is None:
request.app.state.faster_whisper_model = set_faster_whisper_model(
request.app.state.config.WHISPER_MODEL
)
model = request.app.state.faster_whisper_model
segments, info = model.transcribe(file_path, beam_size=5)
log.info(
"Detected language '%s' with probability %f"
% (info.language, info.language_probability)
)
transcript = "".join([segment.text for segment in list(segments)])
data = {"text": transcript.strip()}
# save the transcript to a json file
transcript_file = f"{file_dir}/{id}.json"
with open(transcript_file, "w") as f:
json.dump(data, f)
log.debug(data)
return data
elif request.app.state.config.STT_ENGINE == "openai":
if is_mp4_audio(file_path):
os.rename(file_path, file_path.replace(".wav", ".mp4"))
# Convert MP4 audio file to WAV format
convert_mp4_to_wav(file_path.replace(".wav", ".mp4"), file_path)
r = None
try:
r = requests.post(
url=f"{request.app.state.config.STT_OPENAI_API_BASE_URL}/audio/transcriptions",
headers={
"Authorization": f"Bearer {request.app.state.config.STT_OPENAI_API_KEY}"
},
files={"file": (filename, open(file_path, "rb"))},
data={"model": request.app.state.config.STT_MODEL},
)
r.raise_for_status()
data = r.json()
# save the transcript to a json file
transcript_file = f"{file_dir}/{id}.json"
with open(transcript_file, "w") as f:
json.dump(data, f)
return data
except Exception as e:
log.exception(e)
detail = None
if r is not None:
try:
res = r.json()
if "error" in res:
detail = f"External: {res['error'].get('message', '')}"
except Exception:
detail = f"External: {e}"
raise Exception(detail if detail else "Open WebUI: Server Connection Error")
def compress_audio(file_path):
if os.path.getsize(file_path) > MAX_FILE_SIZE:
file_dir = os.path.dirname(file_path)
audio = AudioSegment.from_file(file_path)
audio = audio.set_frame_rate(16000).set_channels(1) # Compress audio
compressed_path = f"{file_dir}/{id}_compressed.opus"
audio.export(compressed_path, format="opus", bitrate="32k")
log.debug(f"Compressed audio to {compressed_path}")
if (
os.path.getsize(compressed_path) > MAX_FILE_SIZE
): # Still larger than MAX_FILE_SIZE after compression
raise Exception(ERROR_MESSAGES.FILE_TOO_LARGE(size=f"{MAX_FILE_SIZE_MB}MB"))
return compressed_path
else:
return file_path
@router.post("/transcriptions")
def transcription(
request: Request,
file: UploadFile = File(...),
user=Depends(get_verified_user),
):
log.info(f"file.content_type: {file.content_type}")
if file.content_type not in ["audio/mpeg", "audio/wav", "audio/ogg", "audio/x-m4a"]:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.FILE_NOT_SUPPORTED,
)
try:
ext = file.filename.split(".")[-1]
id = uuid.uuid4()
filename = f"{id}.{ext}"
contents = file.file.read()
file_dir = f"{CACHE_DIR}/audio/transcriptions"
os.makedirs(file_dir, exist_ok=True)
file_path = f"{file_dir}/{filename}"
with open(file_path, "wb") as f:
f.write(contents)
try:
try:
file_path = compress_audio(file_path)
except Exception as e:
log.exception(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
data = transcribe(request, file_path)
file_path = file_path.split("/")[-1]
return {**data, "filename": file_path}
except Exception as e:
log.exception(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
except Exception as e:
log.exception(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
def get_available_models(request: Request) -> list[dict]:
available_models = []
if request.app.state.config.TTS_ENGINE == "openai":
available_models = [{"id": "tts-1"}, {"id": "tts-1-hd"}]
elif request.app.state.config.TTS_ENGINE == "elevenlabs":
try:
response = requests.get(
"https://api.elevenlabs.io/v1/models",
headers={
"xi-api-key": request.app.state.config.TTS_API_KEY,
"Content-Type": "application/json",
},
timeout=5,
)
response.raise_for_status()
models = response.json()
available_models = [
{"name": model["name"], "id": model["model_id"]} for model in models
]
except requests.RequestException as e:
log.error(f"Error fetching voices: {str(e)}")
return available_models
@router.get("/models")
async def get_models(request: Request, user=Depends(get_verified_user)):
return {"models": get_available_models(request)}
def get_available_voices(request) -> dict:
"""Returns {voice_id: voice_name} dict"""
available_voices = {}
if request.app.state.config.TTS_ENGINE == "openai":
available_voices = {
"alloy": "alloy",
"echo": "echo",
"fable": "fable",
"onyx": "onyx",
"nova": "nova",
"shimmer": "shimmer",
}
elif request.app.state.config.TTS_ENGINE == "elevenlabs":
try:
available_voices = get_elevenlabs_voices(
api_key=request.app.state.config.TTS_API_KEY
)
except Exception:
# Avoided @lru_cache with exception
pass
elif request.app.state.config.TTS_ENGINE == "azure":
try:
region = request.app.state.config.TTS_AZURE_SPEECH_REGION
url = f"https://{region}.tts.speech.microsoft.com/cognitiveservices/voices/list"
headers = {
"Ocp-Apim-Subscription-Key": request.app.state.config.TTS_API_KEY
}
response = requests.get(url, headers=headers)
response.raise_for_status()
voices = response.json()
for voice in voices:
available_voices[voice["ShortName"]] = (
f"{voice['DisplayName']} ({voice['ShortName']})"
)
except requests.RequestException as e:
log.error(f"Error fetching voices: {str(e)}")
return available_voices
@lru_cache
def get_elevenlabs_voices(api_key: str) -> dict:
"""
Note, set the following in your .env file to use Elevenlabs:
AUDIO_TTS_ENGINE=elevenlabs
AUDIO_TTS_API_KEY=sk_... # Your Elevenlabs API key
AUDIO_TTS_VOICE=EXAVITQu4vr4xnSDxMaL # From https://api.elevenlabs.io/v1/voices
AUDIO_TTS_MODEL=eleven_multilingual_v2
"""
try:
# TODO: Add retries
response = requests.get(
"https://api.elevenlabs.io/v1/voices",
headers={
"xi-api-key": api_key,
"Content-Type": "application/json",
},
)
response.raise_for_status()
voices_data = response.json()
voices = {}
for voice in voices_data.get("voices", []):
voices[voice["voice_id"]] = voice["name"]
except requests.RequestException as e:
# Avoid @lru_cache with exception
log.error(f"Error fetching voices: {str(e)}")
raise RuntimeError(f"Error fetching voices: {str(e)}")
return voices
@router.get("/voices")
async def get_voices(request: Request, user=Depends(get_verified_user)):
return {
"voices": [
{"id": k, "name": v} for k, v in get_available_voices(request).items()
]
}

View File

@ -2,12 +2,15 @@ import re
import uuid
import time
import datetime
import logging
from aiohttp import ClientSession
from open_webui.apps.webui.models.auths import (
from open_webui.models.auths import (
AddUserForm,
ApiKey,
Auths,
Token,
LdapForm,
SigninForm,
SigninResponse,
SignupForm,
@ -15,20 +18,26 @@ from open_webui.apps.webui.models.auths import (
UpdateProfileForm,
UserResponse,
)
from open_webui.apps.webui.models.users import Users
from open_webui.config import WEBUI_AUTH
from open_webui.models.users import Users
from open_webui.constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
from open_webui.env import (
WEBUI_AUTH,
WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
WEBUI_AUTH_TRUSTED_NAME_HEADER,
WEBUI_SESSION_COOKIE_SAME_SITE,
WEBUI_SESSION_COOKIE_SECURE,
SRC_LOG_LEVELS,
)
from fastapi import APIRouter, Depends, HTTPException, Request, status
from fastapi.responses import Response
from fastapi.responses import RedirectResponse, Response
from open_webui.config import (
OPENID_PROVIDER_URL,
ENABLE_OAUTH_SIGNUP,
)
from pydantic import BaseModel
from open_webui.utils.misc import parse_duration, validate_email_format
from open_webui.utils.utils import (
from open_webui.utils.auth import (
create_api_key,
create_token,
get_admin_user,
@ -37,10 +46,19 @@ from open_webui.utils.utils import (
get_password_hash,
)
from open_webui.utils.webhook import post_webhook
from typing import Optional
from open_webui.utils.access_control import get_permissions
from typing import Optional, List
from ssl import CERT_REQUIRED, PROTOCOL_TLS
from ldap3 import Server, Connection, ALL, Tls
from ldap3.utils.conv import escape_filter_chars
router = APIRouter()
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MAIN"])
############################
# GetSessionUser
############################
@ -48,6 +66,7 @@ router = APIRouter()
class SessionUserResponse(Token, UserResponse):
expires_at: Optional[int] = None
permissions: Optional[dict] = None
@router.get("/", response_model=SessionUserResponse)
@ -80,6 +99,10 @@ async def get_session_user(
secure=WEBUI_SESSION_COOKIE_SECURE,
)
user_permissions = get_permissions(
user.id, request.app.state.config.USER_PERMISSIONS
)
return {
"token": token,
"token_type": "Bearer",
@ -89,6 +112,7 @@ async def get_session_user(
"name": user.name,
"role": user.role,
"profile_image_url": user.profile_image_url,
"permissions": user_permissions,
}
@ -137,6 +161,146 @@ async def update_password(
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
############################
# LDAP Authentication
############################
@router.post("/ldap", response_model=SigninResponse)
async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
ENABLE_LDAP = request.app.state.config.ENABLE_LDAP
LDAP_SERVER_LABEL = request.app.state.config.LDAP_SERVER_LABEL
LDAP_SERVER_HOST = request.app.state.config.LDAP_SERVER_HOST
LDAP_SERVER_PORT = request.app.state.config.LDAP_SERVER_PORT
LDAP_ATTRIBUTE_FOR_USERNAME = request.app.state.config.LDAP_ATTRIBUTE_FOR_USERNAME
LDAP_SEARCH_BASE = request.app.state.config.LDAP_SEARCH_BASE
LDAP_SEARCH_FILTERS = request.app.state.config.LDAP_SEARCH_FILTERS
LDAP_APP_DN = request.app.state.config.LDAP_APP_DN
LDAP_APP_PASSWORD = request.app.state.config.LDAP_APP_PASSWORD
LDAP_USE_TLS = request.app.state.config.LDAP_USE_TLS
LDAP_CA_CERT_FILE = request.app.state.config.LDAP_CA_CERT_FILE
LDAP_CIPHERS = (
request.app.state.config.LDAP_CIPHERS
if request.app.state.config.LDAP_CIPHERS
else "ALL"
)
if not ENABLE_LDAP:
raise HTTPException(400, detail="LDAP authentication is not enabled")
try:
tls = Tls(
validate=CERT_REQUIRED,
version=PROTOCOL_TLS,
ca_certs_file=LDAP_CA_CERT_FILE,
ciphers=LDAP_CIPHERS,
)
except Exception as e:
log.error(f"An error occurred on TLS: {str(e)}")
raise HTTPException(400, detail=str(e))
try:
server = Server(
host=LDAP_SERVER_HOST,
port=LDAP_SERVER_PORT,
get_info=ALL,
use_ssl=LDAP_USE_TLS,
tls=tls,
)
connection_app = Connection(
server,
LDAP_APP_DN,
LDAP_APP_PASSWORD,
auto_bind="NONE",
authentication="SIMPLE",
)
if not connection_app.bind():
raise HTTPException(400, detail="Application account bind failed")
search_success = connection_app.search(
search_base=LDAP_SEARCH_BASE,
search_filter=f"(&({LDAP_ATTRIBUTE_FOR_USERNAME}={escape_filter_chars(form_data.user.lower())}){LDAP_SEARCH_FILTERS})",
attributes=[f"{LDAP_ATTRIBUTE_FOR_USERNAME}", "mail", "cn"],
)
if not search_success:
raise HTTPException(400, detail="User not found in the LDAP server")
entry = connection_app.entries[0]
username = str(entry[f"{LDAP_ATTRIBUTE_FOR_USERNAME}"]).lower()
mail = str(entry["mail"])
cn = str(entry["cn"])
user_dn = entry.entry_dn
if username == form_data.user.lower():
connection_user = Connection(
server,
user_dn,
form_data.password,
auto_bind="NONE",
authentication="SIMPLE",
)
if not connection_user.bind():
raise HTTPException(400, f"Authentication failed for {form_data.user}")
user = Users.get_user_by_email(mail)
if not user:
try:
role = (
"admin"
if Users.get_num_users() == 0
else request.app.state.config.DEFAULT_USER_ROLE
)
user = Auths.insert_new_auth(
email=mail, password=str(uuid.uuid4()), name=cn, role=role
)
if not user:
raise HTTPException(
500, detail=ERROR_MESSAGES.CREATE_USER_ERROR
)
except HTTPException:
raise
except Exception as err:
raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT(err))
user = Auths.authenticate_user_by_trusted_header(mail)
if user:
token = create_token(
data={"id": user.id},
expires_delta=parse_duration(
request.app.state.config.JWT_EXPIRES_IN
),
)
# Set the cookie token
response.set_cookie(
key="token",
value=token,
httponly=True, # Ensures the cookie is not accessible via JavaScript
)
return {
"token": token,
"token_type": "Bearer",
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
"profile_image_url": user.profile_image_url,
}
else:
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
else:
raise HTTPException(
400,
f"User {form_data.user} does not match the record. Search result: {str(entry[f'{LDAP_ATTRIBUTE_FOR_USERNAME}'])}",
)
except Exception as e:
raise HTTPException(400, detail=str(e))
############################
# SignIn
############################
@ -211,6 +375,10 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
secure=WEBUI_SESSION_COOKIE_SECURE,
)
user_permissions = get_permissions(
user.id, request.app.state.config.USER_PERMISSIONS
)
return {
"token": token,
"token_type": "Bearer",
@ -220,6 +388,7 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
"name": user.name,
"role": user.role,
"profile_image_url": user.profile_image_url,
"permissions": user_permissions,
}
else:
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
@ -260,6 +429,11 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
if Users.get_num_users() == 0
else request.app.state.config.DEFAULT_USER_ROLE
)
if Users.get_num_users() == 0:
# Disable signup after the first user is created
request.app.state.config.ENABLE_SIGNUP = False
hashed = get_password_hash(form_data.password)
user = Auths.insert_new_auth(
form_data.email.lower(),
@ -307,6 +481,10 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
},
)
user_permissions = get_permissions(
user.id, request.app.state.config.USER_PERMISSIONS
)
return {
"token": token,
"token_type": "Bearer",
@ -316,6 +494,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
"name": user.name,
"role": user.role,
"profile_image_url": user.profile_image_url,
"permissions": user_permissions,
}
else:
raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_USER_ERROR)
@ -324,8 +503,31 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
@router.get("/signout")
async def signout(response: Response):
async def signout(request: Request, response: Response):
response.delete_cookie("token")
if ENABLE_OAUTH_SIGNUP.value:
oauth_id_token = request.cookies.get("oauth_id_token")
if oauth_id_token:
try:
async with ClientSession() as session:
async with session.get(OPENID_PROVIDER_URL.value) as resp:
if resp.status == 200:
openid_data = await resp.json()
logout_url = openid_data.get("end_session_endpoint")
if logout_url:
response.delete_cookie("oauth_id_token")
return RedirectResponse(
url=f"{logout_url}?id_token_hint={oauth_id_token}"
)
else:
raise HTTPException(
status_code=resp.status,
detail="Failed to fetch OpenID configuration",
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
return {"status": True}
@ -413,6 +615,7 @@ async def get_admin_config(request: Request, user=Depends(get_admin_user)):
return {
"SHOW_ADMIN_DETAILS": request.app.state.config.SHOW_ADMIN_DETAILS,
"ENABLE_SIGNUP": request.app.state.config.ENABLE_SIGNUP,
"ENABLE_API_KEY": request.app.state.config.ENABLE_API_KEY,
"DEFAULT_USER_ROLE": request.app.state.config.DEFAULT_USER_ROLE,
"JWT_EXPIRES_IN": request.app.state.config.JWT_EXPIRES_IN,
"ENABLE_COMMUNITY_SHARING": request.app.state.config.ENABLE_COMMUNITY_SHARING,
@ -423,6 +626,7 @@ async def get_admin_config(request: Request, user=Depends(get_admin_user)):
class AdminConfig(BaseModel):
SHOW_ADMIN_DETAILS: bool
ENABLE_SIGNUP: bool
ENABLE_API_KEY: bool
DEFAULT_USER_ROLE: str
JWT_EXPIRES_IN: str
ENABLE_COMMUNITY_SHARING: bool
@ -435,6 +639,7 @@ async def update_admin_config(
):
request.app.state.config.SHOW_ADMIN_DETAILS = form_data.SHOW_ADMIN_DETAILS
request.app.state.config.ENABLE_SIGNUP = form_data.ENABLE_SIGNUP
request.app.state.config.ENABLE_API_KEY = form_data.ENABLE_API_KEY
if form_data.DEFAULT_USER_ROLE in ["pending", "user", "admin"]:
request.app.state.config.DEFAULT_USER_ROLE = form_data.DEFAULT_USER_ROLE
@ -453,6 +658,7 @@ async def update_admin_config(
return {
"SHOW_ADMIN_DETAILS": request.app.state.config.SHOW_ADMIN_DETAILS,
"ENABLE_SIGNUP": request.app.state.config.ENABLE_SIGNUP,
"ENABLE_API_KEY": request.app.state.config.ENABLE_API_KEY,
"DEFAULT_USER_ROLE": request.app.state.config.DEFAULT_USER_ROLE,
"JWT_EXPIRES_IN": request.app.state.config.JWT_EXPIRES_IN,
"ENABLE_COMMUNITY_SHARING": request.app.state.config.ENABLE_COMMUNITY_SHARING,
@ -460,6 +666,105 @@ async def update_admin_config(
}
class LdapServerConfig(BaseModel):
label: str
host: str
port: Optional[int] = None
attribute_for_username: str = "uid"
app_dn: str
app_dn_password: str
search_base: str
search_filters: str = ""
use_tls: bool = True
certificate_path: Optional[str] = None
ciphers: Optional[str] = "ALL"
@router.get("/admin/config/ldap/server", response_model=LdapServerConfig)
async def get_ldap_server(request: Request, user=Depends(get_admin_user)):
return {
"label": request.app.state.config.LDAP_SERVER_LABEL,
"host": request.app.state.config.LDAP_SERVER_HOST,
"port": request.app.state.config.LDAP_SERVER_PORT,
"attribute_for_username": request.app.state.config.LDAP_ATTRIBUTE_FOR_USERNAME,
"app_dn": request.app.state.config.LDAP_APP_DN,
"app_dn_password": request.app.state.config.LDAP_APP_PASSWORD,
"search_base": request.app.state.config.LDAP_SEARCH_BASE,
"search_filters": request.app.state.config.LDAP_SEARCH_FILTERS,
"use_tls": request.app.state.config.LDAP_USE_TLS,
"certificate_path": request.app.state.config.LDAP_CA_CERT_FILE,
"ciphers": request.app.state.config.LDAP_CIPHERS,
}
@router.post("/admin/config/ldap/server")
async def update_ldap_server(
request: Request, form_data: LdapServerConfig, user=Depends(get_admin_user)
):
required_fields = [
"label",
"host",
"attribute_for_username",
"app_dn",
"app_dn_password",
"search_base",
]
for key in required_fields:
value = getattr(form_data, key)
if not value:
raise HTTPException(400, detail=f"Required field {key} is empty")
if form_data.use_tls and not form_data.certificate_path:
raise HTTPException(
400, detail="TLS is enabled but certificate file path is missing"
)
request.app.state.config.LDAP_SERVER_LABEL = form_data.label
request.app.state.config.LDAP_SERVER_HOST = form_data.host
request.app.state.config.LDAP_SERVER_PORT = form_data.port
request.app.state.config.LDAP_ATTRIBUTE_FOR_USERNAME = (
form_data.attribute_for_username
)
request.app.state.config.LDAP_APP_DN = form_data.app_dn
request.app.state.config.LDAP_APP_PASSWORD = form_data.app_dn_password
request.app.state.config.LDAP_SEARCH_BASE = form_data.search_base
request.app.state.config.LDAP_SEARCH_FILTERS = form_data.search_filters
request.app.state.config.LDAP_USE_TLS = form_data.use_tls
request.app.state.config.LDAP_CA_CERT_FILE = form_data.certificate_path
request.app.state.config.LDAP_CIPHERS = form_data.ciphers
return {
"label": request.app.state.config.LDAP_SERVER_LABEL,
"host": request.app.state.config.LDAP_SERVER_HOST,
"port": request.app.state.config.LDAP_SERVER_PORT,
"attribute_for_username": request.app.state.config.LDAP_ATTRIBUTE_FOR_USERNAME,
"app_dn": request.app.state.config.LDAP_APP_DN,
"app_dn_password": request.app.state.config.LDAP_APP_PASSWORD,
"search_base": request.app.state.config.LDAP_SEARCH_BASE,
"search_filters": request.app.state.config.LDAP_SEARCH_FILTERS,
"use_tls": request.app.state.config.LDAP_USE_TLS,
"certificate_path": request.app.state.config.LDAP_CA_CERT_FILE,
"ciphers": request.app.state.config.LDAP_CIPHERS,
}
@router.get("/admin/config/ldap")
async def get_ldap_config(request: Request, user=Depends(get_admin_user)):
return {"ENABLE_LDAP": request.app.state.config.ENABLE_LDAP}
class LdapConfigForm(BaseModel):
enable_ldap: Optional[bool] = None
@router.post("/admin/config/ldap")
async def update_ldap_config(
request: Request, form_data: LdapConfigForm, user=Depends(get_admin_user)
):
request.app.state.config.ENABLE_LDAP = form_data.enable_ldap
return {"ENABLE_LDAP": request.app.state.config.ENABLE_LDAP}
############################
# API Key
############################
@ -467,9 +772,16 @@ async def update_admin_config(
# create api key
@router.post("/api_key", response_model=ApiKey)
async def create_api_key_(user=Depends(get_current_user)):
async def generate_api_key(request: Request, user=Depends(get_current_user)):
if not request.app.state.config.ENABLE_API_KEY:
raise HTTPException(
status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.API_KEY_CREATION_NOT_ALLOWED,
)
api_key = create_api_key()
success = Users.update_user_api_key_by_id(user.id, api_key)
if success:
return {
"api_key": api_key,

View File

@ -2,22 +2,25 @@ import json
import logging
from typing import Optional
from open_webui.apps.webui.models.chats import (
from open_webui.models.chats import (
ChatForm,
ChatImportForm,
ChatResponse,
Chats,
ChatTitleIdResponse,
)
from open_webui.apps.webui.models.tags import TagModel, Tags
from open_webui.apps.webui.models.folders import Folders
from open_webui.models.tags import TagModel, Tags
from open_webui.models.folders import Folders
from open_webui.config import ENABLE_ADMIN_CHAT_ACCESS, ENABLE_ADMIN_EXPORT
from open_webui.constants import ERROR_MESSAGES
from open_webui.env import SRC_LOG_LEVELS
from fastapi import APIRouter, Depends, HTTPException, Request, status
from pydantic import BaseModel
from open_webui.utils.utils import get_admin_user, get_verified_user
from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.utils.access_control import has_permission
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
@ -50,9 +53,10 @@ async def get_session_user_chat_list(
@router.delete("/", response_model=bool)
async def delete_all_user_chats(request: Request, user=Depends(get_verified_user)):
if user.role == "user" and not request.app.state.config.USER_PERMISSIONS.get(
"chat", {}
).get("deletion", {}):
if user.role == "user" and not has_permission(
user.id, "chat.delete", request.app.state.config.USER_PERMISSIONS
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
@ -385,8 +389,8 @@ async def delete_chat_by_id(request: Request, id: str, user=Depends(get_verified
return result
else:
if not request.app.state.config.USER_PERMISSIONS.get("chat", {}).get(
"deletion", {}
if not has_permission(
user.id, "chat.delete", request.app.state.config.USER_PERMISSIONS
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
@ -603,7 +607,6 @@ async def add_tag_by_id_and_tag_name(
detail=ERROR_MESSAGES.DEFAULT("Tag name cannot be 'None'"),
)
print(tags, tag_id)
if tag_id not in tags:
Chats.add_chat_tag_by_id_and_user_id_and_tag_name(
id, user.id, form_data.name

View File

@ -1,10 +1,12 @@
from open_webui.config import BannerModel
from fastapi import APIRouter, Depends, Request
from pydantic import BaseModel
from open_webui.utils.utils import get_admin_user, get_verified_user
from typing import Optional
from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.config import get_config, save_config
from open_webui.config import BannerModel
router = APIRouter()
@ -34,8 +36,32 @@ async def export_config(user=Depends(get_admin_user)):
return get_config()
class SetDefaultModelsForm(BaseModel):
models: str
############################
# SetDefaultModels
############################
class ModelsConfigForm(BaseModel):
DEFAULT_MODELS: Optional[str]
MODEL_ORDER_LIST: Optional[list[str]]
@router.get("/models", response_model=ModelsConfigForm)
async def get_models_config(request: Request, user=Depends(get_admin_user)):
return {
"DEFAULT_MODELS": request.app.state.config.DEFAULT_MODELS,
"MODEL_ORDER_LIST": request.app.state.config.MODEL_ORDER_LIST,
}
@router.post("/models", response_model=ModelsConfigForm)
async def set_models_config(
request: Request, form_data: ModelsConfigForm, user=Depends(get_admin_user)
):
request.app.state.config.DEFAULT_MODELS = form_data.DEFAULT_MODELS
request.app.state.config.MODEL_ORDER_LIST = form_data.MODEL_ORDER_LIST
return {
"DEFAULT_MODELS": request.app.state.config.DEFAULT_MODELS,
"MODEL_ORDER_LIST": request.app.state.config.MODEL_ORDER_LIST,
}
class PromptSuggestion(BaseModel):
@ -47,21 +73,8 @@ class SetDefaultSuggestionsForm(BaseModel):
suggestions: list[PromptSuggestion]
############################
# SetDefaultModels
############################
@router.post("/default/models", response_model=str)
async def set_global_default_models(
request: Request, form_data: SetDefaultModelsForm, user=Depends(get_admin_user)
):
request.app.state.config.DEFAULT_MODELS = form_data.models
return request.app.state.config.DEFAULT_MODELS
@router.post("/default/suggestions", response_model=list[PromptSuggestion])
async def set_global_default_suggestions(
@router.post("/suggestions", response_model=list[PromptSuggestion])
async def set_default_suggestions(
request: Request,
form_data: SetDefaultSuggestionsForm,
user=Depends(get_admin_user),

View File

@ -2,8 +2,8 @@ from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, status, Request
from pydantic import BaseModel
from open_webui.apps.webui.models.users import Users, UserModel
from open_webui.apps.webui.models.feedbacks import (
from open_webui.models.users import Users, UserModel
from open_webui.models.feedbacks import (
FeedbackModel,
FeedbackResponse,
FeedbackForm,
@ -11,7 +11,7 @@ from open_webui.apps.webui.models.feedbacks import (
)
from open_webui.constants import ERROR_MESSAGES
from open_webui.utils.utils import get_admin_user, get_verified_user
from open_webui.utils.auth import get_admin_user, get_verified_user
router = APIRouter()

View File

@ -5,27 +5,28 @@ from pathlib import Path
from typing import Optional
from pydantic import BaseModel
import mimetypes
from urllib.parse import quote
from open_webui.storage.provider import Storage
from open_webui.apps.webui.models.files import (
from open_webui.models.files import (
FileForm,
FileModel,
FileModelResponse,
Files,
)
from open_webui.apps.retrieval.main import process_file, ProcessFileForm
from open_webui.routers.retrieval import process_file, ProcessFileForm
from open_webui.config import UPLOAD_DIR
from open_webui.env import SRC_LOG_LEVELS
from open_webui.constants import ERROR_MESSAGES
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status, Request
from fastapi.responses import FileResponse, StreamingResponse
from open_webui.utils.utils import get_admin_user, get_verified_user
from open_webui.utils.auth import get_admin_user, get_verified_user
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
@ -39,7 +40,9 @@ router = APIRouter()
@router.post("/", response_model=FileModelResponse)
def upload_file(file: UploadFile = File(...), user=Depends(get_verified_user)):
def upload_file(
request: Request, file: UploadFile = File(...), user=Depends(get_verified_user)
):
log.info(f"file.content_type: {file.content_type}")
try:
unsanitized_filename = file.filename
@ -56,7 +59,7 @@ def upload_file(file: UploadFile = File(...), user=Depends(get_verified_user)):
FileForm(
**{
"id": id,
"filename": filename,
"filename": name,
"path": file_path,
"meta": {
"name": name,
@ -68,7 +71,7 @@ def upload_file(file: UploadFile = File(...), user=Depends(get_verified_user)):
)
try:
process_file(ProcessFileForm(file_id=id))
process_file(request, ProcessFileForm(file_id=id))
file_item = Files.get_file_by_id(id=id)
except Exception as e:
log.exception(e)
@ -183,13 +186,15 @@ class ContentForm(BaseModel):
@router.post("/{id}/data/content/update")
async def update_file_data_content_by_id(
id: str, form_data: ContentForm, user=Depends(get_verified_user)
request: Request, id: str, form_data: ContentForm, user=Depends(get_verified_user)
):
file = Files.get_file_by_id(id)
if file and (file.user_id == user.id or user.role == "admin"):
try:
process_file(ProcessFileForm(file_id=id, content=form_data.content))
process_file(
request, ProcessFileForm(file_id=id, content=form_data.content)
)
file = Files.get_file_by_id(id=id)
except Exception as e:
log.exception(e)
@ -218,11 +223,15 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
# Check if the file already exists in the cache
if file_path.is_file():
print(f"file_path: {file_path}")
# Handle Unicode filenames
filename = file.meta.get("name", file.filename)
encoded_filename = quote(filename) # RFC5987 encoding
headers = {
"Content-Disposition": f'attachment; filename="{file.meta.get("name", file.filename)}"'
"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}"
}
return FileResponse(file_path, headers=headers)
else:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
@ -279,16 +288,20 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
if file and (file.user_id == user.id or user.role == "admin"):
file_path = file.path
# Handle Unicode filenames
filename = file.meta.get("name", file.filename)
encoded_filename = quote(filename) # RFC5987 encoding
headers = {
"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}"
}
if file_path:
file_path = Storage.get_file(file_path)
file_path = Path(file_path)
# Check if the file already exists in the cache
if file_path.is_file():
print(f"file_path: {file_path}")
headers = {
"Content-Disposition": f'attachment; filename="{file.meta.get("name", file.filename)}"'
}
return FileResponse(file_path, headers=headers)
else:
raise HTTPException(
@ -307,7 +320,7 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
return StreamingResponse(
generator(),
media_type="text/plain",
headers={"Content-Disposition": f"attachment; filename={file_name}"},
headers=headers,
)
else:
raise HTTPException(

View File

@ -8,12 +8,12 @@ from pydantic import BaseModel
import mimetypes
from open_webui.apps.webui.models.folders import (
from open_webui.models.folders import (
FolderForm,
FolderModel,
Folders,
)
from open_webui.apps.webui.models.chats import Chats
from open_webui.models.chats import Chats
from open_webui.config import UPLOAD_DIR
from open_webui.env import SRC_LOG_LEVELS
@ -24,7 +24,7 @@ from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status
from fastapi.responses import FileResponse, StreamingResponse
from open_webui.utils.utils import get_admin_user, get_verified_user
from open_webui.utils.auth import get_admin_user, get_verified_user
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])

View File

@ -2,17 +2,17 @@ import os
from pathlib import Path
from typing import Optional
from open_webui.apps.webui.models.functions import (
from open_webui.models.functions import (
FunctionForm,
FunctionModel,
FunctionResponse,
Functions,
)
from open_webui.apps.webui.utils import load_function_module_by_id, replace_imports
from open_webui.utils.plugin import load_function_module_by_id, replace_imports
from open_webui.config import CACHE_DIR
from open_webui.constants import ERROR_MESSAGES
from fastapi import APIRouter, Depends, HTTPException, Request, status
from open_webui.utils.utils import get_admin_user, get_verified_user
from open_webui.utils.auth import get_admin_user, get_verified_user
router = APIRouter()

Some files were not shown because too many files have changed in this diff Show More