diff --git a/CHANGELOG.md b/CHANGELOG.md index bad83dc1e..86ff57384 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,43 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.5.12] - 2025-02-13 + +### Added + +- **🛠️ Multiple Tool Calls Support for Native Function Mode**: Functions now can call multiple tools within a single response, unlocking better automation and workflow flexibility when using native function calling. + +### Fixed + +- **📝 Playground Text Completion Restored**: Addressed an issue where text completion in the Playground was not functioning. +- **🔗 Direct Connections Now Work for Regular Users**: Fixed a bug where users with the 'user' role couldn't establish direct API connections, enabling seamless model usage for all user tiers. +- **⚡ Landing Page Input No Longer Lags with Long Text**: Improved input responsiveness on the landing page, ensuring fast and smooth typing experiences even when entering long messages. +- **🔧 Parameter in Functions Fixed**: Fixed an issue where the reserved parameters wasn’t recognized within functions, restoring full functionality for advanced task-based automation. + +## [0.5.11] - 2025-02-13 + +### Added + +- **🎤 Kokoro-JS TTS Support**: A new on-device, high-quality text-to-speech engine has been integrated, vastly improving voice generation quality—everything runs directly in your browser. +- **🐍 Jupyter Notebook Support in Code Interpreter**: Now, you can configure Code Interpreter to run Python code not only via Pyodide but also through Jupyter, offering a more robust coding environment for AI-driven computations and analysis. +- **🔗 Direct API Connections for Private & Local Inference**: You can now connect Open WebUI to your private or localhost API inference endpoints. CORS must be enabled, but this unlocks direct, on-device AI infrastructure support. +- **🔍 Advanced Domain Filtering for Web Search**: You can now specify which domains should be included or excluded from web searches, refining results for more relevant information retrieval. +- **🚀 Improved Image Generation Metadata Handling**: Generated images now retain metadata for better organization and future retrieval. +- **📂 S3 Key Prefix Support**: Fine-grained control over S3 storage file structuring with configurable key prefixes. +- **📸 Support for Image-Only Messages**: Send messages containing only images, facilitating more visual-centric interactions. +- **🌍 Updated Translations**: German, Spanish, Traditional Chinese, and Catalan translations updated for better multilingual support. + +### Fixed + +- **🔧 OAuth Debug Logs & Username Claim Fixes**: Debug logs have been added for OAuth role and group management, with fixes ensuring proper OAuth username retrieval and claim handling. +- **📌 Citations Formatting & Toggle Fixes**: Inline citation toggles now function correctly, and citations with more than three sources are now fully visible when expanded. +- **📸 ComfyUI Maximum Seed Value Constraint Fixed**: The maximum allowed seed value for ComfyUI has been corrected, preventing unintended behavior. +- **🔑 Connection Settings Stability**: Addressed connection settings issues that were causing instability when saving configurations. +- **📂 GGUF Model Upload Stability**: Fixed upload inconsistencies for GGUF models, ensuring reliable local model handling. +- **🔧 Web Search Configuration Bug**: Fixed issues where web search filters and settings weren't correctly applied. +- **💾 User Settings Persistence Fix**: Ensured user-specific settings are correctly saved and applied across sessions. +- **🔄 OpenID Username Retrieval Enhancement**: Usernames are now correctly picked up and assigned for OpenID Connect (OIDC) logins. + ## [0.5.10] - 2025-02-05 ### Fixed diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index d3e323a4a..c926759ca 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -1190,6 +1190,12 @@ ENABLE_TAGS_GENERATION = PersistentConfig( os.environ.get("ENABLE_TAGS_GENERATION", "True").lower() == "true", ) +ENABLE_TITLE_GENERATION = PersistentConfig( + "ENABLE_TITLE_GENERATION", + "task.title.enable", + os.environ.get("ENABLE_TITLE_GENERATION", "True").lower() == "true", +) + ENABLE_SEARCH_QUERY_GENERATION = PersistentConfig( "ENABLE_SEARCH_QUERY_GENERATION", @@ -1803,6 +1809,18 @@ SEARCHAPI_ENGINE = PersistentConfig( os.getenv("SEARCHAPI_ENGINE", ""), ) +SERPAPI_API_KEY = PersistentConfig( + "SERPAPI_API_KEY", + "rag.web.search.serpapi_api_key", + os.getenv("SERPAPI_API_KEY", ""), +) + +SERPAPI_ENGINE = PersistentConfig( + "SERPAPI_ENGINE", + "rag.web.search.serpapi_engine", + os.getenv("SERPAPI_ENGINE", ""), +) + BING_SEARCH_V7_ENDPOINT = PersistentConfig( "BING_SEARCH_V7_ENDPOINT", "rag.web.search.bing_search_v7_endpoint", @@ -1841,6 +1859,12 @@ RAG_WEB_LOADER = PersistentConfig( os.environ.get("RAG_WEB_LOADER", "safe_web") ) +RAG_WEB_SEARCH_TRUST_ENV = PersistentConfig( + "RAG_WEB_SEARCH_TRUST_ENV", + "rag.web.search.trust_env", + os.getenv("RAG_WEB_SEARCH_TRUST_ENV", False), +) + PLAYWRIGHT_WS_URI = PersistentConfig( "PLAYWRIGHT_WS_URI", "rag.web.loader.playwright.ws.uri", diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index d6e2832f1..5035119a0 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -177,10 +177,13 @@ from open_webui.config import ( RAG_WEB_SEARCH_ENGINE, RAG_WEB_SEARCH_RESULT_COUNT, RAG_WEB_SEARCH_CONCURRENT_REQUESTS, + RAG_WEB_SEARCH_TRUST_ENV, RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, JINA_API_KEY, SEARCHAPI_API_KEY, SEARCHAPI_ENGINE, + SERPAPI_API_KEY, + SERPAPI_ENGINE, SEARXNG_QUERY_URL, SERPER_API_KEY, SERPLY_API_KEY, @@ -266,6 +269,7 @@ from open_webui.config import ( TASK_MODEL, TASK_MODEL_EXTERNAL, ENABLE_TAGS_GENERATION, + ENABLE_TITLE_GENERATION, ENABLE_SEARCH_QUERY_GENERATION, ENABLE_RETRIEVAL_QUERY_GENERATION, ENABLE_AUTOCOMPLETE_GENERATION, @@ -347,12 +351,12 @@ class SPAStaticFiles(StaticFiles): print( rf""" - ___ __ __ _ _ _ ___ - / _ \ _ __ ___ _ __ \ \ / /__| |__ | | | |_ _| -| | | | '_ \ / _ \ '_ \ \ \ /\ / / _ \ '_ \| | | || | -| |_| | |_) | __/ | | | \ V V / __/ |_) | |_| || | - \___/| .__/ \___|_| |_| \_/\_/ \___|_.__/ \___/|___| - |_| + ██████╗ ██████╗ ███████╗███╗ ██╗ ██╗ ██╗███████╗██████╗ ██╗ ██╗██╗ +██╔═══██╗██╔══██╗██╔════╝████╗ ██║ ██║ ██║██╔════╝██╔══██╗██║ ██║██║ +██║ ██║██████╔╝█████╗ ██╔██╗ ██║ ██║ █╗ ██║█████╗ ██████╔╝██║ ██║██║ +██║ ██║██╔═══╝ ██╔══╝ ██║╚██╗██║ ██║███╗██║██╔══╝ ██╔══██╗██║ ██║██║ +╚██████╔╝██║ ███████╗██║ ╚████║ ╚███╔███╔╝███████╗██████╔╝╚██████╔╝██║ + ╚═════╝ ╚═╝ ╚══════╝╚═╝ ╚═══╝ ╚══╝╚══╝ ╚══════╝╚═════╝ ╚═════╝ ╚═╝ v{VERSION} - building the best open-source AI user interface. @@ -548,6 +552,8 @@ app.state.config.SERPLY_API_KEY = SERPLY_API_KEY app.state.config.TAVILY_API_KEY = TAVILY_API_KEY app.state.config.SEARCHAPI_API_KEY = SEARCHAPI_API_KEY app.state.config.SEARCHAPI_ENGINE = SEARCHAPI_ENGINE +app.state.config.SERPAPI_API_KEY = SERPAPI_API_KEY +app.state.config.SERPAPI_ENGINE = SERPAPI_ENGINE app.state.config.JINA_API_KEY = JINA_API_KEY app.state.config.BING_SEARCH_V7_ENDPOINT = BING_SEARCH_V7_ENDPOINT app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY = BING_SEARCH_V7_SUBSCRIPTION_KEY @@ -556,6 +562,7 @@ app.state.config.EXA_API_KEY = EXA_API_KEY app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = RAG_WEB_SEARCH_RESULT_COUNT app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS app.state.config.RAG_WEB_LOADER = RAG_WEB_LOADER +app.state.config.RAG_WEB_SEARCH_TRUST_ENV = RAG_WEB_SEARCH_TRUST_ENV app.state.config.PLAYWRIGHT_WS_URI = PLAYWRIGHT_WS_URI app.state.EMBEDDING_FUNCTION = None @@ -693,6 +700,7 @@ app.state.config.ENABLE_SEARCH_QUERY_GENERATION = ENABLE_SEARCH_QUERY_GENERATION app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION = ENABLE_RETRIEVAL_QUERY_GENERATION app.state.config.ENABLE_AUTOCOMPLETE_GENERATION = ENABLE_AUTOCOMPLETE_GENERATION app.state.config.ENABLE_TAGS_GENERATION = ENABLE_TAGS_GENERATION +app.state.config.ENABLE_TITLE_GENERATION = ENABLE_TITLE_GENERATION app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMPLATE @@ -904,20 +912,30 @@ async def chat_completion( if not request.app.state.MODELS: await get_all_models(request) + model_item = form_data.pop("model_item", {}) tasks = form_data.pop("background_tasks", None) - try: - model_id = form_data.get("model", None) - if model_id not in request.app.state.MODELS: - raise Exception("Model not found") - model = request.app.state.MODELS[model_id] - model_info = Models.get_model_by_id(model_id) - # Check if user has access to the model - if not BYPASS_MODEL_ACCESS_CONTROL and user.role == "user": - try: - check_model_access(user, model) - except Exception as e: - raise e + try: + if not model_item.get("direct", False): + model_id = form_data.get("model", None) + if model_id not in request.app.state.MODELS: + raise Exception("Model not found") + + model = request.app.state.MODELS[model_id] + model_info = Models.get_model_by_id(model_id) + + # Check if user has access to the model + if not BYPASS_MODEL_ACCESS_CONTROL and user.role == "user": + try: + check_model_access(user, model) + except Exception as e: + raise e + else: + model = model_item + model_info = None + + request.state.direct = True + request.state.model = model metadata = { "user_id": user.id, @@ -929,6 +947,7 @@ async def chat_completion( "features": form_data.get("features", None), "variables": form_data.get("variables", None), "model": model_info, + "direct": model_item.get("direct", False), **( {"function_calling": "native"} if form_data.get("params", {}).get("function_calling") == "native" @@ -940,6 +959,8 @@ async def chat_completion( else {} ), } + + request.state.metadata = metadata form_data["metadata"] = metadata form_data, metadata, events = await process_chat_payload( @@ -947,6 +968,7 @@ async def chat_completion( ) except Exception as e: + log.debug(f"Error processing chat payload: {e}") raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=str(e), @@ -975,6 +997,12 @@ async def chat_completed( request: Request, form_data: dict, user=Depends(get_verified_user) ): try: + model_item = form_data.pop("model_item", {}) + + if model_item.get("direct", False): + request.state.direct = True + request.state.model = model_item + return await chat_completed_handler(request, form_data, user) except Exception as e: raise HTTPException( @@ -988,6 +1016,12 @@ async def chat_action( request: Request, action_id: str, form_data: dict, user=Depends(get_verified_user) ): try: + model_item = form_data.pop("model_item", {}) + + if model_item.get("direct", False): + request.state.direct = True + request.state.model = model_item + return await chat_action_handler(request, action_id, form_data, user) except Exception as e: raise HTTPException( diff --git a/backend/open_webui/retrieval/web/bocha.py b/backend/open_webui/retrieval/web/bocha.py index 98bdae704..f26da36f8 100644 --- a/backend/open_webui/retrieval/web/bocha.py +++ b/backend/open_webui/retrieval/web/bocha.py @@ -9,6 +9,7 @@ from open_webui.env import SRC_LOG_LEVELS log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["RAG"]) + def _parse_response(response): result = {} if "data" in response: @@ -25,7 +26,8 @@ def _parse_response(response): "summary": item.get("summary", ""), "siteName": item.get("siteName", ""), "siteIcon": item.get("siteIcon", ""), - "datePublished": item.get("datePublished", "") or item.get("dateLastCrawled", ""), + "datePublished": item.get("datePublished", "") + or item.get("dateLastCrawled", ""), } for item in webPages["value"] ] @@ -42,17 +44,11 @@ def search_bocha( query (str): The query to search for """ url = "https://api.bochaai.com/v1/web-search?utm_source=ollama" - headers = { - "Authorization": f"Bearer {api_key}", - "Content-Type": "application/json" - } - - payload = json.dumps({ - "query": query, - "summary": True, - "freshness": "noLimit", - "count": count - }) + headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"} + + payload = json.dumps( + {"query": query, "summary": True, "freshness": "noLimit", "count": count} + ) response = requests.post(url, headers=headers, data=payload, timeout=5) response.raise_for_status() @@ -63,10 +59,7 @@ def search_bocha( return [ SearchResult( - link=result["url"], - title=result.get("name"), - snippet=result.get("summary") + link=result["url"], title=result.get("name"), snippet=result.get("summary") ) - for result in results.get("webpage", [])[:count] + for result in results.get("webpage", [])[:count] ] - diff --git a/backend/open_webui/retrieval/web/google_pse.py b/backend/open_webui/retrieval/web/google_pse.py index ea8225262..2d2b863b4 100644 --- a/backend/open_webui/retrieval/web/google_pse.py +++ b/backend/open_webui/retrieval/web/google_pse.py @@ -8,6 +8,7 @@ from open_webui.env import SRC_LOG_LEVELS log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["RAG"]) + def search_google_pse( api_key: str, search_engine_id: str, @@ -46,12 +47,14 @@ def search_google_pse( response.raise_for_status() json_response = response.json() results = json_response.get("items", []) - if results: # check if results are returned. If not, no more pages to fetch. + if results: # check if results are returned. If not, no more pages to fetch. all_results.extend(results) - count -= len(results) # Decrement count by the number of results fetched in this page. - start_index += 10 # Increment start index for the next page + count -= len( + results + ) # Decrement count by the number of results fetched in this page. + start_index += 10 # Increment start index for the next page else: - break # No more results from Google PSE, break the loop + break # No more results from Google PSE, break the loop if filter_list: all_results = get_filtered_results(all_results, filter_list) diff --git a/backend/open_webui/retrieval/web/serpapi.py b/backend/open_webui/retrieval/web/serpapi.py new file mode 100644 index 000000000..028b6bcfe --- /dev/null +++ b/backend/open_webui/retrieval/web/serpapi.py @@ -0,0 +1,48 @@ +import logging +from typing import Optional +from urllib.parse import urlencode + +import requests +from open_webui.retrieval.web.main import SearchResult, get_filtered_results +from open_webui.env import SRC_LOG_LEVELS + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["RAG"]) + + +def search_serpapi( + api_key: str, + engine: str, + query: str, + count: int, + filter_list: Optional[list[str]] = None, +) -> list[SearchResult]: + """Search using serpapi.com's API and return the results as a list of SearchResult objects. + + Args: + api_key (str): A serpapi.com API key + query (str): The query to search for + """ + url = "https://serpapi.com/search" + + engine = engine or "google" + + payload = {"engine": engine, "q": query, "api_key": api_key} + + url = f"{url}?{urlencode(payload)}" + response = requests.request("GET", url) + + json_response = response.json() + log.info(f"results from serpapi search: {json_response}") + + results = sorted( + json_response.get("organic_results", []), key=lambda x: x.get("position", 0) + ) + if filter_list: + results = get_filtered_results(results, filter_list) + return [ + SearchResult( + link=result["link"], title=result["title"], snippet=result["snippet"] + ) + for result in results[:count] + ] diff --git a/backend/open_webui/retrieval/web/tavily.py b/backend/open_webui/retrieval/web/tavily.py index cc468725d..7839b715e 100644 --- a/backend/open_webui/retrieval/web/tavily.py +++ b/backend/open_webui/retrieval/web/tavily.py @@ -1,4 +1,5 @@ import logging +from typing import Optional import requests from open_webui.retrieval.web.main import SearchResult @@ -8,7 +9,13 @@ log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["RAG"]) -def search_tavily(api_key: str, query: str, count: int) -> list[SearchResult]: +def search_tavily( + api_key: str, + query: str, + count: int, + filter_list: Optional[list[str]] = None, + # **kwargs, +) -> list[SearchResult]: """Search using Tavily's Search API and return the results as a list of SearchResult objects. Args: @@ -20,8 +27,8 @@ def search_tavily(api_key: str, query: str, count: int) -> list[SearchResult]: """ url = "https://api.tavily.com/search" data = {"query": query, "api_key": api_key} - - response = requests.post(url, json=data) + include_domain = filter_list + response = requests.post(url, include_domain, json=data) response.raise_for_status() json_response = response.json() diff --git a/backend/open_webui/retrieval/web/utils.py b/backend/open_webui/retrieval/web/utils.py index 3c33809d7..25ca5aef5 100644 --- a/backend/open_webui/retrieval/web/utils.py +++ b/backend/open_webui/retrieval/web/utils.py @@ -2,11 +2,15 @@ import asyncio from datetime import datetime, time, timedelta import socket import ssl +import aiohttp +import asyncio import urllib.parse import certifi import validators from collections import defaultdict from typing import AsyncIterator, Dict, List, Optional, Union, Sequence, Iterator +from typing import Any, AsyncIterator, Dict, Iterator, List, Sequence, Union + from langchain_community.document_loaders import ( WebBaseLoader, @@ -230,6 +234,71 @@ class SafePlaywrightURLLoader(PlaywrightURLLoader): class SafeWebBaseLoader(WebBaseLoader): """WebBaseLoader with enhanced error handling for URLs.""" + def __init__(self, trust_env: bool = False, *args, **kwargs): + """Initialize SafeWebBaseLoader + Args: + trust_env (bool, optional): set to True if using proxy to make web requests, for example + using http(s)_proxy environment variables. Defaults to False. + """ + super().__init__(*args, **kwargs) + self.trust_env = trust_env + + async def _fetch( + self, url: str, retries: int = 3, cooldown: int = 2, backoff: float = 1.5 + ) -> str: + async with aiohttp.ClientSession(trust_env=self.trust_env) as session: + for i in range(retries): + try: + kwargs: Dict = dict( + headers=self.session.headers, + cookies=self.session.cookies.get_dict(), + ) + if not self.session.verify: + kwargs["ssl"] = False + + async with session.get( + url, **(self.requests_kwargs | kwargs) + ) as response: + if self.raise_for_status: + response.raise_for_status() + return await response.text() + except aiohttp.ClientConnectionError as e: + if i == retries - 1: + raise + else: + log.warning( + f"Error fetching {url} with attempt " + f"{i + 1}/{retries}: {e}. Retrying..." + ) + await asyncio.sleep(cooldown * backoff**i) + raise ValueError("retry count exceeded") + + def _unpack_fetch_results( + self, results: Any, urls: List[str], parser: Union[str, None] = None + ) -> List[Any]: + """Unpack fetch results into BeautifulSoup objects.""" + from bs4 import BeautifulSoup + + final_results = [] + for i, result in enumerate(results): + url = urls[i] + if parser is None: + if url.endswith(".xml"): + parser = "xml" + else: + parser = self.default_parser + self._check_parser(parser) + final_results.append(BeautifulSoup(result, parser, **self.bs_kwargs)) + return final_results + + async def ascrape_all( + self, urls: List[str], parser: Union[str, None] = None + ) -> List[Any]: + """Async fetch all urls, then return soups for all results.""" + results = await self.fetch_all(urls) + return self._unpack_fetch_results(results, urls, parser=parser) + + def lazy_load(self) -> Iterator[Document]: """Lazy load text from the url(s) in web_path with error handling.""" for path in self.web_paths: @@ -245,6 +314,26 @@ class SafeWebBaseLoader(WebBaseLoader): # Log the error and continue with the next URL log.exception(e, "Error loading %s", path) + async def alazy_load(self) -> AsyncIterator[Document]: + """Async lazy load text from the url(s) in web_path.""" + results = await self.ascrape_all(self.web_paths) + for path, soup in zip(self.web_paths, results): + text = soup.get_text(**self.bs_get_text_kwargs) + metadata = {"source": path} + if title := soup.find("title"): + metadata["title"] = title.get_text() + if description := soup.find("meta", attrs={"name": "description"}): + metadata["description"] = description.get( + "content", "No description found." + ) + if html := soup.find("html"): + metadata["language"] = html.get("lang", "No language found.") + yield Document(page_content=text, metadata=metadata) + + async def aload(self) -> list[Document]: + """Load data into Document objects.""" + return [document async for document in self.alazy_load()] + RAG_WEB_LOADERS = defaultdict(lambda: SafeWebBaseLoader) RAG_WEB_LOADERS["playwright"] = SafePlaywrightURLLoader RAG_WEB_LOADERS["safe_web"] = SafeWebBaseLoader @@ -253,16 +342,19 @@ def get_web_loader( urls: Union[str, Sequence[str]], verify_ssl: bool = True, requests_per_second: int = 2, + trust_env: bool = False, ): # Check if the URLs are valid safe_urls = safe_validate_urls([urls] if isinstance(urls, str) else urls) web_loader_args = { + web_path=safe_urls, "urls": safe_urls, "verify_ssl": verify_ssl, "requests_per_second": requests_per_second, - "continue_on_failure": True + "continue_on_failure": True, + trust_env=trust_env } if PLAYWRIGHT_WS_URI.value: diff --git a/backend/open_webui/routers/retrieval.py b/backend/open_webui/routers/retrieval.py index a624b8048..e79e414b1 100644 --- a/backend/open_webui/routers/retrieval.py +++ b/backend/open_webui/routers/retrieval.py @@ -21,6 +21,7 @@ from fastapi import ( APIRouter, ) from fastapi.middleware.cors import CORSMiddleware +from fastapi.concurrency import run_in_threadpool from pydantic import BaseModel import tiktoken @@ -50,6 +51,7 @@ from open_webui.retrieval.web.duckduckgo import search_duckduckgo from open_webui.retrieval.web.google_pse import search_google_pse from open_webui.retrieval.web.jina_search import search_jina from open_webui.retrieval.web.searchapi import search_searchapi +from open_webui.retrieval.web.serpapi import search_serpapi from open_webui.retrieval.web.searxng import search_searxng from open_webui.retrieval.web.serper import search_serper from open_webui.retrieval.web.serply import search_serply @@ -388,6 +390,8 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)): "tavily_api_key": request.app.state.config.TAVILY_API_KEY, "searchapi_api_key": request.app.state.config.SEARCHAPI_API_KEY, "searchapi_engine": request.app.state.config.SEARCHAPI_ENGINE, + "serpapi_api_key": request.app.state.config.SERPAPI_API_KEY, + "serpapi_engine": request.app.state.config.SERPAPI_ENGINE, "jina_api_key": request.app.state.config.JINA_API_KEY, "bing_search_v7_endpoint": request.app.state.config.BING_SEARCH_V7_ENDPOINT, "bing_search_v7_subscription_key": request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY, @@ -439,12 +443,15 @@ class WebSearchConfig(BaseModel): tavily_api_key: Optional[str] = None searchapi_api_key: Optional[str] = None searchapi_engine: Optional[str] = None + serpapi_api_key: Optional[str] = None + serpapi_engine: Optional[str] = None jina_api_key: Optional[str] = None bing_search_v7_endpoint: Optional[str] = None bing_search_v7_subscription_key: Optional[str] = None exa_api_key: Optional[str] = None result_count: Optional[int] = None concurrent_requests: Optional[int] = None + trust_env: Optional[bool] = None domain_filter_list: Optional[List[str]] = [] @@ -545,6 +552,9 @@ async def update_rag_config( form_data.web.search.searchapi_engine ) + request.app.state.config.SERPAPI_API_KEY = form_data.web.search.serpapi_api_key + request.app.state.config.SERPAPI_ENGINE = form_data.web.search.serpapi_engine + request.app.state.config.JINA_API_KEY = form_data.web.search.jina_api_key request.app.state.config.BING_SEARCH_V7_ENDPOINT = ( form_data.web.search.bing_search_v7_endpoint @@ -561,6 +571,9 @@ async def update_rag_config( request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = ( form_data.web.search.concurrent_requests ) + request.app.state.config.RAG_WEB_SEARCH_TRUST_ENV = ( + form_data.web.search.trust_env + ) request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = ( form_data.web.search.domain_filter_list ) @@ -604,6 +617,8 @@ async def update_rag_config( "serply_api_key": request.app.state.config.SERPLY_API_KEY, "serachapi_api_key": request.app.state.config.SEARCHAPI_API_KEY, "searchapi_engine": request.app.state.config.SEARCHAPI_ENGINE, + "serpapi_api_key": request.app.state.config.SERPAPI_API_KEY, + "serpapi_engine": request.app.state.config.SERPAPI_ENGINE, "tavily_api_key": request.app.state.config.TAVILY_API_KEY, "jina_api_key": request.app.state.config.JINA_API_KEY, "bing_search_v7_endpoint": request.app.state.config.BING_SEARCH_V7_ENDPOINT, @@ -611,6 +626,7 @@ async def update_rag_config( "exa_api_key": request.app.state.config.EXA_API_KEY, "result_count": request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, "concurrent_requests": request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS, + "trust_env": request.app.state.config.RAG_WEB_SEARCH_TRUST_ENV, "domain_filter_list": request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, }, }, @@ -760,7 +776,11 @@ def save_docs_to_vector_db( # for meta-data so convert them to string. for metadata in metadatas: for key, value in metadata.items(): - if isinstance(value, datetime): + if ( + isinstance(value, datetime) + or isinstance(value, list) + or isinstance(value, dict) + ): metadata[key] = str(value) try: @@ -1127,6 +1147,7 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]: - TAVILY_API_KEY - EXA_API_KEY - SEARCHAPI_API_KEY + SEARCHAPI_ENGINE (by default `google`) + - SERPAPI_API_KEY + SERPAPI_ENGINE (by default `google`) Args: query (str): The query to search for """ @@ -1255,6 +1276,17 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]: ) else: raise Exception("No SEARCHAPI_API_KEY found in environment variables") + elif engine == "serpapi": + if request.app.state.config.SERPAPI_API_KEY: + return search_serpapi( + request.app.state.config.SERPAPI_API_KEY, + request.app.state.config.SERPAPI_ENGINE, + query, + request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, + request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, + ) + else: + raise Exception("No SERPAPI_API_KEY found in environment variables") elif engine == "jina": return search_jina( request.app.state.config.JINA_API_KEY, @@ -1314,17 +1346,25 @@ async def process_web_search( urls, verify_ssl=request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, requests_per_second=request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS, + trust_env=request.app.state.config.RAG_WEB_SEARCH_TRUST_ENV, ) docs = [doc async for doc in loader.alazy_load()] # docs = loader.load() - save_docs_to_vector_db( - request, docs, collection_name, overwrite=True, user=user + docs = await loader.aload() + await run_in_threadpool( + save_docs_to_vector_db, + request, + docs, + collection_name, + overwrite=True, + user=user ) return { "status": True, "collection_name": collection_name, "filenames": urls, + "loaded_count": len(docs), } except Exception as e: log.exception(e) diff --git a/backend/open_webui/routers/tasks.py b/backend/open_webui/routers/tasks.py index f56a0232d..8b17c6c4b 100644 --- a/backend/open_webui/routers/tasks.py +++ b/backend/open_webui/routers/tasks.py @@ -58,6 +58,7 @@ async def get_task_config(request: Request, user=Depends(get_verified_user)): "AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH, "TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE, "ENABLE_TAGS_GENERATION": request.app.state.config.ENABLE_TAGS_GENERATION, + "ENABLE_TITLE_GENERATION": request.app.state.config.ENABLE_TITLE_GENERATION, "ENABLE_SEARCH_QUERY_GENERATION": request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION, "ENABLE_RETRIEVAL_QUERY_GENERATION": request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION, "QUERY_GENERATION_PROMPT_TEMPLATE": request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE, @@ -68,6 +69,7 @@ async def get_task_config(request: Request, user=Depends(get_verified_user)): class TaskConfigForm(BaseModel): TASK_MODEL: Optional[str] TASK_MODEL_EXTERNAL: Optional[str] + ENABLE_TITLE_GENERATION: bool TITLE_GENERATION_PROMPT_TEMPLATE: str IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE: str ENABLE_AUTOCOMPLETE_GENERATION: bool @@ -86,6 +88,7 @@ async def update_task_config( ): request.app.state.config.TASK_MODEL = form_data.TASK_MODEL request.app.state.config.TASK_MODEL_EXTERNAL = form_data.TASK_MODEL_EXTERNAL + request.app.state.config.ENABLE_TITLE_GENERATION = form_data.ENABLE_TITLE_GENERATION request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = ( form_data.TITLE_GENERATION_PROMPT_TEMPLATE ) @@ -122,6 +125,7 @@ async def update_task_config( return { "TASK_MODEL": request.app.state.config.TASK_MODEL, "TASK_MODEL_EXTERNAL": request.app.state.config.TASK_MODEL_EXTERNAL, + "ENABLE_TITLE_GENERATION": request.app.state.config.ENABLE_TITLE_GENERATION, "TITLE_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE, "IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE": request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE, "ENABLE_AUTOCOMPLETE_GENERATION": request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION, @@ -139,7 +143,19 @@ async def update_task_config( async def generate_title( request: Request, form_data: dict, user=Depends(get_verified_user) ): - models = request.app.state.MODELS + + if not request.app.state.config.ENABLE_TITLE_GENERATION: + return JSONResponse( + status_code=status.HTTP_200_OK, + content={"detail": "Title generation is disabled"}, + ) + + if getattr(request.state, "direct", False) and hasattr(request.state, "model"): + models = { + request.state.model["id"]: request.state.model, + } + else: + models = request.app.state.MODELS model_id = form_data["model"] if model_id not in models: @@ -198,6 +214,7 @@ async def generate_title( } ), "metadata": { + **(request.state.metadata if hasattr(request.state, "metadata") else {}), "task": str(TASKS.TITLE_GENERATION), "task_body": form_data, "chat_id": form_data.get("chat_id", None), @@ -225,7 +242,12 @@ async def generate_chat_tags( content={"detail": "Tags generation is disabled"}, ) - models = request.app.state.MODELS + if getattr(request.state, "direct", False) and hasattr(request.state, "model"): + models = { + request.state.model["id"]: request.state.model, + } + else: + models = request.app.state.MODELS model_id = form_data["model"] if model_id not in models: @@ -261,6 +283,7 @@ async def generate_chat_tags( "messages": [{"role": "user", "content": content}], "stream": False, "metadata": { + **(request.state.metadata if hasattr(request.state, "metadata") else {}), "task": str(TASKS.TAGS_GENERATION), "task_body": form_data, "chat_id": form_data.get("chat_id", None), @@ -281,7 +304,12 @@ async def generate_chat_tags( async def generate_image_prompt( request: Request, form_data: dict, user=Depends(get_verified_user) ): - models = request.app.state.MODELS + if getattr(request.state, "direct", False) and hasattr(request.state, "model"): + models = { + request.state.model["id"]: request.state.model, + } + else: + models = request.app.state.MODELS model_id = form_data["model"] if model_id not in models: @@ -321,6 +349,7 @@ async def generate_image_prompt( "messages": [{"role": "user", "content": content}], "stream": False, "metadata": { + **(request.state.metadata if hasattr(request.state, "metadata") else {}), "task": str(TASKS.IMAGE_PROMPT_GENERATION), "task_body": form_data, "chat_id": form_data.get("chat_id", None), @@ -356,7 +385,12 @@ async def generate_queries( detail=f"Query generation is disabled", ) - models = request.app.state.MODELS + if getattr(request.state, "direct", False) and hasattr(request.state, "model"): + models = { + request.state.model["id"]: request.state.model, + } + else: + models = request.app.state.MODELS model_id = form_data["model"] if model_id not in models: @@ -392,6 +426,7 @@ async def generate_queries( "messages": [{"role": "user", "content": content}], "stream": False, "metadata": { + **(request.state.metadata if hasattr(request.state, "metadata") else {}), "task": str(TASKS.QUERY_GENERATION), "task_body": form_data, "chat_id": form_data.get("chat_id", None), @@ -431,7 +466,12 @@ async def generate_autocompletion( detail=f"Input prompt exceeds maximum length of {request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH}", ) - models = request.app.state.MODELS + if getattr(request.state, "direct", False) and hasattr(request.state, "model"): + models = { + request.state.model["id"]: request.state.model, + } + else: + models = request.app.state.MODELS model_id = form_data["model"] if model_id not in models: @@ -467,6 +507,7 @@ async def generate_autocompletion( "messages": [{"role": "user", "content": content}], "stream": False, "metadata": { + **(request.state.metadata if hasattr(request.state, "metadata") else {}), "task": str(TASKS.AUTOCOMPLETE_GENERATION), "task_body": form_data, "chat_id": form_data.get("chat_id", None), @@ -488,7 +529,12 @@ async def generate_emoji( request: Request, form_data: dict, user=Depends(get_verified_user) ): - models = request.app.state.MODELS + if getattr(request.state, "direct", False) and hasattr(request.state, "model"): + models = { + request.state.model["id"]: request.state.model, + } + else: + models = request.app.state.MODELS model_id = form_data["model"] if model_id not in models: @@ -531,7 +577,11 @@ async def generate_emoji( } ), "chat_id": form_data.get("chat_id", None), - "metadata": {"task": str(TASKS.EMOJI_GENERATION), "task_body": form_data}, + "metadata": { + **(request.state.metadata if hasattr(request.state, "metadata") else {}), + "task": str(TASKS.EMOJI_GENERATION), + "task_body": form_data, + }, } try: @@ -548,7 +598,13 @@ async def generate_moa_response( request: Request, form_data: dict, user=Depends(get_verified_user) ): - models = request.app.state.MODELS + if getattr(request.state, "direct", False) and hasattr(request.state, "model"): + models = { + request.state.model["id"]: request.state.model, + } + else: + models = request.app.state.MODELS + model_id = form_data["model"] if model_id not in models: @@ -581,6 +637,7 @@ async def generate_moa_response( "messages": [{"role": "user", "content": content}], "stream": form_data.get("stream", False), "metadata": { + **(request.state.metadata if hasattr(request.state, "metadata") else {}), "chat_id": form_data.get("chat_id", None), "task": str(TASKS.MOA_RESPONSE_GENERATION), "task_body": form_data, diff --git a/backend/open_webui/socket/main.py b/backend/open_webui/socket/main.py index 3788139ea..6f5915122 100644 --- a/backend/open_webui/socket/main.py +++ b/backend/open_webui/socket/main.py @@ -279,8 +279,8 @@ def get_event_emitter(request_info): await sio.emit( "chat-events", { - "chat_id": request_info["chat_id"], - "message_id": request_info["message_id"], + "chat_id": request_info.get("chat_id", None), + "message_id": request_info.get("message_id", None), "data": event_data, }, to=session_id, @@ -329,8 +329,8 @@ def get_event_call(request_info): response = await sio.call( "chat-events", { - "chat_id": request_info["chat_id"], - "message_id": request_info["message_id"], + "chat_id": request_info.get("chat_id", None), + "message_id": request_info.get("message_id", None), "data": event_data, }, to=request_info["session_id"], diff --git a/backend/open_webui/utils/chat.py b/backend/open_webui/utils/chat.py index 3b6d5ea04..569bcad85 100644 --- a/backend/open_webui/utils/chat.py +++ b/backend/open_webui/utils/chat.py @@ -7,14 +7,17 @@ from typing import Any, Optional import random import json import inspect +import uuid +import asyncio -from fastapi import Request -from starlette.responses import Response, StreamingResponse +from fastapi import Request, status +from starlette.responses import Response, StreamingResponse, JSONResponse from open_webui.models.users import UserModel from open_webui.socket.main import ( + sio, get_event_call, get_event_emitter, ) @@ -57,16 +60,127 @@ log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MAIN"]) +async def generate_direct_chat_completion( + request: Request, + form_data: dict, + user: Any, + models: dict, +): + print("generate_direct_chat_completion") + + metadata = form_data.pop("metadata", {}) + + user_id = metadata.get("user_id") + session_id = metadata.get("session_id") + request_id = str(uuid.uuid4()) # Generate a unique request ID + + event_caller = get_event_call(metadata) + + channel = f"{user_id}:{session_id}:{request_id}" + + if form_data.get("stream"): + q = asyncio.Queue() + + async def message_listener(sid, data): + """ + Handle received socket messages and push them into the queue. + """ + await q.put(data) + + # Register the listener + sio.on(channel, message_listener) + + # Start processing chat completion in background + res = await event_caller( + { + "type": "request:chat:completion", + "data": { + "form_data": form_data, + "model": models[form_data["model"]], + "channel": channel, + "session_id": session_id, + }, + } + ) + + print("res", res) + + if res.get("status", False): + # Define a generator to stream responses + async def event_generator(): + nonlocal q + try: + while True: + data = await q.get() # Wait for new messages + if isinstance(data, dict): + if "done" in data and data["done"]: + break # Stop streaming when 'done' is received + + yield f"data: {json.dumps(data)}\n\n" + elif isinstance(data, str): + yield data + except Exception as e: + log.debug(f"Error in event generator: {e}") + pass + + # Define a background task to run the event generator + async def background(): + try: + del sio.handlers["/"][channel] + except Exception as e: + pass + + # Return the streaming response + return StreamingResponse( + event_generator(), media_type="text/event-stream", background=background + ) + else: + raise Exception(str(res)) + else: + res = await event_caller( + { + "type": "request:chat:completion", + "data": { + "form_data": form_data, + "model": models[form_data["model"]], + "channel": channel, + "session_id": session_id, + }, + } + ) + + if "error" in res: + raise Exception(res["error"]) + + return res + + async def generate_chat_completion( request: Request, form_data: dict, user: Any, bypass_filter: bool = False, ): + log.debug(f"generate_chat_completion: {form_data}") if BYPASS_MODEL_ACCESS_CONTROL: bypass_filter = True - models = request.app.state.MODELS + if hasattr(request.state, "metadata"): + if "metadata" not in form_data: + form_data["metadata"] = request.state.metadata + else: + form_data["metadata"] = { + **form_data["metadata"], + **request.state.metadata, + } + + if getattr(request.state, "direct", False) and hasattr(request.state, "model"): + models = { + request.state.model["id"]: request.state.model, + } + log.debug(f"direct connection to model: {models}") + else: + models = request.app.state.MODELS model_id = form_data["model"] if model_id not in models: @@ -80,85 +194,96 @@ async def generate_chat_completion( model = models[model_id] - # Check if user has access to the model - if not bypass_filter and user.role == "user": - try: - check_model_access(user, model) - except Exception as e: - raise e - - if model["owned_by"] == "arena": - model_ids = model.get("info", {}).get("meta", {}).get("model_ids") - filter_mode = model.get("info", {}).get("meta", {}).get("filter_mode") - if model_ids and filter_mode == "exclude": - model_ids = [ - model["id"] - for model in list(request.app.state.MODELS.values()) - if model.get("owned_by") != "arena" and model["id"] not in model_ids - ] - - selected_model_id = None - if isinstance(model_ids, list) and model_ids: - selected_model_id = random.choice(model_ids) - else: - model_ids = [ - model["id"] - for model in list(request.app.state.MODELS.values()) - if model.get("owned_by") != "arena" - ] - selected_model_id = random.choice(model_ids) - - form_data["model"] = selected_model_id - - if form_data.get("stream") == True: - - async def stream_wrapper(stream): - yield f"data: {json.dumps({'selected_model_id': selected_model_id})}\n\n" - async for chunk in stream: - yield chunk - - response = await generate_chat_completion( - request, form_data, user, bypass_filter=True - ) - return StreamingResponse( - stream_wrapper(response.body_iterator), - media_type="text/event-stream", - background=response.background, - ) - else: - return { - **( - await generate_chat_completion( - request, form_data, user, bypass_filter=True - ) - ), - "selected_model_id": selected_model_id, - } - - if model.get("pipe"): - # Below does not require bypass_filter because this is the only route the uses this function and it is already bypassing the filter - return await generate_function_chat_completion( + if getattr(request.state, "direct", False): + return await generate_direct_chat_completion( request, form_data, user=user, models=models ) - if model["owned_by"] == "ollama": - # Using /ollama/api/chat endpoint - form_data = convert_payload_openai_to_ollama(form_data) - response = await generate_ollama_chat_completion( - request=request, form_data=form_data, user=user, bypass_filter=bypass_filter - ) - if form_data.get("stream"): - response.headers["content-type"] = "text/event-stream" - return StreamingResponse( - convert_streaming_response_ollama_to_openai(response), - headers=dict(response.headers), - background=response.background, - ) - else: - return convert_response_ollama_to_openai(response) else: - return await generate_openai_chat_completion( - request=request, form_data=form_data, user=user, bypass_filter=bypass_filter - ) + # Check if user has access to the model + if not bypass_filter and user.role == "user": + try: + check_model_access(user, model) + except Exception as e: + raise e + + if model["owned_by"] == "arena": + model_ids = model.get("info", {}).get("meta", {}).get("model_ids") + filter_mode = model.get("info", {}).get("meta", {}).get("filter_mode") + if model_ids and filter_mode == "exclude": + model_ids = [ + model["id"] + for model in list(request.app.state.MODELS.values()) + if model.get("owned_by") != "arena" and model["id"] not in model_ids + ] + + selected_model_id = None + if isinstance(model_ids, list) and model_ids: + selected_model_id = random.choice(model_ids) + else: + model_ids = [ + model["id"] + for model in list(request.app.state.MODELS.values()) + if model.get("owned_by") != "arena" + ] + selected_model_id = random.choice(model_ids) + + form_data["model"] = selected_model_id + + if form_data.get("stream") == True: + + async def stream_wrapper(stream): + yield f"data: {json.dumps({'selected_model_id': selected_model_id})}\n\n" + async for chunk in stream: + yield chunk + + response = await generate_chat_completion( + request, form_data, user, bypass_filter=True + ) + return StreamingResponse( + stream_wrapper(response.body_iterator), + media_type="text/event-stream", + background=response.background, + ) + else: + return { + **( + await generate_chat_completion( + request, form_data, user, bypass_filter=True + ) + ), + "selected_model_id": selected_model_id, + } + + if model.get("pipe"): + # Below does not require bypass_filter because this is the only route the uses this function and it is already bypassing the filter + return await generate_function_chat_completion( + request, form_data, user=user, models=models + ) + if model["owned_by"] == "ollama": + # Using /ollama/api/chat endpoint + form_data = convert_payload_openai_to_ollama(form_data) + response = await generate_ollama_chat_completion( + request=request, + form_data=form_data, + user=user, + bypass_filter=bypass_filter, + ) + if form_data.get("stream"): + response.headers["content-type"] = "text/event-stream" + return StreamingResponse( + convert_streaming_response_ollama_to_openai(response), + headers=dict(response.headers), + background=response.background, + ) + else: + return convert_response_ollama_to_openai(response) + else: + return await generate_openai_chat_completion( + request=request, + form_data=form_data, + user=user, + bypass_filter=bypass_filter, + ) chat_completion = generate_chat_completion @@ -167,7 +292,13 @@ chat_completion = generate_chat_completion async def chat_completed(request: Request, form_data: dict, user: Any): if not request.app.state.MODELS: await get_all_models(request) - models = request.app.state.MODELS + + if getattr(request.state, "direct", False) and hasattr(request.state, "model"): + models = { + request.state.model["id"]: request.state.model, + } + else: + models = request.app.state.MODELS data = form_data model_id = data["model"] @@ -227,7 +358,13 @@ async def chat_action(request: Request, action_id: str, form_data: dict, user: A if not request.app.state.MODELS: await get_all_models(request) - models = request.app.state.MODELS + + if getattr(request.state, "direct", False) and hasattr(request.state, "model"): + models = { + request.state.model["id"]: request.state.model, + } + else: + models = request.app.state.MODELS data = form_data model_id = data["model"] diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index acd429118..d1aaacb13 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -616,7 +616,13 @@ async def process_chat_payload(request, form_data, metadata, user, model): # Initialize events to store additional event to be sent to the client # Initialize contexts and citation - models = request.app.state.MODELS + if getattr(request.state, "direct", False) and hasattr(request.state, "model"): + models = { + request.state.model["id"]: request.state.model, + } + else: + models = request.app.state.MODELS + task_model_id = get_task_model_id( form_data["model"], request.app.state.config.TASK_MODEL, @@ -766,17 +772,7 @@ async def process_chat_payload(request, form_data, metadata, user, model): if "document" in source: for doc_idx, doc_context in enumerate(source["document"]): - doc_metadata = source.get("metadata") - doc_source_id = None - - if doc_metadata: - doc_source_id = doc_metadata[doc_idx].get("source", source_id) - - if source_id: - context_string += f"{doc_source_id if doc_source_id is not None else source_id}{doc_context}\n" - else: - # If there is no source_id, then do not include the source_id tag - context_string += f"{doc_context}\n" + context_string += f"{doc_idx}{doc_context}\n" context_string = context_string.strip() prompt = get_last_user_message(form_data["messages"]) @@ -1149,6 +1145,46 @@ async def process_chat_response( return content.strip() + def convert_content_blocks_to_messages(content_blocks): + messages = [] + + temp_blocks = [] + for idx, block in enumerate(content_blocks): + if block["type"] == "tool_calls": + messages.append( + { + "role": "assistant", + "content": serialize_content_blocks(temp_blocks), + "tool_calls": block.get("content"), + } + ) + + results = block.get("results", []) + + for result in results: + messages.append( + { + "role": "tool", + "tool_call_id": result["tool_call_id"], + "content": result["content"], + } + ) + temp_blocks = [] + else: + temp_blocks.append(block) + + if temp_blocks: + content = serialize_content_blocks(temp_blocks) + if content: + messages.append( + { + "role": "assistant", + "content": content, + } + ) + + return messages + def tag_content_handler(content_type, tags, content, content_blocks): end_flag = False @@ -1540,7 +1576,6 @@ async def process_chat_response( results = [] for tool_call in response_tool_calls: - print("\n\n" + str(tool_call) + "\n\n") tool_call_id = tool_call.get("id", "") tool_name = tool_call.get("function", {}).get("name", "") @@ -1606,23 +1641,10 @@ async def process_chat_response( { "model": model_id, "stream": True, + "tools": form_data["tools"], "messages": [ *form_data["messages"], - { - "role": "assistant", - "content": serialize_content_blocks( - content_blocks, raw=True - ), - "tool_calls": response_tool_calls, - }, - *[ - { - "role": "tool", - "tool_call_id": result["tool_call_id"], - "content": result["content"], - } - for result in results - ], + *convert_content_blocks_to_messages(content_blocks), ], }, user, @@ -1671,6 +1693,9 @@ async def process_chat_response( "data": { "id": str(uuid4()), "code": code, + "session_id": metadata.get( + "session_id", None + ), }, } ) @@ -1699,10 +1724,12 @@ async def process_chat_response( "stdout": "Code interpreter engine not configured." } + log.debug(f"Code interpreter output: {output}") + if isinstance(output, dict): stdout = output.get("stdout", "") - if stdout: + if isinstance(stdout, str): stdoutLines = stdout.split("\n") for idx, line in enumerate(stdoutLines): if "data:image/png;base64" in line: @@ -1734,7 +1761,7 @@ async def process_chat_response( result = output.get("result", "") - if result: + if isinstance(result, str): resultLines = result.split("\n") for idx, line in enumerate(resultLines): if "data:image/png;base64" in line: @@ -1784,6 +1811,8 @@ async def process_chat_response( } ) + print(content_blocks, serialize_content_blocks(content_blocks)) + try: res = await generate_chat_completion( request, diff --git a/backend/open_webui/utils/misc.py b/backend/open_webui/utils/misc.py index 99e6d9c39..f79b62684 100644 --- a/backend/open_webui/utils/misc.py +++ b/backend/open_webui/utils/misc.py @@ -217,12 +217,20 @@ def openai_chat_chunk_message_template( def openai_chat_completion_message_template( - model: str, message: Optional[str] = None, usage: Optional[dict] = None + model: str, + message: Optional[str] = None, + tool_calls: Optional[list[dict]] = None, + usage: Optional[dict] = None, ) -> dict: template = openai_chat_message_template(model) template["object"] = "chat.completion" if message is not None: - template["choices"][0]["message"] = {"content": message, "role": "assistant"} + template["choices"][0]["message"] = { + "content": message, + "role": "assistant", + **({"tool_calls": tool_calls} if tool_calls else {}), + } + template["choices"][0]["finish_reason"] = "stop" if usage: diff --git a/backend/open_webui/utils/response.py b/backend/open_webui/utils/response.py index 4917d3852..f9979b4a2 100644 --- a/backend/open_webui/utils/response.py +++ b/backend/open_webui/utils/response.py @@ -6,9 +6,32 @@ from open_webui.utils.misc import ( ) +def convert_ollama_tool_call_to_openai(tool_calls: dict) -> dict: + openai_tool_calls = [] + for tool_call in tool_calls: + openai_tool_call = { + "index": tool_call.get("index", 0), + "id": tool_call.get("id", f"call_{str(uuid4())}"), + "type": "function", + "function": { + "name": tool_call.get("function", {}).get("name", ""), + "arguments": json.dumps( + tool_call.get("function", {}).get("arguments", {}) + ), + }, + } + openai_tool_calls.append(openai_tool_call) + return openai_tool_calls + + def convert_response_ollama_to_openai(ollama_response: dict) -> dict: model = ollama_response.get("model", "ollama") message_content = ollama_response.get("message", {}).get("content", "") + tool_calls = ollama_response.get("message", {}).get("tool_calls", None) + openai_tool_calls = None + + if tool_calls: + openai_tool_calls = convert_ollama_tool_call_to_openai(tool_calls) data = ollama_response usage = { @@ -51,7 +74,9 @@ def convert_response_ollama_to_openai(ollama_response: dict) -> dict: ), } - response = openai_chat_completion_message_template(model, message_content, usage) + response = openai_chat_completion_message_template( + model, message_content, openai_tool_calls, usage + ) return response @@ -65,20 +90,7 @@ async def convert_streaming_response_ollama_to_openai(ollama_streaming_response) openai_tool_calls = None if tool_calls: - openai_tool_calls = [] - for tool_call in tool_calls: - openai_tool_call = { - "index": tool_call.get("index", 0), - "id": tool_call.get("id", f"call_{str(uuid4())}"), - "type": "function", - "function": { - "name": tool_call.get("function", {}).get("name", ""), - "arguments": json.dumps( - tool_call.get("function", {}).get("arguments", {}) - ), - }, - } - openai_tool_calls.append(openai_tool_call) + openai_tool_calls = convert_ollama_tool_call_to_openai(tool_calls) done = data.get("done", False) diff --git a/backend/requirements.txt b/backend/requirements.txt index 86755f50f..c9b7a4b78 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -3,9 +3,6 @@ uvicorn[standard]==0.30.6 pydantic==2.9.2 python-multipart==0.0.18 -Flask==3.1.0 -Flask-Cors==5.0.0 - python-socketio==5.11.3 python-jose==3.3.0 passlib[bcrypt]==1.7.4 diff --git a/package-lock.json b/package-lock.json index e5c18101b..56a76c09c 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "open-webui", - "version": "0.5.10", + "version": "0.5.12", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "open-webui", - "version": "0.5.10", + "version": "0.5.12", "dependencies": { "@codemirror/lang-javascript": "^6.2.2", "@codemirror/lang-python": "^6.1.6", diff --git a/package.json b/package.json index aa43f6a75..c1a76fd78 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "open-webui", - "version": "0.5.10", + "version": "0.5.12", "private": true, "scripts": { "dev": "npm run pyodide:fetch && vite dev --host", diff --git a/pyproject.toml b/pyproject.toml index 544105e11..74b0facac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,9 +11,6 @@ dependencies = [ "pydantic==2.9.2", "python-multipart==0.0.18", - "Flask==3.1.0", - "Flask-Cors==5.0.0", - "python-socketio==5.11.3", "python-jose==3.3.0", "passlib[bcrypt]==1.7.4", @@ -55,7 +52,7 @@ dependencies = [ "chromadb==0.6.2", "pymilvus==2.5.0", "qdrant-client~=1.12.0", - "opensearch-py==2.7.1", + "opensearch-py==2.8.0", "playwright==1.49.1", "transformers", diff --git a/src/lib/apis/index.ts b/src/lib/apis/index.ts index 53c577a45..3fb4a5d01 100644 --- a/src/lib/apis/index.ts +++ b/src/lib/apis/index.ts @@ -68,7 +68,21 @@ export const getModels = async ( })() ); } else { - requests.push(getOpenAIModelsDirect(url, OPENAI_API_KEYS[idx])); + requests.push( + (async () => { + return await getOpenAIModelsDirect(url, OPENAI_API_KEYS[idx]) + .then((res) => { + return res; + }) + .catch((err) => { + return { + object: 'list', + data: [], + urlIdx: idx + }; + }); + })() + ); } } else { requests.push( diff --git a/src/lib/components/admin/Settings/Interface.svelte b/src/lib/components/admin/Settings/Interface.svelte index 332d02c5a..316aaad1f 100644 --- a/src/lib/components/admin/Settings/Interface.svelte +++ b/src/lib/components/admin/Settings/Interface.svelte @@ -23,6 +23,7 @@ let taskConfig = { TASK_MODEL: '', TASK_MODEL_EXTERNAL: '', + ENABLE_TITLE_GENERATION: true, TITLE_GENERATION_PROMPT_TEMPLATE: '', IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE: '', ENABLE_AUTOCOMPLETE_GENERATION: true, @@ -126,22 +127,34 @@ -
-
{$i18n.t('Title Generation Prompt')}
+
- -