diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index bd4ffcbf1..fe59a4e9b 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -3,6 +3,7 @@ import logging import os import shutil import base64 +import redis from datetime import datetime from pathlib import Path @@ -17,6 +18,7 @@ from open_webui.env import ( DATA_DIR, DATABASE_URL, ENV, + REDIS_URL, FRONTEND_BUILD_DIR, OFFLINE_MODE, OPEN_WEBUI_DIR, @@ -248,9 +250,14 @@ class PersistentConfig(Generic[T]): class AppConfig: _state: dict[str, PersistentConfig] + _redis: Optional[redis.Redis] = None - def __init__(self): + def __init__(self, redis_url: Optional[str] = None): super().__setattr__("_state", {}) + if redis_url: + super().__setattr__( + "_redis", redis.Redis.from_url(redis_url, decode_responses=True) + ) def __setattr__(self, key, value): if isinstance(value, PersistentConfig): @@ -259,7 +266,31 @@ class AppConfig: self._state[key].value = value self._state[key].save() + if self._redis: + redis_key = f"open-webui:config:{key}" + self._redis.set(redis_key, json.dumps(self._state[key].value)) + def __getattr__(self, key): + if key not in self._state: + raise AttributeError(f"Config key '{key}' not found") + + # If Redis is available, check for an updated value + if self._redis: + redis_key = f"open-webui:config:{key}" + redis_value = self._redis.get(redis_key) + + if redis_value is not None: + try: + decoded_value = json.loads(redis_value) + + # Update the in-memory value if different + if self._state[key].value != decoded_value: + self._state[key].value = decoded_value + log.info(f"Updated {key} from Redis: {decoded_value}") + + except json.JSONDecodeError: + log.error(f"Invalid JSON format in Redis for {key}: {redis_value}") + return self._state[key].value @@ -1956,6 +1987,12 @@ TAVILY_API_KEY = PersistentConfig( os.getenv("TAVILY_API_KEY", ""), ) +TAVILY_EXTRACT_DEPTH = PersistentConfig( + "TAVILY_EXTRACT_DEPTH", + "rag.web.search.tavily_extract_depth", + os.getenv("TAVILY_EXTRACT_DEPTH", "basic"), +) + JINA_API_KEY = PersistentConfig( "JINA_API_KEY", "rag.web.search.jina_api_key", diff --git a/backend/open_webui/env.py b/backend/open_webui/env.py index 3b3d6d4f3..2abf65924 100644 --- a/backend/open_webui/env.py +++ b/backend/open_webui/env.py @@ -330,7 +330,7 @@ ENABLE_REALTIME_CHAT_SAVE = ( # REDIS #################################### -REDIS_URL = os.environ.get("REDIS_URL", "redis://localhost:6379/0") +REDIS_URL = os.environ.get("REDIS_URL", "") #################################### # WEBUI_AUTH (Required for security) diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 0ace155eb..a453df0d7 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -213,6 +213,7 @@ from open_webui.config import ( SERPSTACK_API_KEY, SERPSTACK_HTTPS, TAVILY_API_KEY, + TAVILY_EXTRACT_DEPTH, BING_SEARCH_V7_ENDPOINT, BING_SEARCH_V7_SUBSCRIPTION_KEY, BRAVE_SEARCH_API_KEY, @@ -313,6 +314,7 @@ from open_webui.env import ( AUDIT_EXCLUDED_PATHS, AUDIT_LOG_LEVEL, CHANGELOG, + REDIS_URL, GLOBAL_LOG_LEVEL, MAX_BODY_LOG_SIZE, SAFE_MODE, @@ -419,7 +421,7 @@ app = FastAPI( oauth_manager = OAuthManager(app) -app.state.config = AppConfig() +app.state.config = AppConfig(redis_url=REDIS_URL) app.state.WEBUI_NAME = WEBUI_NAME app.state.LICENSE_METADATA = None @@ -616,6 +618,7 @@ app.state.config.RAG_WEB_SEARCH_TRUST_ENV = RAG_WEB_SEARCH_TRUST_ENV app.state.config.PLAYWRIGHT_WS_URI = PLAYWRIGHT_WS_URI app.state.config.FIRECRAWL_API_BASE_URL = FIRECRAWL_API_BASE_URL app.state.config.FIRECRAWL_API_KEY = FIRECRAWL_API_KEY +app.state.config.TAVILY_EXTRACT_DEPTH = TAVILY_EXTRACT_DEPTH app.state.EMBEDDING_FUNCTION = None app.state.ef = None diff --git a/backend/open_webui/retrieval/loaders/tavily.py b/backend/open_webui/retrieval/loaders/tavily.py new file mode 100644 index 000000000..b96396eba --- /dev/null +++ b/backend/open_webui/retrieval/loaders/tavily.py @@ -0,0 +1,98 @@ +import requests +import logging +from typing import Iterator, List, Literal, Union + +from langchain_core.document_loaders import BaseLoader +from langchain_core.documents import Document +from open_webui.env import SRC_LOG_LEVELS + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["RAG"]) + +class TavilyLoader(BaseLoader): + """Extract web page content from URLs using Tavily Extract API. + + This is a LangChain document loader that uses Tavily's Extract API to + retrieve content from web pages and return it as Document objects. + + Args: + urls: URL or list of URLs to extract content from. + api_key: The Tavily API key. + extract_depth: Depth of extraction, either "basic" or "advanced". + continue_on_failure: Whether to continue if extraction of a URL fails. + """ + def __init__( + self, + urls: Union[str, List[str]], + api_key: str, + extract_depth: Literal["basic", "advanced"] = "basic", + continue_on_failure: bool = True, + ) -> None: + """Initialize Tavily Extract client. + + Args: + urls: URL or list of URLs to extract content from. + api_key: The Tavily API key. + include_images: Whether to include images in the extraction. + extract_depth: Depth of extraction, either "basic" or "advanced". + advanced extraction retrieves more data, including tables and + embedded content, with higher success but may increase latency. + basic costs 1 credit per 5 successful URL extractions, + advanced costs 2 credits per 5 successful URL extractions. + continue_on_failure: Whether to continue if extraction of a URL fails. + """ + if not urls: + raise ValueError("At least one URL must be provided.") + + self.api_key = api_key + self.urls = urls if isinstance(urls, list) else [urls] + self.extract_depth = extract_depth + self.continue_on_failure = continue_on_failure + self.api_url = "https://api.tavily.com/extract" + + def lazy_load(self) -> Iterator[Document]: + """Extract and yield documents from the URLs using Tavily Extract API.""" + batch_size = 20 + for i in range(0, len(self.urls), batch_size): + batch_urls = self.urls[i:i + batch_size] + try: + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {self.api_key}" + } + # Use string for single URL, array for multiple URLs + urls_param = batch_urls[0] if len(batch_urls) == 1 else batch_urls + payload = { + "urls": urls_param, + "extract_depth": self.extract_depth + } + # Make the API call + response = requests.post( + self.api_url, + headers=headers, + json=payload + ) + response.raise_for_status() + response_data = response.json() + # Process successful results + for result in response_data.get("results", []): + url = result.get("url", "") + content = result.get("raw_content", "") + if not content: + log.warning(f"No content extracted from {url}") + continue + # Add URLs as metadata + metadata = {"source": url} + yield Document( + page_content=content, + metadata=metadata, + ) + for failed in response_data.get("failed_results", []): + url = failed.get("url", "") + error = failed.get("error", "Unknown error") + log.error(f"Failed to extract content from {url}: {error}") + except Exception as e: + if self.continue_on_failure: + log.error(f"Error extracting content from batch {batch_urls}: {e}") + else: + raise e \ No newline at end of file diff --git a/backend/open_webui/retrieval/vector/dbs/opensearch.py b/backend/open_webui/retrieval/vector/dbs/opensearch.py index 2629bfcba..4844f7d4e 100644 --- a/backend/open_webui/retrieval/vector/dbs/opensearch.py +++ b/backend/open_webui/retrieval/vector/dbs/opensearch.py @@ -1,4 +1,5 @@ from opensearchpy import OpenSearch +from opensearchpy.helpers import bulk from typing import Optional from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult @@ -20,8 +21,14 @@ class OpenSearchClient: verify_certs=OPENSEARCH_CERT_VERIFY, http_auth=(OPENSEARCH_USERNAME, OPENSEARCH_PASSWORD), ) + + def _get_index_name(self, collection_name: str) -> str: + return f"{self.index_prefix}_{collection_name}" def _result_to_get_result(self, result) -> GetResult: + if not result["hits"]["hits"]: + return None + ids = [] documents = [] metadatas = [] @@ -31,9 +38,12 @@ class OpenSearchClient: documents.append(hit["_source"].get("text")) metadatas.append(hit["_source"].get("metadata")) - return GetResult(ids=ids, documents=documents, metadatas=metadatas) + return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas]) def _result_to_search_result(self, result) -> SearchResult: + if not result["hits"]["hits"]: + return None + ids = [] distances = [] documents = [] @@ -46,25 +56,32 @@ class OpenSearchClient: metadatas.append(hit["_source"].get("metadata")) return SearchResult( - ids=ids, distances=distances, documents=documents, metadatas=metadatas + ids=[ids], distances=[distances], documents=[documents], metadatas=[metadatas] ) def _create_index(self, collection_name: str, dimension: int): body = { + "settings": { + "index": { + "knn": True + } + }, "mappings": { "properties": { "id": {"type": "keyword"}, "vector": { - "type": "dense_vector", - "dims": dimension, # Adjust based on your vector dimensions - "index": true, + "type": "knn_vector", + "dimension": dimension, # Adjust based on your vector dimensions + "index": True, "similarity": "faiss", "method": { "name": "hnsw", - "space_type": "ip", # Use inner product to approximate cosine similarity + "space_type": "innerproduct", # Use inner product to approximate cosine similarity "engine": "faiss", - "ef_construction": 128, - "m": 16, + "parameters": { + "ef_construction": 128, + "m": 16, + } }, }, "text": {"type": "text"}, @@ -73,7 +90,7 @@ class OpenSearchClient: } } self.client.indices.create( - index=f"{self.index_prefix}_{collection_name}", body=body + index=self._get_index_name(collection_name), body=body ) def _create_batches(self, items: list[VectorItem], batch_size=100): @@ -84,38 +101,49 @@ class OpenSearchClient: # has_collection here means has index. # We are simply adapting to the norms of the other DBs. return self.client.indices.exists( - index=f"{self.index_prefix}_{collection_name}" + index=self._get_index_name(collection_name) ) - def delete_colleciton(self, collection_name: str): + def delete_collection(self, collection_name: str): # delete_collection here means delete index. # We are simply adapting to the norms of the other DBs. - self.client.indices.delete(index=f"{self.index_prefix}_{collection_name}") + self.client.indices.delete(index=self._get_index_name(collection_name)) def search( - self, collection_name: str, vectors: list[list[float]], limit: int + self, collection_name: str, vectors: list[list[float | int]], limit: int ) -> Optional[SearchResult]: - query = { - "size": limit, - "_source": ["text", "metadata"], - "query": { - "script_score": { - "query": {"match_all": {}}, - "script": { - "source": "cosineSimilarity(params.vector, 'vector') + 1.0", - "params": { - "vector": vectors[0] - }, # Assuming single query vector - }, - } - }, - } + try: + if not self.has_collection(collection_name): + return None + + query = { + "size": limit, + "_source": ["text", "metadata"], + "query": { + "script_score": { + "query": { + "match_all": {} + }, + "script": { + "source": "cosineSimilarity(params.query_value, doc[params.field]) + 1.0", + "params": { + "field": "vector", + "query_value": vectors[0] + }, # Assuming single query vector + }, + } + }, + } + + result = self.client.search( + index=self._get_index_name(collection_name), + body=query + ) - result = self.client.search( - index=f"{self.index_prefix}_{collection_name}", body=query - ) - - return self._result_to_search_result(result) + return self._result_to_search_result(result) + + except Exception as e: + return None def query( self, collection_name: str, filter: dict, limit: Optional[int] = None @@ -124,18 +152,26 @@ class OpenSearchClient: return None query_body = { - "query": {"bool": {"filter": []}}, + "query": { + "bool": { + "filter": [] + } + }, "_source": ["text", "metadata"], } for field, value in filter.items(): - query_body["query"]["bool"]["filter"].append({"term": {field: value}}) + query_body["query"]["bool"]["filter"].append({ + "match": { + "metadata." + str(field): value + } + }) size = limit if limit else 10 try: result = self.client.search( - index=f"{self.index_prefix}_{collection_name}", + index=self._get_index_name(collection_name), body=query_body, size=size, ) @@ -146,14 +182,14 @@ class OpenSearchClient: return None def _create_index_if_not_exists(self, collection_name: str, dimension: int): - if not self.has_index(collection_name): + if not self.has_collection(collection_name): self._create_index(collection_name, dimension) def get(self, collection_name: str) -> Optional[GetResult]: query = {"query": {"match_all": {}}, "_source": ["text", "metadata"]} result = self.client.search( - index=f"{self.index_prefix}_{collection_name}", body=query + index=self._get_index_name(collection_name), body=query ) return self._result_to_get_result(result) @@ -165,18 +201,18 @@ class OpenSearchClient: for batch in self._create_batches(items): actions = [ { - "index": { - "_id": item["id"], - "_source": { - "vector": item["vector"], - "text": item["text"], - "metadata": item["metadata"], - }, - } + "_op_type": "index", + "_index": self._get_index_name(collection_name), + "_id": item["id"], + "_source": { + "vector": item["vector"], + "text": item["text"], + "metadata": item["metadata"], + }, } for item in batch ] - self.client.bulk(actions) + bulk(self.client, actions) def upsert(self, collection_name: str, items: list[VectorItem]): self._create_index_if_not_exists( @@ -186,27 +222,47 @@ class OpenSearchClient: for batch in self._create_batches(items): actions = [ { - "index": { - "_id": item["id"], - "_index": f"{self.index_prefix}_{collection_name}", - "_source": { - "vector": item["vector"], - "text": item["text"], - "metadata": item["metadata"], - }, - } + "_op_type": "update", + "_index": self._get_index_name(collection_name), + "_id": item["id"], + "doc": { + "vector": item["vector"], + "text": item["text"], + "metadata": item["metadata"], + }, + "doc_as_upsert": True, } for item in batch ] - self.client.bulk(actions) - - def delete(self, collection_name: str, ids: list[str]): - actions = [ - {"delete": {"_index": f"{self.index_prefix}_{collection_name}", "_id": id}} - for id in ids - ] - self.client.bulk(body=actions) + bulk(self.client, actions) + def delete(self, collection_name: str, ids: Optional[list[str]] = None, filter: Optional[dict] = None): + if ids: + actions = [ + { + "_op_type": "delete", + "_index": self._get_index_name(collection_name), + "_id": id, + } + for id in ids + ] + bulk(self.client, actions) + elif filter: + query_body = { + "query": { + "bool": { + "filter": [] + } + }, + } + for field, value in filter.items(): + query_body["query"]["bool"]["filter"].append({ + "match": { + "metadata." + str(field): value + } + }) + self.client.delete_by_query(index=self._get_index_name(collection_name), body=query_body) + def reset(self): indices = self.client.indices.get(index=f"{self.index_prefix}_*") for index in indices: diff --git a/backend/open_webui/retrieval/web/utils.py b/backend/open_webui/retrieval/web/utils.py index fd94a1a32..65654d8e8 100644 --- a/backend/open_webui/retrieval/web/utils.py +++ b/backend/open_webui/retrieval/web/utils.py @@ -24,6 +24,7 @@ from langchain_community.document_loaders import PlaywrightURLLoader, WebBaseLoa from langchain_community.document_loaders.firecrawl import FireCrawlLoader from langchain_community.document_loaders.base import BaseLoader from langchain_core.documents import Document +from open_webui.retrieval.loaders.tavily import TavilyLoader from open_webui.constants import ERROR_MESSAGES from open_webui.config import ( ENABLE_RAG_LOCAL_WEB_FETCH, @@ -31,6 +32,8 @@ from open_webui.config import ( RAG_WEB_LOADER_ENGINE, FIRECRAWL_API_BASE_URL, FIRECRAWL_API_KEY, + TAVILY_API_KEY, + TAVILY_EXTRACT_DEPTH, ) from open_webui.env import SRC_LOG_LEVELS @@ -113,7 +116,47 @@ def verify_ssl_cert(url: str) -> bool: return False -class SafeFireCrawlLoader(BaseLoader): +class RateLimitMixin: + async def _wait_for_rate_limit(self): + """Wait to respect the rate limit if specified.""" + if self.requests_per_second and self.last_request_time: + min_interval = timedelta(seconds=1.0 / self.requests_per_second) + time_since_last = datetime.now() - self.last_request_time + if time_since_last < min_interval: + await asyncio.sleep((min_interval - time_since_last).total_seconds()) + self.last_request_time = datetime.now() + + def _sync_wait_for_rate_limit(self): + """Synchronous version of rate limit wait.""" + if self.requests_per_second and self.last_request_time: + min_interval = timedelta(seconds=1.0 / self.requests_per_second) + time_since_last = datetime.now() - self.last_request_time + if time_since_last < min_interval: + time.sleep((min_interval - time_since_last).total_seconds()) + self.last_request_time = datetime.now() + + +class URLProcessingMixin: + def _verify_ssl_cert(self, url: str) -> bool: + """Verify SSL certificate for a URL.""" + return verify_ssl_cert(url) + + async def _safe_process_url(self, url: str) -> bool: + """Perform safety checks before processing a URL.""" + if self.verify_ssl and not self._verify_ssl_cert(url): + raise ValueError(f"SSL certificate verification failed for {url}") + await self._wait_for_rate_limit() + return True + + def _safe_process_url_sync(self, url: str) -> bool: + """Synchronous version of safety checks.""" + if self.verify_ssl and not self._verify_ssl_cert(url): + raise ValueError(f"SSL certificate verification failed for {url}") + self._sync_wait_for_rate_limit() + return True + + +class SafeFireCrawlLoader(BaseLoader, RateLimitMixin, URLProcessingMixin): def __init__( self, web_paths, @@ -208,43 +251,120 @@ class SafeFireCrawlLoader(BaseLoader): continue raise e - def _verify_ssl_cert(self, url: str) -> bool: - return verify_ssl_cert(url) - async def _wait_for_rate_limit(self): - """Wait to respect the rate limit if specified.""" - if self.requests_per_second and self.last_request_time: - min_interval = timedelta(seconds=1.0 / self.requests_per_second) - time_since_last = datetime.now() - self.last_request_time - if time_since_last < min_interval: - await asyncio.sleep((min_interval - time_since_last).total_seconds()) - self.last_request_time = datetime.now() +class SafeTavilyLoader(BaseLoader, RateLimitMixin, URLProcessingMixin): + def __init__( + self, + web_paths: Union[str, List[str]], + api_key: str, + extract_depth: Literal["basic", "advanced"] = "basic", + continue_on_failure: bool = True, + requests_per_second: Optional[float] = None, + verify_ssl: bool = True, + trust_env: bool = False, + proxy: Optional[Dict[str, str]] = None, + ): + """Initialize SafeTavilyLoader with rate limiting and SSL verification support. - def _sync_wait_for_rate_limit(self): - """Synchronous version of rate limit wait.""" - if self.requests_per_second and self.last_request_time: - min_interval = timedelta(seconds=1.0 / self.requests_per_second) - time_since_last = datetime.now() - self.last_request_time - if time_since_last < min_interval: - time.sleep((min_interval - time_since_last).total_seconds()) - self.last_request_time = datetime.now() + Args: + web_paths: List of URLs/paths to process. + api_key: The Tavily API key. + extract_depth: Depth of extraction ("basic" or "advanced"). + continue_on_failure: Whether to continue if extraction of a URL fails. + requests_per_second: Number of requests per second to limit to. + verify_ssl: If True, verify SSL certificates. + trust_env: If True, use proxy settings from environment variables. + proxy: Optional proxy configuration. + """ + # Initialize proxy configuration if using environment variables + proxy_server = proxy.get("server") if proxy else None + if trust_env and not proxy_server: + env_proxies = urllib.request.getproxies() + env_proxy_server = env_proxies.get("https") or env_proxies.get("http") + if env_proxy_server: + if proxy: + proxy["server"] = env_proxy_server + else: + proxy = {"server": env_proxy_server} + + # Store parameters for creating TavilyLoader instances + self.web_paths = web_paths if isinstance(web_paths, list) else [web_paths] + self.api_key = api_key + self.extract_depth = extract_depth + self.continue_on_failure = continue_on_failure + self.verify_ssl = verify_ssl + self.trust_env = trust_env + self.proxy = proxy + + # Add rate limiting + self.requests_per_second = requests_per_second + self.last_request_time = None - async def _safe_process_url(self, url: str) -> bool: - """Perform safety checks before processing a URL.""" - if self.verify_ssl and not self._verify_ssl_cert(url): - raise ValueError(f"SSL certificate verification failed for {url}") - await self._wait_for_rate_limit() - return True - - def _safe_process_url_sync(self, url: str) -> bool: - """Synchronous version of safety checks.""" - if self.verify_ssl and not self._verify_ssl_cert(url): - raise ValueError(f"SSL certificate verification failed for {url}") - self._sync_wait_for_rate_limit() - return True + def lazy_load(self) -> Iterator[Document]: + """Load documents with rate limiting support, delegating to TavilyLoader.""" + valid_urls = [] + for url in self.web_paths: + try: + self._safe_process_url_sync(url) + valid_urls.append(url) + except Exception as e: + log.warning(f"SSL verification failed for {url}: {str(e)}") + if not self.continue_on_failure: + raise e + if not valid_urls: + if self.continue_on_failure: + log.warning("No valid URLs to process after SSL verification") + return + raise ValueError("No valid URLs to process after SSL verification") + try: + loader = TavilyLoader( + urls=valid_urls, + api_key=self.api_key, + extract_depth=self.extract_depth, + continue_on_failure=self.continue_on_failure, + ) + yield from loader.lazy_load() + except Exception as e: + if self.continue_on_failure: + log.exception(e, "Error extracting content from URLs") + else: + raise e + + async def alazy_load(self) -> AsyncIterator[Document]: + """Async version with rate limiting and SSL verification.""" + valid_urls = [] + for url in self.web_paths: + try: + await self._safe_process_url(url) + valid_urls.append(url) + except Exception as e: + log.warning(f"SSL verification failed for {url}: {str(e)}") + if not self.continue_on_failure: + raise e + + if not valid_urls: + if self.continue_on_failure: + log.warning("No valid URLs to process after SSL verification") + return + raise ValueError("No valid URLs to process after SSL verification") + + try: + loader = TavilyLoader( + urls=valid_urls, + api_key=self.api_key, + extract_depth=self.extract_depth, + continue_on_failure=self.continue_on_failure, + ) + async for document in loader.alazy_load(): + yield document + except Exception as e: + if self.continue_on_failure: + log.exception(e, "Error loading URLs") + else: + raise e -class SafePlaywrightURLLoader(PlaywrightURLLoader): +class SafePlaywrightURLLoader(PlaywrightURLLoader, RateLimitMixin, URLProcessingMixin): """Load HTML pages safely with Playwright, supporting SSL verification, rate limiting, and remote browser connection. Attributes: @@ -356,40 +476,6 @@ class SafePlaywrightURLLoader(PlaywrightURLLoader): raise e await browser.close() - def _verify_ssl_cert(self, url: str) -> bool: - return verify_ssl_cert(url) - - async def _wait_for_rate_limit(self): - """Wait to respect the rate limit if specified.""" - if self.requests_per_second and self.last_request_time: - min_interval = timedelta(seconds=1.0 / self.requests_per_second) - time_since_last = datetime.now() - self.last_request_time - if time_since_last < min_interval: - await asyncio.sleep((min_interval - time_since_last).total_seconds()) - self.last_request_time = datetime.now() - - def _sync_wait_for_rate_limit(self): - """Synchronous version of rate limit wait.""" - if self.requests_per_second and self.last_request_time: - min_interval = timedelta(seconds=1.0 / self.requests_per_second) - time_since_last = datetime.now() - self.last_request_time - if time_since_last < min_interval: - time.sleep((min_interval - time_since_last).total_seconds()) - self.last_request_time = datetime.now() - - async def _safe_process_url(self, url: str) -> bool: - """Perform safety checks before processing a URL.""" - if self.verify_ssl and not self._verify_ssl_cert(url): - raise ValueError(f"SSL certificate verification failed for {url}") - await self._wait_for_rate_limit() - return True - - def _safe_process_url_sync(self, url: str) -> bool: - """Synchronous version of safety checks.""" - if self.verify_ssl and not self._verify_ssl_cert(url): - raise ValueError(f"SSL certificate verification failed for {url}") - self._sync_wait_for_rate_limit() - return True class SafeWebBaseLoader(WebBaseLoader): @@ -499,6 +585,7 @@ RAG_WEB_LOADER_ENGINES = defaultdict(lambda: SafeWebBaseLoader) RAG_WEB_LOADER_ENGINES["playwright"] = SafePlaywrightURLLoader RAG_WEB_LOADER_ENGINES["safe_web"] = SafeWebBaseLoader RAG_WEB_LOADER_ENGINES["firecrawl"] = SafeFireCrawlLoader +RAG_WEB_LOADER_ENGINES["tavily"] = SafeTavilyLoader def get_web_loader( @@ -525,6 +612,10 @@ def get_web_loader( web_loader_args["api_key"] = FIRECRAWL_API_KEY.value web_loader_args["api_url"] = FIRECRAWL_API_BASE_URL.value + if RAG_WEB_LOADER_ENGINE.value == "tavily": + web_loader_args["api_key"] = TAVILY_API_KEY.value + web_loader_args["extract_depth"] = TAVILY_EXTRACT_DEPTH.value + # Create the appropriate WebLoader based on the configuration WebLoaderClass = RAG_WEB_LOADER_ENGINES[RAG_WEB_LOADER_ENGINE.value] web_loader = WebLoaderClass(**web_loader_args) diff --git a/backend/open_webui/routers/auths.py b/backend/open_webui/routers/auths.py index 399283ee4..f30ae50c3 100644 --- a/backend/open_webui/routers/auths.py +++ b/backend/open_webui/routers/auths.py @@ -210,7 +210,7 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm): LDAP_APP_DN, LDAP_APP_PASSWORD, auto_bind="NONE", - authentication="SIMPLE", + authentication="SIMPLE" if LDAP_APP_DN else "ANONYMOUS", ) if not connection_app.bind(): raise HTTPException(400, detail="Application account bind failed") diff --git a/backend/open_webui/routers/openai.py b/backend/open_webui/routers/openai.py index 73b182d3c..bef286ca9 100644 --- a/backend/open_webui/routers/openai.py +++ b/backend/open_webui/routers/openai.py @@ -36,6 +36,9 @@ from open_webui.utils.payload import ( apply_model_params_to_body_openai, apply_model_system_prompt_to_body, ) +from open_webui.utils.misc import ( + convert_logit_bias_input_to_json, +) from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.utils.access_control import has_access @@ -396,6 +399,7 @@ async def get_all_models(request: Request, user: UserModel) -> dict[str, list]: for idx, models in enumerate(model_lists): if models is not None and "error" not in models: + merged_list.extend( [ { @@ -406,18 +410,21 @@ async def get_all_models(request: Request, user: UserModel) -> dict[str, list]: "urlIdx": idx, } for model in models - if "api.openai.com" - not in request.app.state.config.OPENAI_API_BASE_URLS[idx] - or not any( - name in model["id"] - for name in [ - "babbage", - "dall-e", - "davinci", - "embedding", - "tts", - "whisper", - ] + if (model.get("id") or model.get("name")) + and ( + "api.openai.com" + not in request.app.state.config.OPENAI_API_BASE_URLS[idx] + or not any( + name in model["id"] + for name in [ + "babbage", + "dall-e", + "davinci", + "embedding", + "tts", + "whisper", + ] + ) ) ] ) @@ -666,6 +673,11 @@ async def generate_chat_completion( del payload["max_tokens"] # Convert the modified body back to JSON + if "logit_bias" in payload: + payload["logit_bias"] = json.loads( + convert_logit_bias_input_to_json(payload["logit_bias"]) + ) + payload = json.dumps(payload) r = None diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index 289d887df..ccb459865 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -189,17 +189,15 @@ async def chat_completion_tools_handler( tool_function_params = tool_call.get("parameters", {}) try: - required_params = ( - tools[tool_function_name] - .get("spec", {}) - .get("parameters", {}) - .get("required", []) + spec = tools[tool_function_name].get("spec", {}) + allowed_params = ( + spec.get("parameters", {}).get("properties", {}).keys() ) tool_function = tools[tool_function_name]["callable"] tool_function_params = { k: v for k, v in tool_function_params.items() - if k in required_params + if k in allowed_params } tool_output = await tool_function(**tool_function_params) @@ -1765,14 +1763,16 @@ async def process_chat_response( spec = tool.get("spec", {}) try: - required_params = spec.get("parameters", {}).get( - "required", [] + allowed_params = ( + spec.get("parameters", {}) + .get("properties", {}) + .keys() ) tool_function = tool["callable"] tool_function_params = { k: v for k, v in tool_function_params.items() - if k in required_params + if k in allowed_params } tool_result = await tool_function( **tool_function_params diff --git a/src/lib/components/AddConnectionModal.svelte b/src/lib/components/AddConnectionModal.svelte index cbd90b68d..f3132640a 100644 --- a/src/lib/components/AddConnectionModal.svelte +++ b/src/lib/components/AddConnectionModal.svelte @@ -179,7 +179,7 @@ - + - {#if !isFirstMessage && !readOnly} + {#if !readOnly && siblings.length > 1} + + + {#each tags as tag}