diff --git a/CHANGELOG.md b/CHANGELOG.md index f6e8f7d29..3aaa79292 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,29 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.6.1] - 2025-04-05 + +### Added + +- 🛠️ **Global Tool Servers Configuration**: Admins can now centrally configure global external tool servers from Admin Settings > Tools, allowing seamless sharing of tool integrations across all users without manual setup per user. +- 🔐 **Direct Tool Usage Permission for Users**: Introduced a new user-level permission toggle that grants non-admin users access to direct external tools, empowering broader team collaboration while maintaining control. +- 🧠 **Mistral OCR Content Extraction Support**: Added native support for Mistral OCR as a high-accuracy document loader, drastically improving text extraction from scanned documents in RAG workflows. +- 🖼️ **Tools Indicator UI Redesign**: Enhanced message input now smartly displays both built-in and external tools via a unified dropdown, making it simpler and more intuitive to activate tools during conversations. +- 📄 **RAG Prompt Improved and More Coherent**: Default RAG system prompt has been revised to be more clear and citation-focused—admins can leave the template field empty to use this new gold-standard prompt. +- 🧰 **Performance & Developer Improvements**: Major internal restructuring of several tool-related components, simplifying styling and merging external/internal handling logic, resulting in better maintainability and performance. +- 🌍 **Improved Translations**: Updated translations for Tibetan, Polish, Chinese (Simplified & Traditional), Arabic, Russian, Ukrainian, Dutch, Finnish, and French to improve clarity and consistency across the interface. + +### Fixed + +- 🔑 **External Tool Server API Key Bug Resolved**: Fixed a critical issue where authentication headers were not being sent when calling tools from external OpenAPI tool servers, ensuring full security and smooth tool operations. +- 🚫 **Conditional Export Button Visibility**: UI now gracefully hides export buttons when there's nothing to export in models, prompts, tools, or functions, improving visual clarity and reducing confusion. +- 🧪 **Hybrid Search Failure Recovery**: Resolved edge case in parallel hybrid search where empty or unindexed collections caused backend crashes—these are now cleanly skipped to ensure system stability. +- 📂 **Admin Folder Deletion Fix**: Addressed an issue where folders created in the admin workspace couldn't be deleted, restoring full organizational flexibility for admins. +- 🔐 **Improved Generic Error Feedback on Login**: Authentication errors now show simplified, non-revealing messages for privacy and improved UX, especially with federated logins. +- 📝 **Tool Message with Images Improved**: Enhanced how tool-generated messages with image outputs are shown in chat, making them more readable and consistent with the overall UI design. +- ⚙️ **Auto-Exclusion for Broken RAG Collections**: Auto-skips document collections that fail to fetch data or return "None", preventing silent errors and streamlining retrieval workflows. +- 📝 **Docling Text File Handling Fix**: Fixed file parsing inconsistency that broke docling-based RAG functionality for certain plain text files, ensuring wider file compatibility. + ## [0.6.0] - 2025-03-31 ### Added diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index 0a33c68b1..8238f8a87 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -881,6 +881,17 @@ except Exception: pass OPENAI_API_BASE_URL = "https://api.openai.com/v1" +#################################### +# TOOL_SERVERS +#################################### + + +TOOL_SERVER_CONNECTIONS = PersistentConfig( + "TOOL_SERVER_CONNECTIONS", + "tool_server.connections", + [], +) + #################################### # WEBUI #################################### @@ -1889,7 +1900,7 @@ CHUNK_OVERLAP = PersistentConfig( ) DEFAULT_RAG_TEMPLATE = """### Task: -Respond to the user query using the provided context, incorporating inline citations in the format [source_id] **only when the tag includes an explicit id attribute** (e.g., ). +Respond to the user query using the provided context, incorporating inline citations in the format [id] **only when the tag includes an explicit id attribute** (e.g., ). ### Guidelines: - If you don't know the answer, clearly state that. @@ -1897,17 +1908,17 @@ Respond to the user query using the provided context, incorporating inline citat - Respond in the same language as the user's query. - If the context is unreadable or of poor quality, inform the user and provide the best possible answer. - If the answer isn't present in the context but you possess the knowledge, explain this to the user and provide the answer using your own understanding. -- **Only include inline citations using [source_id] (e.g., [1], [2]) when the tag includes an id attribute.** -- Do not cite if the tag does not contain an id attribute. +- **Only include inline citations using [id] (e.g., [1], [2]) when the tag includes an id attribute.** +- Do not cite if the tag does not contain an id attribute. - Do not use XML tags in your response. - Ensure citations are concise and directly related to the information provided. ### Example of Citation: -If the user asks about a specific topic and the information is found in a source with a provided id attribute, the response should include the citation like so: +If the user asks about a specific topic and the information is found in a source with a provided id attribute, the response should include the citation like in the following example: * "According to the study, the proposed method increases efficiency by 20% [1]." ### Output: -Provide a clear and direct response to the user's query, including inline citations in the format [source_id] only when the tag is present in the context. +Provide a clear and direct response to the user's query, including inline citations in the format [id] only when the tag with id attribute is present in the context. {{CONTEXT}} diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 383523174..c9ca059c2 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -105,6 +105,8 @@ from open_webui.config import ( OPENAI_API_CONFIGS, # Direct Connections ENABLE_DIRECT_CONNECTIONS, + # Tool Server Configs + TOOL_SERVER_CONNECTIONS, # Code Execution ENABLE_CODE_EXECUTION, CODE_EXECUTION_ENGINE, @@ -356,6 +358,7 @@ from open_webui.utils.access_control import has_access from open_webui.utils.auth import ( get_license_data, + get_http_authorization_cred, decode_token, get_admin_user, get_verified_user, @@ -478,6 +481,15 @@ app.state.config.OPENAI_API_CONFIGS = OPENAI_API_CONFIGS app.state.OPENAI_MODELS = {} +######################################## +# +# TOOL SERVERS +# +######################################## + +app.state.config.TOOL_SERVER_CONNECTIONS = TOOL_SERVER_CONNECTIONS +app.state.TOOL_SERVERS = [] + ######################################## # # DIRECT CONNECTIONS @@ -864,6 +876,10 @@ async def commit_session_after_request(request: Request, call_next): @app.middleware("http") async def check_url(request: Request, call_next): start_time = int(time.time()) + request.state.token = get_http_authorization_cred( + request.headers.get("Authorization") + ) + request.state.enable_api_key = app.state.config.ENABLE_API_KEY response = await call_next(request) process_time = int(time.time()) - start_time diff --git a/backend/open_webui/retrieval/loaders/main.py b/backend/open_webui/retrieval/loaders/main.py index 7098822b4..24944bd8a 100644 --- a/backend/open_webui/retrieval/loaders/main.py +++ b/backend/open_webui/retrieval/loaders/main.py @@ -184,13 +184,16 @@ class Loader: for doc in docs ] + def _is_text_file(self, file_ext: str, file_content_type: str) -> bool: + return file_ext in known_source_ext or ( + file_content_type and file_content_type.find("text/") >= 0 + ) + def _get_loader(self, filename: str, file_content_type: str, file_path: str): file_ext = filename.split(".")[-1].lower() if self.engine == "tika" and self.kwargs.get("TIKA_SERVER_URL"): - if file_ext in known_source_ext or ( - file_content_type and file_content_type.find("text/") >= 0 - ): + if self._is_text_file(file_ext, file_content_type): loader = TextLoader(file_path, autodetect_encoding=True) else: loader = TikaLoader( @@ -199,11 +202,14 @@ class Loader: mime_type=file_content_type, ) elif self.engine == "docling" and self.kwargs.get("DOCLING_SERVER_URL"): - loader = DoclingLoader( - url=self.kwargs.get("DOCLING_SERVER_URL"), - file_path=file_path, - mime_type=file_content_type, - ) + if self._is_text_file(file_ext, file_content_type): + loader = TextLoader(file_path, autodetect_encoding=True) + else: + loader = DoclingLoader( + url=self.kwargs.get("DOCLING_SERVER_URL"), + file_path=file_path, + mime_type=file_content_type, + ) elif ( self.engine == "document_intelligence" and self.kwargs.get("DOCUMENT_INTELLIGENCE_ENDPOINT") != "" @@ -269,9 +275,7 @@ class Loader: loader = UnstructuredPowerPointLoader(file_path) elif file_ext == "msg": loader = OutlookMessageLoader(file_path) - elif file_ext in known_source_ext or ( - file_content_type and file_content_type.find("text/") >= 0 - ): + elif self._is_text_file(file_ext, file_content_type): loader = TextLoader(file_path, autodetect_encoding=True) else: loader = TextLoader(file_path, autodetect_encoding=True) diff --git a/backend/open_webui/retrieval/utils.py b/backend/open_webui/retrieval/utils.py index 518a12136..f2d2c61de 100644 --- a/backend/open_webui/retrieval/utils.py +++ b/backend/open_webui/retrieval/utils.py @@ -320,10 +320,13 @@ def query_collection_with_hybrid_search( log.exception(f"Error when querying the collection with hybrid_search: {e}") return None, e + # Prepare tasks for all collections and queries + # Avoid running any tasks for collections that failed to fetch data (have assigned None) tasks = [ - (collection_name, query) - for collection_name in collection_names - for query in queries + (cn, q) + for cn in collection_names + if collection_results[cn] is not None + for q in queries ] with ThreadPoolExecutor() as executor: diff --git a/backend/open_webui/routers/auths.py b/backend/open_webui/routers/auths.py index 07804f9ea..67c2e9f2a 100644 --- a/backend/open_webui/routers/auths.py +++ b/backend/open_webui/routers/auths.py @@ -194,8 +194,8 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm): ciphers=LDAP_CIPHERS, ) except Exception as e: - log.error(f"An error occurred on TLS: {str(e)}") - raise HTTPException(400, detail=str(e)) + log.error(f"TLS configuration error: {str(e)}") + raise HTTPException(400, detail="Failed to configure TLS for LDAP connection.") try: server = Server( @@ -232,7 +232,7 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm): username = str(entry[f"{LDAP_ATTRIBUTE_FOR_USERNAME}"]).lower() email = str(entry[f"{LDAP_ATTRIBUTE_FOR_MAIL}"]) if not email or email == "" or email == "[]": - raise HTTPException(400, f"User {form_data.user} does not have email.") + raise HTTPException(400, "User does not have a valid email address.") else: email = email.lower() @@ -248,7 +248,7 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm): authentication="SIMPLE", ) if not connection_user.bind(): - raise HTTPException(400, f"Authentication failed for {form_data.user}") + raise HTTPException(400, "Authentication failed.") user = Users.get_user_by_email(email) if not user: @@ -276,7 +276,10 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm): except HTTPException: raise except Exception as err: - raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT(err)) + log.error(f"LDAP user creation error: {str(err)}") + raise HTTPException( + 500, detail="Internal error occurred during LDAP user creation." + ) user = Auths.authenticate_user_by_trusted_header(email) @@ -312,12 +315,10 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm): else: raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) else: - raise HTTPException( - 400, - f"User {form_data.user} does not match the record. Search result: {str(entry[f'{LDAP_ATTRIBUTE_FOR_USERNAME}'])}", - ) + raise HTTPException(400, "User record mismatch.") except Exception as e: - raise HTTPException(400, detail=str(e)) + log.error(f"LDAP authentication error: {str(e)}") + raise HTTPException(400, detail="LDAP authentication failed.") ############################ @@ -519,7 +520,8 @@ async def signup(request: Request, response: Response, form_data: SignupForm): else: raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_USER_ERROR) except Exception as err: - raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT(err)) + log.error(f"Signup error: {str(err)}") + raise HTTPException(500, detail="An internal error occurred during signup.") @router.get("/signout") @@ -547,7 +549,11 @@ async def signout(request: Request, response: Response): detail="Failed to fetch OpenID configuration", ) except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) + log.error(f"OpenID signout error: {str(e)}") + raise HTTPException( + status_code=500, + detail="Failed to sign out from the OpenID provider.", + ) return {"status": True} @@ -591,7 +597,10 @@ async def add_user(form_data: AddUserForm, user=Depends(get_admin_user)): else: raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_USER_ERROR) except Exception as err: - raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT(err)) + log.error(f"Add user error: {str(err)}") + raise HTTPException( + 500, detail="An internal error occurred while adding the user." + ) ############################ diff --git a/backend/open_webui/routers/configs.py b/backend/open_webui/routers/configs.py index 2a4c651f2..44b2ef40c 100644 --- a/backend/open_webui/routers/configs.py +++ b/backend/open_webui/routers/configs.py @@ -1,5 +1,5 @@ -from fastapi import APIRouter, Depends, Request -from pydantic import BaseModel +from fastapi import APIRouter, Depends, Request, HTTPException +from pydantic import BaseModel, ConfigDict from typing import Optional @@ -7,6 +7,8 @@ from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.config import get_config, save_config from open_webui.config import BannerModel +from open_webui.utils.tools import get_tool_server_data, get_tool_servers_data + router = APIRouter() @@ -66,6 +68,75 @@ async def set_direct_connections_config( } +############################ +# ToolServers Config +############################ + + +class ToolServerConnection(BaseModel): + url: str + path: str + auth_type: Optional[str] + key: Optional[str] + config: Optional[dict] + + model_config = ConfigDict(extra="allow") + + +class ToolServersConfigForm(BaseModel): + TOOL_SERVER_CONNECTIONS: list[ToolServerConnection] + + +@router.get("/tool_servers", response_model=ToolServersConfigForm) +async def get_tool_servers_config(request: Request, user=Depends(get_admin_user)): + return { + "TOOL_SERVER_CONNECTIONS": request.app.state.config.TOOL_SERVER_CONNECTIONS, + } + + +@router.post("/tool_servers", response_model=ToolServersConfigForm) +async def set_tool_servers_config( + request: Request, + form_data: ToolServersConfigForm, + user=Depends(get_admin_user), +): + request.app.state.config.TOOL_SERVER_CONNECTIONS = [ + connection.model_dump() for connection in form_data.TOOL_SERVER_CONNECTIONS + ] + + request.app.state.TOOL_SERVERS = await get_tool_servers_data( + request.app.state.config.TOOL_SERVER_CONNECTIONS + ) + + return { + "TOOL_SERVER_CONNECTIONS": request.app.state.config.TOOL_SERVER_CONNECTIONS, + } + + +@router.post("/tool_servers/verify") +async def verify_tool_servers_config( + request: Request, form_data: ToolServerConnection, user=Depends(get_admin_user) +): + """ + Verify the connection to the tool server. + """ + try: + + token = None + if form_data.auth_type == "bearer": + token = form_data.key + elif form_data.auth_type == "session": + token = request.state.token.credentials + + url = f"{form_data.url}/{form_data.path}" + return await get_tool_server_data(token, url) + except Exception as e: + raise HTTPException( + status_code=400, + detail=f"Failed to connect to the tool server: {str(e)}", + ) + + ############################ # CodeInterpreterConfig ############################ diff --git a/backend/open_webui/routers/ollama.py b/backend/open_webui/routers/ollama.py index fcb263d1e..775cd0446 100644 --- a/backend/open_webui/routers/ollama.py +++ b/backend/open_webui/routers/ollama.py @@ -1197,7 +1197,7 @@ class OpenAIChatMessageContent(BaseModel): class OpenAIChatMessage(BaseModel): role: str - content: Union[str, list[OpenAIChatMessageContent]] + content: Union[Optional[str], list[OpenAIChatMessageContent]] model_config = ConfigDict(extra="allow") diff --git a/backend/open_webui/routers/retrieval.py b/backend/open_webui/routers/retrieval.py index 250d27eb3..6f71e11d3 100644 --- a/backend/open_webui/routers/retrieval.py +++ b/backend/open_webui/routers/retrieval.py @@ -1534,8 +1534,13 @@ def query_doc_handler( ): try: if request.app.state.config.ENABLE_RAG_HYBRID_SEARCH: + collection_results = {} + collection_results[form_data.collection_name] = VECTOR_DB_CLIENT.get( + collection_name=form_data.collection_name + ) return query_doc_with_hybrid_search( collection_name=form_data.collection_name, + collection_result=collection_results[form_data.collection_name], query=form_data.query, embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION( query, prefix=prefix, user=user diff --git a/backend/open_webui/routers/tools.py b/backend/open_webui/routers/tools.py index 211264cde..8a98b4e20 100644 --- a/backend/open_webui/routers/tools.py +++ b/backend/open_webui/routers/tools.py @@ -1,6 +1,7 @@ import logging from pathlib import Path from typing import Optional +import time from open_webui.models.tools import ( ToolForm, @@ -18,6 +19,8 @@ from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.utils.access_control import has_access, has_permission from open_webui.env import SRC_LOG_LEVELS +from open_webui.utils.tools import get_tool_servers_data + log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MAIN"]) @@ -30,11 +33,51 @@ router = APIRouter() @router.get("/", response_model=list[ToolUserResponse]) -async def get_tools(user=Depends(get_verified_user)): - if user.role == "admin": - tools = Tools.get_tools() - else: - tools = Tools.get_tools_by_user_id(user.id, "read") +async def get_tools(request: Request, user=Depends(get_verified_user)): + + if not request.app.state.TOOL_SERVERS: + # If the tool servers are not set, we need to set them + # This is done only once when the server starts + # This is done to avoid loading the tool servers every time + + request.app.state.TOOL_SERVERS = await get_tool_servers_data( + request.app.state.config.TOOL_SERVER_CONNECTIONS + ) + + tools = Tools.get_tools() + for idx, server in enumerate(request.app.state.TOOL_SERVERS): + tools.append( + ToolUserResponse( + **{ + "id": f"server:{server['idx']}", + "user_id": f"server:{server['idx']}", + "name": server["openapi"] + .get("info", {}) + .get("title", "Tool Server"), + "meta": { + "description": server["openapi"] + .get("info", {}) + .get("description", ""), + }, + "access_control": request.app.state.config.TOOL_SERVER_CONNECTIONS[ + idx + ] + .get("config", {}) + .get("access_control", None), + "updated_at": int(time.time()), + "created_at": int(time.time()), + } + ) + ) + + if user.role != "admin": + tools = [ + tool + for tool in tools + if tool.user_id == user.id + or has_access(user.id, "read", tool.access_control) + ] + return tools diff --git a/backend/open_webui/utils/auth.py b/backend/open_webui/utils/auth.py index 54ad6a0bf..118ac049e 100644 --- a/backend/open_webui/utils/auth.py +++ b/backend/open_webui/utils/auth.py @@ -143,12 +143,14 @@ def create_api_key(): return f"sk-{key}" -def get_http_authorization_cred(auth_header: str): +def get_http_authorization_cred(auth_header: Optional[str]): + if not auth_header: + return None try: scheme, credentials = auth_header.split(" ") return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials) except Exception: - raise ValueError(ERROR_MESSAGES.INVALID_TOKEN) + return None def get_current_user( @@ -182,7 +184,12 @@ def get_current_user( ).split(",") ] - if request.url.path not in allowed_paths: + # Check if the request path matches any allowed endpoint. + if not any( + request.url.path == allowed + or request.url.path.startswith(allowed + "/") + for allowed in allowed_paths + ): raise HTTPException( status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.API_KEY_NOT_ALLOWED ) diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index 72f1f30ce..62f43a702 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -221,13 +221,23 @@ async def chat_completion_tools_handler( except Exception as e: tool_result = str(e) + tool_result_files = [] + if isinstance(tool_result, list): + for item in tool_result: + # check if string + if isinstance(item, str) and item.startswith("data:"): + tool_result_files.append(item) + tool_result.remove(item) + if isinstance(tool_result, dict) or isinstance(tool_result, list): tool_result = json.dumps(tool_result, indent=2) if isinstance(tool_result, str): tool = tools[tool_function_name] - tool_id = tool.get("toolkit_id", "") - if tool.get("citation", False) or tool.get("direct", False): + tool_id = tool.get("tool_id", "") + if tool.get("metadata", {}).get("citation", False) or tool.get( + "direct", False + ): sources.append( { @@ -238,7 +248,7 @@ async def chat_completion_tools_handler( else f"{tool_function_name}" ), }, - "document": [tool_result], + "document": [tool_result, *tool_result_files], "metadata": [ { "source": ( @@ -254,7 +264,7 @@ async def chat_completion_tools_handler( sources.append( { "source": {}, - "document": [tool_result], + "document": [tool_result, *tool_result_files], "metadata": [ { "source": ( @@ -267,7 +277,11 @@ async def chat_completion_tools_handler( } ) - if tools[tool_function_name].get("file_handler", False): + if ( + tools[tool_function_name] + .get("metadata", {}) + .get("file_handler", False) + ): skip_files = True # check if "tool_calls" in result @@ -1906,7 +1920,8 @@ async def process_chat_response( tool_result_files = [] if isinstance(tool_result, list): for item in tool_result: - if item.startswith("data:"): + # check if string + if isinstance(item, str) and item.startswith("data:"): tool_result_files.append(item) tool_result.remove(item) diff --git a/backend/open_webui/utils/plugin.py b/backend/open_webui/utils/plugin.py index 29a4d0cce..f0746da77 100644 --- a/backend/open_webui/utils/plugin.py +++ b/backend/open_webui/utils/plugin.py @@ -68,23 +68,23 @@ def replace_imports(content): return content -def load_tools_module_by_id(toolkit_id, content=None): +def load_tools_module_by_id(tool_id, content=None): if content is None: - tool = Tools.get_tool_by_id(toolkit_id) + tool = Tools.get_tool_by_id(tool_id) if not tool: - raise Exception(f"Toolkit not found: {toolkit_id}") + raise Exception(f"Toolkit not found: {tool_id}") content = tool.content content = replace_imports(content) - Tools.update_tool_by_id(toolkit_id, {"content": content}) + Tools.update_tool_by_id(tool_id, {"content": content}) else: frontmatter = extract_frontmatter(content) # Install required packages found within the frontmatter install_frontmatter_requirements(frontmatter.get("requirements", "")) - module_name = f"tool_{toolkit_id}" + module_name = f"tool_{tool_id}" module = types.ModuleType(module_name) sys.modules[module_name] = module @@ -108,7 +108,7 @@ def load_tools_module_by_id(toolkit_id, content=None): else: raise Exception("No Tools class found in the module") except Exception as e: - log.error(f"Error loading module: {toolkit_id}: {e}") + log.error(f"Error loading module: {tool_id}: {e}") del sys.modules[module_name] # Clean up raise e finally: diff --git a/backend/open_webui/utils/tools.py b/backend/open_webui/utils/tools.py index bd2a731e6..60311a690 100644 --- a/backend/open_webui/utils/tools.py +++ b/backend/open_webui/utils/tools.py @@ -2,9 +2,10 @@ import inspect import logging import re import inspect -import uuid +import aiohttp +import asyncio -from typing import Any, Awaitable, Callable, get_type_hints +from typing import Any, Awaitable, Callable, get_type_hints, Dict, List, Union, Optional from functools import update_wrapper, partial @@ -17,96 +18,162 @@ from open_webui.models.tools import Tools from open_webui.models.users import UserModel from open_webui.utils.plugin import load_tools_module_by_id +import copy + log = logging.getLogger(__name__) -def apply_extra_params_to_tool_function( +def get_async_tool_function_and_apply_extra_params( function: Callable, extra_params: dict ) -> Callable[..., Awaitable]: sig = inspect.signature(function) extra_params = {k: v for k, v in extra_params.items() if k in sig.parameters} partial_func = partial(function, **extra_params) + if inspect.iscoroutinefunction(function): update_wrapper(partial_func, function) return partial_func + else: + # Make it a coroutine function + async def new_function(*args, **kwargs): + return partial_func(*args, **kwargs) - async def new_function(*args, **kwargs): - return partial_func(*args, **kwargs) - - update_wrapper(new_function, function) - return new_function + update_wrapper(new_function, function) + return new_function -# Mutation on extra_params def get_tools( request: Request, tool_ids: list[str], user: UserModel, extra_params: dict ) -> dict[str, dict]: tools_dict = {} for tool_id in tool_ids: - tools = Tools.get_tool_by_id(tool_id) - if tools is None: - continue + tool = Tools.get_tool_by_id(tool_id) + if tool is None: + if tool_id.startswith("server:"): + server_idx = int(tool_id.split(":")[1]) + tool_server_connection = ( + request.app.state.config.TOOL_SERVER_CONNECTIONS[server_idx] + ) + tool_server_data = request.app.state.TOOL_SERVERS[server_idx] + specs = tool_server_data.get("specs", []) - module = request.app.state.TOOLS.get(tool_id, None) - if module is None: - module, _ = load_tools_module_by_id(tool_id) - request.app.state.TOOLS[tool_id] = module + for spec in specs: + function_name = spec["name"] - extra_params["__id__"] = tool_id - if hasattr(module, "valves") and hasattr(module, "Valves"): - valves = Tools.get_tool_valves_by_id(tool_id) or {} - module.valves = module.Valves(**valves) + auth_type = tool_server_connection.get("auth_type", "bearer") + token = None - if hasattr(module, "UserValves"): - extra_params["__user__"]["valves"] = module.UserValves( # type: ignore - **Tools.get_user_valves_by_id_and_user_id(tool_id, user.id) - ) + if auth_type == "bearer": + token = tool_server_connection.get("key", "") + elif auth_type == "session": + token = request.state.token.credentials - for spec in tools.specs: - # TODO: Fix hack for OpenAI API - # Some times breaks OpenAI but others don't. Leaving the comment - for val in spec.get("parameters", {}).get("properties", {}).values(): - if val["type"] == "str": - val["type"] = "string" + def make_tool_function(function_name, token, tool_server_data): + async def tool_function(**kwargs): + print( + f"Executing tool function {function_name} with params: {kwargs}" + ) + return await execute_tool_server( + token=token, + url=tool_server_data["url"], + name=function_name, + params=kwargs, + server_data=tool_server_data, + ) - # Remove internal parameters - spec["parameters"]["properties"] = { - key: val - for key, val in spec["parameters"]["properties"].items() - if not key.startswith("__") - } + return tool_function - function_name = spec["name"] + tool_function = make_tool_function( + function_name, token, tool_server_data + ) - # convert to function that takes only model params and inserts custom params - original_func = getattr(module, function_name) - callable = apply_extra_params_to_tool_function(original_func, extra_params) + callable = get_async_tool_function_and_apply_extra_params( + tool_function, + {}, + ) - if callable.__doc__ and callable.__doc__.strip() != "": - s = re.split(":(param|return)", callable.__doc__, 1) - spec["description"] = s[0] + tool_dict = { + "tool_id": tool_id, + "callable": callable, + "spec": spec, + } + + # TODO: if collision, prepend toolkit name + if function_name in tools_dict: + log.warning( + f"Tool {function_name} already exists in another tools!" + ) + log.warning(f"Discarding {tool_id}.{function_name}") + else: + tools_dict[function_name] = tool_dict else: - spec["description"] = function_name + continue + else: + module = request.app.state.TOOLS.get(tool_id, None) + if module is None: + module, _ = load_tools_module_by_id(tool_id) + request.app.state.TOOLS[tool_id] = module - # TODO: This needs to be a pydantic model - tool_dict = { - "spec": spec, - "callable": callable, - "toolkit_id": tool_id, - "pydantic_model": function_to_pydantic_model(callable), - # Misc info - "file_handler": hasattr(module, "file_handler") and module.file_handler, - "citation": hasattr(module, "citation") and module.citation, - } + extra_params["__id__"] = tool_id - # TODO: if collision, prepend toolkit name - if function_name in tools_dict: - log.warning(f"Tool {function_name} already exists in another tools!") - log.warning(f"Collision between {tools} and {tool_id}.") - log.warning(f"Discarding {tools}.{function_name}") - else: - tools_dict[function_name] = tool_dict + # Set valves for the tool + if hasattr(module, "valves") and hasattr(module, "Valves"): + valves = Tools.get_tool_valves_by_id(tool_id) or {} + module.valves = module.Valves(**valves) + if hasattr(module, "UserValves"): + extra_params["__user__"]["valves"] = module.UserValves( # type: ignore + **Tools.get_user_valves_by_id_and_user_id(tool_id, user.id) + ) + + for spec in tool.specs: + # TODO: Fix hack for OpenAI API + # Some times breaks OpenAI but others don't. Leaving the comment + for val in spec.get("parameters", {}).get("properties", {}).values(): + if val["type"] == "str": + val["type"] = "string" + + # Remove internal reserved parameters (e.g. __id__, __user__) + spec["parameters"]["properties"] = { + key: val + for key, val in spec["parameters"]["properties"].items() + if not key.startswith("__") + } + + # convert to function that takes only model params and inserts custom params + function_name = spec["name"] + tool_function = getattr(module, function_name) + callable = get_async_tool_function_and_apply_extra_params( + tool_function, extra_params + ) + + # TODO: Support Pydantic models as parameters + if callable.__doc__ and callable.__doc__.strip() != "": + s = re.split(":(param|return)", callable.__doc__, 1) + spec["description"] = s[0] + else: + spec["description"] = function_name + + tool_dict = { + "tool_id": tool_id, + "callable": callable, + "spec": spec, + # Misc info + "metadata": { + "file_handler": hasattr(module, "file_handler") + and module.file_handler, + "citation": hasattr(module, "citation") and module.citation, + }, + } + + # TODO: if collision, prepend toolkit name + if function_name in tools_dict: + log.warning( + f"Tool {function_name} already exists in another tools!" + ) + log.warning(f"Discarding {tool_id}.{function_name}") + else: + tools_dict[function_name] = tool_dict return tools_dict @@ -214,6 +281,271 @@ def get_callable_attributes(tool: object) -> list[Callable]: def get_tools_specs(tool_class: object) -> list[dict]: - function_list = get_callable_attributes(tool_class) - models = map(function_to_pydantic_model, function_list) - return [convert_to_openai_function(tool) for tool in models] + function_model_list = map( + function_to_pydantic_model, get_callable_attributes(tool_class) + ) + return [ + convert_to_openai_function(function_model) + for function_model in function_model_list + ] + + +def resolve_schema(schema, components): + """ + Recursively resolves a JSON schema using OpenAPI components. + """ + if not schema: + return {} + + if "$ref" in schema: + ref_path = schema["$ref"] + ref_parts = ref_path.strip("#/").split("/") + resolved = components + for part in ref_parts[1:]: # Skip the initial 'components' + resolved = resolved.get(part, {}) + return resolve_schema(resolved, components) + + resolved_schema = copy.deepcopy(schema) + + # Recursively resolve inner schemas + if "properties" in resolved_schema: + for prop, prop_schema in resolved_schema["properties"].items(): + resolved_schema["properties"][prop] = resolve_schema( + prop_schema, components + ) + + if "items" in resolved_schema: + resolved_schema["items"] = resolve_schema(resolved_schema["items"], components) + + return resolved_schema + + +def convert_openapi_to_tool_payload(openapi_spec): + """ + Converts an OpenAPI specification into a custom tool payload structure. + + Args: + openapi_spec (dict): The OpenAPI specification as a Python dict. + + Returns: + list: A list of tool payloads. + """ + tool_payload = [] + + for path, methods in openapi_spec.get("paths", {}).items(): + for method, operation in methods.items(): + tool = { + "type": "function", + "name": operation.get("operationId"), + "description": operation.get("summary", "No description available."), + "parameters": {"type": "object", "properties": {}, "required": []}, + } + + # Extract path and query parameters + for param in operation.get("parameters", []): + param_name = param["name"] + param_schema = param.get("schema", {}) + tool["parameters"]["properties"][param_name] = { + "type": param_schema.get("type"), + "description": param_schema.get("description", ""), + } + if param.get("required"): + tool["parameters"]["required"].append(param_name) + + # Extract and resolve requestBody if available + request_body = operation.get("requestBody") + if request_body: + content = request_body.get("content", {}) + json_schema = content.get("application/json", {}).get("schema") + if json_schema: + resolved_schema = resolve_schema( + json_schema, openapi_spec.get("components", {}) + ) + + if resolved_schema.get("properties"): + tool["parameters"]["properties"].update( + resolved_schema["properties"] + ) + if "required" in resolved_schema: + tool["parameters"]["required"] = list( + set( + tool["parameters"]["required"] + + resolved_schema["required"] + ) + ) + elif resolved_schema.get("type") == "array": + tool["parameters"] = resolved_schema # special case for array + + tool_payload.append(tool) + + return tool_payload + + +async def get_tool_server_data(token: str, url: str) -> Dict[str, Any]: + headers = { + "Accept": "application/json", + "Content-Type": "application/json", + } + if token: + headers["Authorization"] = f"Bearer {token}" + + error = None + try: + async with aiohttp.ClientSession() as session: + async with session.get(url, headers=headers) as response: + if response.status != 200: + error_body = await response.json() + raise Exception(error_body) + res = await response.json() + except Exception as err: + print("Error:", err) + if isinstance(err, dict) and "detail" in err: + error = err["detail"] + else: + error = str(err) + raise Exception(error) + + data = { + "openapi": res, + "info": res.get("info", {}), + "specs": convert_openapi_to_tool_payload(res), + } + + print("Fetched data:", data) + return data + + +async def get_tool_servers_data( + servers: List[Dict[str, Any]], session_token: Optional[str] = None +) -> List[Dict[str, Any]]: + # Prepare list of enabled servers along with their original index + server_entries = [] + for idx, server in enumerate(servers): + if server.get("config", {}).get("enable"): + url_path = server.get("path", "openapi.json") + full_url = f"{server.get('url')}/{url_path}" + + auth_type = server.get("auth_type", "bearer") + token = None + + if auth_type == "bearer": + token = server.get("key", "") + elif auth_type == "session": + token = session_token + server_entries.append((idx, server, full_url, token)) + + # Create async tasks to fetch data + tasks = [get_tool_server_data(token, url) for (_, _, url, token) in server_entries] + + # Execute tasks concurrently + responses = await asyncio.gather(*tasks, return_exceptions=True) + + # Build final results with index and server metadata + results = [] + for (idx, server, url, _), response in zip(server_entries, responses): + if isinstance(response, Exception): + print(f"Failed to connect to {url} OpenAPI tool server") + continue + + results.append( + { + "idx": idx, + "url": server.get("url"), + "openapi": response.get("openapi"), + "info": response.get("info"), + "specs": response.get("specs"), + } + ) + + return results + + +async def execute_tool_server( + token: str, url: str, name: str, params: Dict[str, Any], server_data: Dict[str, Any] +) -> Any: + error = None + try: + openapi = server_data.get("openapi", {}) + paths = openapi.get("paths", {}) + + matching_route = None + for route_path, methods in paths.items(): + for http_method, operation in methods.items(): + if isinstance(operation, dict) and operation.get("operationId") == name: + matching_route = (route_path, methods) + break + if matching_route: + break + + if not matching_route: + raise Exception(f"No matching route found for operationId: {name}") + + route_path, methods = matching_route + + method_entry = None + for http_method, operation in methods.items(): + if operation.get("operationId") == name: + method_entry = (http_method.lower(), operation) + break + + if not method_entry: + raise Exception(f"No matching method found for operationId: {name}") + + http_method, operation = method_entry + + path_params = {} + query_params = {} + body_params = {} + + for param in operation.get("parameters", []): + param_name = param["name"] + param_in = param["in"] + if param_name in params: + if param_in == "path": + path_params[param_name] = params[param_name] + elif param_in == "query": + query_params[param_name] = params[param_name] + + final_url = f"{url}{route_path}" + for key, value in path_params.items(): + final_url = final_url.replace(f"{{{key}}}", str(value)) + + if query_params: + query_string = "&".join(f"{k}={v}" for k, v in query_params.items()) + final_url = f"{final_url}?{query_string}" + + if operation.get("requestBody", {}).get("content"): + if params: + body_params = params + else: + raise Exception( + f"Request body expected for operation '{name}' but none found." + ) + + headers = {"Content-Type": "application/json"} + + if token: + headers["Authorization"] = f"Bearer {token}" + + async with aiohttp.ClientSession() as session: + request_method = getattr(session, http_method.lower()) + + if http_method in ["post", "put", "patch"]: + async with request_method( + final_url, json=body_params, headers=headers + ) as response: + if response.status >= 400: + text = await response.text() + raise Exception(f"HTTP error {response.status}: {text}") + return await response.json() + else: + async with request_method(final_url, headers=headers) as response: + if response.status >= 400: + text = await response.text() + raise Exception(f"HTTP error {response.status}: {text}") + return await response.json() + + except Exception as err: + error = str(err) + print("API Request Error:", error) + return {"error": error} diff --git a/package-lock.json b/package-lock.json index 6eb5064d7..d3de96f8e 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "open-webui", - "version": "0.6.0", + "version": "0.6.1", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "open-webui", - "version": "0.6.0", + "version": "0.6.1", "dependencies": { "@azure/msal-browser": "^4.5.0", "@codemirror/lang-javascript": "^6.2.2", diff --git a/package.json b/package.json index 465fbba0f..9e6396015 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "open-webui", - "version": "0.6.0", + "version": "0.6.1", "private": true, "scripts": { "dev": "npm run pyodide:fetch && vite dev --host", diff --git a/pyproject.toml b/pyproject.toml index 2e8537a77..4c420af79 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,7 +45,7 @@ dependencies = [ "openai", "anthropic", - "google-generativeai==0.7.2", + "google-generativeai==0.8.4", "tiktoken", "langchain==0.3.19", @@ -62,7 +62,7 @@ dependencies = [ "transformers", "sentence-transformers==3.3.1", "colbert-ai==0.2.21", - "einops==0.8.0", + "einops==0.8.1", "ftfy==6.2.3", "pypdf==4.3.1", @@ -73,7 +73,7 @@ dependencies = [ "unstructured==0.16.17", "nltk==3.9.1", "Markdown==3.7", - "pypandoc==1.13", + "pypandoc==1.15", "pandas==2.2.3", "openpyxl==3.1.5", "pyxlsb==1.0.10", @@ -89,6 +89,8 @@ dependencies = [ "rapidocr-onnxruntime==1.3.24", "rank-bm25==0.2.2", + "onnxruntime==1.20.1", + "faster-whisper==1.1.1", "PyJWT[crypto]==2.10.1", diff --git a/src/lib/apis/configs/index.ts b/src/lib/apis/configs/index.ts index f7f02c740..5872303f6 100644 --- a/src/lib/apis/configs/index.ts +++ b/src/lib/apis/configs/index.ts @@ -115,6 +115,93 @@ export const setDirectConnectionsConfig = async (token: string, config: object) return res; }; +export const getToolServerConnections = async (token: string) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/configs/tool_servers`, { + method: 'GET', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + error = err.detail; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const setToolServerConnections = async (token: string, connections: object) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/configs/tool_servers`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + ...connections + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + error = err.detail; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const verifyToolServerConnection = async (token: string, connection: object) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/configs/tool_servers/verify`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + ...connection + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + error = err.detail; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + export const getCodeExecutionConfig = async (token: string) => { let error = null; diff --git a/src/lib/apis/index.ts b/src/lib/apis/index.ts index 015e1272a..cdd6887b2 100644 --- a/src/lib/apis/index.ts +++ b/src/lib/apis/index.ts @@ -262,7 +262,7 @@ export const stopTask = async (token: string, id: string) => { export const getToolServerData = async (token: string, url: string) => { let error = null; - const res = await fetch(`${url}/openapi.json`, { + const res = await fetch(`${url}`, { method: 'GET', headers: { Accept: 'application/json', @@ -304,10 +304,13 @@ export const getToolServersData = async (i18n, servers: object[]) => { servers .filter((server) => server?.config?.enable) .map(async (server) => { - const data = await getToolServerData(server?.key, server?.url).catch((err) => { + const data = await getToolServerData( + server?.key, + server?.url + '/' + (server?.path ?? 'openapi.json') + ).catch((err) => { toast.error( i18n.t(`Failed to connect to {{URL}} OpenAPI tool server`, { - URL: server?.url + URL: server?.url + '/' + (server?.path ?? 'openapi.json') }) ); return null; diff --git a/src/lib/components/AddServerModal.svelte b/src/lib/components/AddServerModal.svelte index d0f79c576..1ce7369e4 100644 --- a/src/lib/components/AddServerModal.svelte +++ b/src/lib/components/AddServerModal.svelte @@ -15,6 +15,9 @@ import Tooltip from '$lib/components/common/Tooltip.svelte'; import Switch from '$lib/components/common/Switch.svelte'; import Tags from './common/Tags.svelte'; + import { getToolServerData } from '$lib/apis'; + import { verifyToolServerConnection } from '$lib/apis/configs'; + import AccessControl from './workspace/common/AccessControl.svelte'; export let onSubmit: Function = () => {}; export let onDelete: Function = () => {}; @@ -22,18 +25,66 @@ export let show = false; export let edit = false; + export let direct = false; + export let connection = null; let url = ''; - let path = '/openapi.json'; + let path = 'openapi.json'; let auth_type = 'bearer'; let key = ''; + let accessControl = null; + let enable = true; let loading = false; + const verifyHandler = async () => { + if (url === '') { + toast.error($i18n.t('Please enter a valid URL')); + return; + } + + if (path === '') { + toast.error($i18n.t('Please enter a valid path')); + return; + } + + if (direct) { + const res = await getToolServerData( + auth_type === 'bearer' ? key : localStorage.token, + `${url}/${path}` + ).catch((err) => { + toast.error($i18n.t('Connection failed')); + }); + + if (res) { + toast.success($i18n.t('Connection successful')); + console.debug('Connection successful', res); + } + } else { + const res = await verifyToolServerConnection(localStorage.token, { + url, + path, + auth_type, + key, + config: { + enable: enable, + access_control: accessControl + } + }).catch((err) => { + toast.error($i18n.t('Connection failed')); + }); + + if (res) { + toast.success($i18n.t('Connection successful')); + console.debug('Connection successful', res); + } + } + }; + const submitHandler = async () => { loading = true; @@ -46,7 +97,8 @@ auth_type, key, config: { - enable: enable + enable: enable, + access_control: accessControl } }; @@ -56,22 +108,24 @@ show = false; url = ''; + path = 'openapi.json'; key = ''; - path = '/openapi.json'; auth_type = 'bearer'; enable = true; + accessControl = null; }; const init = () => { if (connection) { url = connection.url; - path = connection?.path ?? '/openapi.json'; + path = connection?.path ?? 'openapi.json'; auth_type = connection?.auth_type ?? 'bearer'; key = connection?.key ?? ''; enable = connection.config?.enable ?? true; + accessControl = connection.config?.access_control ?? null; } }; @@ -125,20 +179,53 @@
-
{$i18n.t('URL')}
+
+
{$i18n.t('URL')}
+
-
+
+ + + + + + + +
-
+
+
/
- -
- - - -
- {$i18n.t(`WebUI will make requests to "{{url}}{{path}}"`, { - url: url, - path: path + {$i18n.t(`WebUI will make requests to "{{url}}"`, { + url: `${url}/${path}` })}
@@ -171,7 +251,7 @@