diff --git a/CHANGELOG.md b/CHANGELOG.md index a61d81f46..13712232f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,25 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.5.15] - 2025-02-20 + +### Added + +- **📄 Full Context Mode for Local Document Search (RAG)**: Toggle full context mode from Admin Settings > Documents to inject entire document content into context, improving accuracy for models with large context windows—ideal for deep context understanding. +- **🌍 Smarter Web Search with Agentic Workflows**: Web searches now intelligently gather and refine multiple relevant terms, similar to RAG handling, delivering significantly better search results for more accurate information retrieval. +- **🔎 Experimental Playwright Support for Web Loader**: Web content retrieval is taken to the next level with Playwright-powered scraping for enhanced accuracy in extracted web data. +- **☁️ Experimental Azure Storage Provider**: Early-stage support for Azure Storage allows more cloud storage flexibility directly within Open WebUI. +- **📊 Improved Jupyter Code Execution with Plots**: Interactive coding now properly displays inline plots, making data visualization more seamless inside chat interactions. +- **⏳ Adjustable Execution Timeout for Jupyter Interpreter**: Customize execution timeout (default: 60s) for Jupyter-based code execution, allowing longer or more constrained execution based on your needs. +- **▶️ "Running..." Indicator for Jupyter Code Execution**: A visual indicator now appears while code execution is in progress, providing real-time status updates on ongoing computations. +- **⚙️ General Backend & Frontend Stability Enhancements**: Extensive refactoring improves reliability, performance, and overall user experience for a more seamless Open WebUI. +- **🌍 Translation Updates**: Various international translation refinements ensure better localization and a more natural user interface experience. + +### Fixed + +- **📱 Mobile Hover Issue Resolved**: Users can now edit responses smoothly on mobile without interference, fixing a longstanding hover issue. +- **🔄 Temporary Chat Message Duplication Fixed**: Eliminated buggy behavior where messages were being unnecessarily repeated in temporary chat mode, ensuring a smooth and consistent conversation flow. + ## [0.5.14] - 2025-02-17 ### Fixed diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index 3e08dbb72..325ba486d 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -684,6 +684,10 @@ GOOGLE_APPLICATION_CREDENTIALS_JSON = os.environ.get( "GOOGLE_APPLICATION_CREDENTIALS_JSON", None ) +AZURE_STORAGE_ENDPOINT = os.environ.get("AZURE_STORAGE_ENDPOINT", None) +AZURE_STORAGE_CONTAINER_NAME = os.environ.get("AZURE_STORAGE_CONTAINER_NAME", None) +AZURE_STORAGE_KEY = os.environ.get("AZURE_STORAGE_KEY", None) + #################################### # File Upload DIR #################################### @@ -783,6 +787,9 @@ ENABLE_OPENAI_API = PersistentConfig( OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "") OPENAI_API_BASE_URL = os.environ.get("OPENAI_API_BASE_URL", "") +GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY", "") +GEMINI_API_BASE_URL = os.environ.get("GEMINI_API_BASE_URL", "") + if OPENAI_API_BASE_URL == "": OPENAI_API_BASE_URL = "https://api.openai.com/v1" @@ -1395,6 +1402,11 @@ CODE_EXECUTION_JUPYTER_AUTH_PASSWORD = PersistentConfig( os.environ.get("CODE_EXECUTION_JUPYTER_AUTH_PASSWORD", ""), ) +CODE_EXECUTION_JUPYTER_TIMEOUT = PersistentConfig( + "CODE_EXECUTION_JUPYTER_TIMEOUT", + "code_execution.jupyter.timeout", + int(os.environ.get("CODE_EXECUTION_JUPYTER_TIMEOUT", "60")), +) ENABLE_CODE_INTERPRETER = PersistentConfig( "ENABLE_CODE_INTERPRETER", @@ -1450,6 +1462,17 @@ CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD = PersistentConfig( ), ) +CODE_INTERPRETER_JUPYTER_TIMEOUT = PersistentConfig( + "CODE_INTERPRETER_JUPYTER_TIMEOUT", + "code_interpreter.jupyter.timeout", + int( + os.environ.get( + "CODE_INTERPRETER_JUPYTER_TIMEOUT", + os.environ.get("CODE_EXECUTION_JUPYTER_TIMEOUT", "60"), + ) + ), +) + DEFAULT_CODE_INTERPRETER_PROMPT = """ #### Tools Available @@ -1571,6 +1594,12 @@ ENABLE_RAG_HYBRID_SEARCH = PersistentConfig( os.environ.get("ENABLE_RAG_HYBRID_SEARCH", "").lower() == "true", ) +RAG_FULL_CONTEXT = PersistentConfig( + "RAG_FULL_CONTEXT", + "rag.full_context", + os.getenv("RAG_FULL_CONTEXT", "False").lower() == "true", +) + RAG_FILE_MAX_COUNT = PersistentConfig( "RAG_FILE_MAX_COUNT", "rag.file.max_count", @@ -1919,12 +1948,36 @@ RAG_WEB_SEARCH_CONCURRENT_REQUESTS = PersistentConfig( int(os.getenv("RAG_WEB_SEARCH_CONCURRENT_REQUESTS", "10")), ) +RAG_WEB_LOADER_ENGINE = PersistentConfig( + "RAG_WEB_LOADER_ENGINE", + "rag.web.loader.engine", + os.environ.get("RAG_WEB_LOADER_ENGINE", "safe_web"), +) + RAG_WEB_SEARCH_TRUST_ENV = PersistentConfig( "RAG_WEB_SEARCH_TRUST_ENV", "rag.web.search.trust_env", os.getenv("RAG_WEB_SEARCH_TRUST_ENV", False), ) +PLAYWRIGHT_WS_URI = PersistentConfig( + "PLAYWRIGHT_WS_URI", + "rag.web.loader.engine.playwright.ws.uri", + os.environ.get("PLAYWRIGHT_WS_URI", None), +) + +FIRECRAWL_API_KEY = PersistentConfig( + "FIRECRAWL_API_KEY", + "firecrawl.api_key", + os.environ.get("FIRECRAWL_API_KEY", ""), +) + +FIRECRAWL_API_BASE_URL = PersistentConfig( + "FIRECRAWL_API_BASE_URL", + "firecrawl.api_url", + os.environ.get("FIRECRAWL_API_BASE_URL", "https://api.firecrawl.dev"), +) + #################################### # Images #################################### @@ -2135,6 +2188,17 @@ IMAGES_OPENAI_API_KEY = PersistentConfig( os.getenv("IMAGES_OPENAI_API_KEY", OPENAI_API_KEY), ) +IMAGES_GEMINI_API_BASE_URL = PersistentConfig( + "IMAGES_GEMINI_API_BASE_URL", + "image_generation.gemini.api_base_url", + os.getenv("IMAGES_GEMINI_API_BASE_URL", GEMINI_API_BASE_URL), +) +IMAGES_GEMINI_API_KEY = PersistentConfig( + "IMAGES_GEMINI_API_KEY", + "image_generation.gemini.api_key", + os.getenv("IMAGES_GEMINI_API_KEY", GEMINI_API_KEY), +) + IMAGE_SIZE = PersistentConfig( "IMAGE_SIZE", "image_generation.size", os.getenv("IMAGE_SIZE", "512x512") ) diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index dd0c2bf9f..346d28d6c 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -106,6 +106,7 @@ from open_webui.config import ( CODE_EXECUTION_JUPYTER_AUTH, CODE_EXECUTION_JUPYTER_AUTH_TOKEN, CODE_EXECUTION_JUPYTER_AUTH_PASSWORD, + CODE_EXECUTION_JUPYTER_TIMEOUT, ENABLE_CODE_INTERPRETER, CODE_INTERPRETER_ENGINE, CODE_INTERPRETER_PROMPT_TEMPLATE, @@ -113,6 +114,7 @@ from open_webui.config import ( CODE_INTERPRETER_JUPYTER_AUTH, CODE_INTERPRETER_JUPYTER_AUTH_TOKEN, CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD, + CODE_INTERPRETER_JUPYTER_TIMEOUT, # Image AUTOMATIC1111_API_AUTH, AUTOMATIC1111_BASE_URL, @@ -131,6 +133,8 @@ from open_webui.config import ( IMAGE_STEPS, IMAGES_OPENAI_API_BASE_URL, IMAGES_OPENAI_API_KEY, + IMAGES_GEMINI_API_BASE_URL, + IMAGES_GEMINI_API_KEY, # Audio AUDIO_STT_ENGINE, AUDIO_STT_MODEL, @@ -145,6 +149,10 @@ from open_webui.config import ( AUDIO_TTS_VOICE, AUDIO_TTS_AZURE_SPEECH_REGION, AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT, + PLAYWRIGHT_WS_URI, + FIRECRAWL_API_BASE_URL, + FIRECRAWL_API_KEY, + RAG_WEB_LOADER_ENGINE, WHISPER_MODEL, DEEPGRAM_API_KEY, WHISPER_MODEL_AUTO_UPDATE, @@ -152,6 +160,7 @@ from open_webui.config import ( # Retrieval RAG_TEMPLATE, DEFAULT_RAG_TEMPLATE, + RAG_FULL_CONTEXT, RAG_EMBEDDING_MODEL, RAG_EMBEDDING_MODEL_AUTO_UPDATE, RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE, @@ -515,6 +524,8 @@ app.state.config.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD app.state.config.FILE_MAX_SIZE = RAG_FILE_MAX_SIZE app.state.config.FILE_MAX_COUNT = RAG_FILE_MAX_COUNT + +app.state.config.RAG_FULL_CONTEXT = RAG_FULL_CONTEXT app.state.config.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = ( ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION @@ -576,7 +587,11 @@ app.state.config.EXA_API_KEY = EXA_API_KEY app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = RAG_WEB_SEARCH_RESULT_COUNT app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS +app.state.config.RAG_WEB_LOADER_ENGINE = RAG_WEB_LOADER_ENGINE app.state.config.RAG_WEB_SEARCH_TRUST_ENV = RAG_WEB_SEARCH_TRUST_ENV +app.state.config.PLAYWRIGHT_WS_URI = PLAYWRIGHT_WS_URI +app.state.config.FIRECRAWL_API_BASE_URL = FIRECRAWL_API_BASE_URL +app.state.config.FIRECRAWL_API_KEY = FIRECRAWL_API_KEY app.state.EMBEDDING_FUNCTION = None app.state.ef = None @@ -631,6 +646,7 @@ app.state.config.CODE_EXECUTION_JUPYTER_AUTH_TOKEN = CODE_EXECUTION_JUPYTER_AUTH app.state.config.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD = ( CODE_EXECUTION_JUPYTER_AUTH_PASSWORD ) +app.state.config.CODE_EXECUTION_JUPYTER_TIMEOUT = CODE_EXECUTION_JUPYTER_TIMEOUT app.state.config.ENABLE_CODE_INTERPRETER = ENABLE_CODE_INTERPRETER app.state.config.CODE_INTERPRETER_ENGINE = CODE_INTERPRETER_ENGINE @@ -644,6 +660,7 @@ app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN = ( app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD = ( CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD ) +app.state.config.CODE_INTERPRETER_JUPYTER_TIMEOUT = CODE_INTERPRETER_JUPYTER_TIMEOUT ######################################## # @@ -658,6 +675,9 @@ app.state.config.ENABLE_IMAGE_PROMPT_GENERATION = ENABLE_IMAGE_PROMPT_GENERATION app.state.config.IMAGES_OPENAI_API_BASE_URL = IMAGES_OPENAI_API_BASE_URL app.state.config.IMAGES_OPENAI_API_KEY = IMAGES_OPENAI_API_KEY +app.state.config.IMAGES_GEMINI_API_BASE_URL = IMAGES_GEMINI_API_BASE_URL +app.state.config.IMAGES_GEMINI_API_KEY = IMAGES_GEMINI_API_KEY + app.state.config.IMAGE_GENERATION_MODEL = IMAGE_GENERATION_MODEL app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL @@ -967,7 +987,7 @@ async def chat_completion( "files": form_data.get("files", None), "features": form_data.get("features", None), "variables": form_data.get("variables", None), - "model": model_info, + "model": model_info.model_dump() if model_info else model, "direct": model_item.get("direct", False), **( {"function_calling": "native"} diff --git a/backend/open_webui/retrieval/utils.py b/backend/open_webui/retrieval/utils.py index 437183369..5f181fba0 100644 --- a/backend/open_webui/retrieval/utils.py +++ b/backend/open_webui/retrieval/utils.py @@ -84,6 +84,19 @@ def query_doc( raise e +def get_doc(collection_name: str, user: UserModel = None): + try: + result = VECTOR_DB_CLIENT.get(collection_name=collection_name) + + if result: + log.info(f"query_doc:result {result.ids} {result.metadatas}") + + return result + except Exception as e: + print(e) + raise e + + def query_doc_with_hybrid_search( collection_name: str, query: str, @@ -137,6 +150,27 @@ def query_doc_with_hybrid_search( raise e +def merge_get_results(get_results: list[dict]) -> dict: + # Initialize lists to store combined data + combined_documents = [] + combined_metadatas = [] + combined_ids = [] + + for data in get_results: + combined_documents.extend(data["documents"][0]) + combined_metadatas.extend(data["metadatas"][0]) + combined_ids.extend(data["ids"][0]) + + # Create the output dictionary + result = { + "documents": [combined_documents], + "metadatas": [combined_metadatas], + "ids": [combined_ids], + } + + return result + + def merge_and_sort_query_results( query_results: list[dict], k: int, reverse: bool = False ) -> list[dict]: @@ -144,31 +178,45 @@ def merge_and_sort_query_results( combined_distances = [] combined_documents = [] combined_metadatas = [] + combined_ids = [] for data in query_results: combined_distances.extend(data["distances"][0]) combined_documents.extend(data["documents"][0]) combined_metadatas.extend(data["metadatas"][0]) + # DISTINCT(chunk_id,file_id) - in case if id (chunk_ids) become ordinals + combined_ids.extend( + [ + f"{id}-{meta['file_id']}" + for id, meta in zip(data["ids"][0], data["metadatas"][0]) + ] + ) - # Create a list of tuples (distance, document, metadata) - combined = list(zip(combined_distances, combined_documents, combined_metadatas)) + # Create a list of tuples (distance, document, metadata, ids) + combined = list( + zip(combined_distances, combined_documents, combined_metadatas, combined_ids) + ) # Sort the list based on distances combined.sort(key=lambda x: x[0], reverse=reverse) - # We don't have anything :-( - if not combined: - sorted_distances = [] - sorted_documents = [] - sorted_metadatas = [] - else: + sorted_distances = [] + sorted_documents = [] + sorted_metadatas = [] + # Otherwise we don't have anything :-( + if combined: # Unzip the sorted list - sorted_distances, sorted_documents, sorted_metadatas = zip(*combined) - + all_distances, all_documents, all_metadatas, all_ids = zip(*combined) + seen_ids = set() # Slicing the lists to include only k elements - sorted_distances = list(sorted_distances)[:k] - sorted_documents = list(sorted_documents)[:k] - sorted_metadatas = list(sorted_metadatas)[:k] + for index, id in enumerate(all_ids): + if id not in seen_ids: + sorted_distances.append(all_distances[index]) + sorted_documents.append(all_documents[index]) + sorted_metadatas.append(all_metadatas[index]) + seen_ids.add(id) + if len(sorted_distances) >= k: + break # Create the output dictionary result = { @@ -180,6 +228,23 @@ def merge_and_sort_query_results( return result +def get_all_items_from_collections(collection_names: list[str]) -> dict: + results = [] + + for collection_name in collection_names: + if collection_name: + try: + result = get_doc(collection_name=collection_name) + if result is not None: + results.append(result.model_dump()) + except Exception as e: + log.exception(f"Error when querying the collection: {e}") + else: + pass + + return merge_get_results(results) + + def query_collection( collection_names: list[str], queries: list[str], @@ -297,8 +362,11 @@ def get_sources_from_files( reranking_function, r, hybrid_search, + full_context=False, ): - log.debug(f"files: {files} {queries} {embedding_function} {reranking_function}") + log.debug( + f"files: {files} {queries} {embedding_function} {reranking_function} {full_context}" + ) extracted_collections = [] relevant_contexts = [] @@ -336,36 +404,43 @@ def get_sources_from_files( log.debug(f"skipping {file} as it has already been extracted") continue - try: - context = None - if file.get("type") == "text": - context = file["content"] - else: - if hybrid_search: - try: - context = query_collection_with_hybrid_search( + if full_context: + try: + context = get_all_items_from_collections(collection_names) + except Exception as e: + log.exception(e) + + else: + try: + context = None + if file.get("type") == "text": + context = file["content"] + else: + if hybrid_search: + try: + context = query_collection_with_hybrid_search( + collection_names=collection_names, + queries=queries, + embedding_function=embedding_function, + k=k, + reranking_function=reranking_function, + r=r, + ) + except Exception as e: + log.debug( + "Error when using hybrid search, using" + " non hybrid search as fallback." + ) + + if (not hybrid_search) or (context is None): + context = query_collection( collection_names=collection_names, queries=queries, embedding_function=embedding_function, k=k, - reranking_function=reranking_function, - r=r, ) - except Exception as e: - log.debug( - "Error when using hybrid search, using" - " non hybrid search as fallback." - ) - - if (not hybrid_search) or (context is None): - context = query_collection( - collection_names=collection_names, - queries=queries, - embedding_function=embedding_function, - k=k, - ) - except Exception as e: - log.exception(e) + except Exception as e: + log.exception(e) extracted_collections.extend(collection_names) diff --git a/backend/open_webui/retrieval/web/utils.py b/backend/open_webui/retrieval/web/utils.py index 145f1adbc..fd94a1a32 100644 --- a/backend/open_webui/retrieval/web/utils.py +++ b/backend/open_webui/retrieval/web/utils.py @@ -1,22 +1,38 @@ -import socket -import aiohttp import asyncio -import urllib.parse -import validators -from typing import Any, AsyncIterator, Dict, Iterator, List, Sequence, Union - - -from langchain_community.document_loaders import ( - WebBaseLoader, -) -from langchain_core.documents import Document - - -from open_webui.constants import ERROR_MESSAGES -from open_webui.config import ENABLE_RAG_LOCAL_WEB_FETCH -from open_webui.env import SRC_LOG_LEVELS - import logging +import socket +import ssl +import urllib.parse +import urllib.request +from collections import defaultdict +from datetime import datetime, time, timedelta +from typing import ( + Any, + AsyncIterator, + Dict, + Iterator, + List, + Optional, + Sequence, + Union, + Literal, +) +import aiohttp +import certifi +import validators +from langchain_community.document_loaders import PlaywrightURLLoader, WebBaseLoader +from langchain_community.document_loaders.firecrawl import FireCrawlLoader +from langchain_community.document_loaders.base import BaseLoader +from langchain_core.documents import Document +from open_webui.constants import ERROR_MESSAGES +from open_webui.config import ( + ENABLE_RAG_LOCAL_WEB_FETCH, + PLAYWRIGHT_WS_URI, + RAG_WEB_LOADER_ENGINE, + FIRECRAWL_API_BASE_URL, + FIRECRAWL_API_KEY, +) +from open_webui.env import SRC_LOG_LEVELS log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["RAG"]) @@ -68,6 +84,314 @@ def resolve_hostname(hostname): return ipv4_addresses, ipv6_addresses +def extract_metadata(soup, url): + metadata = {"source": url} + if title := soup.find("title"): + metadata["title"] = title.get_text() + if description := soup.find("meta", attrs={"name": "description"}): + metadata["description"] = description.get("content", "No description found.") + if html := soup.find("html"): + metadata["language"] = html.get("lang", "No language found.") + return metadata + + +def verify_ssl_cert(url: str) -> bool: + """Verify SSL certificate for the given URL.""" + if not url.startswith("https://"): + return True + + try: + hostname = url.split("://")[-1].split("/")[0] + context = ssl.create_default_context(cafile=certifi.where()) + with context.wrap_socket(ssl.socket(), server_hostname=hostname) as s: + s.connect((hostname, 443)) + return True + except ssl.SSLError: + return False + except Exception as e: + log.warning(f"SSL verification failed for {url}: {str(e)}") + return False + + +class SafeFireCrawlLoader(BaseLoader): + def __init__( + self, + web_paths, + verify_ssl: bool = True, + trust_env: bool = False, + requests_per_second: Optional[float] = None, + continue_on_failure: bool = True, + api_key: Optional[str] = None, + api_url: Optional[str] = None, + mode: Literal["crawl", "scrape", "map"] = "crawl", + proxy: Optional[Dict[str, str]] = None, + params: Optional[Dict] = None, + ): + """Concurrent document loader for FireCrawl operations. + + Executes multiple FireCrawlLoader instances concurrently using thread pooling + to improve bulk processing efficiency. + Args: + web_paths: List of URLs/paths to process. + verify_ssl: If True, verify SSL certificates. + trust_env: If True, use proxy settings from environment variables. + requests_per_second: Number of requests per second to limit to. + continue_on_failure (bool): If True, continue loading other URLs on failure. + api_key: API key for FireCrawl service. Defaults to None + (uses FIRE_CRAWL_API_KEY environment variable if not provided). + api_url: Base URL for FireCrawl API. Defaults to official API endpoint. + mode: Operation mode selection: + - 'crawl': Website crawling mode (default) + - 'scrape': Direct page scraping + - 'map': Site map generation + proxy: Proxy override settings for the FireCrawl API. + params: The parameters to pass to the Firecrawl API. + Examples include crawlerOptions. + For more details, visit: https://github.com/mendableai/firecrawl-py + """ + proxy_server = proxy.get("server") if proxy else None + if trust_env and not proxy_server: + env_proxies = urllib.request.getproxies() + env_proxy_server = env_proxies.get("https") or env_proxies.get("http") + if env_proxy_server: + if proxy: + proxy["server"] = env_proxy_server + else: + proxy = {"server": env_proxy_server} + self.web_paths = web_paths + self.verify_ssl = verify_ssl + self.requests_per_second = requests_per_second + self.last_request_time = None + self.trust_env = trust_env + self.continue_on_failure = continue_on_failure + self.api_key = api_key + self.api_url = api_url + self.mode = mode + self.params = params + + def lazy_load(self) -> Iterator[Document]: + """Load documents concurrently using FireCrawl.""" + for url in self.web_paths: + try: + self._safe_process_url_sync(url) + loader = FireCrawlLoader( + url=url, + api_key=self.api_key, + api_url=self.api_url, + mode=self.mode, + params=self.params, + ) + yield from loader.lazy_load() + except Exception as e: + if self.continue_on_failure: + log.exception(e, "Error loading %s", url) + continue + raise e + + async def alazy_load(self): + """Async version of lazy_load.""" + for url in self.web_paths: + try: + await self._safe_process_url(url) + loader = FireCrawlLoader( + url=url, + api_key=self.api_key, + api_url=self.api_url, + mode=self.mode, + params=self.params, + ) + async for document in loader.alazy_load(): + yield document + except Exception as e: + if self.continue_on_failure: + log.exception(e, "Error loading %s", url) + continue + raise e + + def _verify_ssl_cert(self, url: str) -> bool: + return verify_ssl_cert(url) + + async def _wait_for_rate_limit(self): + """Wait to respect the rate limit if specified.""" + if self.requests_per_second and self.last_request_time: + min_interval = timedelta(seconds=1.0 / self.requests_per_second) + time_since_last = datetime.now() - self.last_request_time + if time_since_last < min_interval: + await asyncio.sleep((min_interval - time_since_last).total_seconds()) + self.last_request_time = datetime.now() + + def _sync_wait_for_rate_limit(self): + """Synchronous version of rate limit wait.""" + if self.requests_per_second and self.last_request_time: + min_interval = timedelta(seconds=1.0 / self.requests_per_second) + time_since_last = datetime.now() - self.last_request_time + if time_since_last < min_interval: + time.sleep((min_interval - time_since_last).total_seconds()) + self.last_request_time = datetime.now() + + async def _safe_process_url(self, url: str) -> bool: + """Perform safety checks before processing a URL.""" + if self.verify_ssl and not self._verify_ssl_cert(url): + raise ValueError(f"SSL certificate verification failed for {url}") + await self._wait_for_rate_limit() + return True + + def _safe_process_url_sync(self, url: str) -> bool: + """Synchronous version of safety checks.""" + if self.verify_ssl and not self._verify_ssl_cert(url): + raise ValueError(f"SSL certificate verification failed for {url}") + self._sync_wait_for_rate_limit() + return True + + +class SafePlaywrightURLLoader(PlaywrightURLLoader): + """Load HTML pages safely with Playwright, supporting SSL verification, rate limiting, and remote browser connection. + + Attributes: + web_paths (List[str]): List of URLs to load. + verify_ssl (bool): If True, verify SSL certificates. + trust_env (bool): If True, use proxy settings from environment variables. + requests_per_second (Optional[float]): Number of requests per second to limit to. + continue_on_failure (bool): If True, continue loading other URLs on failure. + headless (bool): If True, the browser will run in headless mode. + proxy (dict): Proxy override settings for the Playwright session. + playwright_ws_url (Optional[str]): WebSocket endpoint URI for remote browser connection. + """ + + def __init__( + self, + web_paths: List[str], + verify_ssl: bool = True, + trust_env: bool = False, + requests_per_second: Optional[float] = None, + continue_on_failure: bool = True, + headless: bool = True, + remove_selectors: Optional[List[str]] = None, + proxy: Optional[Dict[str, str]] = None, + playwright_ws_url: Optional[str] = None, + ): + """Initialize with additional safety parameters and remote browser support.""" + + proxy_server = proxy.get("server") if proxy else None + if trust_env and not proxy_server: + env_proxies = urllib.request.getproxies() + env_proxy_server = env_proxies.get("https") or env_proxies.get("http") + if env_proxy_server: + if proxy: + proxy["server"] = env_proxy_server + else: + proxy = {"server": env_proxy_server} + + # We'll set headless to False if using playwright_ws_url since it's handled by the remote browser + super().__init__( + urls=web_paths, + continue_on_failure=continue_on_failure, + headless=headless if playwright_ws_url is None else False, + remove_selectors=remove_selectors, + proxy=proxy, + ) + self.verify_ssl = verify_ssl + self.requests_per_second = requests_per_second + self.last_request_time = None + self.playwright_ws_url = playwright_ws_url + self.trust_env = trust_env + + def lazy_load(self) -> Iterator[Document]: + """Safely load URLs synchronously with support for remote browser.""" + from playwright.sync_api import sync_playwright + + with sync_playwright() as p: + # Use remote browser if ws_endpoint is provided, otherwise use local browser + if self.playwright_ws_url: + browser = p.chromium.connect(self.playwright_ws_url) + else: + browser = p.chromium.launch(headless=self.headless, proxy=self.proxy) + + for url in self.urls: + try: + self._safe_process_url_sync(url) + page = browser.new_page() + response = page.goto(url) + if response is None: + raise ValueError(f"page.goto() returned None for url {url}") + + text = self.evaluator.evaluate(page, browser, response) + metadata = {"source": url} + yield Document(page_content=text, metadata=metadata) + except Exception as e: + if self.continue_on_failure: + log.exception(e, "Error loading %s", url) + continue + raise e + browser.close() + + async def alazy_load(self) -> AsyncIterator[Document]: + """Safely load URLs asynchronously with support for remote browser.""" + from playwright.async_api import async_playwright + + async with async_playwright() as p: + # Use remote browser if ws_endpoint is provided, otherwise use local browser + if self.playwright_ws_url: + browser = await p.chromium.connect(self.playwright_ws_url) + else: + browser = await p.chromium.launch( + headless=self.headless, proxy=self.proxy + ) + + for url in self.urls: + try: + await self._safe_process_url(url) + page = await browser.new_page() + response = await page.goto(url) + if response is None: + raise ValueError(f"page.goto() returned None for url {url}") + + text = await self.evaluator.evaluate_async(page, browser, response) + metadata = {"source": url} + yield Document(page_content=text, metadata=metadata) + except Exception as e: + if self.continue_on_failure: + log.exception(e, "Error loading %s", url) + continue + raise e + await browser.close() + + def _verify_ssl_cert(self, url: str) -> bool: + return verify_ssl_cert(url) + + async def _wait_for_rate_limit(self): + """Wait to respect the rate limit if specified.""" + if self.requests_per_second and self.last_request_time: + min_interval = timedelta(seconds=1.0 / self.requests_per_second) + time_since_last = datetime.now() - self.last_request_time + if time_since_last < min_interval: + await asyncio.sleep((min_interval - time_since_last).total_seconds()) + self.last_request_time = datetime.now() + + def _sync_wait_for_rate_limit(self): + """Synchronous version of rate limit wait.""" + if self.requests_per_second and self.last_request_time: + min_interval = timedelta(seconds=1.0 / self.requests_per_second) + time_since_last = datetime.now() - self.last_request_time + if time_since_last < min_interval: + time.sleep((min_interval - time_since_last).total_seconds()) + self.last_request_time = datetime.now() + + async def _safe_process_url(self, url: str) -> bool: + """Perform safety checks before processing a URL.""" + if self.verify_ssl and not self._verify_ssl_cert(url): + raise ValueError(f"SSL certificate verification failed for {url}") + await self._wait_for_rate_limit() + return True + + def _safe_process_url_sync(self, url: str) -> bool: + """Synchronous version of safety checks.""" + if self.verify_ssl and not self._verify_ssl_cert(url): + raise ValueError(f"SSL certificate verification failed for {url}") + self._sync_wait_for_rate_limit() + return True + + class SafeWebBaseLoader(WebBaseLoader): """WebBaseLoader with enhanced error handling for URLs.""" @@ -143,20 +467,12 @@ class SafeWebBaseLoader(WebBaseLoader): text = soup.get_text(**self.bs_get_text_kwargs) # Build metadata - metadata = {"source": path} - if title := soup.find("title"): - metadata["title"] = title.get_text() - if description := soup.find("meta", attrs={"name": "description"}): - metadata["description"] = description.get( - "content", "No description found." - ) - if html := soup.find("html"): - metadata["language"] = html.get("lang", "No language found.") + metadata = extract_metadata(soup, path) yield Document(page_content=text, metadata=metadata) except Exception as e: # Log the error and continue with the next URL - log.error(f"Error loading {path}: {e}") + log.exception(e, "Error loading %s", path) async def alazy_load(self) -> AsyncIterator[Document]: """Async lazy load text from the url(s) in web_path.""" @@ -179,6 +495,12 @@ class SafeWebBaseLoader(WebBaseLoader): return [document async for document in self.alazy_load()] +RAG_WEB_LOADER_ENGINES = defaultdict(lambda: SafeWebBaseLoader) +RAG_WEB_LOADER_ENGINES["playwright"] = SafePlaywrightURLLoader +RAG_WEB_LOADER_ENGINES["safe_web"] = SafeWebBaseLoader +RAG_WEB_LOADER_ENGINES["firecrawl"] = SafeFireCrawlLoader + + def get_web_loader( urls: Union[str, Sequence[str]], verify_ssl: bool = True, @@ -188,10 +510,29 @@ def get_web_loader( # Check if the URLs are valid safe_urls = safe_validate_urls([urls] if isinstance(urls, str) else urls) - return SafeWebBaseLoader( - web_path=safe_urls, - verify_ssl=verify_ssl, - requests_per_second=requests_per_second, - continue_on_failure=True, - trust_env=trust_env, + web_loader_args = { + "web_paths": safe_urls, + "verify_ssl": verify_ssl, + "requests_per_second": requests_per_second, + "continue_on_failure": True, + "trust_env": trust_env, + } + + if PLAYWRIGHT_WS_URI.value: + web_loader_args["playwright_ws_url"] = PLAYWRIGHT_WS_URI.value + + if RAG_WEB_LOADER_ENGINE.value == "firecrawl": + web_loader_args["api_key"] = FIRECRAWL_API_KEY.value + web_loader_args["api_url"] = FIRECRAWL_API_BASE_URL.value + + # Create the appropriate WebLoader based on the configuration + WebLoaderClass = RAG_WEB_LOADER_ENGINES[RAG_WEB_LOADER_ENGINE.value] + web_loader = WebLoaderClass(**web_loader_args) + + log.debug( + "Using RAG_WEB_LOADER_ENGINE %s for %s URLs", + web_loader.__class__.__name__, + len(safe_urls), ) + + return web_loader diff --git a/backend/open_webui/routers/audio.py b/backend/open_webui/routers/audio.py index e2d05ba90..a970366d1 100644 --- a/backend/open_webui/routers/audio.py +++ b/backend/open_webui/routers/audio.py @@ -37,6 +37,7 @@ from open_webui.config import ( from open_webui.constants import ERROR_MESSAGES from open_webui.env import ( + AIOHTTP_CLIENT_TIMEOUT, ENV, SRC_LOG_LEVELS, DEVICE_TYPE, @@ -266,7 +267,10 @@ async def speech(request: Request, user=Depends(get_verified_user)): try: # print(payload) - async with aiohttp.ClientSession() as session: + timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT) + async with aiohttp.ClientSession( + timeout=timeout, trust_env=True + ) as session: async with session.post( url=f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech", json=payload, @@ -323,7 +327,10 @@ async def speech(request: Request, user=Depends(get_verified_user)): ) try: - async with aiohttp.ClientSession() as session: + timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT) + async with aiohttp.ClientSession( + timeout=timeout, trust_env=True + ) as session: async with session.post( f"https://api.elevenlabs.io/v1/text-to-speech/{voice_id}", json={ @@ -380,7 +387,10 @@ async def speech(request: Request, user=Depends(get_verified_user)): data = f""" {payload["input"]} """ - async with aiohttp.ClientSession() as session: + timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT) + async with aiohttp.ClientSession( + timeout=timeout, trust_env=True + ) as session: async with session.post( f"https://{region}.tts.speech.microsoft.com/cognitiveservices/v1", headers={ diff --git a/backend/open_webui/routers/auths.py b/backend/open_webui/routers/auths.py index a3f2e8b32..3fa2ffe2e 100644 --- a/backend/open_webui/routers/auths.py +++ b/backend/open_webui/routers/auths.py @@ -546,7 +546,8 @@ async def signout(request: Request, response: Response): if logout_url: response.delete_cookie("oauth_id_token") return RedirectResponse( - url=f"{logout_url}?id_token_hint={oauth_id_token}" + headers=response.headers, + url=f"{logout_url}?id_token_hint={oauth_id_token}", ) else: raise HTTPException( diff --git a/backend/open_webui/routers/configs.py b/backend/open_webui/routers/configs.py index d460ae670..388c44f9c 100644 --- a/backend/open_webui/routers/configs.py +++ b/backend/open_webui/routers/configs.py @@ -75,6 +75,7 @@ class CodeInterpreterConfigForm(BaseModel): CODE_EXECUTION_JUPYTER_AUTH: Optional[str] CODE_EXECUTION_JUPYTER_AUTH_TOKEN: Optional[str] CODE_EXECUTION_JUPYTER_AUTH_PASSWORD: Optional[str] + CODE_EXECUTION_JUPYTER_TIMEOUT: Optional[int] ENABLE_CODE_INTERPRETER: bool CODE_INTERPRETER_ENGINE: str CODE_INTERPRETER_PROMPT_TEMPLATE: Optional[str] @@ -82,6 +83,7 @@ class CodeInterpreterConfigForm(BaseModel): CODE_INTERPRETER_JUPYTER_AUTH: Optional[str] CODE_INTERPRETER_JUPYTER_AUTH_TOKEN: Optional[str] CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD: Optional[str] + CODE_INTERPRETER_JUPYTER_TIMEOUT: Optional[int] @router.get("/code_execution", response_model=CodeInterpreterConfigForm) @@ -92,6 +94,7 @@ async def get_code_execution_config(request: Request, user=Depends(get_admin_use "CODE_EXECUTION_JUPYTER_AUTH": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH, "CODE_EXECUTION_JUPYTER_AUTH_TOKEN": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_TOKEN, "CODE_EXECUTION_JUPYTER_AUTH_PASSWORD": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD, + "CODE_EXECUTION_JUPYTER_TIMEOUT": request.app.state.config.CODE_EXECUTION_JUPYTER_TIMEOUT, "ENABLE_CODE_INTERPRETER": request.app.state.config.ENABLE_CODE_INTERPRETER, "CODE_INTERPRETER_ENGINE": request.app.state.config.CODE_INTERPRETER_ENGINE, "CODE_INTERPRETER_PROMPT_TEMPLATE": request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE, @@ -99,6 +102,7 @@ async def get_code_execution_config(request: Request, user=Depends(get_admin_use "CODE_INTERPRETER_JUPYTER_AUTH": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH, "CODE_INTERPRETER_JUPYTER_AUTH_TOKEN": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN, "CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD, + "CODE_INTERPRETER_JUPYTER_TIMEOUT": request.app.state.config.CODE_INTERPRETER_JUPYTER_TIMEOUT, } @@ -120,6 +124,9 @@ async def set_code_execution_config( request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD = ( form_data.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD ) + request.app.state.config.CODE_EXECUTION_JUPYTER_TIMEOUT = ( + form_data.CODE_EXECUTION_JUPYTER_TIMEOUT + ) request.app.state.config.ENABLE_CODE_INTERPRETER = form_data.ENABLE_CODE_INTERPRETER request.app.state.config.CODE_INTERPRETER_ENGINE = form_data.CODE_INTERPRETER_ENGINE @@ -141,6 +148,9 @@ async def set_code_execution_config( request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD = ( form_data.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD ) + request.app.state.config.CODE_INTERPRETER_JUPYTER_TIMEOUT = ( + form_data.CODE_INTERPRETER_JUPYTER_TIMEOUT + ) return { "CODE_EXECUTION_ENGINE": request.app.state.config.CODE_EXECUTION_ENGINE, @@ -148,6 +158,7 @@ async def set_code_execution_config( "CODE_EXECUTION_JUPYTER_AUTH": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH, "CODE_EXECUTION_JUPYTER_AUTH_TOKEN": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_TOKEN, "CODE_EXECUTION_JUPYTER_AUTH_PASSWORD": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD, + "CODE_EXECUTION_JUPYTER_TIMEOUT": request.app.state.config.CODE_EXECUTION_JUPYTER_TIMEOUT, "ENABLE_CODE_INTERPRETER": request.app.state.config.ENABLE_CODE_INTERPRETER, "CODE_INTERPRETER_ENGINE": request.app.state.config.CODE_INTERPRETER_ENGINE, "CODE_INTERPRETER_PROMPT_TEMPLATE": request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE, @@ -155,6 +166,7 @@ async def set_code_execution_config( "CODE_INTERPRETER_JUPYTER_AUTH": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH, "CODE_INTERPRETER_JUPYTER_AUTH_TOKEN": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN, "CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD, + "CODE_INTERPRETER_JUPYTER_TIMEOUT": request.app.state.config.CODE_INTERPRETER_JUPYTER_TIMEOUT, } diff --git a/backend/open_webui/routers/files.py b/backend/open_webui/routers/files.py index 051321257..504baa60d 100644 --- a/backend/open_webui/routers/files.py +++ b/backend/open_webui/routers/files.py @@ -225,17 +225,24 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)): filename = file.meta.get("name", file.filename) encoded_filename = quote(filename) # RFC5987 encoding + content_type = file.meta.get("content_type") + filename = file.meta.get("name", file.filename) + encoded_filename = quote(filename) headers = {} - if file.meta.get("content_type") not in [ - "application/pdf", - "text/plain", - ]: - headers = { - **headers, - "Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}", - } - return FileResponse(file_path, headers=headers) + if content_type == "application/pdf" or filename.lower().endswith( + ".pdf" + ): + headers["Content-Disposition"] = ( + f"inline; filename*=UTF-8''{encoded_filename}" + ) + content_type = "application/pdf" + elif content_type != "text/plain": + headers["Content-Disposition"] = ( + f"attachment; filename*=UTF-8''{encoded_filename}" + ) + + return FileResponse(file_path, headers=headers, media_type=content_type) else: raise HTTPException( diff --git a/backend/open_webui/routers/images.py b/backend/open_webui/routers/images.py index 4046773de..3288ec6d8 100644 --- a/backend/open_webui/routers/images.py +++ b/backend/open_webui/routers/images.py @@ -55,6 +55,10 @@ async def get_config(request: Request, user=Depends(get_admin_user)): "COMFYUI_WORKFLOW": request.app.state.config.COMFYUI_WORKFLOW, "COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES, }, + "gemini": { + "GEMINI_API_BASE_URL": request.app.state.config.IMAGES_GEMINI_API_BASE_URL, + "GEMINI_API_KEY": request.app.state.config.IMAGES_GEMINI_API_KEY, + }, } @@ -78,6 +82,11 @@ class ComfyUIConfigForm(BaseModel): COMFYUI_WORKFLOW_NODES: list[dict] +class GeminiConfigForm(BaseModel): + GEMINI_API_BASE_URL: str + GEMINI_API_KEY: str + + class ConfigForm(BaseModel): enabled: bool engine: str @@ -85,6 +94,7 @@ class ConfigForm(BaseModel): openai: OpenAIConfigForm automatic1111: Automatic1111ConfigForm comfyui: ComfyUIConfigForm + gemini: GeminiConfigForm @router.post("/config/update") @@ -103,6 +113,11 @@ async def update_config( ) request.app.state.config.IMAGES_OPENAI_API_KEY = form_data.openai.OPENAI_API_KEY + request.app.state.config.IMAGES_GEMINI_API_BASE_URL = ( + form_data.gemini.GEMINI_API_BASE_URL + ) + request.app.state.config.IMAGES_GEMINI_API_KEY = form_data.gemini.GEMINI_API_KEY + request.app.state.config.AUTOMATIC1111_BASE_URL = ( form_data.automatic1111.AUTOMATIC1111_BASE_URL ) @@ -155,6 +170,10 @@ async def update_config( "COMFYUI_WORKFLOW": request.app.state.config.COMFYUI_WORKFLOW, "COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES, }, + "gemini": { + "GEMINI_API_BASE_URL": request.app.state.config.IMAGES_GEMINI_API_BASE_URL, + "GEMINI_API_KEY": request.app.state.config.IMAGES_GEMINI_API_KEY, + }, } @@ -224,6 +243,12 @@ def get_image_model(request): if request.app.state.config.IMAGE_GENERATION_MODEL else "dall-e-2" ) + elif request.app.state.config.IMAGE_GENERATION_ENGINE == "gemini": + return ( + request.app.state.config.IMAGE_GENERATION_MODEL + if request.app.state.config.IMAGE_GENERATION_MODEL + else "imagen-3.0-generate-002" + ) elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui": return ( request.app.state.config.IMAGE_GENERATION_MODEL @@ -299,6 +324,10 @@ def get_models(request: Request, user=Depends(get_verified_user)): {"id": "dall-e-2", "name": "DALL·E 2"}, {"id": "dall-e-3", "name": "DALL·E 3"}, ] + elif request.app.state.config.IMAGE_GENERATION_ENGINE == "gemini": + return [ + {"id": "imagen-3-0-generate-002", "name": "imagen-3.0 generate-002"}, + ] elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui": # TODO - get models from comfyui headers = { @@ -483,6 +512,41 @@ async def image_generations( images.append({"url": url}) return images + elif request.app.state.config.IMAGE_GENERATION_ENGINE == "gemini": + headers = {} + headers["Content-Type"] = "application/json" + headers["x-goog-api-key"] = request.app.state.config.IMAGES_GEMINI_API_KEY + + model = get_image_model(request) + data = { + "instances": {"prompt": form_data.prompt}, + "parameters": { + "sampleCount": form_data.n, + "outputOptions": {"mimeType": "image/png"}, + }, + } + + # Use asyncio.to_thread for the requests.post call + r = await asyncio.to_thread( + requests.post, + url=f"{request.app.state.config.IMAGES_GEMINI_API_BASE_URL}/models/{model}:predict", + json=data, + headers=headers, + ) + + r.raise_for_status() + res = r.json() + + images = [] + for image in res["predictions"]: + image_data, content_type = load_b64_image_data( + image["bytesBase64Encoded"] + ) + url = upload_image(request, data, image_data, content_type, user) + images.append({"url": url}) + + return images + elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui": data = { "prompt": form_data.prompt, diff --git a/backend/open_webui/routers/ollama.py b/backend/open_webui/routers/ollama.py index 4fca10e1f..732dd36f9 100644 --- a/backend/open_webui/routers/ollama.py +++ b/backend/open_webui/routers/ollama.py @@ -26,7 +26,7 @@ from fastapi import ( ) from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, validator from starlette.background import BackgroundTask @@ -936,10 +936,23 @@ async def generate_completion( class ChatMessage(BaseModel): role: str - content: str + content: Optional[str] = None tool_calls: Optional[list[dict]] = None images: Optional[list[str]] = None + @validator("content", pre=True) + @classmethod + def check_at_least_one_field(cls, field_value, values, **kwargs): + # Raise an error if both 'content' and 'tool_calls' are None + if field_value is None and ( + "tool_calls" not in values or values["tool_calls"] is None + ): + raise ValueError( + "At least one of 'content' or 'tool_calls' must be provided" + ) + + return field_value + class GenerateChatCompletionForm(BaseModel): model: str diff --git a/backend/open_webui/routers/retrieval.py b/backend/open_webui/routers/retrieval.py index 71eec6a68..e69d2ce96 100644 --- a/backend/open_webui/routers/retrieval.py +++ b/backend/open_webui/routers/retrieval.py @@ -351,6 +351,7 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)): return { "status": True, "pdf_extract_images": request.app.state.config.PDF_EXTRACT_IMAGES, + "RAG_FULL_CONTEXT": request.app.state.config.RAG_FULL_CONTEXT, "enable_google_drive_integration": request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION, "content_extraction": { "engine": request.app.state.config.CONTENT_EXTRACTION_ENGINE, @@ -463,6 +464,7 @@ class WebConfig(BaseModel): class ConfigUpdateForm(BaseModel): + RAG_FULL_CONTEXT: Optional[bool] = None pdf_extract_images: Optional[bool] = None enable_google_drive_integration: Optional[bool] = None file: Optional[FileConfig] = None @@ -482,6 +484,12 @@ async def update_rag_config( else request.app.state.config.PDF_EXTRACT_IMAGES ) + request.app.state.config.RAG_FULL_CONTEXT = ( + form_data.RAG_FULL_CONTEXT + if form_data.RAG_FULL_CONTEXT is not None + else request.app.state.config.RAG_FULL_CONTEXT + ) + request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION = ( form_data.enable_google_drive_integration if form_data.enable_google_drive_integration is not None @@ -588,6 +596,7 @@ async def update_rag_config( return { "status": True, "pdf_extract_images": request.app.state.config.PDF_EXTRACT_IMAGES, + "RAG_FULL_CONTEXT": request.app.state.config.RAG_FULL_CONTEXT, "file": { "max_size": request.app.state.config.FILE_MAX_SIZE, "max_count": request.app.state.config.FILE_MAX_COUNT, diff --git a/backend/open_webui/routers/utils.py b/backend/open_webui/routers/utils.py index 61863bda5..fb1dc8272 100644 --- a/backend/open_webui/routers/utils.py +++ b/backend/open_webui/routers/utils.py @@ -56,6 +56,7 @@ async def execute_code( if request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH == "password" else None ), + request.app.state.config.CODE_EXECUTION_JUPYTER_TIMEOUT, ) return output diff --git a/backend/open_webui/storage/provider.py b/backend/open_webui/storage/provider.py index b03cf0a7e..160a45153 100644 --- a/backend/open_webui/storage/provider.py +++ b/backend/open_webui/storage/provider.py @@ -15,12 +15,18 @@ from open_webui.config import ( S3_SECRET_ACCESS_KEY, GCS_BUCKET_NAME, GOOGLE_APPLICATION_CREDENTIALS_JSON, + AZURE_STORAGE_ENDPOINT, + AZURE_STORAGE_CONTAINER_NAME, + AZURE_STORAGE_KEY, STORAGE_PROVIDER, UPLOAD_DIR, ) from google.cloud import storage from google.cloud.exceptions import GoogleCloudError, NotFound from open_webui.constants import ERROR_MESSAGES +from azure.identity import DefaultAzureCredential +from azure.storage.blob import BlobServiceClient +from azure.core.exceptions import ResourceNotFoundError class StorageProvider(ABC): @@ -221,6 +227,74 @@ class GCSStorageProvider(StorageProvider): LocalStorageProvider.delete_all_files() +class AzureStorageProvider(StorageProvider): + def __init__(self): + self.endpoint = AZURE_STORAGE_ENDPOINT + self.container_name = AZURE_STORAGE_CONTAINER_NAME + storage_key = AZURE_STORAGE_KEY + + if storage_key: + # Configure using the Azure Storage Account Endpoint and Key + self.blob_service_client = BlobServiceClient( + account_url=self.endpoint, credential=storage_key + ) + else: + # Configure using the Azure Storage Account Endpoint and DefaultAzureCredential + # If the key is not configured, then the DefaultAzureCredential will be used to support Managed Identity authentication + self.blob_service_client = BlobServiceClient( + account_url=self.endpoint, credential=DefaultAzureCredential() + ) + self.container_client = self.blob_service_client.get_container_client( + self.container_name + ) + + def upload_file(self, file: BinaryIO, filename: str) -> Tuple[bytes, str]: + """Handles uploading of the file to Azure Blob Storage.""" + contents, file_path = LocalStorageProvider.upload_file(file, filename) + try: + blob_client = self.container_client.get_blob_client(filename) + blob_client.upload_blob(contents, overwrite=True) + return contents, f"{self.endpoint}/{self.container_name}/{filename}" + except Exception as e: + raise RuntimeError(f"Error uploading file to Azure Blob Storage: {e}") + + def get_file(self, file_path: str) -> str: + """Handles downloading of the file from Azure Blob Storage.""" + try: + filename = file_path.split("/")[-1] + local_file_path = f"{UPLOAD_DIR}/{filename}" + blob_client = self.container_client.get_blob_client(filename) + with open(local_file_path, "wb") as download_file: + download_file.write(blob_client.download_blob().readall()) + return local_file_path + except ResourceNotFoundError as e: + raise RuntimeError(f"Error downloading file from Azure Blob Storage: {e}") + + def delete_file(self, file_path: str) -> None: + """Handles deletion of the file from Azure Blob Storage.""" + try: + filename = file_path.split("/")[-1] + blob_client = self.container_client.get_blob_client(filename) + blob_client.delete_blob() + except ResourceNotFoundError as e: + raise RuntimeError(f"Error deleting file from Azure Blob Storage: {e}") + + # Always delete from local storage + LocalStorageProvider.delete_file(file_path) + + def delete_all_files(self) -> None: + """Handles deletion of all files from Azure Blob Storage.""" + try: + blobs = self.container_client.list_blobs() + for blob in blobs: + self.container_client.delete_blob(blob.name) + except Exception as e: + raise RuntimeError(f"Error deleting all files from Azure Blob Storage: {e}") + + # Always delete from local storage + LocalStorageProvider.delete_all_files() + + def get_storage_provider(storage_provider: str): if storage_provider == "local": Storage = LocalStorageProvider() @@ -228,6 +302,8 @@ def get_storage_provider(storage_provider: str): Storage = S3StorageProvider() elif storage_provider == "gcs": Storage = GCSStorageProvider() + elif storage_provider == "azure": + Storage = AzureStorageProvider() else: raise RuntimeError(f"Unsupported storage provider: {storage_provider}") return Storage diff --git a/backend/open_webui/test/apps/webui/storage/test_provider.py b/backend/open_webui/test/apps/webui/storage/test_provider.py index 863106e75..a5ef13504 100644 --- a/backend/open_webui/test/apps/webui/storage/test_provider.py +++ b/backend/open_webui/test/apps/webui/storage/test_provider.py @@ -7,6 +7,8 @@ from moto import mock_aws from open_webui.storage import provider from gcp_storage_emulator.server import create_server from google.cloud import storage +from azure.storage.blob import BlobServiceClient, ContainerClient, BlobClient +from unittest.mock import MagicMock def mock_upload_dir(monkeypatch, tmp_path): @@ -22,6 +24,7 @@ def test_imports(): provider.LocalStorageProvider provider.S3StorageProvider provider.GCSStorageProvider + provider.AzureStorageProvider provider.Storage @@ -32,6 +35,8 @@ def test_get_storage_provider(): assert isinstance(Storage, provider.S3StorageProvider) Storage = provider.get_storage_provider("gcs") assert isinstance(Storage, provider.GCSStorageProvider) + Storage = provider.get_storage_provider("azure") + assert isinstance(Storage, provider.AzureStorageProvider) with pytest.raises(RuntimeError): provider.get_storage_provider("invalid") @@ -48,6 +53,7 @@ def test_class_instantiation(): provider.LocalStorageProvider() provider.S3StorageProvider() provider.GCSStorageProvider() + provider.AzureStorageProvider() class TestLocalStorageProvider: @@ -272,3 +278,147 @@ class TestGCSStorageProvider: assert not (upload_dir / self.filename_extra).exists() assert self.Storage.bucket.get_blob(self.filename) == None assert self.Storage.bucket.get_blob(self.filename_extra) == None + + +class TestAzureStorageProvider: + def __init__(self): + super().__init__() + + @pytest.fixture(scope="class") + def setup_storage(self, monkeypatch): + # Create mock Blob Service Client and related clients + mock_blob_service_client = MagicMock() + mock_container_client = MagicMock() + mock_blob_client = MagicMock() + + # Set up return values for the mock + mock_blob_service_client.get_container_client.return_value = ( + mock_container_client + ) + mock_container_client.get_blob_client.return_value = mock_blob_client + + # Monkeypatch the Azure classes to return our mocks + monkeypatch.setattr( + azure.storage.blob, + "BlobServiceClient", + lambda *args, **kwargs: mock_blob_service_client, + ) + monkeypatch.setattr( + azure.storage.blob, + "ContainerClient", + lambda *args, **kwargs: mock_container_client, + ) + monkeypatch.setattr( + azure.storage.blob, "BlobClient", lambda *args, **kwargs: mock_blob_client + ) + + self.Storage = provider.AzureStorageProvider() + self.Storage.endpoint = "https://myaccount.blob.core.windows.net" + self.Storage.container_name = "my-container" + self.file_content = b"test content" + self.filename = "test.txt" + self.filename_extra = "test_extra.txt" + self.file_bytesio_empty = io.BytesIO() + + # Apply mocks to the Storage instance + self.Storage.blob_service_client = mock_blob_service_client + self.Storage.container_client = mock_container_client + + def test_upload_file(self, monkeypatch, tmp_path): + upload_dir = mock_upload_dir(monkeypatch, tmp_path) + + # Simulate an error when container does not exist + self.Storage.container_client.get_blob_client.side_effect = Exception( + "Container does not exist" + ) + with pytest.raises(Exception): + self.Storage.upload_file(io.BytesIO(self.file_content), self.filename) + + # Reset side effect and create container + self.Storage.container_client.get_blob_client.side_effect = None + self.Storage.create_container() + contents, azure_file_path = self.Storage.upload_file( + io.BytesIO(self.file_content), self.filename + ) + + # Assertions + self.Storage.container_client.get_blob_client.assert_called_with(self.filename) + self.Storage.container_client.get_blob_client().upload_blob.assert_called_once_with( + self.file_content, overwrite=True + ) + assert contents == self.file_content + assert ( + azure_file_path + == f"https://myaccount.blob.core.windows.net/{self.Storage.container_name}/{self.filename}" + ) + assert (upload_dir / self.filename).exists() + assert (upload_dir / self.filename).read_bytes() == self.file_content + + with pytest.raises(ValueError): + self.Storage.upload_file(self.file_bytesio_empty, self.filename) + + def test_get_file(self, monkeypatch, tmp_path): + upload_dir = mock_upload_dir(monkeypatch, tmp_path) + self.Storage.create_container() + + # Mock upload behavior + self.Storage.upload_file(io.BytesIO(self.file_content), self.filename) + # Mock blob download behavior + self.Storage.container_client.get_blob_client().download_blob().readall.return_value = ( + self.file_content + ) + + file_url = f"https://myaccount.blob.core.windows.net/{self.Storage.container_name}/{self.filename}" + file_path = self.Storage.get_file(file_url) + + assert file_path == str(upload_dir / self.filename) + assert (upload_dir / self.filename).exists() + assert (upload_dir / self.filename).read_bytes() == self.file_content + + def test_delete_file(self, monkeypatch, tmp_path): + upload_dir = mock_upload_dir(monkeypatch, tmp_path) + self.Storage.create_container() + + # Mock file upload + self.Storage.upload_file(io.BytesIO(self.file_content), self.filename) + # Mock deletion + self.Storage.container_client.get_blob_client().delete_blob.return_value = None + + file_url = f"https://myaccount.blob.core.windows.net/{self.Storage.container_name}/{self.filename}" + self.Storage.delete_file(file_url) + + self.Storage.container_client.get_blob_client().delete_blob.assert_called_once() + assert not (upload_dir / self.filename).exists() + + def test_delete_all_files(self, monkeypatch, tmp_path): + upload_dir = mock_upload_dir(monkeypatch, tmp_path) + self.Storage.create_container() + + # Mock file uploads + self.Storage.upload_file(io.BytesIO(self.file_content), self.filename) + self.Storage.upload_file(io.BytesIO(self.file_content), self.filename_extra) + + # Mock listing and deletion behavior + self.Storage.container_client.list_blobs.return_value = [ + {"name": self.filename}, + {"name": self.filename_extra}, + ] + self.Storage.container_client.get_blob_client().delete_blob.return_value = None + + self.Storage.delete_all_files() + + self.Storage.container_client.list_blobs.assert_called_once() + self.Storage.container_client.get_blob_client().delete_blob.assert_any_call() + assert not (upload_dir / self.filename).exists() + assert not (upload_dir / self.filename_extra).exists() + + def test_get_file_not_found(self, monkeypatch): + self.Storage.create_container() + + file_url = f"https://myaccount.blob.core.windows.net/{self.Storage.container_name}/{self.filename}" + # Mock behavior to raise an error for missing blobs + self.Storage.container_client.get_blob_client().download_blob.side_effect = ( + Exception("Blob not found") + ) + with pytest.raises(Exception, match="Blob not found"): + self.Storage.get_file(file_url) diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index 93edc8f72..7ec764fc0 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -321,89 +321,94 @@ async def chat_web_search_handler( ) return form_data - searchQuery = queries[0] + all_results = [] - await event_emitter( - { - "type": "status", - "data": { - "action": "web_search", - "description": 'Searching "{{searchQuery}}"', - "query": searchQuery, - "done": False, - }, - } - ) - - try: - - results = await process_web_search( - request, - SearchForm( - **{ + for searchQuery in queries: + await event_emitter( + { + "type": "status", + "data": { + "action": "web_search", + "description": 'Searching "{{searchQuery}}"', "query": searchQuery, - } - ), - user, + "done": False, + }, + } ) - if results: - await event_emitter( - { - "type": "status", - "data": { - "action": "web_search", - "description": "Searched {{count}} sites", + try: + results = await process_web_search( + request, + SearchForm( + **{ "query": searchQuery, - "urls": results["filenames"], - "done": True, - }, - } + } + ), + user=user, ) - files = form_data.get("files", []) + if results: + all_results.append(results) + files = form_data.get("files", []) - if request.app.state.config.RAG_WEB_SEARCH_FULL_CONTEXT: - files.append( - { - "docs": results.get("docs", []), - "name": searchQuery, - "type": "web_search_docs", - "urls": results["filenames"], - } - ) - else: - files.append( - { - "collection_name": results["collection_name"], - "name": searchQuery, - "type": "web_search_results", - "urls": results["filenames"], - } - ) - form_data["files"] = files - else: + if request.app.state.config.RAG_WEB_SEARCH_FULL_CONTEXT: + files.append( + { + "docs": results.get("docs", []), + "name": searchQuery, + "type": "web_search_docs", + "urls": results["filenames"], + } + ) + else: + files.append( + { + "collection_name": results["collection_name"], + "name": searchQuery, + "type": "web_search_results", + "urls": results["filenames"], + } + ) + form_data["files"] = files + except Exception as e: + log.exception(e) await event_emitter( { "type": "status", "data": { "action": "web_search", - "description": "No search results found", + "description": 'Error searching "{{searchQuery}}"', "query": searchQuery, "done": True, "error": True, }, } ) - except Exception as e: - log.exception(e) + + if all_results: + urls = [] + for results in all_results: + if "filenames" in results: + urls.extend(results["filenames"]) + await event_emitter( { "type": "status", "data": { "action": "web_search", - "description": 'Error searching "{{searchQuery}}"', - "query": searchQuery, + "description": "Searched {{count}} sites", + "urls": urls, + "done": True, + }, + } + ) + else: + await event_emitter( + { + "type": "status", + "data": { + "action": "web_search", + "description": "No search results found", "done": True, "error": True, }, @@ -560,9 +565,9 @@ async def chat_completion_files_handler( reranking_function=request.app.state.rf, r=request.app.state.config.RELEVANCE_THRESHOLD, hybrid_search=request.app.state.config.ENABLE_RAG_HYBRID_SEARCH, + full_context=request.app.state.config.RAG_FULL_CONTEXT, ), ) - except Exception as e: log.exception(e) @@ -1359,7 +1364,15 @@ async def process_chat_response( tool_calls = [] - last_assistant_message = get_last_assistant_message(form_data["messages"]) + last_assistant_message = None + try: + if form_data["messages"][-1]["role"] == "assistant": + last_assistant_message = get_last_assistant_message( + form_data["messages"] + ) + except Exception as e: + pass + content = ( message.get("content", "") if message @@ -1748,6 +1761,7 @@ async def process_chat_response( == "password" else None ), + request.app.state.config.CODE_INTERPRETER_JUPYTER_TIMEOUT, ) else: output = { diff --git a/backend/open_webui/utils/oauth.py b/backend/open_webui/utils/oauth.py index a635853d6..13835e784 100644 --- a/backend/open_webui/utils/oauth.py +++ b/backend/open_webui/utils/oauth.py @@ -140,7 +140,14 @@ class OAuthManager: log.debug("Running OAUTH Group management") oauth_claim = auth_manager_config.OAUTH_GROUPS_CLAIM - user_oauth_groups: list[str] = user_data.get(oauth_claim, list()) + # Nested claim search for groups claim + if oauth_claim: + claim_data = user_data + nested_claims = oauth_claim.split(".") + for nested_claim in nested_claims: + claim_data = claim_data.get(nested_claim, {}) + user_oauth_groups = claim_data if isinstance(claim_data, list) else None + user_current_groups: list[GroupModel] = Groups.get_groups_by_member_id(user.id) all_available_groups: list[GroupModel] = Groups.get_groups() @@ -239,11 +246,46 @@ class OAuthManager: raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) provider_sub = f"{provider}@{sub}" email_claim = auth_manager_config.OAUTH_EMAIL_CLAIM - email = user_data.get(email_claim, "").lower() + email = user_data.get(email_claim, "") # We currently mandate that email addresses are provided if not email: - log.warning(f"OAuth callback failed, email is missing: {user_data}") - raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) + # If the provider is GitHub,and public email is not provided, we can use the access token to fetch the user's email + if provider == "github": + try: + access_token = token.get("access_token") + headers = {"Authorization": f"Bearer {access_token}"} + async with aiohttp.ClientSession() as session: + async with session.get( + "https://api.github.com/user/emails", headers=headers + ) as resp: + if resp.ok: + emails = await resp.json() + # use the primary email as the user's email + primary_email = next( + (e["email"] for e in emails if e.get("primary")), + None, + ) + if primary_email: + email = primary_email + else: + log.warning( + "No primary email found in GitHub response" + ) + raise HTTPException( + 400, detail=ERROR_MESSAGES.INVALID_CRED + ) + else: + log.warning("Failed to fetch GitHub email") + raise HTTPException( + 400, detail=ERROR_MESSAGES.INVALID_CRED + ) + except Exception as e: + log.warning(f"Error fetching GitHub email: {e}") + raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) + else: + log.warning(f"OAuth callback failed, email is missing: {user_data}") + raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) + email = email.lower() if ( "*" not in auth_manager_config.OAUTH_ALLOWED_DOMAINS and email.split("@")[-1] not in auth_manager_config.OAUTH_ALLOWED_DOMAINS @@ -285,9 +327,7 @@ class OAuthManager: # If the user does not exist, check if signups are enabled if auth_manager_config.ENABLE_OAUTH_SIGNUP: # Check if an existing user with the same email already exists - existing_user = Users.get_user_by_email( - user_data.get("email", "").lower() - ) + existing_user = Users.get_user_by_email(email) if existing_user: raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN) diff --git a/backend/open_webui/utils/payload.py b/backend/open_webui/utils/payload.py index 5eb040434..51e8d50cc 100644 --- a/backend/open_webui/utils/payload.py +++ b/backend/open_webui/utils/payload.py @@ -4,6 +4,7 @@ from open_webui.utils.misc import ( ) from typing import Callable, Optional +import json # inplace function: form_data is modified @@ -66,38 +67,49 @@ def apply_model_params_to_body_openai(params: dict, form_data: dict) -> dict: def apply_model_params_to_body_ollama(params: dict, form_data: dict) -> dict: - opts = [ - "temperature", - "top_p", - "seed", - "mirostat", - "mirostat_eta", - "mirostat_tau", - "num_ctx", - "num_batch", - "num_keep", - "repeat_last_n", - "tfs_z", - "top_k", - "min_p", - "use_mmap", - "use_mlock", - "num_thread", - "num_gpu", - ] - mappings = {i: lambda x: x for i in opts} - form_data = apply_model_params_to_body(params, form_data, mappings) - + # Convert OpenAI parameter names to Ollama parameter names if needed. name_differences = { "max_tokens": "num_predict", - "frequency_penalty": "repeat_penalty", } for key, value in name_differences.items(): if (param := params.get(key, None)) is not None: - form_data[value] = param + # Copy the parameter to new name then delete it, to prevent Ollama warning of invalid option provided + params[value] = params[key] + del params[key] - return form_data + # See https://github.com/ollama/ollama/blob/main/docs/api.md#request-8 + mappings = { + "temperature": float, + "top_p": float, + "seed": lambda x: x, + "mirostat": int, + "mirostat_eta": float, + "mirostat_tau": float, + "num_ctx": int, + "num_batch": int, + "num_keep": int, + "num_predict": int, + "repeat_last_n": int, + "top_k": int, + "min_p": float, + "typical_p": float, + "repeat_penalty": float, + "presence_penalty": float, + "frequency_penalty": float, + "penalize_newline": bool, + "stop": lambda x: [bytes(s, "utf-8").decode("unicode_escape") for s in x], + "numa": bool, + "num_gpu": int, + "main_gpu": int, + "low_vram": bool, + "vocab_only": bool, + "use_mmap": bool, + "use_mlock": bool, + "num_thread": int, + } + + return apply_model_params_to_body(params, form_data, mappings) def convert_messages_openai_to_ollama(messages: list[dict]) -> list[dict]: @@ -108,11 +120,38 @@ def convert_messages_openai_to_ollama(messages: list[dict]) -> list[dict]: new_message = {"role": message["role"]} content = message.get("content", []) + tool_calls = message.get("tool_calls", None) + tool_call_id = message.get("tool_call_id", None) # Check if the content is a string (just a simple message) if isinstance(content, str): # If the content is a string, it's pure text new_message["content"] = content + + # If message is a tool call, add the tool call id to the message + if tool_call_id: + new_message["tool_call_id"] = tool_call_id + + elif tool_calls: + # If tool calls are present, add them to the message + ollama_tool_calls = [] + for tool_call in tool_calls: + ollama_tool_call = { + "index": tool_call.get("index", 0), + "id": tool_call.get("id", None), + "function": { + "name": tool_call.get("function", {}).get("name", ""), + "arguments": json.loads( + tool_call.get("function", {}).get("arguments", {}) + ), + }, + } + ollama_tool_calls.append(ollama_tool_call) + new_message["tool_calls"] = ollama_tool_calls + + # Put the content to empty string (Ollama requires an empty string for tool calls) + new_message["content"] = "" + else: # Otherwise, assume the content is a list of dicts, e.g., text followed by an image URL content_text = "" @@ -173,34 +212,23 @@ def convert_payload_openai_to_ollama(openai_payload: dict) -> dict: ollama_payload["format"] = openai_payload["format"] # If there are advanced parameters in the payload, format them in Ollama's options field - ollama_options = {} - if openai_payload.get("options"): ollama_payload["options"] = openai_payload["options"] ollama_options = openai_payload["options"] - # Handle parameters which map directly - for param in ["temperature", "top_p", "seed"]: - if param in openai_payload: - ollama_options[param] = openai_payload[param] + # Re-Mapping OpenAI's `max_tokens` -> Ollama's `num_predict` + if "max_tokens" in ollama_options: + ollama_options["num_predict"] = ollama_options["max_tokens"] + del ollama_options[ + "max_tokens" + ] # To prevent Ollama warning of invalid option provided - # Mapping OpenAI's `max_tokens` -> Ollama's `num_predict` - if "max_completion_tokens" in openai_payload: - ollama_options["num_predict"] = openai_payload["max_completion_tokens"] - elif "max_tokens" in openai_payload: - ollama_options["num_predict"] = openai_payload["max_tokens"] - - # Handle frequency / presence_penalty, which needs renaming and checking - if "frequency_penalty" in openai_payload: - ollama_options["repeat_penalty"] = openai_payload["frequency_penalty"] - - if "presence_penalty" in openai_payload and "penalty" not in ollama_options: - # We are assuming presence penalty uses a similar concept in Ollama, which needs custom handling if exists. - ollama_options["new_topic_penalty"] = openai_payload["presence_penalty"] - - # Add options to payload if any have been set - if ollama_options: - ollama_payload["options"] = ollama_options + # Ollama lacks a "system" prompt option. It has to be provided as a direct parameter, so we copy it down. + if "system" in ollama_options: + ollama_payload["system"] = ollama_options["system"] + del ollama_options[ + "system" + ] # To prevent Ollama warning of invalid option provided if "metadata" in openai_payload: ollama_payload["metadata"] = openai_payload["metadata"] diff --git a/backend/open_webui/utils/response.py b/backend/open_webui/utils/response.py index f9979b4a2..bc47e1e13 100644 --- a/backend/open_webui/utils/response.py +++ b/backend/open_webui/utils/response.py @@ -24,17 +24,8 @@ def convert_ollama_tool_call_to_openai(tool_calls: dict) -> dict: return openai_tool_calls -def convert_response_ollama_to_openai(ollama_response: dict) -> dict: - model = ollama_response.get("model", "ollama") - message_content = ollama_response.get("message", {}).get("content", "") - tool_calls = ollama_response.get("message", {}).get("tool_calls", None) - openai_tool_calls = None - - if tool_calls: - openai_tool_calls = convert_ollama_tool_call_to_openai(tool_calls) - - data = ollama_response - usage = { +def convert_ollama_usage_to_openai(data: dict) -> dict: + return { "response_token/s": ( round( ( @@ -66,14 +57,42 @@ def convert_response_ollama_to_openai(ollama_response: dict) -> dict: "total_duration": data.get("total_duration", 0), "load_duration": data.get("load_duration", 0), "prompt_eval_count": data.get("prompt_eval_count", 0), + "prompt_tokens": int( + data.get("prompt_eval_count", 0) + ), # This is the OpenAI compatible key "prompt_eval_duration": data.get("prompt_eval_duration", 0), "eval_count": data.get("eval_count", 0), + "completion_tokens": int( + data.get("eval_count", 0) + ), # This is the OpenAI compatible key "eval_duration": data.get("eval_duration", 0), "approximate_total": (lambda s: f"{s // 3600}h{(s % 3600) // 60}m{s % 60}s")( (data.get("total_duration", 0) or 0) // 1_000_000_000 ), + "total_tokens": int( # This is the OpenAI compatible key + data.get("prompt_eval_count", 0) + data.get("eval_count", 0) + ), + "completion_tokens_details": { # This is the OpenAI compatible key + "reasoning_tokens": 0, + "accepted_prediction_tokens": 0, + "rejected_prediction_tokens": 0, + }, } + +def convert_response_ollama_to_openai(ollama_response: dict) -> dict: + model = ollama_response.get("model", "ollama") + message_content = ollama_response.get("message", {}).get("content", "") + tool_calls = ollama_response.get("message", {}).get("tool_calls", None) + openai_tool_calls = None + + if tool_calls: + openai_tool_calls = convert_ollama_tool_call_to_openai(tool_calls) + + data = ollama_response + + usage = convert_ollama_usage_to_openai(data) + response = openai_chat_completion_message_template( model, message_content, openai_tool_calls, usage ) @@ -96,45 +115,7 @@ async def convert_streaming_response_ollama_to_openai(ollama_streaming_response) usage = None if done: - usage = { - "response_token/s": ( - round( - ( - ( - data.get("eval_count", 0) - / ((data.get("eval_duration", 0) / 10_000_000)) - ) - * 100 - ), - 2, - ) - if data.get("eval_duration", 0) > 0 - else "N/A" - ), - "prompt_token/s": ( - round( - ( - ( - data.get("prompt_eval_count", 0) - / ((data.get("prompt_eval_duration", 0) / 10_000_000)) - ) - * 100 - ), - 2, - ) - if data.get("prompt_eval_duration", 0) > 0 - else "N/A" - ), - "total_duration": data.get("total_duration", 0), - "load_duration": data.get("load_duration", 0), - "prompt_eval_count": data.get("prompt_eval_count", 0), - "prompt_eval_duration": data.get("prompt_eval_duration", 0), - "eval_count": data.get("eval_count", 0), - "eval_duration": data.get("eval_duration", 0), - "approximate_total": ( - lambda s: f"{s // 3600}h{(s % 3600) // 60}m{s % 60}s" - )((data.get("total_duration", 0) or 0) // 1_000_000_000), - } + usage = convert_ollama_usage_to_openai(data) data = openai_chat_chunk_message_template( model, message_content if not done else None, openai_tool_calls, usage diff --git a/backend/requirements.txt b/backend/requirements.txt index 9b859b84a..e4c594e58 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -1,6 +1,6 @@ fastapi==0.115.7 uvicorn[standard]==0.30.6 -pydantic==2.9.2 +pydantic==2.10.6 python-multipart==0.0.18 python-socketio==5.11.3 @@ -45,7 +45,7 @@ chromadb==0.6.2 pymilvus==2.5.0 qdrant-client~=1.12.0 opensearch-py==2.8.0 - +playwright==1.49.1 # Caution: version must match docker-compose.playwright.yaml transformers sentence-transformers==3.3.1 @@ -59,7 +59,7 @@ fpdf2==2.8.2 pymdown-extensions==10.14.2 docx2txt==0.8 python-pptx==1.0.0 -unstructured==0.16.11 +unstructured==0.16.17 nltk==3.9.1 Markdown==3.7 pypandoc==1.13 @@ -103,5 +103,12 @@ pytest-docker~=3.1.1 googleapis-common-protos==1.63.2 google-cloud-storage==2.19.0 +azure-identity==1.20.0 +azure-storage-blob==12.24.1 + + ## LDAP ldap3==2.9.1 + +## Firecrawl +firecrawl-py==1.12.0 diff --git a/backend/start.sh b/backend/start.sh index a945acb62..671c22ff7 100755 --- a/backend/start.sh +++ b/backend/start.sh @@ -3,6 +3,17 @@ SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) cd "$SCRIPT_DIR" || exit +# Add conditional Playwright browser installation +if [[ "${RAG_WEB_LOADER_ENGINE,,}" == "playwright" ]]; then + if [[ -z "${PLAYWRIGHT_WS_URI}" ]]; then + echo "Installing Playwright browsers..." + playwright install chromium + playwright install-deps chromium + fi + + python -c "import nltk; nltk.download('punkt_tab')" +fi + KEY_FILE=.webui_secret_key PORT="${PORT:-8080}" diff --git a/backend/start_windows.bat b/backend/start_windows.bat index 3e8c6b97c..7049cd1b3 100644 --- a/backend/start_windows.bat +++ b/backend/start_windows.bat @@ -6,6 +6,17 @@ SETLOCAL ENABLEDELAYEDEXPANSION SET "SCRIPT_DIR=%~dp0" cd /d "%SCRIPT_DIR%" || exit /b +:: Add conditional Playwright browser installation +IF /I "%RAG_WEB_LOADER_ENGINE%" == "playwright" ( + IF "%PLAYWRIGHT_WS_URI%" == "" ( + echo Installing Playwright browsers... + playwright install chromium + playwright install-deps chromium + ) + + python -c "import nltk; nltk.download('punkt_tab')" +) + SET "KEY_FILE=.webui_secret_key" IF "%PORT%"=="" SET PORT=8080 IF "%HOST%"=="" SET HOST=0.0.0.0 diff --git a/docker-compose.playwright.yaml b/docker-compose.playwright.yaml new file mode 100644 index 000000000..fe570bed0 --- /dev/null +++ b/docker-compose.playwright.yaml @@ -0,0 +1,10 @@ +services: + playwright: + image: mcr.microsoft.com/playwright:v1.49.1-noble # Version must match requirements.txt + container_name: playwright + command: npx -y playwright@1.49.1 run-server --port 3000 --host 0.0.0.0 + + open-webui: + environment: + - 'RAG_WEB_LOADER_ENGINE=playwright' + - 'PLAYWRIGHT_WS_URI=ws://playwright:3000' \ No newline at end of file diff --git a/package-lock.json b/package-lock.json index 3d443d49f..bde18359d 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "open-webui", - "version": "0.5.14", + "version": "0.5.15", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "open-webui", - "version": "0.5.14", + "version": "0.5.15", "dependencies": { "@codemirror/lang-javascript": "^6.2.2", "@codemirror/lang-python": "^6.1.6", @@ -63,6 +63,7 @@ "svelte-sonner": "^0.3.19", "tippy.js": "^6.3.7", "turndown": "^7.2.0", + "undici": "^7.3.0", "uuid": "^9.0.1", "vite-plugin-static-copy": "^2.2.0" }, @@ -11528,6 +11529,15 @@ "node": "*" } }, + "node_modules/undici": { + "version": "7.3.0", + "resolved": "https://registry.npmjs.org/undici/-/undici-7.3.0.tgz", + "integrity": "sha512-Qy96NND4Dou5jKoSJ2gm8ax8AJM/Ey9o9mz7KN1bb9GP+G0l20Zw8afxTnY2f4b7hmhn/z8aC2kfArVQlAhFBw==", + "license": "MIT", + "engines": { + "node": ">=20.18.1" + } + }, "node_modules/undici-types": { "version": "5.26.5", "resolved": "https://registry.npmjs.org/undici-types/-/undici-types-5.26.5.tgz", diff --git a/package.json b/package.json index 70588cfdc..5715aa871 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "open-webui", - "version": "0.5.14", + "version": "0.5.15", "private": true, "scripts": { "dev": "npm run pyodide:fetch && vite dev --host", @@ -106,6 +106,7 @@ "svelte-sonner": "^0.3.19", "tippy.js": "^6.3.7", "turndown": "^7.2.0", + "undici": "^7.3.0", "uuid": "^9.0.1", "vite-plugin-static-copy": "^2.2.0" }, diff --git a/pyproject.toml b/pyproject.toml index dac8bbf78..d1175e605 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ license = { file = "LICENSE" } dependencies = [ "fastapi==0.115.7", "uvicorn[standard]==0.30.6", - "pydantic==2.9.2", + "pydantic==2.10.6", "python-multipart==0.0.18", "python-socketio==5.11.3", @@ -53,6 +53,7 @@ dependencies = [ "pymilvus==2.5.0", "qdrant-client~=1.12.0", "opensearch-py==2.8.0", + "playwright==1.49.1", "transformers", "sentence-transformers==3.3.1", @@ -65,7 +66,7 @@ dependencies = [ "pymdown-extensions==10.14.2", "docx2txt==0.8", "python-pptx==1.0.0", - "unstructured==0.16.11", + "unstructured==0.16.17", "nltk==3.9.1", "Markdown==3.7", "pypandoc==1.13", @@ -108,7 +109,13 @@ dependencies = [ "googleapis-common-protos==1.63.2", "google-cloud-storage==2.19.0", + "azure-identity==1.20.0", + "azure-storage-blob==12.24.1", + "ldap3==2.9.1", + + "firecrawl-py==1.12.0", + "gcp-storage-emulator>=2024.8.3", ] readme = "README.md" diff --git a/run-compose.sh b/run-compose.sh index 21574e959..4fafedc6f 100755 --- a/run-compose.sh +++ b/run-compose.sh @@ -74,6 +74,7 @@ usage() { echo " --enable-api[port=PORT] Enable API and expose it on the specified port." echo " --webui[port=PORT] Set the port for the web user interface." echo " --data[folder=PATH] Bind mount for ollama data folder (by default will create the 'ollama' volume)." + echo " --playwright Enable Playwright support for web scraping." echo " --build Build the docker image before running the compose project." echo " --drop Drop the compose project." echo " -q, --quiet Run script in headless mode." @@ -100,6 +101,7 @@ webui_port=3000 headless=false build_image=false kill_compose=false +enable_playwright=false # Function to extract value from the parameter extract_value() { @@ -129,6 +131,9 @@ while [[ $# -gt 0 ]]; do value=$(extract_value "$key") data_dir=${value:-"./ollama-data"} ;; + --playwright) + enable_playwright=true + ;; --drop) kill_compose=true ;; @@ -182,6 +187,9 @@ else DEFAULT_COMPOSE_COMMAND+=" -f docker-compose.data.yaml" export OLLAMA_DATA_DIR=$data_dir # Set OLLAMA_DATA_DIR environment variable fi + if [[ $enable_playwright == true ]]; then + DEFAULT_COMPOSE_COMMAND+=" -f docker-compose.playwright.yaml" + fi if [[ -n $webui_port ]]; then export OPEN_WEBUI_PORT=$webui_port # Set OPEN_WEBUI_PORT environment variable fi @@ -201,6 +209,7 @@ echo -e " ${GREEN}${BOLD}GPU Count:${NC} ${OLLAMA_GPU_COUNT:-Not Enabled}" echo -e " ${GREEN}${BOLD}WebAPI Port:${NC} ${OLLAMA_WEBAPI_PORT:-Not Enabled}" echo -e " ${GREEN}${BOLD}Data Folder:${NC} ${data_dir:-Using ollama volume}" echo -e " ${GREEN}${BOLD}WebUI Port:${NC} $webui_port" +echo -e " ${GREEN}${BOLD}Playwright:${NC} ${enable_playwright:-false}" echo if [[ $headless == true ]]; then diff --git a/scripts/prepare-pyodide.js b/scripts/prepare-pyodide.js index 71f2a2cb2..70f3cf5c6 100644 --- a/scripts/prepare-pyodide.js +++ b/scripts/prepare-pyodide.js @@ -16,8 +16,39 @@ const packages = [ ]; import { loadPyodide } from 'pyodide'; +import { setGlobalDispatcher, ProxyAgent } from 'undici'; import { writeFile, readFile, copyFile, readdir, rmdir } from 'fs/promises'; +/** + * Loading network proxy configurations from the environment variables. + * And the proxy config with lowercase name has the highest priority to use. + */ +function initNetworkProxyFromEnv() { + // we assume all subsequent requests in this script are HTTPS: + // https://cdn.jsdelivr.net + // https://pypi.org + // https://files.pythonhosted.org + const allProxy = process.env.all_proxy || process.env.ALL_PROXY; + const httpsProxy = process.env.https_proxy || process.env.HTTPS_PROXY; + const httpProxy = process.env.http_proxy || process.env.HTTP_PROXY; + const preferedProxy = httpsProxy || allProxy || httpProxy; + /** + * use only http(s) proxy because socks5 proxy is not supported currently: + * @see https://github.com/nodejs/undici/issues/2224 + */ + if (!preferedProxy || !preferedProxy.startsWith('http')) return; + let preferedProxyURL; + try { + preferedProxyURL = new URL(preferedProxy).toString(); + } catch { + console.warn(`Invalid network proxy URL: "${preferedProxy}"`); + return; + } + const dispatcher = new ProxyAgent({ uri: preferedProxyURL }); + setGlobalDispatcher(dispatcher); + console.log(`Initialized network proxy "${preferedProxy}" from env`); +} + async function downloadPackages() { console.log('Setting up pyodide + micropip'); @@ -84,5 +115,6 @@ async function copyPyodide() { } } +initNetworkProxyFromEnv(); await downloadPackages(); await copyPyodide(); diff --git a/src/app.css b/src/app.css index d324175b5..8bdc6f1ad 100644 --- a/src/app.css +++ b/src/app.css @@ -101,7 +101,7 @@ li p { /* Dark theme scrollbar styles */ .dark ::-webkit-scrollbar-thumb { - background-color: rgba(33, 33, 33, 0.8); /* Darker color for dark theme */ + background-color: rgba(42, 42, 42, 0.8); /* Darker color for dark theme */ border-color: rgba(0, 0, 0, var(--tw-border-opacity)); } diff --git a/src/lib/components/admin/Settings/CodeExecution.svelte b/src/lib/components/admin/Settings/CodeExecution.svelte index cbce8e0e9..c83537455 100644 --- a/src/lib/components/admin/Settings/CodeExecution.svelte +++ b/src/lib/components/admin/Settings/CodeExecution.svelte @@ -91,45 +91,65 @@ -
-
- {$i18n.t('Jupyter Auth')} -
+
+
+
+ {$i18n.t('Jupyter Auth')} +
-
- -
-
- - {#if config.CODE_EXECUTION_JUPYTER_AUTH} -
-
- {#if config.CODE_EXECUTION_JUPYTER_AUTH === 'password'} - - {:else} - - {/if} +
+
- {/if} + + {#if config.CODE_EXECUTION_JUPYTER_AUTH} +
+
+ {#if config.CODE_EXECUTION_JUPYTER_AUTH === 'password'} + + {:else} + + {/if} +
+
+ {/if} +
+ +
+
+ {$i18n.t('Code Execution Timeout')} +
+ +
+ + + +
+
{/if}
@@ -197,45 +217,65 @@
-
-
- {$i18n.t('Jupyter Auth')} -
+
+
+
+ {$i18n.t('Jupyter Auth')} +
-
- -
-
- - {#if config.CODE_INTERPRETER_JUPYTER_AUTH} -
-
- {#if config.CODE_INTERPRETER_JUPYTER_AUTH === 'password'} - - {:else} - - {/if} +
+
- {/if} + + {#if config.CODE_INTERPRETER_JUPYTER_AUTH} +
+
+ {#if config.CODE_INTERPRETER_JUPYTER_AUTH === 'password'} + + {:else} + + {/if} +
+
+ {/if} +
+ +
+
+ {$i18n.t('Code Execution Timeout')} +
+ +
+ + + +
+
{/if}
diff --git a/src/lib/components/admin/Settings/Documents.svelte b/src/lib/components/admin/Settings/Documents.svelte index 917e924ae..c7c1f0e8f 100644 --- a/src/lib/components/admin/Settings/Documents.svelte +++ b/src/lib/components/admin/Settings/Documents.svelte @@ -27,7 +27,6 @@ import SensitiveInput from '$lib/components/common/SensitiveInput.svelte'; import Tooltip from '$lib/components/common/Tooltip.svelte'; import Switch from '$lib/components/common/Switch.svelte'; - import { text } from '@sveltejs/kit'; import Textarea from '$lib/components/common/Textarea.svelte'; const i18n = getContext('i18n'); @@ -56,6 +55,8 @@ let chunkOverlap = 0; let pdfExtractImages = true; + let RAG_FULL_CONTEXT = false; + let enableGoogleDriveIntegration = false; let OpenAIUrl = ''; @@ -182,6 +183,7 @@ max_size: fileMaxSize === '' ? null : fileMaxSize, max_count: fileMaxCount === '' ? null : fileMaxCount }, + RAG_FULL_CONTEXT: RAG_FULL_CONTEXT, chunk: { text_splitter: textSplitter, chunk_overlap: chunkOverlap, @@ -242,6 +244,8 @@ chunkSize = res.chunk.chunk_size; chunkOverlap = res.chunk.chunk_overlap; + RAG_FULL_CONTEXT = res.RAG_FULL_CONTEXT; + contentExtractionEngine = res.content_extraction.engine; tikaServerUrl = res.content_extraction.tika_server_url; showTikaServerUrl = contentExtractionEngine === 'tika'; @@ -388,6 +392,19 @@ {/if}
+ +
+
{$i18n.t('Full Context Mode')}
+
+ + + +
+

diff --git a/src/lib/components/admin/Settings/Images.svelte b/src/lib/components/admin/Settings/Images.svelte index 957cc6971..e63158bcd 100644 --- a/src/lib/components/admin/Settings/Images.svelte +++ b/src/lib/components/admin/Settings/Images.svelte @@ -261,6 +261,9 @@ } else if (config.engine === 'openai' && config.openai.OPENAI_API_KEY === '') { toast.error($i18n.t('OpenAI API Key is required.')); config.enabled = false; + } else if (config.engine === 'gemini' && config.gemini.GEMINI_API_KEY === '') { + toast.error($i18n.t('Gemini API Key is required.')); + config.enabled = false; } } @@ -294,6 +297,7 @@ + @@ -605,6 +609,24 @@ /> + {:else if config?.engine === 'gemini'} +
+
{$i18n.t('Gemini API Config')}
+ +
+ + + +
+
{/if} diff --git a/src/lib/components/admin/Settings/Interface.svelte b/src/lib/components/admin/Settings/Interface.svelte index 8b2a310ef..e3542475e 100644 --- a/src/lib/components/admin/Settings/Interface.svelte +++ b/src/lib/components/admin/Settings/Interface.svelte @@ -51,7 +51,7 @@ onMount(async () => { taskConfig = await getTaskConfig(localStorage.token); - promptSuggestions = $config?.default_prompt_suggestions; + promptSuggestions = $config?.default_prompt_suggestions ?? []; banners = await getBanners(localStorage.token); }); diff --git a/src/lib/components/admin/Users/UserList.svelte b/src/lib/components/admin/Users/UserList.svelte index 3f7832517..7f8e516fb 100644 --- a/src/lib/components/admin/Users/UserList.svelte +++ b/src/lib/components/admin/Users/UserList.svelte @@ -85,8 +85,9 @@ return true; } else { let name = user.name.toLowerCase(); + let email = user.email.toLowerCase(); const query = search.toLowerCase(); - return name.includes(query); + return name.includes(query) || email.includes(query); } }) .sort((a, b) => { diff --git a/src/lib/components/chat/MessageInput.svelte b/src/lib/components/chat/MessageInput.svelte index d18c7d4d2..5cde963ee 100644 --- a/src/lib/components/chat/MessageInput.svelte +++ b/src/lib/components/chat/MessageInput.svelte @@ -430,7 +430,7 @@ {/if} - {#if webSearchEnabled || ($settings?.webSearch ?? false) === 'always'} + {#if webSearchEnabled || ($config?.features?.enable_web_search && ($settings?.webSearch ?? false)) === 'always'}
diff --git a/src/lib/components/chat/Messages/Citations.svelte b/src/lib/components/chat/Messages/Citations.svelte index f8d57cbb7..095e29edf 100644 --- a/src/lib/components/chat/Messages/Citations.svelte +++ b/src/lib/components/chat/Messages/Citations.svelte @@ -7,6 +7,7 @@ const i18n = getContext('i18n'); + export let id = ''; export let sources = []; let citations = []; @@ -100,7 +101,7 @@
{#each citations as citation, idx}