diff --git a/.github/workflows/deploy-to-hf-spaces.yml b/.github/workflows/deploy-to-hf-spaces.yml index aa8bbcfce..7fc66acf5 100644 --- a/.github/workflows/deploy-to-hf-spaces.yml +++ b/.github/workflows/deploy-to-hf-spaces.yml @@ -28,6 +28,8 @@ jobs: steps: - name: Checkout repository uses: actions/checkout@v4 + with: + lfs: true - name: Remove git history run: rm -rf .git @@ -52,7 +54,9 @@ jobs: - name: Set up Git and push to Space run: | git init --initial-branch=main + git lfs install git lfs track "*.ttf" + git lfs track "*.jpg" rm demo.gif git add . git commit -m "GitHub deploy: ${{ github.sha }}" diff --git a/CHANGELOG.md b/CHANGELOG.md index 05c9f369c..2be2e872f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,67 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.4.1] - 2024-11-19 + +### Added + +- **📊 Enhanced Feedback System**: Introduced a detailed 1-10 rating scale for feedback alongside thumbs up/down, preparing for more precise model fine-tuning and improving feedback quality. +- **ℹ️ Tool Descriptions on Hover**: Easily access tool descriptions by hovering over the message input, providing a smoother workflow with more context when utilizing tools. + +### Fixed + +- **🗑️ Graceful Handling of Deleted Users**: Resolved an issue where deleted users caused workspace items (models, knowledge, prompts, tools) to fail, ensuring reliable workspace loading. +- **🔑 API Key Creation**: Fixed an issue preventing users from creating new API keys, restoring secure and seamless API management. +- **🔗 HTTPS Proxy Fix**: Corrected HTTPS proxy issues affecting the '/api/v1/models/' endpoint, ensuring smoother, uninterrupted model management. + +## [0.4.0] - 2024-11-19 + +### Added + +- **👥 User Groups**: You can now create and manage user groups, making user organization seamless. +- **🔐 Group-Based Access Control**: Set granular access to models, knowledge, prompts, and tools based on user groups, allowing for more controlled and secure environments. +- **🛠️ Group-Based User Permissions**: Easily manage workspace permissions. Grant users the ability to upload files, delete, edit, or create temporary chats, as well as define their ability to create models, knowledge, prompts, and tools. +- **🔑 LDAP Support**: Newly introduced LDAP authentication adds robust security and scalability to user management. +- **🌐 Enhanced OpenAI-Compatible Connections**: Added prefix ID support to avoid model ID clashes, with explicit model ID support for APIs lacking '/models' endpoint support, ensuring smooth operation with custom setups. +- **🔐 Ollama API Key Support**: Now manage credentials for Ollama when set behind proxies, including the option to utilize prefix ID for proper distinction across multiple Ollama instances. +- **🔄 Connection Enable/Disable Toggle**: Easily enable or disable individual OpenAI and Ollama connections as needed. +- **🎨 Redesigned Model Workspace**: Freshly redesigned to improve usability for managing models across users and groups. +- **🎨 Redesigned Prompt Workspace**: A fresh UI to conveniently organize and manage prompts. +- **🧩 Sorted Functions Workspace**: Functions are now automatically categorized by type (Action, Filter, Pipe), streamlining management. +- **💻 Redesigned Collaborative Workspace**: Enhanced support for multiple users contributing to models, knowledge, prompts, or tools, improving collaboration. +- **🔧 Auto-Selected Tools in Model Editor**: Tools enabled through the model editor are now automatically selected, whereas previously it only gave users the option to enable the tool, reducing manual steps and enhancing efficiency. +- **🔔 Web Search & Tools Indicator**: A clear indication now shows when web search or tools are active, reducing confusion. +- **🔑 Toggle API Key Auth**: Tighten security by easily enabling or disabling API key authentication option for Open WebUI. +- **🗂️ Agentic Retrieval**: Improve RAG accuracy via smart pre-processing of chat history to determine the best queries before retrieval. +- **📁 Large Text as File Option**: Optionally convert large pasted text into a file upload, keeping the chat interface cleaner. +- **🗂️ Toggle Citations for Models**: Ability to disable citations has been introduced in the model editor. +- **🔍 User Settings Search**: Quickly search for settings fields, improving ease of use and navigation. +- **🗣️ Experimental SpeechT5 TTS**: Local SpeechT5 support added for improved text-to-speech capabilities. +- **🔄 Unified Reset for Models**: A one-click option has been introduced to reset and remove all models from the Admin Settings. +- **🛠️ Initial Setup Wizard**: The setup process now explicitly informs users that they are creating an admin account during the first-time setup, ensuring clarity. Previously, users encountered the login page right away without this distinction. +- **🌐 Enhanced Translations**: Several language translations, including Ukrainian, Norwegian, and Brazilian Portuguese, were refined for better localization. + +### Fixed + +- **🎥 YouTube Video Attachments**: Fixed issues preventing proper loading and attachment of YouTube videos as files. +- **🔄 Shared Chat Update**: Corrected issues where shared chats were not updating, improving collaboration consistency. +- **🔍 DuckDuckGo Rate Limit Fix**: Addressed issues with DuckDuckGo search integration, enhancing search stability and performance when operating within rate limits. +- **🧾 Citations Relevance Fix**: Adjusted the relevance percentage calculation for citations, so that Open WebUI properly reflect the accuracy of a retrieved document in RAG, ensuring users get clearer insights into sources. +- **🔑 Jina Search API Key Requirement**: Added the option to input an API key for Jina Search, ensuring smooth functionality as keys are now mandatory. + +### Changed + +- **🛠️ Functions Moved to Admin Panel**: As Functions operate as advanced plugins, they are now accessible from the Admin Panel instead of the workspace. +- **🛠️ Manage Ollama Connections**: The "Models" section in Admin Settings has been relocated to Admin Settings > "Connections" > Ollama Connections. You can now manage Ollama instances via a dedicated "Manage Ollama" modal from "Connections", streamlining the setup and configuration of Ollama models. +- **📊 Base Models in Admin Settings**: Admins can now find all base models, both connections or functions, in the "Models" Admin setting. Global model accessibility can be enabled or disabled here. Models are private by default, requiring explicit permission assignment for user access. +- **📌 Sticky Model Selection for New Chats**: The model chosen from a previous chat now persists when creating a new chat. If you click "New Chat" again from the new chat page, it will revert to your default model. +- **🎨 Design Refactoring**: Overall design refinements across the platform have been made, providing a more cohesive and polished user experience. + +### Removed + +- **📂 Model List Reordering**: Temporarily removed and will be reintroduced in upcoming user group settings improvements. +- **⚙️ Default Model Setting**: Removed the ability to set a default model for users, will be reintroduced with user group settings in the future. + ## [0.3.35] - 2024-10-26 ### Added diff --git a/backend/open_webui/apps/ollama/main.py b/backend/open_webui/apps/ollama/main.py index 463cc86cc..e051d6646 100644 --- a/backend/open_webui/apps/ollama/main.py +++ b/backend/open_webui/apps/ollama/main.py @@ -362,8 +362,6 @@ async def get_ollama_tags( user.id, type="read", access_control=model_info.access_control ): filtered_models.append(model) - else: - filtered_models.append(model) models["models"] = filtered_models return models @@ -931,9 +929,6 @@ async def generate_chat_completion( del payload["metadata"] model_id = payload["model"] - if ":" not in model_id: - model_id = f"{model_id}:latest" - model_info = Models.get_model_by_id(model_id) if model_info: @@ -963,6 +958,12 @@ async def generate_chat_completion( status_code=403, detail="Model not found", ) + elif not bypass_filter: + if user.role != "admin": + raise HTTPException( + status_code=403, + detail="Model not found", + ) if ":" not in payload["model"]: payload["model"] = f"{payload['model']}:latest" @@ -1051,6 +1052,12 @@ async def generate_openai_chat_completion( status_code=403, detail="Model not found", ) + else: + if user.role != "admin": + raise HTTPException( + status_code=403, + detail="Model not found", + ) if ":" not in payload["model"]: payload["model"] = f"{payload['model']}:latest" @@ -1133,8 +1140,6 @@ async def get_openai_models( user.id, type="read", access_control=model_info.access_control ): filtered_models.append(model) - else: - filtered_models.append(model) models = filtered_models return { diff --git a/backend/open_webui/apps/openai/main.py b/backend/open_webui/apps/openai/main.py index ff842a374..42f4388f5 100644 --- a/backend/open_webui/apps/openai/main.py +++ b/backend/open_webui/apps/openai/main.py @@ -313,7 +313,9 @@ async def get_all_models_responses() -> list: prefix_id = api_config.get("prefix_id", None) if prefix_id: - for model in response["data"]: + for model in ( + response if isinstance(response, list) else response.get("data", []) + ): model["id"] = f"{prefix_id}.{model['id']}" log.debug(f"get_all_models:responses() {responses}") @@ -424,8 +426,6 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_us user.id, type="read", access_control=model_info.access_control ): filtered_models.append(model) - else: - filtered_models.append(model) models["data"] = filtered_models return models @@ -512,6 +512,12 @@ async def generate_chat_completion( status_code=403, detail="Model not found", ) + elif not bypass_filter: + if user.role != "admin": + raise HTTPException( + status_code=403, + detail="Model not found", + ) # Attemp to get urlIdx from the model models = await get_all_models() diff --git a/backend/open_webui/apps/retrieval/loaders/youtube.py b/backend/open_webui/apps/retrieval/loaders/youtube.py new file mode 100644 index 000000000..ad1088be0 --- /dev/null +++ b/backend/open_webui/apps/retrieval/loaders/youtube.py @@ -0,0 +1,98 @@ +from typing import Any, Dict, Generator, List, Optional, Sequence, Union +from urllib.parse import parse_qs, urlparse +from langchain_core.documents import Document + + +ALLOWED_SCHEMES = {"http", "https"} +ALLOWED_NETLOCS = { + "youtu.be", + "m.youtube.com", + "youtube.com", + "www.youtube.com", + "www.youtube-nocookie.com", + "vid.plus", +} + + +def _parse_video_id(url: str) -> Optional[str]: + """Parse a YouTube URL and return the video ID if valid, otherwise None.""" + parsed_url = urlparse(url) + + if parsed_url.scheme not in ALLOWED_SCHEMES: + return None + + if parsed_url.netloc not in ALLOWED_NETLOCS: + return None + + path = parsed_url.path + + if path.endswith("/watch"): + query = parsed_url.query + parsed_query = parse_qs(query) + if "v" in parsed_query: + ids = parsed_query["v"] + video_id = ids if isinstance(ids, str) else ids[0] + else: + return None + else: + path = parsed_url.path.lstrip("/") + video_id = path.split("/")[-1] + + if len(video_id) != 11: # Video IDs are 11 characters long + return None + + return video_id + + +class YoutubeLoader: + """Load `YouTube` video transcripts.""" + + def __init__( + self, + video_id: str, + language: Union[str, Sequence[str]] = "en", + ): + """Initialize with YouTube video ID.""" + _video_id = _parse_video_id(video_id) + self.video_id = _video_id if _video_id is not None else video_id + self._metadata = {"source": video_id} + self.language = language + if isinstance(language, str): + self.language = [language] + else: + self.language = language + + def load(self) -> List[Document]: + """Load YouTube transcripts into `Document` objects.""" + try: + from youtube_transcript_api import ( + NoTranscriptFound, + TranscriptsDisabled, + YouTubeTranscriptApi, + ) + except ImportError: + raise ImportError( + 'Could not import "youtube_transcript_api" Python package. ' + "Please install it with `pip install youtube-transcript-api`." + ) + + try: + transcript_list = YouTubeTranscriptApi.list_transcripts(self.video_id) + except Exception as e: + print(e) + return [] + + try: + transcript = transcript_list.find_transcript(self.language) + except NoTranscriptFound: + transcript = transcript_list.find_transcript(["en"]) + + transcript_pieces: List[Dict[str, Any]] = transcript.fetch() + + transcript = " ".join( + map( + lambda transcript_piece: transcript_piece["text"].strip(" "), + transcript_pieces, + ) + ) + return [Document(page_content=transcript, metadata=self._metadata)] diff --git a/backend/open_webui/apps/retrieval/main.py b/backend/open_webui/apps/retrieval/main.py index 8de2a04cb..776fb98de 100644 --- a/backend/open_webui/apps/retrieval/main.py +++ b/backend/open_webui/apps/retrieval/main.py @@ -23,6 +23,7 @@ from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT # Document loaders from open_webui.apps.retrieval.loaders.main import Loader +from open_webui.apps.retrieval.loaders.youtube import YoutubeLoader # Web search engines from open_webui.apps.retrieval.web.main import SearchResult @@ -75,6 +76,8 @@ from open_webui.config import ( RAG_FILE_MAX_SIZE, RAG_OPENAI_API_BASE_URL, RAG_OPENAI_API_KEY, + RAG_OLLAMA_BASE_URL, + RAG_OLLAMA_API_KEY, RAG_RELEVANCE_THRESHOLD, RAG_RERANKING_MODEL, RAG_RERANKING_MODEL_AUTO_UPDATE, @@ -118,9 +121,6 @@ from open_webui.utils.misc import ( from open_webui.utils.utils import get_admin_user, get_verified_user from langchain.text_splitter import RecursiveCharacterTextSplitter, TokenTextSplitter -from langchain_community.document_loaders import ( - YoutubeLoader, -) from langchain_core.documents import Document @@ -163,6 +163,9 @@ app.state.config.RAG_TEMPLATE = RAG_TEMPLATE app.state.config.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL app.state.config.OPENAI_API_KEY = RAG_OPENAI_API_KEY +app.state.config.OLLAMA_BASE_URL = RAG_OLLAMA_BASE_URL +app.state.config.OLLAMA_API_KEY = RAG_OLLAMA_API_KEY + app.state.config.PDF_EXTRACT_IMAGES = PDF_EXTRACT_IMAGES app.state.config.YOUTUBE_LOADER_LANGUAGE = YOUTUBE_LOADER_LANGUAGE @@ -261,8 +264,16 @@ app.state.EMBEDDING_FUNCTION = get_embedding_function( app.state.config.RAG_EMBEDDING_ENGINE, app.state.config.RAG_EMBEDDING_MODEL, app.state.sentence_transformer_ef, - app.state.config.OPENAI_API_KEY, - app.state.config.OPENAI_API_BASE_URL, + ( + app.state.config.OPENAI_API_BASE_URL + if app.state.config.RAG_EMBEDDING_ENGINE == "openai" + else app.state.config.OLLAMA_BASE_URL + ), + ( + app.state.config.OPENAI_API_KEY + if app.state.config.RAG_EMBEDDING_ENGINE == "openai" + else app.state.config.OLLAMA_API_KEY + ), app.state.config.RAG_EMBEDDING_BATCH_SIZE, ) @@ -312,6 +323,10 @@ async def get_embedding_config(user=Depends(get_admin_user)): "url": app.state.config.OPENAI_API_BASE_URL, "key": app.state.config.OPENAI_API_KEY, }, + "ollama_config": { + "url": app.state.config.OLLAMA_BASE_URL, + "key": app.state.config.OLLAMA_API_KEY, + }, } @@ -328,8 +343,14 @@ class OpenAIConfigForm(BaseModel): key: str +class OllamaConfigForm(BaseModel): + url: str + key: str + + class EmbeddingModelUpdateForm(BaseModel): openai_config: Optional[OpenAIConfigForm] = None + ollama_config: Optional[OllamaConfigForm] = None embedding_engine: str embedding_model: str embedding_batch_size: Optional[int] = 1 @@ -350,6 +371,11 @@ async def update_embedding_config( if form_data.openai_config is not None: app.state.config.OPENAI_API_BASE_URL = form_data.openai_config.url app.state.config.OPENAI_API_KEY = form_data.openai_config.key + + if form_data.ollama_config is not None: + app.state.config.OLLAMA_BASE_URL = form_data.ollama_config.url + app.state.config.OLLAMA_API_KEY = form_data.ollama_config.key + app.state.config.RAG_EMBEDDING_BATCH_SIZE = form_data.embedding_batch_size update_embedding_model(app.state.config.RAG_EMBEDDING_MODEL) @@ -358,8 +384,16 @@ async def update_embedding_config( app.state.config.RAG_EMBEDDING_ENGINE, app.state.config.RAG_EMBEDDING_MODEL, app.state.sentence_transformer_ef, - app.state.config.OPENAI_API_KEY, - app.state.config.OPENAI_API_BASE_URL, + ( + app.state.config.OPENAI_API_BASE_URL + if app.state.config.RAG_EMBEDDING_ENGINE == "openai" + else app.state.config.OLLAMA_BASE_URL + ), + ( + app.state.config.OPENAI_API_KEY + if app.state.config.RAG_EMBEDDING_ENGINE == "openai" + else app.state.config.OLLAMA_API_KEY + ), app.state.config.RAG_EMBEDDING_BATCH_SIZE, ) @@ -372,6 +406,10 @@ async def update_embedding_config( "url": app.state.config.OPENAI_API_BASE_URL, "key": app.state.config.OPENAI_API_KEY, }, + "ollama_config": { + "url": app.state.config.OLLAMA_BASE_URL, + "key": app.state.config.OLLAMA_API_KEY, + }, } except Exception as e: log.exception(f"Problem updating embedding model: {e}") @@ -785,8 +823,16 @@ def save_docs_to_vector_db( app.state.config.RAG_EMBEDDING_ENGINE, app.state.config.RAG_EMBEDDING_MODEL, app.state.sentence_transformer_ef, - app.state.config.OPENAI_API_KEY, - app.state.config.OPENAI_API_BASE_URL, + ( + app.state.config.OPENAI_API_BASE_URL + if app.state.config.RAG_EMBEDDING_ENGINE == "openai" + else app.state.config.OLLAMA_BASE_URL + ), + ( + app.state.config.OPENAI_API_KEY + if app.state.config.RAG_EMBEDDING_ENGINE == "openai" + else app.state.config.OLLAMA_API_KEY + ), app.state.config.RAG_EMBEDDING_BATCH_SIZE, ) @@ -1011,12 +1057,10 @@ def process_youtube_video(form_data: ProcessUrlForm, user=Depends(get_verified_u if not collection_name: collection_name = calculate_sha256_string(form_data.url)[:63] - loader = YoutubeLoader.from_youtube_url( - form_data.url, - add_video_info=False, - language=app.state.config.YOUTUBE_LOADER_LANGUAGE, - translation=app.state.YOUTUBE_LOADER_TRANSLATION, + loader = YoutubeLoader( + form_data.url, language=app.state.config.YOUTUBE_LOADER_LANGUAGE ) + docs = loader.load() content = " ".join([doc.page_content for doc in docs]) log.debug(f"text_content: {content}") @@ -1235,9 +1279,11 @@ def process_web_search(form_data: SearchForm, user=Depends(get_verified_user)): urls = [result.link for result in web_results] loader = get_web_loader( - urls, verify_ssl=app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION + urls, + verify_ssl=app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, + requests_per_second=app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS, ) - docs = loader.load() + docs = loader.aload() save_docs_to_vector_db(docs, collection_name, overwrite=True) diff --git a/backend/open_webui/apps/retrieval/utils.py b/backend/open_webui/apps/retrieval/utils.py index 7d92b7350..6d87c98e3 100644 --- a/backend/open_webui/apps/retrieval/utils.py +++ b/backend/open_webui/apps/retrieval/utils.py @@ -11,11 +11,6 @@ from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriev from langchain_community.retrievers import BM25Retriever from langchain_core.documents import Document - -from open_webui.apps.ollama.main import ( - GenerateEmbedForm, - generate_ollama_batch_embeddings, -) from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT from open_webui.utils.misc import get_last_user_message @@ -182,35 +177,34 @@ def merge_and_sort_query_results( def query_collection( collection_names: list[str], - query: str, + queries: list[str], embedding_function, k: int, ) -> dict: - results = [] - query_embedding = embedding_function(query) - - for collection_name in collection_names: - if collection_name: - try: - result = query_doc( - collection_name=collection_name, - k=k, - query_embedding=query_embedding, - ) - if result is not None: - results.append(result.model_dump()) - except Exception as e: - log.exception(f"Error when querying the collection: {e}") - else: - pass + for query in queries: + query_embedding = embedding_function(query) + for collection_name in collection_names: + if collection_name: + try: + result = query_doc( + collection_name=collection_name, + k=k, + query_embedding=query_embedding, + ) + if result is not None: + results.append(result.model_dump()) + except Exception as e: + log.exception(f"Error when querying the collection: {e}") + else: + pass return merge_and_sort_query_results(results, k=k) def query_collection_with_hybrid_search( collection_names: list[str], - query: str, + queries: list[str], embedding_function, k: int, reranking_function, @@ -220,15 +214,16 @@ def query_collection_with_hybrid_search( error = False for collection_name in collection_names: try: - result = query_doc_with_hybrid_search( - collection_name=collection_name, - query=query, - embedding_function=embedding_function, - k=k, - reranking_function=reranking_function, - r=r, - ) - results.append(result) + for query in queries: + result = query_doc_with_hybrid_search( + collection_name=collection_name, + query=query, + embedding_function=embedding_function, + k=k, + reranking_function=reranking_function, + r=r, + ) + results.append(result) except Exception as e: log.exception( "Error when querying the collection with " f"hybrid_search: {e}" @@ -285,25 +280,19 @@ def get_embedding_function( embedding_engine, embedding_model, embedding_function, - openai_key, - openai_url, + url, + key, embedding_batch_size, ): if embedding_engine == "": return lambda query: embedding_function.encode(query).tolist() elif embedding_engine in ["ollama", "openai"]: - - # Wrapper to run the async generate_embeddings synchronously. - def sync_generate_embeddings(*args, **kwargs): - return asyncio.run(generate_embeddings(*args, **kwargs)) - - # Semantic expectation from the original version (using sync wrapper). - func = lambda query: sync_generate_embeddings( + func = lambda query: generate_embeddings( engine=embedding_engine, model=embedding_model, text=query, - key=openai_key if embedding_engine == "openai" else "", - url=openai_url if embedding_engine == "openai" else "", + url=url, + key=key, ) def generate_multiple(query, func): @@ -320,15 +309,14 @@ def get_embedding_function( def get_rag_context( files, - messages, + queries, embedding_function, k, reranking_function, r, hybrid_search, ): - log.debug(f"files: {files} {messages} {embedding_function} {reranking_function}") - query = get_last_user_message(messages) + log.debug(f"files: {files} {queries} {embedding_function} {reranking_function}") extracted_collections = [] relevant_contexts = [] @@ -370,7 +358,7 @@ def get_rag_context( try: context = query_collection_with_hybrid_search( collection_names=collection_names, - query=query, + queries=queries, embedding_function=embedding_function, k=k, reranking_function=reranking_function, @@ -385,7 +373,7 @@ def get_rag_context( if (not hybrid_search) or (context is None): context = query_collection( collection_names=collection_names, - query=query, + queries=queries, embedding_function=embedding_function, k=k, ) @@ -476,8 +464,8 @@ def get_model_path(model: str, update_model: bool = False): return model -async def generate_openai_batch_embeddings( - model: str, texts: list[str], key: str, url: str = "https://api.openai.com/v1" +def generate_openai_batch_embeddings( + model: str, texts: list[str], url: str = "https://api.openai.com/v1", key: str = "" ) -> Optional[list[list[float]]]: try: r = requests.post( @@ -499,31 +487,50 @@ async def generate_openai_batch_embeddings( return None -async def generate_embeddings( - engine: str, model: str, text: Union[str, list[str]], **kwargs -): +def generate_ollama_batch_embeddings( + model: str, texts: list[str], url: str, key: str +) -> Optional[list[list[float]]]: + try: + r = requests.post( + f"{url}/api/embed", + headers={ + "Content-Type": "application/json", + "Authorization": f"Bearer {key}", + }, + json={"input": texts, "model": model}, + ) + r.raise_for_status() + data = r.json() + + print(data) + if "embeddings" in data: + return data["embeddings"] + else: + raise "Something went wrong :/" + except Exception as e: + print(e) + return None + + +def generate_embeddings(engine: str, model: str, text: Union[str, list[str]], **kwargs): + url = kwargs.get("url", "") + key = kwargs.get("key", "") + if engine == "ollama": if isinstance(text, list): - embeddings = await generate_ollama_batch_embeddings( - GenerateEmbedForm(**{"model": model, "input": text}) + embeddings = generate_ollama_batch_embeddings( + **{"model": model, "texts": text, "url": url, "key": key} ) else: - embeddings = await generate_ollama_batch_embeddings( - GenerateEmbedForm(**{"model": model, "input": [text]}) + embeddings = generate_ollama_batch_embeddings( + **{"model": model, "texts": [text], "url": url, "key": key} ) - return ( - embeddings["embeddings"][0] - if isinstance(text, str) - else embeddings["embeddings"] - ) + return embeddings[0] if isinstance(text, str) else embeddings elif engine == "openai": - key = kwargs.get("key", "") - url = kwargs.get("url", "https://api.openai.com/v1") - if isinstance(text, list): - embeddings = await generate_openai_batch_embeddings(model, text, key, url) + embeddings = generate_openai_batch_embeddings(model, text, url, key) else: - embeddings = await generate_openai_batch_embeddings(model, [text], key, url) + embeddings = generate_openai_batch_embeddings(model, [text], url, key) return embeddings[0] if isinstance(text, str) else embeddings diff --git a/backend/open_webui/apps/retrieval/vector/dbs/opensearch.py b/backend/open_webui/apps/retrieval/vector/dbs/opensearch.py index 248865479..6234b2837 100644 --- a/backend/open_webui/apps/retrieval/vector/dbs/opensearch.py +++ b/backend/open_webui/apps/retrieval/vector/dbs/opensearch.py @@ -1,7 +1,7 @@ from opensearchpy import OpenSearch from typing import Optional -from open_webui.apps.rag.vector.main import VectorItem, SearchResult, GetResult +from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult from open_webui.config import ( OPENSEARCH_URI, OPENSEARCH_SSL, diff --git a/backend/open_webui/apps/socket/main.py b/backend/open_webui/apps/socket/main.py index fca268a6b..5c284f18d 100644 --- a/backend/open_webui/apps/socket/main.py +++ b/backend/open_webui/apps/socket/main.py @@ -1,3 +1,5 @@ +# TODO: move socket to webui app + import asyncio import socketio import logging diff --git a/backend/open_webui/apps/webui/main.py b/backend/open_webui/apps/webui/main.py index ae54ab29a..ce4945b69 100644 --- a/backend/open_webui/apps/webui/main.py +++ b/backend/open_webui/apps/webui/main.py @@ -35,6 +35,7 @@ from open_webui.config import ( ENABLE_LOGIN_FORM, ENABLE_MESSAGE_RATING, ENABLE_SIGNUP, + ENABLE_API_KEY, ENABLE_EVALUATION_ARENA_MODELS, EVALUATION_ARENA_MODELS, DEFAULT_ARENA_MODEL, @@ -98,6 +99,8 @@ app.state.config = AppConfig() app.state.config.ENABLE_SIGNUP = ENABLE_SIGNUP app.state.config.ENABLE_LOGIN_FORM = ENABLE_LOGIN_FORM +app.state.config.ENABLE_API_KEY = ENABLE_API_KEY + app.state.config.JWT_EXPIRES_IN = JWT_EXPIRES_IN app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER app.state.AUTH_TRUSTED_NAME_HEADER = WEBUI_AUTH_TRUSTED_NAME_HEADER @@ -406,6 +409,7 @@ async def generate_function_chat_completion(form_data, user, models: dict = {}): "name": user.name, "role": user.role, }, + "__metadata__": metadata, } extra_params["__tools__"] = get_tools( app, diff --git a/backend/open_webui/apps/webui/models/chats.py b/backend/open_webui/apps/webui/models/chats.py index f6a1e4548..21250add8 100644 --- a/backend/open_webui/apps/webui/models/chats.py +++ b/backend/open_webui/apps/webui/models/chats.py @@ -203,15 +203,22 @@ class ChatTable: def update_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]: try: with get_db() as db: - print("update_shared_chat_by_id") chat = db.get(Chat, chat_id) - print(chat) - chat.title = chat.title - chat.chat = chat.chat - db.commit() - db.refresh(chat) + shared_chat = ( + db.query(Chat).filter_by(user_id=f"shared-{chat_id}").first() + ) - return self.get_chat_by_id(chat.share_id) + if shared_chat is None: + return self.insert_shared_chat_by_chat_id(chat_id) + + shared_chat.title = chat.title + shared_chat.chat = chat.chat + + shared_chat.updated_at = int(time.time()) + db.commit() + db.refresh(shared_chat) + + return ChatModel.model_validate(shared_chat) except Exception: return None diff --git a/backend/open_webui/apps/webui/models/groups.py b/backend/open_webui/apps/webui/models/groups.py index e687374ea..e692198cd 100644 --- a/backend/open_webui/apps/webui/models/groups.py +++ b/backend/open_webui/apps/webui/models/groups.py @@ -11,7 +11,7 @@ from open_webui.apps.webui.models.files import FileMetadataResponse from pydantic import BaseModel, ConfigDict -from sqlalchemy import BigInteger, Column, String, Text, JSON +from sqlalchemy import BigInteger, Column, String, Text, JSON, func log = logging.getLogger(__name__) @@ -128,7 +128,12 @@ class GroupTable: return [ GroupModel.model_validate(group) for group in db.query(Group) - .filter(Group.user_ids.contains([user_id])) + .filter( + func.json_array_length(Group.user_ids) > 0 + ) # Ensure array exists + .filter( + Group.user_ids.cast(String).like(f'%"{user_id}"%') + ) # String-based check .order_by(Group.updated_at.desc()) .all() ] diff --git a/backend/open_webui/apps/webui/models/knowledge.py b/backend/open_webui/apps/webui/models/knowledge.py index 2d0e33f1b..e1a13b3fd 100644 --- a/backend/open_webui/apps/webui/models/knowledge.py +++ b/backend/open_webui/apps/webui/models/knowledge.py @@ -8,6 +8,7 @@ from open_webui.apps.webui.internal.db import Base, get_db from open_webui.env import SRC_LOG_LEVELS from open_webui.apps.webui.models.files import FileMetadataResponse +from open_webui.apps.webui.models.users import Users, UserResponse from pydantic import BaseModel, ConfigDict @@ -79,17 +80,15 @@ class KnowledgeModel(BaseModel): #################### -class KnowledgeResponse(BaseModel): - id: str - name: str - description: str - data: Optional[dict] = None - meta: Optional[dict] = None +class KnowledgeUserModel(KnowledgeModel): + user: Optional[UserResponse] = None - access_control: Optional[dict] = None - created_at: int # timestamp in epoch - updated_at: int # timestamp in epoch +class KnowledgeResponse(KnowledgeModel): + files: Optional[list[FileMetadataResponse | dict]] = None + + +class KnowledgeUserResponse(KnowledgeUserModel): files: Optional[list[FileMetadataResponse | dict]] = None @@ -127,18 +126,26 @@ class KnowledgeTable: except Exception: return None - def get_knowledge_bases(self) -> list[KnowledgeModel]: + def get_knowledge_bases(self) -> list[KnowledgeUserModel]: with get_db() as db: - return [ - KnowledgeModel.model_validate(knowledge) - for knowledge in db.query(Knowledge) - .order_by(Knowledge.updated_at.desc()) - .all() - ] + knowledge_bases = [] + for knowledge in ( + db.query(Knowledge).order_by(Knowledge.updated_at.desc()).all() + ): + user = Users.get_user_by_id(knowledge.user_id) + knowledge_bases.append( + KnowledgeUserModel.model_validate( + { + **KnowledgeModel.model_validate(knowledge).model_dump(), + "user": user.model_dump() if user else None, + } + ) + ) + return knowledge_bases def get_knowledge_bases_by_user_id( self, user_id: str, permission: str = "write" - ) -> list[KnowledgeModel]: + ) -> list[KnowledgeUserModel]: knowledge_bases = self.get_knowledge_bases() return [ knowledge_base diff --git a/backend/open_webui/apps/webui/models/models.py b/backend/open_webui/apps/webui/models/models.py index 46591bd95..50581bc73 100644 --- a/backend/open_webui/apps/webui/models/models.py +++ b/backend/open_webui/apps/webui/models/models.py @@ -5,7 +5,7 @@ from typing import Optional from open_webui.apps.webui.internal.db import Base, JSONField, get_db from open_webui.env import SRC_LOG_LEVELS -from open_webui.apps.webui.models.groups import Groups +from open_webui.apps.webui.models.users import Users, UserResponse from pydantic import BaseModel, ConfigDict @@ -124,20 +124,12 @@ class ModelModel(BaseModel): #################### -class ModelResponse(BaseModel): - id: str - user_id: str - base_model_id: Optional[str] = None +class ModelUserResponse(ModelModel): + user: Optional[UserResponse] = None - name: str - params: ModelParams - meta: ModelMeta - access_control: Optional[dict] = None - - is_active: bool - updated_at: int # timestamp in epoch - created_at: int # timestamp in epoch +class ModelResponse(ModelModel): + pass class ModelForm(BaseModel): @@ -181,12 +173,20 @@ class ModelsTable: with get_db() as db: return [ModelModel.model_validate(model) for model in db.query(Model).all()] - def get_models(self) -> list[ModelModel]: + def get_models(self) -> list[ModelUserResponse]: with get_db() as db: - return [ - ModelModel.model_validate(model) - for model in db.query(Model).filter(Model.base_model_id != None).all() - ] + models = [] + for model in db.query(Model).filter(Model.base_model_id != None).all(): + user = Users.get_user_by_id(model.user_id) + models.append( + ModelUserResponse.model_validate( + { + **ModelModel.model_validate(model).model_dump(), + "user": user.model_dump() if user else None, + } + ) + ) + return models def get_base_models(self) -> list[ModelModel]: with get_db() as db: @@ -197,8 +197,8 @@ class ModelsTable: def get_models_by_user_id( self, user_id: str, permission: str = "write" - ) -> list[ModelModel]: - models = self.get_all_models() + ) -> list[ModelUserResponse]: + models = self.get_models() return [ model for model in models @@ -260,5 +260,15 @@ class ModelsTable: except Exception: return False + def delete_all_models(self) -> bool: + try: + with get_db() as db: + db.query(Model).delete() + db.commit() + + return True + except Exception: + return False + Models = ModelsTable() diff --git a/backend/open_webui/apps/webui/models/prompts.py b/backend/open_webui/apps/webui/models/prompts.py index ea4a229f7..fe9999195 100644 --- a/backend/open_webui/apps/webui/models/prompts.py +++ b/backend/open_webui/apps/webui/models/prompts.py @@ -2,7 +2,7 @@ import time from typing import Optional from open_webui.apps.webui.internal.db import Base, get_db -from open_webui.apps.webui.models.groups import Groups +from open_webui.apps.webui.models.users import Users, UserResponse from pydantic import BaseModel, ConfigDict from sqlalchemy import BigInteger, Column, String, Text, JSON @@ -57,6 +57,10 @@ class PromptModel(BaseModel): #################### +class PromptUserResponse(PromptModel): + user: Optional[UserResponse] = None + + class PromptForm(BaseModel): command: str title: str @@ -97,15 +101,26 @@ class PromptsTable: except Exception: return None - def get_prompts(self) -> list[PromptModel]: + def get_prompts(self) -> list[PromptUserResponse]: with get_db() as db: - return [ - PromptModel.model_validate(prompt) for prompt in db.query(Prompt).all() - ] + prompts = [] + + for prompt in db.query(Prompt).order_by(Prompt.timestamp.desc()).all(): + user = Users.get_user_by_id(prompt.user_id) + prompts.append( + PromptUserResponse.model_validate( + { + **PromptModel.model_validate(prompt).model_dump(), + "user": user.model_dump() if user else None, + } + ) + ) + + return prompts def get_prompts_by_user_id( self, user_id: str, permission: str = "write" - ) -> list[PromptModel]: + ) -> list[PromptUserResponse]: prompts = self.get_prompts() return [ diff --git a/backend/open_webui/apps/webui/models/tools.py b/backend/open_webui/apps/webui/models/tools.py index 63570bee6..b628f4f9f 100644 --- a/backend/open_webui/apps/webui/models/tools.py +++ b/backend/open_webui/apps/webui/models/tools.py @@ -3,7 +3,7 @@ import time from typing import Optional from open_webui.apps.webui.internal.db import Base, JSONField, get_db -from open_webui.apps.webui.models.users import Users +from open_webui.apps.webui.models.users import Users, UserResponse from open_webui.env import SRC_LOG_LEVELS from pydantic import BaseModel, ConfigDict from sqlalchemy import BigInteger, Column, String, Text, JSON @@ -86,6 +86,10 @@ class ToolResponse(BaseModel): created_at: int # timestamp in epoch +class ToolUserResponse(ToolResponse): + user: Optional[UserResponse] = None + + class ToolForm(BaseModel): id: str name: str @@ -134,13 +138,24 @@ class ToolsTable: except Exception: return None - def get_tools(self) -> list[ToolModel]: + def get_tools(self) -> list[ToolUserResponse]: with get_db() as db: - return [ToolModel.model_validate(tool) for tool in db.query(Tool).all()] + tools = [] + for tool in db.query(Tool).order_by(Tool.updated_at.desc()).all(): + user = Users.get_user_by_id(tool.user_id) + tools.append( + ToolUserResponse.model_validate( + { + **ToolModel.model_validate(tool).model_dump(), + "user": user.model_dump() if user else None, + } + ) + ) + return tools def get_tools_by_user_id( self, user_id: str, permission: str = "write" - ) -> list[ToolModel]: + ) -> list[ToolUserResponse]: tools = self.get_tools() return [ diff --git a/backend/open_webui/apps/webui/models/users.py b/backend/open_webui/apps/webui/models/users.py index 328618a67..5bbcc3099 100644 --- a/backend/open_webui/apps/webui/models/users.py +++ b/backend/open_webui/apps/webui/models/users.py @@ -62,6 +62,14 @@ class UserModel(BaseModel): #################### +class UserResponse(BaseModel): + id: str + name: str + email: str + role: str + profile_image_url: str + + class UserRoleUpdateForm(BaseModel): id: str role: str diff --git a/backend/open_webui/apps/webui/routers/auths.py b/backend/open_webui/apps/webui/routers/auths.py index d3592f03b..63ee5e3b0 100644 --- a/backend/open_webui/apps/webui/routers/auths.py +++ b/backend/open_webui/apps/webui/routers/auths.py @@ -18,9 +18,10 @@ from open_webui.apps.webui.models.auths import ( UserResponse, ) from open_webui.apps.webui.models.users import Users -from open_webui.config import WEBUI_AUTH + from open_webui.constants import ERROR_MESSAGES, WEBHOOK_MESSAGES from open_webui.env import ( + WEBUI_AUTH, WEBUI_AUTH_TRUSTED_EMAIL_HEADER, WEBUI_AUTH_TRUSTED_NAME_HEADER, WEBUI_SESSION_COOKIE_SAME_SITE, @@ -580,6 +581,7 @@ async def get_admin_config(request: Request, user=Depends(get_admin_user)): return { "SHOW_ADMIN_DETAILS": request.app.state.config.SHOW_ADMIN_DETAILS, "ENABLE_SIGNUP": request.app.state.config.ENABLE_SIGNUP, + "ENABLE_API_KEY": request.app.state.config.ENABLE_API_KEY, "DEFAULT_USER_ROLE": request.app.state.config.DEFAULT_USER_ROLE, "JWT_EXPIRES_IN": request.app.state.config.JWT_EXPIRES_IN, "ENABLE_COMMUNITY_SHARING": request.app.state.config.ENABLE_COMMUNITY_SHARING, @@ -590,6 +592,7 @@ async def get_admin_config(request: Request, user=Depends(get_admin_user)): class AdminConfig(BaseModel): SHOW_ADMIN_DETAILS: bool ENABLE_SIGNUP: bool + ENABLE_API_KEY: bool DEFAULT_USER_ROLE: str JWT_EXPIRES_IN: str ENABLE_COMMUNITY_SHARING: bool @@ -602,6 +605,7 @@ async def update_admin_config( ): request.app.state.config.SHOW_ADMIN_DETAILS = form_data.SHOW_ADMIN_DETAILS request.app.state.config.ENABLE_SIGNUP = form_data.ENABLE_SIGNUP + request.app.state.config.ENABLE_API_KEY = form_data.ENABLE_API_KEY if form_data.DEFAULT_USER_ROLE in ["pending", "user", "admin"]: request.app.state.config.DEFAULT_USER_ROLE = form_data.DEFAULT_USER_ROLE @@ -620,6 +624,7 @@ async def update_admin_config( return { "SHOW_ADMIN_DETAILS": request.app.state.config.SHOW_ADMIN_DETAILS, "ENABLE_SIGNUP": request.app.state.config.ENABLE_SIGNUP, + "ENABLE_API_KEY": request.app.state.config.ENABLE_API_KEY, "DEFAULT_USER_ROLE": request.app.state.config.DEFAULT_USER_ROLE, "JWT_EXPIRES_IN": request.app.state.config.JWT_EXPIRES_IN, "ENABLE_COMMUNITY_SHARING": request.app.state.config.ENABLE_COMMUNITY_SHARING, @@ -733,9 +738,16 @@ async def update_ldap_config( # create api key @router.post("/api_key", response_model=ApiKey) -async def create_api_key_(user=Depends(get_current_user)): +async def generate_api_key(request: Request, user=Depends(get_current_user)): + if not request.app.state.config.ENABLE_API_KEY: + raise HTTPException( + status.HTTP_403_FORBIDDEN, + detail=ERROR_MESSAGES.API_KEY_CREATION_NOT_ALLOWED, + ) + api_key = create_api_key() success = Users.update_user_api_key_by_id(user.id, api_key) + if success: return { "api_key": api_key, diff --git a/backend/open_webui/apps/webui/routers/knowledge.py b/backend/open_webui/apps/webui/routers/knowledge.py index 966e82960..1b063cda2 100644 --- a/backend/open_webui/apps/webui/routers/knowledge.py +++ b/backend/open_webui/apps/webui/routers/knowledge.py @@ -8,6 +8,7 @@ from open_webui.apps.webui.models.knowledge import ( Knowledges, KnowledgeForm, KnowledgeResponse, + KnowledgeUserResponse, ) from open_webui.apps.webui.models.files import Files, FileModel from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT @@ -32,7 +33,7 @@ router = APIRouter() ############################ -@router.get("/", response_model=list[KnowledgeResponse]) +@router.get("/", response_model=list[KnowledgeUserResponse]) async def get_knowledge(user=Depends(get_verified_user)): knowledge_bases = [] @@ -42,6 +43,7 @@ async def get_knowledge(user=Depends(get_verified_user)): knowledge_bases = Knowledges.get_knowledge_bases_by_user_id(user.id, "read") # Get files for each knowledge base + knowledge_with_files = [] for knowledge_base in knowledge_bases: files = [] if knowledge_base.data: @@ -69,15 +71,17 @@ async def get_knowledge(user=Depends(get_verified_user)): files = Files.get_file_metadatas_by_ids(file_ids) - knowledge_base = KnowledgeResponse( - **knowledge_base.model_dump(), - files=files, + knowledge_with_files.append( + KnowledgeUserResponse( + **knowledge_base.model_dump(), + files=files, + ) ) - return knowledge_bases + return knowledge_with_files -@router.get("/list", response_model=list[KnowledgeResponse]) +@router.get("/list", response_model=list[KnowledgeUserResponse]) async def get_knowledge_list(user=Depends(get_verified_user)): knowledge_bases = [] @@ -87,6 +91,7 @@ async def get_knowledge_list(user=Depends(get_verified_user)): knowledge_bases = Knowledges.get_knowledge_bases_by_user_id(user.id, "write") # Get files for each knowledge base + knowledge_with_files = [] for knowledge_base in knowledge_bases: files = [] if knowledge_base.data: @@ -114,12 +119,13 @@ async def get_knowledge_list(user=Depends(get_verified_user)): files = Files.get_file_metadatas_by_ids(file_ids) - knowledge_base = KnowledgeResponse( - **knowledge_base.model_dump(), - files=files, + knowledge_with_files.append( + KnowledgeUserResponse( + **knowledge_base.model_dump(), + files=files, + ) ) - - return knowledge_bases + return knowledge_with_files ############################ diff --git a/backend/open_webui/apps/webui/routers/models.py b/backend/open_webui/apps/webui/routers/models.py index 634630622..6a8085385 100644 --- a/backend/open_webui/apps/webui/routers/models.py +++ b/backend/open_webui/apps/webui/routers/models.py @@ -4,6 +4,7 @@ from open_webui.apps.webui.models.models import ( ModelForm, ModelModel, ModelResponse, + ModelUserResponse, Models, ) from open_webui.constants import ERROR_MESSAGES @@ -22,7 +23,7 @@ router = APIRouter() ########################### -@router.get("/", response_model=list[ModelResponse]) +@router.get("/", response_model=list[ModelUserResponse]) async def get_models(id: Optional[str] = None, user=Depends(get_verified_user)): if user.role == "admin": return Models.get_models() @@ -82,7 +83,8 @@ async def create_new_model( ########################### -@router.get("/id/{id}", response_model=Optional[ModelResponse]) +# Note: We're not using the typical url path param here, but instead using a query parameter to allow '/' in the id +@router.get("/model", response_model=Optional[ModelResponse]) async def get_model_by_id(id: str, user=Depends(get_verified_user)): model = Models.get_model_by_id(id) if model: @@ -104,7 +106,7 @@ async def get_model_by_id(id: str, user=Depends(get_verified_user)): ############################ -@router.post("/id/{id}/toggle", response_model=Optional[ModelResponse]) +@router.post("/model/toggle", response_model=Optional[ModelResponse]) async def toggle_model_by_id(id: str, user=Depends(get_verified_user)): model = Models.get_model_by_id(id) if model: @@ -139,7 +141,7 @@ async def toggle_model_by_id(id: str, user=Depends(get_verified_user)): ############################ -@router.post("/id/{id}/update", response_model=Optional[ModelModel]) +@router.post("/model/update", response_model=Optional[ModelModel]) async def update_model_by_id( id: str, form_data: ModelForm, @@ -162,7 +164,7 @@ async def update_model_by_id( ############################ -@router.delete("/id/{id}/delete", response_model=bool) +@router.delete("/model/delete", response_model=bool) async def delete_model_by_id(id: str, user=Depends(get_verified_user)): model = Models.get_model_by_id(id) if not model: @@ -179,3 +181,9 @@ async def delete_model_by_id(id: str, user=Depends(get_verified_user)): result = Models.delete_model_by_id(id) return result + + +@router.delete("/delete/all", response_model=bool) +async def delete_all_models(user=Depends(get_admin_user)): + result = Models.delete_all_models() + return result diff --git a/backend/open_webui/apps/webui/routers/prompts.py b/backend/open_webui/apps/webui/routers/prompts.py index e3aab4043..7cacde606 100644 --- a/backend/open_webui/apps/webui/routers/prompts.py +++ b/backend/open_webui/apps/webui/routers/prompts.py @@ -1,6 +1,11 @@ from typing import Optional -from open_webui.apps.webui.models.prompts import PromptForm, PromptModel, Prompts +from open_webui.apps.webui.models.prompts import ( + PromptForm, + PromptUserResponse, + PromptModel, + Prompts, +) from open_webui.constants import ERROR_MESSAGES from fastapi import APIRouter, Depends, HTTPException, status, Request from open_webui.utils.utils import get_admin_user, get_verified_user @@ -23,7 +28,7 @@ async def get_prompts(user=Depends(get_verified_user)): return prompts -@router.get("/list", response_model=list[PromptModel]) +@router.get("/list", response_model=list[PromptUserResponse]) async def get_prompt_list(user=Depends(get_verified_user)): if user.role == "admin": prompts = Prompts.get_prompts() diff --git a/backend/open_webui/apps/webui/routers/tools.py b/backend/open_webui/apps/webui/routers/tools.py index fb6292f2f..883c34405 100644 --- a/backend/open_webui/apps/webui/routers/tools.py +++ b/backend/open_webui/apps/webui/routers/tools.py @@ -2,7 +2,13 @@ import os from pathlib import Path from typing import Optional -from open_webui.apps.webui.models.tools import ToolForm, ToolModel, ToolResponse, Tools +from open_webui.apps.webui.models.tools import ( + ToolForm, + ToolModel, + ToolResponse, + ToolUserResponse, + Tools, +) from open_webui.apps.webui.utils import load_tools_module_by_id, replace_imports from open_webui.config import CACHE_DIR, DATA_DIR from open_webui.constants import ERROR_MESSAGES @@ -19,7 +25,7 @@ router = APIRouter() ############################ -@router.get("/", response_model=list[ToolResponse]) +@router.get("/", response_model=list[ToolUserResponse]) async def get_tools(user=Depends(get_verified_user)): if user.role == "admin": tools = Tools.get_tools() @@ -33,7 +39,7 @@ async def get_tools(user=Depends(get_verified_user)): ############################ -@router.get("/list", response_model=list[ToolResponse]) +@router.get("/list", response_model=list[ToolUserResponse]) async def get_tool_list(user=Depends(get_verified_user)): if user.role == "admin": tools = Tools.get_tools() diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index 47834a192..a5adbb0f1 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -265,6 +265,13 @@ class AppConfig: # WEBUI_AUTH (Required for security) #################################### +ENABLE_API_KEY = PersistentConfig( + "ENABLE_API_KEY", + "auth.api_key.enable", + os.environ.get("ENABLE_API_KEY", "True").lower() == "true", +) + + JWT_EXPIRES_IN = PersistentConfig( "JWT_EXPIRES_IN", "auth.jwt_expiry", os.environ.get("JWT_EXPIRES_IN", "-1") ) @@ -941,19 +948,49 @@ ENABLE_TAGS_GENERATION = PersistentConfig( os.environ.get("ENABLE_TAGS_GENERATION", "True").lower() == "true", ) -ENABLE_SEARCH_QUERY = PersistentConfig( - "ENABLE_SEARCH_QUERY", - "task.search.enable", - os.environ.get("ENABLE_SEARCH_QUERY", "True").lower() == "true", + +ENABLE_SEARCH_QUERY_GENERATION = PersistentConfig( + "ENABLE_SEARCH_QUERY_GENERATION", + "task.query.search.enable", + os.environ.get("ENABLE_SEARCH_QUERY_GENERATION", "True").lower() == "true", +) + +ENABLE_RETRIEVAL_QUERY_GENERATION = PersistentConfig( + "ENABLE_RETRIEVAL_QUERY_GENERATION", + "task.query.retrieval.enable", + os.environ.get("ENABLE_RETRIEVAL_QUERY_GENERATION", "True").lower() == "true", ) -SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = PersistentConfig( - "SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE", - "task.search.prompt_template", - os.environ.get("SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE", ""), +QUERY_GENERATION_PROMPT_TEMPLATE = PersistentConfig( + "QUERY_GENERATION_PROMPT_TEMPLATE", + "task.query.prompt_template", + os.environ.get("QUERY_GENERATION_PROMPT_TEMPLATE", ""), ) +DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE = """### Task: +Based on the chat history, determine whether a search is necessary, and if so, generate a 1-3 broad search queries to retrieve comprehensive and updated information. If no search is required, return an empty list. + +### Guidelines: +- Respond exclusively with a JSON object. +- If a search query is needed, return an object like: { "queries": ["query1", "query2"] } where each query is distinct and concise. +- If no search query is necessary, output should be: { "queries": [] } +- Default to suggesting a search query to ensure accurate and updated information, unless it is definitively clear no search is required. +- Be concise, focusing strictly on composing search queries with no additional commentary or text. +- When in doubt, prefer to suggest a search for comprehensiveness. +- Today's date is: {{CURRENT_DATE}} + +### Output: +JSON format: { + "queries": ["query1", "query2"] +} + +### Chat History: + +{{MESSAGES:END:6}} + +""" + TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = PersistentConfig( "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE", @@ -1181,6 +1218,19 @@ RAG_OPENAI_API_KEY = PersistentConfig( os.getenv("RAG_OPENAI_API_KEY", OPENAI_API_KEY), ) +RAG_OLLAMA_BASE_URL = PersistentConfig( + "RAG_OLLAMA_BASE_URL", + "rag.ollama.url", + os.getenv("RAG_OLLAMA_BASE_URL", OLLAMA_BASE_URL), +) + +RAG_OLLAMA_API_KEY = PersistentConfig( + "RAG_OLLAMA_API_KEY", + "rag.ollama.key", + os.getenv("RAG_OLLAMA_API_KEY", ""), +) + + ENABLE_RAG_LOCAL_WEB_FETCH = ( os.getenv("ENABLE_RAG_LOCAL_WEB_FETCH", "False").lower() == "true" ) diff --git a/backend/open_webui/constants.py b/backend/open_webui/constants.py index d6f33af4a..9c7d6f9e9 100644 --- a/backend/open_webui/constants.py +++ b/backend/open_webui/constants.py @@ -62,6 +62,7 @@ class ERROR_MESSAGES(str, Enum): NOT_FOUND = "We could not find what you're looking for :/" USER_NOT_FOUND = "We could not find what you're looking for :/" API_KEY_NOT_FOUND = "Oops! It looks like there's a hiccup. The API key is missing. Please make sure to provide a valid API key to access this feature." + API_KEY_NOT_ALLOWED = "Use of API key is not enabled in the environment." MALICIOUS = "Unusual activities detected, please try again in a few minutes." @@ -75,6 +76,7 @@ class ERROR_MESSAGES(str, Enum): OPENAI_NOT_FOUND = lambda name="": "OpenAI API was not found" OLLAMA_NOT_FOUND = "WebUI could not connect to Ollama" CREATE_API_KEY_ERROR = "Oops! Something went wrong while creating your API key. Please try again later. If the issue persists, contact support for assistance." + API_KEY_CREATION_NOT_ALLOWED = "API key creation is not allowed in the environment." EMPTY_CONTENT = "The content provided is empty. Please ensure that there is text or data present before proceeding." diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index f639b932c..6a7cbb7eb 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -78,11 +78,13 @@ from open_webui.config import ( ENV, FRONTEND_BUILD_DIR, OAUTH_PROVIDERS, - ENABLE_SEARCH_QUERY, - SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE, STATIC_DIR, TASK_MODEL, TASK_MODEL_EXTERNAL, + ENABLE_SEARCH_QUERY_GENERATION, + ENABLE_RETRIEVAL_QUERY_GENERATION, + QUERY_GENERATION_PROMPT_TEMPLATE, + DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE, TITLE_GENERATION_PROMPT_TEMPLATE, TAGS_GENERATION_PROMPT_TEMPLATE, TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, @@ -122,7 +124,7 @@ from open_webui.utils.security_headers import SecurityHeadersMiddleware from open_webui.utils.task import ( moa_response_generation_template, tags_generation_template, - search_query_generation_template, + query_generation_template, emoji_generation_template, title_generation_template, tools_function_calling_generation_template, @@ -206,10 +208,9 @@ app.state.config.ENABLE_TAGS_GENERATION = ENABLE_TAGS_GENERATION app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = TAGS_GENERATION_PROMPT_TEMPLATE -app.state.config.ENABLE_SEARCH_QUERY = ENABLE_SEARCH_QUERY -app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = ( - SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE -) +app.state.config.ENABLE_SEARCH_QUERY_GENERATION = ENABLE_SEARCH_QUERY_GENERATION +app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION = ENABLE_RETRIEVAL_QUERY_GENERATION +app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE = QUERY_GENERATION_PROMPT_TEMPLATE app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = ( TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE @@ -492,14 +493,41 @@ async def chat_completion_tools_handler( return body, {"contexts": contexts, "citations": citations} -async def chat_completion_files_handler(body) -> tuple[dict, dict[str, list]]: +async def chat_completion_files_handler( + body: dict, user: UserModel +) -> tuple[dict, dict[str, list]]: contexts = [] citations = [] + try: + queries_response = await generate_queries( + { + "model": body["model"], + "messages": body["messages"], + "type": "retrieval", + }, + user, + ) + queries_response = queries_response["choices"][0]["message"]["content"] + + try: + queries_response = json.loads(queries_response) + except Exception as e: + queries_response = {"queries": []} + + queries = queries_response.get("queries", []) + except Exception as e: + queries = [] + + if len(queries) == 0: + queries = [get_last_user_message(body["messages"])] + + print(f"{queries=}") + if files := body.get("metadata", {}).get("files", None): contexts, citations = get_rag_context( files=files, - messages=body["messages"], + queries=queries, embedding_function=retrieval_app.state.EMBEDDING_FUNCTION, k=retrieval_app.state.config.TOP_K, reranking_function=retrieval_app.state.sentence_transformer_rf, @@ -557,16 +585,34 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): model_info = Models.get_model_by_id(model["id"]) if user.role == "user": - if model_info and not ( - user.id == model_info.user_id - or has_access( - user.id, type="read", access_control=model_info.access_control - ) - ): - return JSONResponse( - status_code=status.HTTP_403_FORBIDDEN, - content={"detail": "User does not have access to the model"}, - ) + if model.get("arena"): + if not has_access( + user.id, + type="read", + access_control=model.get("info", {}) + .get("meta", {}) + .get("access_control", {}), + ): + raise HTTPException( + status_code=403, + detail="Model not found", + ) + else: + if not model_info: + return JSONResponse( + status_code=status.HTTP_404_NOT_FOUND, + content={"detail": "Model not found"}, + ) + elif not ( + user.id == model_info.user_id + or has_access( + user.id, type="read", access_control=model_info.access_control + ) + ): + return JSONResponse( + status_code=status.HTTP_403_FORBIDDEN, + content={"detail": "User does not have access to the model"}, + ) metadata = { "chat_id": body.pop("chat_id", None), @@ -586,6 +632,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): "name": user.name, "role": user.role, }, + "__metadata__": metadata, } # Initialize data_items to store additional data to be sent to the client @@ -624,7 +671,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): log.exception(e) try: - body, flags = await chat_completion_files_handler(body) + body, flags = await chat_completion_files_handler(body, user) contexts.extend(flags.get("contexts", [])) citations.extend(flags.get("citations", [])) except Exception as e: @@ -808,6 +855,11 @@ class PipelineMiddleware(BaseHTTPMiddleware): status_code=status.HTTP_401_UNAUTHORIZED, content={"detail": "Not authenticated"}, ) + except HTTPException as e: + return JSONResponse( + status_code=e.status_code, + content={"detail": e.detail}, + ) model_list = await get_all_models() models = {model["id"]: model for model in model_list} @@ -893,6 +945,7 @@ async def commit_session_after_request(request: Request, call_next): @app.middleware("http") async def check_url(request: Request, call_next): start_time = int(time.time()) + request.state.enable_api_key = webui_app.state.config.ENABLE_API_KEY response = await call_next(request) process_time = int(time.time()) - start_time response.headers["X-Process-Time"] = str(process_time) @@ -1042,17 +1095,17 @@ async def get_all_models(): ) # Process action_ids to get the actions - def get_action_items_from_module(module): + def get_action_items_from_module(function, module): actions = [] if hasattr(module, "actions"): actions = module.actions return [ { - "id": f"{module.id}.{action['id']}", - "name": action.get("name", f"{module.name} ({action['id']})"), - "description": module.meta.description, + "id": f"{function.id}.{action['id']}", + "name": action.get("name", f"{function.name} ({action['id']})"), + "description": function.meta.description, "icon_url": action.get( - "icon_url", module.meta.manifest.get("icon_url", None) + "icon_url", function.meta.manifest.get("icon_url", None) ), } for action in actions @@ -1060,10 +1113,10 @@ async def get_all_models(): else: return [ { - "id": module.id, - "name": module.name, - "description": module.meta.description, - "icon_url": module.meta.manifest.get("icon_url", None), + "id": function.id, + "name": function.name, + "description": function.meta.description, + "icon_url": function.meta.manifest.get("icon_url", None), } ] @@ -1088,7 +1141,9 @@ async def get_all_models(): raise Exception(f"Action not found: {action_id}") function_module = get_function_module_by_id(action_id) - model["actions"].extend(get_action_items_from_module(function_module)) + model["actions"].extend( + get_action_items_from_module(action_function, function_module) + ) return models @@ -1107,14 +1162,23 @@ async def get_models(user=Depends(get_verified_user)): if user.role == "user": filtered_models = [] for model in models: + if model.get("arena"): + if has_access( + user.id, + type="read", + access_control=model.get("info", {}) + .get("meta", {}) + .get("access_control", {}), + ): + filtered_models.append(model) + continue + model_info = Models.get_model_by_id(model["id"]) if model_info: if user.id == model_info.user_id or has_access( user.id, type="read", access_control=model_info.access_control ): filtered_models.append(model) - else: - filtered_models.append(model) models = filtered_models return {"data": models} @@ -1144,19 +1208,38 @@ async def generate_chat_completions( ) model = models[model_id] + # Check if user has access to the model - if user.role == "user": - model_info = Models.get_model_by_id(model_id) - if not ( - user.id == model_info.user_id - or has_access( - user.id, type="read", access_control=model_info.access_control - ) - ): - raise HTTPException( - status_code=403, - detail="Model not found", - ) + if not bypass_filter and user.role == "user": + if model.get("arena"): + if not has_access( + user.id, + type="read", + access_control=model.get("info", {}) + .get("meta", {}) + .get("access_control", {}), + ): + raise HTTPException( + status_code=403, + detail="Model not found", + ) + else: + model_info = Models.get_model_by_id(model_id) + if not model_info: + raise HTTPException( + status_code=404, + detail="Model not found", + ) + elif not ( + user.id == model_info.user_id + or has_access( + user.id, type="read", access_control=model_info.access_control + ) + ): + raise HTTPException( + status_code=403, + detail="Model not found", + ) if model["owned_by"] == "arena": model_ids = model.get("info", {}).get("meta", {}).get("model_ids") @@ -1165,9 +1248,7 @@ async def generate_chat_completions( model_ids = [ model["id"] for model in await get_all_models() - if model.get("owned_by") != "arena" - and not model.get("info", {}).get("meta", {}).get("hidden", False) - and model["id"] not in model_ids + if model.get("owned_by") != "arena" and model["id"] not in model_ids ] selected_model_id = None @@ -1178,7 +1259,6 @@ async def generate_chat_completions( model["id"] for model in await get_all_models() if model.get("owned_by") != "arena" - and not model.get("info", {}).get("meta", {}).get("hidden", False) ] selected_model_id = random.choice(model_ids) @@ -1533,8 +1613,9 @@ async def get_task_config(user=Depends(get_verified_user)): "TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE, "TAGS_GENERATION_PROMPT_TEMPLATE": app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE, "ENABLE_TAGS_GENERATION": app.state.config.ENABLE_TAGS_GENERATION, - "ENABLE_SEARCH_QUERY": app.state.config.ENABLE_SEARCH_QUERY, - "SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE, + "ENABLE_SEARCH_QUERY_GENERATION": app.state.config.ENABLE_SEARCH_QUERY_GENERATION, + "ENABLE_RETRIEVAL_QUERY_GENERATION": app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION, + "QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE, "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, } @@ -1545,8 +1626,9 @@ class TaskConfigForm(BaseModel): TITLE_GENERATION_PROMPT_TEMPLATE: str TAGS_GENERATION_PROMPT_TEMPLATE: str ENABLE_TAGS_GENERATION: bool - SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE: str - ENABLE_SEARCH_QUERY: bool + ENABLE_SEARCH_QUERY_GENERATION: bool + ENABLE_RETRIEVAL_QUERY_GENERATION: bool + QUERY_GENERATION_PROMPT_TEMPLATE: str TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE: str @@ -1561,11 +1643,16 @@ async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_u form_data.TAGS_GENERATION_PROMPT_TEMPLATE ) app.state.config.ENABLE_TAGS_GENERATION = form_data.ENABLE_TAGS_GENERATION - - app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = ( - form_data.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE + app.state.config.ENABLE_SEARCH_QUERY_GENERATION = ( + form_data.ENABLE_SEARCH_QUERY_GENERATION + ) + app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION = ( + form_data.ENABLE_RETRIEVAL_QUERY_GENERATION + ) + + app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE = ( + form_data.QUERY_GENERATION_PROMPT_TEMPLATE ) - app.state.config.ENABLE_SEARCH_QUERY = form_data.ENABLE_SEARCH_QUERY app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = ( form_data.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE ) @@ -1576,8 +1663,9 @@ async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_u "TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE, "TAGS_GENERATION_PROMPT_TEMPLATE": app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE, "ENABLE_TAGS_GENERATION": app.state.config.ENABLE_TAGS_GENERATION, - "SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE, - "ENABLE_SEARCH_QUERY": app.state.config.ENABLE_SEARCH_QUERY, + "ENABLE_SEARCH_QUERY_GENERATION": app.state.config.ENABLE_SEARCH_QUERY_GENERATION, + "ENABLE_RETRIEVAL_QUERY_GENERATION": app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION, + "QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE, "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, } @@ -1753,14 +1841,22 @@ JSON format: { "tags": ["tag1", "tag2", "tag3"] } return await generate_chat_completions(form_data=payload, user=user) -@app.post("/api/task/query/completions") -async def generate_search_query(form_data: dict, user=Depends(get_verified_user)): - print("generate_search_query") - if not app.state.config.ENABLE_SEARCH_QUERY: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Search query generation is disabled", - ) +@app.post("/api/task/queries/completions") +async def generate_queries(form_data: dict, user=Depends(get_verified_user)): + print("generate_queries") + type = form_data.get("type") + if type == "web_search": + if not app.state.config.ENABLE_SEARCH_QUERY_GENERATION: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Search query generation is disabled", + ) + elif type == "retrieval": + if not app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Query generation is disabled", + ) model_list = await get_all_models() models = {model["id"]: model for model in model_list} @@ -1784,36 +1880,19 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user) model = models[task_model_id] - if app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE != "": - template = app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE + if app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE != "": + template = app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE else: - template = """Given the user's message and interaction history, decide if a web search is necessary. You must be concise and exclusively provide a search query if one is necessary. Refrain from verbose responses or any additional commentary. Prefer suggesting a search if uncertain to provide comprehensive or updated information. If a search isn't needed at all, respond with an empty string. Default to a search query when in doubt. Today's date is {{CURRENT_DATE}}. + template = DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE -User Message: -{{prompt:end:4000}} - -Interaction History: -{{MESSAGES:END:6}} - -Search Query:""" - - content = search_query_generation_template( + content = query_generation_template( template, form_data["messages"], {"name": user.name} ) - print("content", content) - payload = { "model": task_model_id, "messages": [{"role": "user", "content": content}], "stream": False, - **( - {"max_tokens": 30} - if models[task_model_id]["owned_by"] == "ollama" - else { - "max_completion_tokens": 30, - } - ), "metadata": {"task": str(TASKS.QUERY_GENERATION), "task_body": form_data}, } log.debug(payload) @@ -2354,6 +2433,7 @@ async def get_app_config(request: Request): "auth": WEBUI_AUTH, "auth_trusted_header": bool(webui_app.state.AUTH_TRUSTED_EMAIL_HEADER), "enable_ldap": webui_app.state.config.ENABLE_LDAP, + "enable_api_key": webui_app.state.config.ENABLE_API_KEY, "enable_signup": webui_app.state.config.ENABLE_SIGNUP, "enable_login_form": webui_app.state.config.ENABLE_LOGIN_FORM, **( diff --git a/backend/open_webui/utils/task.py b/backend/open_webui/utils/task.py index 799cca11a..28b07da37 100644 --- a/backend/open_webui/utils/task.py +++ b/backend/open_webui/utils/task.py @@ -163,7 +163,7 @@ def emoji_generation_template( return template -def search_query_generation_template( +def query_generation_template( template: str, messages: list[dict], user: Optional[dict] = None ) -> str: prompt = get_last_user_message(messages) diff --git a/backend/open_webui/utils/utils.py b/backend/open_webui/utils/utils.py index 1c2205ebf..cde953102 100644 --- a/backend/open_webui/utils/utils.py +++ b/backend/open_webui/utils/utils.py @@ -5,13 +5,11 @@ import jwt from datetime import UTC, datetime, timedelta from typing import Optional, Union, List, Dict - from open_webui.apps.webui.models.users import Users from open_webui.constants import ERROR_MESSAGES from open_webui.env import WEBUI_SECRET_KEY - from fastapi import Depends, HTTPException, Request, Response, status from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from passlib.context import CryptContext @@ -93,10 +91,13 @@ def get_current_user( # auth by api key if token.startswith("sk-"): + if not request.state.enable_api_key: + raise HTTPException( + status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.API_KEY_NOT_ALLOWED + ) return get_current_user_by_api_key(token) # auth by jwt token - try: data = decode_token(token) except Exception as e: diff --git a/backend/requirements.txt b/backend/requirements.txt index a5bfae585..258f69e25 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -37,8 +37,8 @@ anthropic google-generativeai==0.7.2 tiktoken -langchain==0.3.5 -langchain-community==0.3.3 +langchain==0.3.7 +langchain-community==0.3.7 langchain-chroma==0.1.4 fake-useragent==1.5.1 @@ -82,12 +82,12 @@ authlib==1.3.2 black==24.8.0 langfuse==2.44.0 -youtube-transcript-api==0.6.2 +youtube-transcript-api==0.6.3 pytube==15.0.0 extract_msg pydub -duckduckgo-search~=6.3.4 +duckduckgo-search~=6.3.5 ## Tests docker~=7.1.0 diff --git a/package-lock.json b/package-lock.json index 90ba9b605..d2114a6ac 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "open-webui", - "version": "0.4.0.dev1", + "version": "0.4.1", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "open-webui", - "version": "0.4.0.dev1", + "version": "0.4.1", "dependencies": { "@codemirror/lang-javascript": "^6.2.2", "@codemirror/lang-python": "^6.1.6", @@ -3999,9 +3999,10 @@ "integrity": "sha512-VQ2MBenTq1fWZUH9DJNGti7kKv6EeAuYr3cLwxUWhIu1baTaXh4Ib5W2CqHVqib4/MqbYGJqiL3Zb8GJZr3l4g==" }, "node_modules/cross-spawn": { - "version": "7.0.3", - "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.3.tgz", - "integrity": "sha512-iRDPJKUPVEND7dHPO8rkbOnPpyDygcDFtWjpeWNCgy8WP2rXcxXL8TskReQl6OrB2G7+UJrags1q15Fudc7G6w==", + "version": "7.0.6", + "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.6.tgz", + "integrity": "sha512-uV2QOWP2nWzsy2aMp8aRibhi9dlzF5Hgh5SHaB9OiTGEyDTiJJyx0uy51QXdyWbtAHNua4XJzUKca3OzKUd3vA==", + "license": "MIT", "dependencies": { "path-key": "^3.1.0", "shebang-command": "^2.0.0", diff --git a/package.json b/package.json index b4164a67b..47e232773 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "open-webui", - "version": "0.4.0.dev1", + "version": "0.4.1", "private": true, "scripts": { "dev": "npm run pyodide:fetch && vite dev --host", diff --git a/pyproject.toml b/pyproject.toml index fa16381f2..9a1c2bb03 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ dependencies = [ "fastapi==0.111.0", "uvicorn[standard]==0.30.6", "pydantic==2.9.2", - "python-multipart==0.0.9", + "python-multipart==0.0.17", "Flask==3.0.3", "Flask-Cors==5.0.0", @@ -28,12 +28,13 @@ dependencies = [ "peewee==3.17.6", "peewee-migrate==1.12.2", "psycopg2-binary==2.9.9", + "pgvector==0.3.5", "PyMySQL==1.1.1", "bcrypt==4.2.0", "pymongo", "redis", - "boto3==1.35.0", + "boto3==1.35.53", "argon2-cffi==23.1.0", "APScheduler==3.10.4", @@ -43,13 +44,14 @@ dependencies = [ "google-generativeai==0.7.2", "tiktoken", - "langchain==0.3.0", - "langchain-community==0.2.12", + "langchain==0.3.7", + "langchain-community==0.3.7", "langchain-chroma==0.1.4", "fake-useragent==1.5.1", - "chromadb==0.5.9", - "pymilvus==2.4.7", + "chromadb==0.5.15", + "pymilvus==2.4.9", + "qdrant-client~=1.12.0", "opensearch-py==2.7.1", "sentence-transformers==3.2.0", @@ -86,18 +88,20 @@ dependencies = [ "black==24.8.0", "langfuse==2.44.0", - "youtube-transcript-api==0.6.2", + "youtube-transcript-api==0.6.3", "pytube==15.0.0", "extract_msg", "pydub", - "duckduckgo-search~=6.2.13", + "duckduckgo-search~=6.3.5", "docker~=7.1.0", "pytest~=8.3.2", "pytest-docker~=3.1.1", - "googleapis-common-protos==1.63.2" + "googleapis-common-protos==1.63.2", + + "ldap3==2.9.1" ] readme = "README.md" requires-python = ">= 3.11, < 3.12.0a1" diff --git a/src/lib/apis/index.ts b/src/lib/apis/index.ts index d22923670..699980a5e 100644 --- a/src/lib/apis/index.ts +++ b/src/lib/apis/index.ts @@ -348,15 +348,16 @@ export const generateEmoji = async ( return null; }; -export const generateSearchQuery = async ( +export const generateQueries = async ( token: string = '', model: string, messages: object[], - prompt: string + prompt: string, + type?: string = 'web_search' ) => { let error = null; - const res = await fetch(`${WEBUI_BASE_URL}/api/task/query/completions`, { + const res = await fetch(`${WEBUI_BASE_URL}/api/task/queries/completions`, { method: 'POST', headers: { Accept: 'application/json', @@ -366,7 +367,8 @@ export const generateSearchQuery = async ( body: JSON.stringify({ model: model, messages: messages, - prompt: prompt + prompt: prompt, + type: type }) }) .then(async (res) => { @@ -385,7 +387,39 @@ export const generateSearchQuery = async ( throw error; } - return res?.choices[0]?.message?.content.replace(/["']/g, '') ?? prompt; + try { + // Step 1: Safely extract the response string + const response = res?.choices[0]?.message?.content ?? ''; + + // Step 2: Attempt to fix common JSON format issues like single quotes + const sanitizedResponse = response.replace(/['‘’`]/g, '"'); // Convert single quotes to double quotes for valid JSON + + // Step 3: Find the relevant JSON block within the response + const jsonStartIndex = sanitizedResponse.indexOf('{'); + const jsonEndIndex = sanitizedResponse.lastIndexOf('}'); + + // Step 4: Check if we found a valid JSON block (with both `{` and `}`) + if (jsonStartIndex !== -1 && jsonEndIndex !== -1) { + const jsonResponse = sanitizedResponse.substring(jsonStartIndex, jsonEndIndex + 1); + + // Step 5: Parse the JSON block + const parsed = JSON.parse(jsonResponse); + + // Step 6: If there's a "queries" key, return the queries array; otherwise, return an empty array + if (parsed && parsed.queries) { + return Array.isArray(parsed.queries) ? parsed.queries : []; + } else { + return []; + } + } + + // If no valid JSON block found, return an empty array + return []; + } catch (e) { + // Catch and safely return empty array on any parsing errors + console.error('Failed to parse response: ', e); + return []; + } }; export const generateMoACompletion = async ( diff --git a/src/lib/apis/models/index.ts b/src/lib/apis/models/index.ts index 4aaf651c6..5880874bb 100644 --- a/src/lib/apis/models/index.ts +++ b/src/lib/apis/models/index.ts @@ -3,7 +3,7 @@ import { WEBUI_API_BASE_URL } from '$lib/constants'; export const getModels = async (token: string = '') => { let error = null; - const res = await fetch(`${WEBUI_API_BASE_URL}/models`, { + const res = await fetch(`${WEBUI_API_BASE_URL}/models/`, { method: 'GET', headers: { Accept: 'application/json', @@ -97,7 +97,7 @@ export const getModelById = async (token: string, id: string) => { const searchParams = new URLSearchParams(); searchParams.append('id', id); - const res = await fetch(`${WEBUI_API_BASE_URL}/models/id/${id}`, { + const res = await fetch(`${WEBUI_API_BASE_URL}/models/model?${searchParams.toString()}`, { method: 'GET', headers: { Accept: 'application/json', @@ -132,7 +132,7 @@ export const toggleModelById = async (token: string, id: string) => { const searchParams = new URLSearchParams(); searchParams.append('id', id); - const res = await fetch(`${WEBUI_API_BASE_URL}/models/id/${id}/toggle`, { + const res = await fetch(`${WEBUI_API_BASE_URL}/models/model/toggle?${searchParams.toString()}`, { method: 'POST', headers: { Accept: 'application/json', @@ -167,7 +167,7 @@ export const updateModelById = async (token: string, id: string, model: object) const searchParams = new URLSearchParams(); searchParams.append('id', id); - const res = await fetch(`${WEBUI_API_BASE_URL}/models/id/${id}/update`, { + const res = await fetch(`${WEBUI_API_BASE_URL}/models/model/update?${searchParams.toString()}`, { method: 'POST', headers: { Accept: 'application/json', @@ -203,7 +203,39 @@ export const deleteModelById = async (token: string, id: string) => { const searchParams = new URLSearchParams(); searchParams.append('id', id); - const res = await fetch(`${WEBUI_API_BASE_URL}/models/id/${id}/delete`, { + const res = await fetch(`${WEBUI_API_BASE_URL}/models/model/delete?${searchParams.toString()}`, { + method: 'DELETE', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err; + + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const deleteAllModels = async (token: string) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/models/delete/all`, { method: 'DELETE', headers: { Accept: 'application/json', diff --git a/src/lib/components/ChangelogModal.svelte b/src/lib/components/ChangelogModal.svelte index 9002b9100..b395ddcbd 100644 --- a/src/lib/components/ChangelogModal.svelte +++ b/src/lib/components/ChangelogModal.svelte @@ -22,7 +22,7 @@ }); - +
@@ -59,7 +59,7 @@
-
+
{#if changelog} {#each Object.keys(changelog) as version} @@ -111,7 +111,7 @@ await updateUserSettings(localStorage.token, { ui: $settings }); show = false; }} - class=" px-4 py-2 bg-emerald-700 hover:bg-emerald-800 text-gray-100 transition rounded-lg" + class="px-3.5 py-1.5 text-sm font-medium bg-black hover:bg-gray-900 text-white dark:bg-white dark:text-black dark:hover:bg-gray-100 transition rounded-full" > {$i18n.t("Okay, Let's Go!")} diff --git a/src/lib/components/admin/Evaluations.svelte b/src/lib/components/admin/Evaluations.svelte index f76430f26..a5532ae2f 100644 --- a/src/lib/components/admin/Evaluations.svelte +++ b/src/lib/components/admin/Evaluations.svelte @@ -31,7 +31,7 @@ {#if loaded} -
+
-
+
{$i18n.t('Made by OpenWebUI Community')}
diff --git a/src/lib/components/admin/Settings.svelte b/src/lib/components/admin/Settings.svelte index c9c198c41..f0886ea5c 100644 --- a/src/lib/components/admin/Settings.svelte +++ b/src/lib/components/admin/Settings.svelte @@ -38,7 +38,7 @@ }); -
+
-
- {$i18n.t('Manage Ollama')} +
+
+ {$i18n.t('Manage Ollama')} +
+ +
+ + + +