diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index 5e0e4f0a1..14364d6e1 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -44,7 +44,7 @@ logging.getLogger("uvicorn.access").addFilter(EndpointFilter()) # Function to run the alembic migrations def run_migrations(): - print("Running migrations") + log.info("Running migrations") try: from alembic import command from alembic.config import Config @@ -57,7 +57,7 @@ def run_migrations(): command.upgrade(alembic_cfg, "head") except Exception as e: - print(f"Error: {e}") + log.exception(f"Error running migrations: {e}") run_migrations() @@ -678,6 +678,10 @@ S3_REGION_NAME = os.environ.get("S3_REGION_NAME", None) S3_BUCKET_NAME = os.environ.get("S3_BUCKET_NAME", None) S3_KEY_PREFIX = os.environ.get("S3_KEY_PREFIX", None) S3_ENDPOINT_URL = os.environ.get("S3_ENDPOINT_URL", None) +S3_USE_ACCELERATE_ENDPOINT = ( + os.environ.get("S3_USE_ACCELERATE_ENDPOINT", "False").lower() == "true" +) +S3_ADDRESSING_STYLE = os.environ.get("S3_ADDRESSING_STYLE", None) GCS_BUCKET_NAME = os.environ.get("GCS_BUCKET_NAME", None) GOOGLE_APPLICATION_CREDENTIALS_JSON = os.environ.get( @@ -1094,7 +1098,7 @@ try: banners = json.loads(os.environ.get("WEBUI_BANNERS", "[]")) banners = [BannerModel(**banner) for banner in banners] except Exception as e: - print(f"Error loading WEBUI_BANNERS: {e}") + log.exception(f"Error loading WEBUI_BANNERS: {e}") banners = [] WEBUI_BANNERS = PersistentConfig("WEBUI_BANNERS", "ui.banners", banners) @@ -1566,6 +1570,18 @@ GOOGLE_DRIVE_API_KEY = PersistentConfig( os.environ.get("GOOGLE_DRIVE_API_KEY", ""), ) +ENABLE_ONEDRIVE_INTEGRATION = PersistentConfig( + "ENABLE_ONEDRIVE_INTEGRATION", + "onedrive.enable", + os.getenv("ENABLE_ONEDRIVE_INTEGRATION", "False").lower() == "true", +) + +ONEDRIVE_CLIENT_ID = PersistentConfig( + "ONEDRIVE_CLIENT_ID", + "onedrive.client_id", + os.environ.get("ONEDRIVE_CLIENT_ID", ""), +) + # RAG Content Extraction CONTENT_EXTRACTION_ENGINE = PersistentConfig( "CONTENT_EXTRACTION_ENGINE", @@ -1579,6 +1595,18 @@ TIKA_SERVER_URL = PersistentConfig( os.getenv("TIKA_SERVER_URL", "http://tika:9998"), # Default for sidecar deployment ) +DOCUMENT_INTELLIGENCE_ENDPOINT = PersistentConfig( + "DOCUMENT_INTELLIGENCE_ENDPOINT", + "rag.document_intelligence_endpoint", + os.getenv("DOCUMENT_INTELLIGENCE_ENDPOINT", ""), +) + +DOCUMENT_INTELLIGENCE_KEY = PersistentConfig( + "DOCUMENT_INTELLIGENCE_KEY", + "rag.document_intelligence_key", + os.getenv("DOCUMENT_INTELLIGENCE_KEY", ""), +) + RAG_TOP_K = PersistentConfig( "RAG_TOP_K", "rag.top_k", int(os.environ.get("RAG_TOP_K", "3")) ) diff --git a/backend/open_webui/functions.py b/backend/open_webui/functions.py index 274be56ec..2f94f701e 100644 --- a/backend/open_webui/functions.py +++ b/backend/open_webui/functions.py @@ -2,6 +2,7 @@ import logging import sys import inspect import json +import asyncio from pydantic import BaseModel from typing import AsyncGenerator, Generator, Iterator @@ -76,11 +77,13 @@ async def get_function_models(request): if hasattr(function_module, "pipes"): sub_pipes = [] - # Check if pipes is a function or a list - + # Handle pipes being a list, sync function, or async function try: if callable(function_module.pipes): - sub_pipes = function_module.pipes() + if asyncio.iscoroutinefunction(function_module.pipes): + sub_pipes = await function_module.pipes() + else: + sub_pipes = function_module.pipes() else: sub_pipes = function_module.pipes except Exception as e: diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 346d28d6c..31ea93399 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -95,6 +95,7 @@ from open_webui.config import ( OLLAMA_API_CONFIGS, # OpenAI ENABLE_OPENAI_API, + ONEDRIVE_CLIENT_ID, OPENAI_API_BASE_URLS, OPENAI_API_KEYS, OPENAI_API_CONFIGS, @@ -180,6 +181,8 @@ from open_webui.config import ( CHUNK_SIZE, CONTENT_EXTRACTION_ENGINE, TIKA_SERVER_URL, + DOCUMENT_INTELLIGENCE_ENDPOINT, + DOCUMENT_INTELLIGENCE_KEY, RAG_TOP_K, RAG_TEXT_SPLITTER, TIKTOKEN_ENCODING_NAME, @@ -215,11 +218,13 @@ from open_webui.config import ( GOOGLE_PSE_ENGINE_ID, GOOGLE_DRIVE_CLIENT_ID, GOOGLE_DRIVE_API_KEY, + ONEDRIVE_CLIENT_ID, ENABLE_RAG_HYBRID_SEARCH, ENABLE_RAG_LOCAL_WEB_FETCH, ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, ENABLE_RAG_WEB_SEARCH, ENABLE_GOOGLE_DRIVE_INTEGRATION, + ENABLE_ONEDRIVE_INTEGRATION, UPLOAD_DIR, # WebUI WEBUI_AUTH, @@ -533,6 +538,8 @@ app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = ( app.state.config.CONTENT_EXTRACTION_ENGINE = CONTENT_EXTRACTION_ENGINE app.state.config.TIKA_SERVER_URL = TIKA_SERVER_URL +app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT = DOCUMENT_INTELLIGENCE_ENDPOINT +app.state.config.DOCUMENT_INTELLIGENCE_KEY = DOCUMENT_INTELLIGENCE_KEY app.state.config.TEXT_SPLITTER = RAG_TEXT_SPLITTER app.state.config.TIKTOKEN_ENCODING_NAME = TIKTOKEN_ENCODING_NAME @@ -564,6 +571,7 @@ app.state.config.RAG_WEB_SEARCH_FULL_CONTEXT = RAG_WEB_SEARCH_FULL_CONTEXT app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = RAG_WEB_SEARCH_DOMAIN_FILTER_LIST app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION = ENABLE_GOOGLE_DRIVE_INTEGRATION +app.state.config.ENABLE_ONEDRIVE_INTEGRATION = ENABLE_ONEDRIVE_INTEGRATION app.state.config.SEARXNG_QUERY_URL = SEARXNG_QUERY_URL app.state.config.GOOGLE_PSE_API_KEY = GOOGLE_PSE_API_KEY app.state.config.GOOGLE_PSE_ENGINE_ID = GOOGLE_PSE_ENGINE_ID @@ -911,7 +919,7 @@ async def get_models(request: Request, user=Depends(get_verified_user)): return filtered_models - models = await get_all_models(request) + models = await get_all_models(request, user=user) # Filter out filter pipelines models = [ @@ -951,7 +959,7 @@ async def chat_completion( user=Depends(get_verified_user), ): if not request.app.state.MODELS: - await get_all_models(request) + await get_all_models(request, user=user) model_item = form_data.pop("model_item", {}) tasks = form_data.pop("background_tasks", None) @@ -1146,6 +1154,7 @@ async def get_app_config(request: Request): "enable_admin_export": ENABLE_ADMIN_EXPORT, "enable_admin_chat_access": ENABLE_ADMIN_CHAT_ACCESS, "enable_google_drive_integration": app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION, + "enable_onedrive_integration": app.state.config.ENABLE_ONEDRIVE_INTEGRATION, } if user is not None else {} @@ -1177,6 +1186,7 @@ async def get_app_config(request: Request): "client_id": GOOGLE_DRIVE_CLIENT_ID.value, "api_key": GOOGLE_DRIVE_API_KEY.value, }, + "onedrive": {"client_id": ONEDRIVE_CLIENT_ID.value}, } if user is not None else {} diff --git a/backend/open_webui/models/chats.py b/backend/open_webui/models/chats.py index 9e0a5865e..e9d04b986 100644 --- a/backend/open_webui/models/chats.py +++ b/backend/open_webui/models/chats.py @@ -1,3 +1,4 @@ +import logging import json import time import uuid @@ -5,7 +6,7 @@ from typing import Optional from open_webui.internal.db import Base, get_db from open_webui.models.tags import TagModel, Tag, Tags - +from open_webui.env import SRC_LOG_LEVELS from pydantic import BaseModel, ConfigDict from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON @@ -16,6 +17,8 @@ from sqlalchemy.sql import exists # Chat DB Schema #################### +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["MODELS"]) class Chat(Base): __tablename__ = "chat" @@ -670,7 +673,7 @@ class ChatTable: # Perform pagination at the SQL level all_chats = query.offset(skip).limit(limit).all() - print(len(all_chats)) + log.info(f"The number of chats: {len(all_chats)}") # Validate and return chats return [ChatModel.model_validate(chat) for chat in all_chats] @@ -731,7 +734,7 @@ class ChatTable: query = db.query(Chat).filter_by(user_id=user_id) tag_id = tag_name.replace(" ", "_").lower() - print(db.bind.dialect.name) + log.info(f"DB dialect name: {db.bind.dialect.name}") if db.bind.dialect.name == "sqlite": # SQLite JSON1 querying for tags within the meta JSON field query = query.filter( @@ -752,7 +755,7 @@ class ChatTable: ) all_chats = query.all() - print("all_chats", all_chats) + log.debug(f"all_chats: {all_chats}") return [ChatModel.model_validate(chat) for chat in all_chats] def add_chat_tag_by_id_and_user_id_and_tag_name( @@ -810,7 +813,7 @@ class ChatTable: count = query.count() # Debugging output for inspection - print(f"Count of chats for tag '{tag_name}':", count) + log.info(f"Count of chats for tag '{tag_name}': {count}") return count diff --git a/backend/open_webui/models/feedbacks.py b/backend/open_webui/models/feedbacks.py index 7ff5c4540..215e36aa2 100644 --- a/backend/open_webui/models/feedbacks.py +++ b/backend/open_webui/models/feedbacks.py @@ -118,7 +118,7 @@ class FeedbackTable: else: return None except Exception as e: - print(e) + log.exception(f"Error creating a new feedback: {e}") return None def get_feedback_by_id(self, id: str) -> Optional[FeedbackModel]: diff --git a/backend/open_webui/models/files.py b/backend/open_webui/models/files.py index 91dea5444..6f1511cd1 100644 --- a/backend/open_webui/models/files.py +++ b/backend/open_webui/models/files.py @@ -119,7 +119,7 @@ class FilesTable: else: return None except Exception as e: - print(f"Error creating tool: {e}") + log.exception(f"Error inserting a new file: {e}") return None def get_file_by_id(self, id: str) -> Optional[FileModel]: diff --git a/backend/open_webui/models/folders.py b/backend/open_webui/models/folders.py index 040774196..19739bc5f 100644 --- a/backend/open_webui/models/folders.py +++ b/backend/open_webui/models/folders.py @@ -82,7 +82,7 @@ class FolderTable: else: return None except Exception as e: - print(e) + log.exception(f"Error inserting a new folder: {e}") return None def get_folder_by_id_and_user_id( diff --git a/backend/open_webui/models/functions.py b/backend/open_webui/models/functions.py index 6c6aed862..329795c05 100644 --- a/backend/open_webui/models/functions.py +++ b/backend/open_webui/models/functions.py @@ -105,7 +105,7 @@ class FunctionsTable: else: return None except Exception as e: - print(f"Error creating tool: {e}") + log.exception(f"Error creating a new function: {e}") return None def get_function_by_id(self, id: str) -> Optional[FunctionModel]: @@ -170,7 +170,7 @@ class FunctionsTable: function = db.get(Function, id) return function.valves if function.valves else {} except Exception as e: - print(f"An error occurred: {e}") + log.exception(f"Error getting function valves by id {id}: {e}") return None def update_function_valves_by_id( @@ -202,7 +202,7 @@ class FunctionsTable: return user_settings["functions"]["valves"].get(id, {}) except Exception as e: - print(f"An error occurred: {e}") + log.exception(f"Error getting user values by id {id} and user id {user_id}: {e}") return None def update_user_valves_by_id_and_user_id( @@ -225,7 +225,7 @@ class FunctionsTable: return user_settings["functions"]["valves"][id] except Exception as e: - print(f"An error occurred: {e}") + log.exception(f"Error updating user valves by id {id} and user_id {user_id}: {e}") return None def update_function_by_id(self, id: str, updated: dict) -> Optional[FunctionModel]: diff --git a/backend/open_webui/models/models.py b/backend/open_webui/models/models.py old mode 100644 new mode 100755 index f2f59d7c4..7df8d8656 --- a/backend/open_webui/models/models.py +++ b/backend/open_webui/models/models.py @@ -166,7 +166,7 @@ class ModelsTable: else: return None except Exception as e: - print(e) + log.exception(f"Failed to insert a new model: {e}") return None def get_all_models(self) -> list[ModelModel]: @@ -246,8 +246,7 @@ class ModelsTable: db.refresh(model) return ModelModel.model_validate(model) except Exception as e: - print(e) - + log.exception(f"Failed to update the model by id {id}: {e}") return None def delete_model_by_id(self, id: str) -> bool: diff --git a/backend/open_webui/models/tags.py b/backend/open_webui/models/tags.py index 3e812db95..279dc624d 100644 --- a/backend/open_webui/models/tags.py +++ b/backend/open_webui/models/tags.py @@ -61,7 +61,7 @@ class TagTable: else: return None except Exception as e: - print(e) + log.exception(f"Error inserting a new tag: {e}") return None def get_tag_by_name_and_user_id( diff --git a/backend/open_webui/models/tools.py b/backend/open_webui/models/tools.py index a5f13ebb7..fb8745b58 100644 --- a/backend/open_webui/models/tools.py +++ b/backend/open_webui/models/tools.py @@ -131,7 +131,7 @@ class ToolsTable: else: return None except Exception as e: - print(f"Error creating tool: {e}") + log.exception(f"Error creating a new tool: {e}") return None def get_tool_by_id(self, id: str) -> Optional[ToolModel]: @@ -175,7 +175,7 @@ class ToolsTable: tool = db.get(Tool, id) return tool.valves if tool.valves else {} except Exception as e: - print(f"An error occurred: {e}") + log.exception(f"Error getting tool valves by id {id}: {e}") return None def update_tool_valves_by_id(self, id: str, valves: dict) -> Optional[ToolValves]: @@ -204,7 +204,7 @@ class ToolsTable: return user_settings["tools"]["valves"].get(id, {}) except Exception as e: - print(f"An error occurred: {e}") + log.exception(f"Error getting user values by id {id} and user_id {user_id}: {e}") return None def update_user_valves_by_id_and_user_id( @@ -227,7 +227,7 @@ class ToolsTable: return user_settings["tools"]["valves"][id] except Exception as e: - print(f"An error occurred: {e}") + log.exception(f"Error updating user valves by id {id} and user_id {user_id}: {e}") return None def update_tool_by_id(self, id: str, updated: dict) -> Optional[ToolModel]: diff --git a/backend/open_webui/retrieval/loaders/main.py b/backend/open_webui/retrieval/loaders/main.py index a9372f65a..19d590f5c 100644 --- a/backend/open_webui/retrieval/loaders/main.py +++ b/backend/open_webui/retrieval/loaders/main.py @@ -4,6 +4,7 @@ import ftfy import sys from langchain_community.document_loaders import ( + AzureAIDocumentIntelligenceLoader, BSHTMLLoader, CSVLoader, Docx2txtLoader, @@ -147,6 +148,27 @@ class Loader: file_path=file_path, mime_type=file_content_type, ) + elif ( + self.engine == "document_intelligence" + and self.kwargs.get("DOCUMENT_INTELLIGENCE_ENDPOINT") != "" + and self.kwargs.get("DOCUMENT_INTELLIGENCE_KEY") != "" + and ( + file_ext in ["pdf", "xls", "xlsx", "docx", "ppt", "pptx"] + or file_content_type + in [ + "application/vnd.ms-excel", + "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + "application/vnd.ms-powerpoint", + "application/vnd.openxmlformats-officedocument.presentationml.presentation", + ] + ) + ): + loader = AzureAIDocumentIntelligenceLoader( + file_path=file_path, + api_endpoint=self.kwargs.get("DOCUMENT_INTELLIGENCE_ENDPOINT"), + api_key=self.kwargs.get("DOCUMENT_INTELLIGENCE_KEY"), + ) else: if file_ext == "pdf": loader = PyPDFLoader( diff --git a/backend/open_webui/retrieval/models/colbert.py b/backend/open_webui/retrieval/models/colbert.py index ea3204cb8..5b7499fd1 100644 --- a/backend/open_webui/retrieval/models/colbert.py +++ b/backend/open_webui/retrieval/models/colbert.py @@ -1,13 +1,19 @@ import os +import logging import torch import numpy as np from colbert.infra import ColBERTConfig from colbert.modeling.checkpoint import Checkpoint +from open_webui.env import SRC_LOG_LEVELS + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["RAG"]) + class ColBERT: def __init__(self, name, **kwargs) -> None: - print("ColBERT: Loading model", name) + log.info("ColBERT: Loading model", name) self.device = "cuda" if torch.cuda.is_available() else "cpu" DOCKER = kwargs.get("env") == "docker" diff --git a/backend/open_webui/retrieval/utils.py b/backend/open_webui/retrieval/utils.py index f83d09d9c..09af0eabb 100644 --- a/backend/open_webui/retrieval/utils.py +++ b/backend/open_webui/retrieval/utils.py @@ -81,7 +81,7 @@ def query_doc( return result except Exception as e: - print(e) + log.exception(f"Error querying doc {collection_name} with limit {k}: {e}") raise e @@ -94,7 +94,7 @@ def get_doc(collection_name: str, user: UserModel = None): return result except Exception as e: - print(e) + log.exception(f"Error getting doc {collection_name}: {e}") raise e @@ -530,7 +530,7 @@ def generate_openai_batch_embeddings( else: raise "Something went wrong :/" except Exception as e: - print(e) + log.exception(f"Error generating openai batch embeddings: {e}") return None @@ -564,7 +564,7 @@ def generate_ollama_batch_embeddings( else: raise "Something went wrong :/" except Exception as e: - print(e) + log.exception(f"Error generating ollama batch embeddings: {e}") return None diff --git a/backend/open_webui/retrieval/vector/dbs/chroma.py b/backend/open_webui/retrieval/vector/dbs/chroma.py old mode 100644 new mode 100755 index c40618fcc..81f7a1c5e --- a/backend/open_webui/retrieval/vector/dbs/chroma.py +++ b/backend/open_webui/retrieval/vector/dbs/chroma.py @@ -1,4 +1,5 @@ import chromadb +import logging from chromadb import Settings from chromadb.utils.batch_utils import create_batches @@ -16,6 +17,10 @@ from open_webui.config import ( CHROMA_CLIENT_AUTH_PROVIDER, CHROMA_CLIENT_AUTH_CREDENTIALS, ) +from open_webui.env import SRC_LOG_LEVELS + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["RAG"]) class ChromaClient: @@ -103,7 +108,7 @@ class ChromaClient: ) return None except Exception as e: - print(e) + log.exception(f"Error querying collection {collection} with limit {limit}: {e}") return None def get(self, collection_name: str) -> Optional[GetResult]: diff --git a/backend/open_webui/retrieval/vector/dbs/milvus.py b/backend/open_webui/retrieval/vector/dbs/milvus.py index 43c3f3d1a..7d86fbdf6 100644 --- a/backend/open_webui/retrieval/vector/dbs/milvus.py +++ b/backend/open_webui/retrieval/vector/dbs/milvus.py @@ -1,7 +1,7 @@ from pymilvus import MilvusClient as Client from pymilvus import FieldSchema, DataType import json - +import logging from typing import Optional from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult @@ -10,6 +10,10 @@ from open_webui.config import ( MILVUS_DB, MILVUS_TOKEN, ) +from open_webui.env import SRC_LOG_LEVELS + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["RAG"]) class MilvusClient: @@ -168,7 +172,7 @@ class MilvusClient: try: # Loop until there are no more items to fetch or the desired limit is reached while remaining > 0: - print("remaining", remaining) + log.info(f"remaining: {remaining}") current_fetch = min( max_limit, remaining ) # Determine how many items to fetch in this iteration @@ -195,10 +199,10 @@ class MilvusClient: if results_count < current_fetch: break - print(all_results) + log.debug(all_results) return self._result_to_get_result([all_results]) except Exception as e: - print(e) + log.exception(f"Error querying collection {collection_name} with limit {limit}: {e}") return None def get(self, collection_name: str) -> Optional[GetResult]: diff --git a/backend/open_webui/retrieval/vector/dbs/pgvector.py b/backend/open_webui/retrieval/vector/dbs/pgvector.py index 341b3056f..59f717cb3 100644 --- a/backend/open_webui/retrieval/vector/dbs/pgvector.py +++ b/backend/open_webui/retrieval/vector/dbs/pgvector.py @@ -1,4 +1,5 @@ from typing import Optional, List, Dict, Any +import logging from sqlalchemy import ( cast, column, @@ -24,9 +25,14 @@ from sqlalchemy.exc import NoSuchTableError from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult from open_webui.config import PGVECTOR_DB_URL, PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH +from open_webui.env import SRC_LOG_LEVELS + VECTOR_LENGTH = PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH Base = declarative_base() +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["RAG"]) + class DocumentChunk(Base): __tablename__ = "document_chunk" @@ -82,10 +88,10 @@ class PgvectorClient: ) ) self.session.commit() - print("Initialization complete.") + log.info("Initialization complete.") except Exception as e: self.session.rollback() - print(f"Error during initialization: {e}") + log.exception(f"Error during initialization: {e}") raise def check_vector_length(self) -> None: @@ -150,12 +156,12 @@ class PgvectorClient: new_items.append(new_chunk) self.session.bulk_save_objects(new_items) self.session.commit() - print( + log.info( f"Inserted {len(new_items)} items into collection '{collection_name}'." ) except Exception as e: self.session.rollback() - print(f"Error during insert: {e}") + log.exception(f"Error during insert: {e}") raise def upsert(self, collection_name: str, items: List[VectorItem]) -> None: @@ -184,10 +190,10 @@ class PgvectorClient: ) self.session.add(new_chunk) self.session.commit() - print(f"Upserted {len(items)} items into collection '{collection_name}'.") + log.info(f"Upserted {len(items)} items into collection '{collection_name}'.") except Exception as e: self.session.rollback() - print(f"Error during upsert: {e}") + log.exception(f"Error during upsert: {e}") raise def search( @@ -278,7 +284,7 @@ class PgvectorClient: ids=ids, distances=distances, documents=documents, metadatas=metadatas ) except Exception as e: - print(f"Error during search: {e}") + log.exception(f"Error during search: {e}") return None def query( @@ -310,7 +316,7 @@ class PgvectorClient: metadatas=metadatas, ) except Exception as e: - print(f"Error during query: {e}") + log.exception(f"Error during query: {e}") return None def get( @@ -334,7 +340,7 @@ class PgvectorClient: return GetResult(ids=ids, documents=documents, metadatas=metadatas) except Exception as e: - print(f"Error during get: {e}") + log.exception(f"Error during get: {e}") return None def delete( @@ -356,22 +362,22 @@ class PgvectorClient: ) deleted = query.delete(synchronize_session=False) self.session.commit() - print(f"Deleted {deleted} items from collection '{collection_name}'.") + log.info(f"Deleted {deleted} items from collection '{collection_name}'.") except Exception as e: self.session.rollback() - print(f"Error during delete: {e}") + log.exception(f"Error during delete: {e}") raise def reset(self) -> None: try: deleted = self.session.query(DocumentChunk).delete() self.session.commit() - print( + log.info( f"Reset complete. Deleted {deleted} items from 'document_chunk' table." ) except Exception as e: self.session.rollback() - print(f"Error during reset: {e}") + log.exception(f"Error during reset: {e}") raise def close(self) -> None: @@ -387,9 +393,9 @@ class PgvectorClient: ) return exists except Exception as e: - print(f"Error checking collection existence: {e}") + log.exception(f"Error checking collection existence: {e}") return False def delete_collection(self, collection_name: str) -> None: self.delete(collection_name) - print(f"Collection '{collection_name}' deleted.") + log.info(f"Collection '{collection_name}' deleted.") diff --git a/backend/open_webui/retrieval/vector/dbs/qdrant.py b/backend/open_webui/retrieval/vector/dbs/qdrant.py index f077ae45a..28f0b3779 100644 --- a/backend/open_webui/retrieval/vector/dbs/qdrant.py +++ b/backend/open_webui/retrieval/vector/dbs/qdrant.py @@ -1,4 +1,5 @@ from typing import Optional +import logging from qdrant_client import QdrantClient as Qclient from qdrant_client.http.models import PointStruct @@ -6,9 +7,13 @@ from qdrant_client.models import models from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult from open_webui.config import QDRANT_URI, QDRANT_API_KEY +from open_webui.env import SRC_LOG_LEVELS NO_LIMIT = 999999999 +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["RAG"]) + class QdrantClient: def __init__(self): @@ -49,7 +54,7 @@ class QdrantClient: ), ) - print(f"collection {collection_name_with_prefix} successfully created!") + log.info(f"collection {collection_name_with_prefix} successfully created!") def _create_collection_if_not_exists(self, collection_name, dimension): if not self.has_collection(collection_name=collection_name): @@ -120,7 +125,7 @@ class QdrantClient: ) return self._result_to_get_result(points.points) except Exception as e: - print(e) + log.exception(f"Error querying a collection '{collection_name}': {e}") return None def get(self, collection_name: str) -> Optional[GetResult]: diff --git a/backend/open_webui/routers/audio.py b/backend/open_webui/routers/audio.py index a970366d1..c949e65a4 100644 --- a/backend/open_webui/routers/audio.py +++ b/backend/open_webui/routers/audio.py @@ -71,7 +71,7 @@ from pydub.utils import mediainfo def is_mp4_audio(file_path): """Check if the given file is an MP4 audio file.""" if not os.path.isfile(file_path): - print(f"File not found: {file_path}") + log.error(f"File not found: {file_path}") return False info = mediainfo(file_path) @@ -88,7 +88,7 @@ def convert_mp4_to_wav(file_path, output_path): """Convert MP4 audio file to WAV format.""" audio = AudioSegment.from_file(file_path, format="mp4") audio.export(output_path, format="wav") - print(f"Converted {file_path} to {output_path}") + log.info(f"Converted {file_path} to {output_path}") def set_faster_whisper_model(model: str, auto_update: bool = False): @@ -266,7 +266,6 @@ async def speech(request: Request, user=Depends(get_verified_user)): payload["model"] = request.app.state.config.TTS_MODEL try: - # print(payload) timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT) async with aiohttp.ClientSession( timeout=timeout, trust_env=True @@ -468,7 +467,7 @@ async def speech(request: Request, user=Depends(get_verified_user)): def transcribe(request: Request, file_path): - print("transcribe", file_path) + log.info(f"transcribe: {file_path}") filename = os.path.basename(file_path) file_dir = os.path.dirname(file_path) id = filename.split(".")[0] @@ -680,7 +679,22 @@ def transcription( def get_available_models(request: Request) -> list[dict]: available_models = [] if request.app.state.config.TTS_ENGINE == "openai": - available_models = [{"id": "tts-1"}, {"id": "tts-1-hd"}] + # Use custom endpoint if not using the official OpenAI API URL + if not request.app.state.config.TTS_OPENAI_API_BASE_URL.startswith( + "https://api.openai.com" + ): + try: + response = requests.get( + f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/models" + ) + response.raise_for_status() + data = response.json() + available_models = data.get("models", []) + except Exception as e: + log.error(f"Error fetching models from custom endpoint: {str(e)}") + available_models = [{"id": "tts-1"}, {"id": "tts-1-hd"}] + else: + available_models = [{"id": "tts-1"}, {"id": "tts-1-hd"}] elif request.app.state.config.TTS_ENGINE == "elevenlabs": try: response = requests.get( @@ -711,14 +725,37 @@ def get_available_voices(request) -> dict: """Returns {voice_id: voice_name} dict""" available_voices = {} if request.app.state.config.TTS_ENGINE == "openai": - available_voices = { - "alloy": "alloy", - "echo": "echo", - "fable": "fable", - "onyx": "onyx", - "nova": "nova", - "shimmer": "shimmer", - } + # Use custom endpoint if not using the official OpenAI API URL + if not request.app.state.config.TTS_OPENAI_API_BASE_URL.startswith( + "https://api.openai.com" + ): + try: + response = requests.get( + f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/voices" + ) + response.raise_for_status() + data = response.json() + voices_list = data.get("voices", []) + available_voices = {voice["id"]: voice["name"] for voice in voices_list} + except Exception as e: + log.error(f"Error fetching voices from custom endpoint: {str(e)}") + available_voices = { + "alloy": "alloy", + "echo": "echo", + "fable": "fable", + "onyx": "onyx", + "nova": "nova", + "shimmer": "shimmer", + } + else: + available_voices = { + "alloy": "alloy", + "echo": "echo", + "fable": "fable", + "onyx": "onyx", + "nova": "nova", + "shimmer": "shimmer", + } elif request.app.state.config.TTS_ENGINE == "elevenlabs": try: available_voices = get_elevenlabs_voices( diff --git a/backend/open_webui/routers/auths.py b/backend/open_webui/routers/auths.py index 3fa2ffe2e..59c6ed4a8 100644 --- a/backend/open_webui/routers/auths.py +++ b/backend/open_webui/routers/auths.py @@ -252,14 +252,6 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm): if not user: try: user_count = Users.get_num_users() - if ( - request.app.state.USER_COUNT - and user_count >= request.app.state.USER_COUNT - ): - raise HTTPException( - status.HTTP_403_FORBIDDEN, - detail=ERROR_MESSAGES.ACCESS_PROHIBITED, - ) role = ( "admin" @@ -439,11 +431,6 @@ async def signup(request: Request, response: Response, form_data: SignupForm): ) user_count = Users.get_num_users() - if request.app.state.USER_COUNT and user_count >= request.app.state.USER_COUNT: - raise HTTPException( - status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED - ) - if not validate_email_format(form_data.email.lower()): raise HTTPException( status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT @@ -613,7 +600,7 @@ async def get_admin_details(request: Request, user=Depends(get_current_user)): admin_email = request.app.state.config.ADMIN_EMAIL admin_name = None - print(admin_email, admin_name) + log.info(f"Admin details - Email: {admin_email}, Name: {admin_name}") if admin_email: admin = Users.get_user_by_email(admin_email) diff --git a/backend/open_webui/routers/files.py b/backend/open_webui/routers/files.py index 504baa60d..bccf74b24 100644 --- a/backend/open_webui/routers/files.py +++ b/backend/open_webui/routers/files.py @@ -273,7 +273,7 @@ async def get_html_file_content_by_id(id: str, user=Depends(get_verified_user)): # Check if the file already exists in the cache if file_path.is_file(): - print(f"file_path: {file_path}") + log.info(f"file_path: {file_path}") return FileResponse(file_path) else: raise HTTPException( diff --git a/backend/open_webui/routers/functions.py b/backend/open_webui/routers/functions.py index 7f3305f25..ac2db9322 100644 --- a/backend/open_webui/routers/functions.py +++ b/backend/open_webui/routers/functions.py @@ -1,4 +1,5 @@ import os +import logging from pathlib import Path from typing import Optional @@ -13,6 +14,11 @@ from open_webui.config import CACHE_DIR from open_webui.constants import ERROR_MESSAGES from fastapi import APIRouter, Depends, HTTPException, Request, status from open_webui.utils.auth import get_admin_user, get_verified_user +from open_webui.env import SRC_LOG_LEVELS + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["MAIN"]) + router = APIRouter() @@ -79,7 +85,7 @@ async def create_new_function( detail=ERROR_MESSAGES.DEFAULT("Error creating function"), ) except Exception as e: - print(e) + log.exception(f"Failed to create a new function: {e}") raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT(e), @@ -183,7 +189,7 @@ async def update_function_by_id( FUNCTIONS[id] = function_module updated = {**form_data.model_dump(exclude={"id"}), "type": function_type} - print(updated) + log.debug(updated) function = Functions.update_function_by_id(id, updated) @@ -299,7 +305,7 @@ async def update_function_valves_by_id( Functions.update_function_valves_by_id(id, valves.model_dump()) return valves.model_dump() except Exception as e: - print(e) + log.exception(f"Error updating function values by id {id}: {e}") raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT(e), @@ -388,7 +394,7 @@ async def update_function_user_valves_by_id( ) return user_valves.model_dump() except Exception as e: - print(e) + log.exception(f"Error updating function user valves by id {id}: {e}") raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT(e), diff --git a/backend/open_webui/routers/groups.py b/backend/open_webui/routers/groups.py old mode 100644 new mode 100755 index 5b5130f71..ae822c0d0 --- a/backend/open_webui/routers/groups.py +++ b/backend/open_webui/routers/groups.py @@ -1,7 +1,7 @@ import os from pathlib import Path from typing import Optional - +import logging from open_webui.models.users import Users from open_webui.models.groups import ( @@ -14,7 +14,13 @@ from open_webui.models.groups import ( from open_webui.config import CACHE_DIR from open_webui.constants import ERROR_MESSAGES from fastapi import APIRouter, Depends, HTTPException, Request, status + from open_webui.utils.auth import get_admin_user, get_verified_user +from open_webui.env import SRC_LOG_LEVELS + + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["MAIN"]) router = APIRouter() @@ -37,7 +43,7 @@ async def get_groups(user=Depends(get_verified_user)): @router.post("/create", response_model=Optional[GroupResponse]) -async def create_new_function(form_data: GroupForm, user=Depends(get_admin_user)): +async def create_new_group(form_data: GroupForm, user=Depends(get_admin_user)): try: group = Groups.insert_new_group(user.id, form_data) if group: @@ -48,7 +54,7 @@ async def create_new_function(form_data: GroupForm, user=Depends(get_admin_user) detail=ERROR_MESSAGES.DEFAULT("Error creating group"), ) except Exception as e: - print(e) + log.exception(f"Error creating a new group: {e}") raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT(e), @@ -94,7 +100,7 @@ async def update_group_by_id( detail=ERROR_MESSAGES.DEFAULT("Error updating group"), ) except Exception as e: - print(e) + log.exception(f"Error updating group {id}: {e}") raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT(e), @@ -118,7 +124,7 @@ async def delete_group_by_id(id: str, user=Depends(get_admin_user)): detail=ERROR_MESSAGES.DEFAULT("Error deleting group"), ) except Exception as e: - print(e) + log.exception(f"Error deleting group {id}: {e}") raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT(e), diff --git a/backend/open_webui/routers/images.py b/backend/open_webui/routers/images.py index 3288ec6d8..7187856e7 100644 --- a/backend/open_webui/routers/images.py +++ b/backend/open_webui/routers/images.py @@ -351,7 +351,7 @@ def get_models(request: Request, user=Depends(get_verified_user)): if model_node_id: model_list_key = None - print(workflow[model_node_id]["class_type"]) + log.info(workflow[model_node_id]["class_type"]) for key in info[workflow[model_node_id]["class_type"]]["input"][ "required" ]: diff --git a/backend/open_webui/routers/knowledge.py b/backend/open_webui/routers/knowledge.py index 0ba6191a2..196904550 100644 --- a/backend/open_webui/routers/knowledge.py +++ b/backend/open_webui/routers/knowledge.py @@ -614,7 +614,7 @@ def add_files_to_knowledge_batch( ) # Get files content - print(f"files/batch/add - {len(form_data)} files") + log.info(f"files/batch/add - {len(form_data)} files") files: List[FileModel] = [] for form in form_data: file = Files.get_file_by_id(form.file_id) diff --git a/backend/open_webui/routers/ollama.py b/backend/open_webui/routers/ollama.py index 732dd36f9..d99416c83 100644 --- a/backend/open_webui/routers/ollama.py +++ b/backend/open_webui/routers/ollama.py @@ -14,6 +14,11 @@ from urllib.parse import urlparse import aiohttp from aiocache import cached import requests +from open_webui.models.users import UserModel + +from open_webui.env import ( + ENABLE_FORWARD_USER_INFO_HEADERS, +) from fastapi import ( Depends, @@ -66,12 +71,26 @@ log.setLevel(SRC_LOG_LEVELS["OLLAMA"]) ########################################## -async def send_get_request(url, key=None): +async def send_get_request(url, key=None, user: UserModel = None): timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST) try: async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: async with session.get( - url, headers={**({"Authorization": f"Bearer {key}"} if key else {})} + url, + headers={ + "Content-Type": "application/json", + **({"Authorization": f"Bearer {key}"} if key else {}), + **( + { + "X-OpenWebUI-User-Name": user.name, + "X-OpenWebUI-User-Id": user.id, + "X-OpenWebUI-User-Email": user.email, + "X-OpenWebUI-User-Role": user.role, + } + if ENABLE_FORWARD_USER_INFO_HEADERS and user + else {} + ), + }, ) as response: return await response.json() except Exception as e: @@ -96,6 +115,7 @@ async def send_post_request( stream: bool = True, key: Optional[str] = None, content_type: Optional[str] = None, + user: UserModel = None, ): r = None @@ -110,6 +130,16 @@ async def send_post_request( headers={ "Content-Type": "application/json", **({"Authorization": f"Bearer {key}"} if key else {}), + **( + { + "X-OpenWebUI-User-Name": user.name, + "X-OpenWebUI-User-Id": user.id, + "X-OpenWebUI-User-Email": user.email, + "X-OpenWebUI-User-Role": user.role, + } + if ENABLE_FORWARD_USER_INFO_HEADERS and user + else {} + ), }, ) r.raise_for_status() @@ -191,7 +221,19 @@ async def verify_connection( try: async with session.get( f"{url}/api/version", - headers={**({"Authorization": f"Bearer {key}"} if key else {})}, + headers={ + **({"Authorization": f"Bearer {key}"} if key else {}), + **( + { + "X-OpenWebUI-User-Name": user.name, + "X-OpenWebUI-User-Id": user.id, + "X-OpenWebUI-User-Email": user.email, + "X-OpenWebUI-User-Role": user.role, + } + if ENABLE_FORWARD_USER_INFO_HEADERS and user + else {} + ), + }, ) as r: if r.status != 200: detail = f"HTTP Error: {r.status}" @@ -254,7 +296,7 @@ async def update_config( @cached(ttl=3) -async def get_all_models(request: Request): +async def get_all_models(request: Request, user: UserModel = None): log.info("get_all_models()") if request.app.state.config.ENABLE_OLLAMA_API: request_tasks = [] @@ -262,7 +304,7 @@ async def get_all_models(request: Request): if (str(idx) not in request.app.state.config.OLLAMA_API_CONFIGS) and ( url not in request.app.state.config.OLLAMA_API_CONFIGS # Legacy support ): - request_tasks.append(send_get_request(f"{url}/api/tags")) + request_tasks.append(send_get_request(f"{url}/api/tags", user=user)) else: api_config = request.app.state.config.OLLAMA_API_CONFIGS.get( str(idx), @@ -275,7 +317,9 @@ async def get_all_models(request: Request): key = api_config.get("key", None) if enable: - request_tasks.append(send_get_request(f"{url}/api/tags", key)) + request_tasks.append( + send_get_request(f"{url}/api/tags", key, user=user) + ) else: request_tasks.append(asyncio.ensure_future(asyncio.sleep(0, None))) @@ -360,7 +404,7 @@ async def get_ollama_tags( models = [] if url_idx is None: - models = await get_all_models(request) + models = await get_all_models(request, user=user) else: url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS) @@ -370,7 +414,19 @@ async def get_ollama_tags( r = requests.request( method="GET", url=f"{url}/api/tags", - headers={**({"Authorization": f"Bearer {key}"} if key else {})}, + headers={ + **({"Authorization": f"Bearer {key}"} if key else {}), + **( + { + "X-OpenWebUI-User-Name": user.name, + "X-OpenWebUI-User-Id": user.id, + "X-OpenWebUI-User-Email": user.email, + "X-OpenWebUI-User-Role": user.role, + } + if ENABLE_FORWARD_USER_INFO_HEADERS and user + else {} + ), + }, ) r.raise_for_status() @@ -477,6 +533,7 @@ async def get_ollama_loaded_models(request: Request, user=Depends(get_verified_u url, {} ), # Legacy support ).get("key", None), + user=user, ) for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS) ] @@ -509,6 +566,7 @@ async def pull_model( url=f"{url}/api/pull", payload=json.dumps(payload), key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS), + user=user, ) @@ -527,7 +585,7 @@ async def push_model( user=Depends(get_admin_user), ): if url_idx is None: - await get_all_models(request) + await get_all_models(request, user=user) models = request.app.state.OLLAMA_MODELS if form_data.name in models: @@ -545,6 +603,7 @@ async def push_model( url=f"{url}/api/push", payload=form_data.model_dump_json(exclude_none=True).encode(), key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS), + user=user, ) @@ -571,6 +630,7 @@ async def create_model( url=f"{url}/api/create", payload=form_data.model_dump_json(exclude_none=True).encode(), key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS), + user=user, ) @@ -588,7 +648,7 @@ async def copy_model( user=Depends(get_admin_user), ): if url_idx is None: - await get_all_models(request) + await get_all_models(request, user=user) models = request.app.state.OLLAMA_MODELS if form_data.source in models: @@ -609,6 +669,16 @@ async def copy_model( headers={ "Content-Type": "application/json", **({"Authorization": f"Bearer {key}"} if key else {}), + **( + { + "X-OpenWebUI-User-Name": user.name, + "X-OpenWebUI-User-Id": user.id, + "X-OpenWebUI-User-Email": user.email, + "X-OpenWebUI-User-Role": user.role, + } + if ENABLE_FORWARD_USER_INFO_HEADERS and user + else {} + ), }, data=form_data.model_dump_json(exclude_none=True).encode(), ) @@ -643,7 +713,7 @@ async def delete_model( user=Depends(get_admin_user), ): if url_idx is None: - await get_all_models(request) + await get_all_models(request, user=user) models = request.app.state.OLLAMA_MODELS if form_data.name in models: @@ -665,6 +735,16 @@ async def delete_model( headers={ "Content-Type": "application/json", **({"Authorization": f"Bearer {key}"} if key else {}), + **( + { + "X-OpenWebUI-User-Name": user.name, + "X-OpenWebUI-User-Id": user.id, + "X-OpenWebUI-User-Email": user.email, + "X-OpenWebUI-User-Role": user.role, + } + if ENABLE_FORWARD_USER_INFO_HEADERS and user + else {} + ), }, ) r.raise_for_status() @@ -693,7 +773,7 @@ async def delete_model( async def show_model_info( request: Request, form_data: ModelNameForm, user=Depends(get_verified_user) ): - await get_all_models(request) + await get_all_models(request, user=user) models = request.app.state.OLLAMA_MODELS if form_data.name not in models: @@ -714,6 +794,16 @@ async def show_model_info( headers={ "Content-Type": "application/json", **({"Authorization": f"Bearer {key}"} if key else {}), + **( + { + "X-OpenWebUI-User-Name": user.name, + "X-OpenWebUI-User-Id": user.id, + "X-OpenWebUI-User-Email": user.email, + "X-OpenWebUI-User-Role": user.role, + } + if ENABLE_FORWARD_USER_INFO_HEADERS and user + else {} + ), }, data=form_data.model_dump_json(exclude_none=True).encode(), ) @@ -757,7 +847,7 @@ async def embed( log.info(f"generate_ollama_batch_embeddings {form_data}") if url_idx is None: - await get_all_models(request) + await get_all_models(request, user=user) models = request.app.state.OLLAMA_MODELS model = form_data.model @@ -783,6 +873,16 @@ async def embed( headers={ "Content-Type": "application/json", **({"Authorization": f"Bearer {key}"} if key else {}), + **( + { + "X-OpenWebUI-User-Name": user.name, + "X-OpenWebUI-User-Id": user.id, + "X-OpenWebUI-User-Email": user.email, + "X-OpenWebUI-User-Role": user.role, + } + if ENABLE_FORWARD_USER_INFO_HEADERS and user + else {} + ), }, data=form_data.model_dump_json(exclude_none=True).encode(), ) @@ -826,7 +926,7 @@ async def embeddings( log.info(f"generate_ollama_embeddings {form_data}") if url_idx is None: - await get_all_models(request) + await get_all_models(request, user=user) models = request.app.state.OLLAMA_MODELS model = form_data.model @@ -852,6 +952,16 @@ async def embeddings( headers={ "Content-Type": "application/json", **({"Authorization": f"Bearer {key}"} if key else {}), + **( + { + "X-OpenWebUI-User-Name": user.name, + "X-OpenWebUI-User-Id": user.id, + "X-OpenWebUI-User-Email": user.email, + "X-OpenWebUI-User-Role": user.role, + } + if ENABLE_FORWARD_USER_INFO_HEADERS and user + else {} + ), }, data=form_data.model_dump_json(exclude_none=True).encode(), ) @@ -901,7 +1011,7 @@ async def generate_completion( user=Depends(get_verified_user), ): if url_idx is None: - await get_all_models(request) + await get_all_models(request, user=user) models = request.app.state.OLLAMA_MODELS model = form_data.model @@ -931,6 +1041,7 @@ async def generate_completion( url=f"{url}/api/generate", payload=form_data.model_dump_json(exclude_none=True).encode(), key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS), + user=user, ) @@ -1060,6 +1171,7 @@ async def generate_chat_completion( stream=form_data.stream, key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS), content_type="application/x-ndjson", + user=user, ) @@ -1162,6 +1274,7 @@ async def generate_openai_completion( payload=json.dumps(payload), stream=payload.get("stream", False), key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS), + user=user, ) @@ -1240,6 +1353,7 @@ async def generate_openai_chat_completion( payload=json.dumps(payload), stream=payload.get("stream", False), key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS), + user=user, ) @@ -1253,7 +1367,7 @@ async def get_openai_models( models = [] if url_idx is None: - model_list = await get_all_models(request) + model_list = await get_all_models(request, user=user) models = [ { "id": model["model"], diff --git a/backend/open_webui/routers/openai.py b/backend/open_webui/routers/openai.py index afda36237..dff2461ea 100644 --- a/backend/open_webui/routers/openai.py +++ b/backend/open_webui/routers/openai.py @@ -26,6 +26,7 @@ from open_webui.env import ( ENABLE_FORWARD_USER_INFO_HEADERS, BYPASS_MODEL_ACCESS_CONTROL, ) +from open_webui.models.users import UserModel from open_webui.constants import ERROR_MESSAGES from open_webui.env import ENV, SRC_LOG_LEVELS @@ -51,12 +52,25 @@ log.setLevel(SRC_LOG_LEVELS["OPENAI"]) ########################################## -async def send_get_request(url, key=None): +async def send_get_request(url, key=None, user: UserModel = None): timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST) try: async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: async with session.get( - url, headers={**({"Authorization": f"Bearer {key}"} if key else {})} + url, + headers={ + **({"Authorization": f"Bearer {key}"} if key else {}), + **( + { + "X-OpenWebUI-User-Name": user.name, + "X-OpenWebUI-User-Id": user.id, + "X-OpenWebUI-User-Email": user.email, + "X-OpenWebUI-User-Role": user.role, + } + if ENABLE_FORWARD_USER_INFO_HEADERS + else {} + ), + }, ) as response: return await response.json() except Exception as e: @@ -84,9 +98,15 @@ def openai_o1_o3_handler(payload): payload["max_completion_tokens"] = payload["max_tokens"] del payload["max_tokens"] - # Fix: O1 does not support the "system" parameter, Modify "system" to "user" + # Fix: o1 and o3 do not support the "system" role directly. + # For older models like "o1-mini" or "o1-preview", use role "user". + # For newer o1/o3 models, replace "system" with "developer". if payload["messages"][0]["role"] == "system": - payload["messages"][0]["role"] = "user" + model_lower = payload["model"].lower() + if model_lower.startswith("o1-mini") or model_lower.startswith("o1-preview"): + payload["messages"][0]["role"] = "user" + else: + payload["messages"][0]["role"] = "developer" return payload @@ -247,7 +267,7 @@ async def speech(request: Request, user=Depends(get_verified_user)): raise HTTPException(status_code=401, detail=ERROR_MESSAGES.OPENAI_NOT_FOUND) -async def get_all_models_responses(request: Request) -> list: +async def get_all_models_responses(request: Request, user: UserModel) -> list: if not request.app.state.config.ENABLE_OPENAI_API: return [] @@ -271,7 +291,9 @@ async def get_all_models_responses(request: Request) -> list: ): request_tasks.append( send_get_request( - f"{url}/models", request.app.state.config.OPENAI_API_KEYS[idx] + f"{url}/models", + request.app.state.config.OPENAI_API_KEYS[idx], + user=user, ) ) else: @@ -291,6 +313,7 @@ async def get_all_models_responses(request: Request) -> list: send_get_request( f"{url}/models", request.app.state.config.OPENAI_API_KEYS[idx], + user=user, ) ) else: @@ -352,13 +375,13 @@ async def get_filtered_models(models, user): @cached(ttl=3) -async def get_all_models(request: Request) -> dict[str, list]: +async def get_all_models(request: Request, user: UserModel) -> dict[str, list]: log.info("get_all_models()") if not request.app.state.config.ENABLE_OPENAI_API: return {"data": []} - responses = await get_all_models_responses(request) + responses = await get_all_models_responses(request, user=user) def extract_data(response): if response and "data" in response: @@ -418,7 +441,7 @@ async def get_models( } if url_idx is None: - models = await get_all_models(request) + models = await get_all_models(request, user=user) else: url = request.app.state.config.OPENAI_API_BASE_URLS[url_idx] key = request.app.state.config.OPENAI_API_KEYS[url_idx] @@ -515,6 +538,16 @@ async def verify_connection( headers={ "Authorization": f"Bearer {key}", "Content-Type": "application/json", + **( + { + "X-OpenWebUI-User-Name": user.name, + "X-OpenWebUI-User-Id": user.id, + "X-OpenWebUI-User-Email": user.email, + "X-OpenWebUI-User-Role": user.role, + } + if ENABLE_FORWARD_USER_INFO_HEADERS + else {} + ), }, ) as r: if r.status != 200: @@ -587,7 +620,7 @@ async def generate_chat_completion( detail="Model not found", ) - await get_all_models(request) + await get_all_models(request, user=user) model = request.app.state.OPENAI_MODELS.get(model_id) if model: idx = model["urlIdx"] @@ -777,7 +810,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)): if r is not None: try: res = await r.json() - print(res) + log.error(res) if "error" in res: detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}" except Exception: diff --git a/backend/open_webui/routers/pipelines.py b/backend/open_webui/routers/pipelines.py index ad280b65c..598f90f65 100644 --- a/backend/open_webui/routers/pipelines.py +++ b/backend/open_webui/routers/pipelines.py @@ -101,7 +101,7 @@ async def process_pipeline_inlet_filter(request, payload, user, models): if "detail" in res: raise Exception(response.status, res["detail"]) except Exception as e: - print(f"Connection error: {e}") + log.exception(f"Connection error: {e}") return payload @@ -153,7 +153,7 @@ async def process_pipeline_outlet_filter(request, payload, user, models): except Exception: pass except Exception as e: - print(f"Connection error: {e}") + log.exception(f"Connection error: {e}") return payload @@ -196,7 +196,7 @@ async def upload_pipeline( file: UploadFile = File(...), user=Depends(get_admin_user), ): - print("upload_pipeline", urlIdx, file.filename) + log.info(f"upload_pipeline: urlIdx={urlIdx}, filename={file.filename}") # Check if the uploaded file is a python file if not (file.filename and file.filename.endswith(".py")): raise HTTPException( @@ -231,7 +231,7 @@ async def upload_pipeline( return {**data} except Exception as e: # Handle connection error here - print(f"Connection error: {e}") + log.exception(f"Connection error: {e}") detail = None status_code = status.HTTP_404_NOT_FOUND @@ -282,7 +282,7 @@ async def add_pipeline( return {**data} except Exception as e: # Handle connection error here - print(f"Connection error: {e}") + log.exception(f"Connection error: {e}") detail = None if r is not None: @@ -327,7 +327,7 @@ async def delete_pipeline( return {**data} except Exception as e: # Handle connection error here - print(f"Connection error: {e}") + log.exception(f"Connection error: {e}") detail = None if r is not None: @@ -361,7 +361,7 @@ async def get_pipelines( return {**data} except Exception as e: # Handle connection error here - print(f"Connection error: {e}") + log.exception(f"Connection error: {e}") detail = None if r is not None: @@ -400,7 +400,7 @@ async def get_pipeline_valves( return {**data} except Exception as e: # Handle connection error here - print(f"Connection error: {e}") + log.exception(f"Connection error: {e}") detail = None if r is not None: @@ -440,7 +440,7 @@ async def get_pipeline_valves_spec( return {**data} except Exception as e: # Handle connection error here - print(f"Connection error: {e}") + log.exception(f"Connection error: {e}") detail = None if r is not None: @@ -482,7 +482,7 @@ async def update_pipeline_valves( return {**data} except Exception as e: # Handle connection error here - print(f"Connection error: {e}") + log.exception(f"Connection error: {e}") detail = None diff --git a/backend/open_webui/routers/retrieval.py b/backend/open_webui/routers/retrieval.py index e69d2ce96..9279f9fa3 100644 --- a/backend/open_webui/routers/retrieval.py +++ b/backend/open_webui/routers/retrieval.py @@ -353,9 +353,14 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)): "pdf_extract_images": request.app.state.config.PDF_EXTRACT_IMAGES, "RAG_FULL_CONTEXT": request.app.state.config.RAG_FULL_CONTEXT, "enable_google_drive_integration": request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION, + "enable_onedrive_integration": request.app.state.config.ENABLE_ONEDRIVE_INTEGRATION, "content_extraction": { "engine": request.app.state.config.CONTENT_EXTRACTION_ENGINE, "tika_server_url": request.app.state.config.TIKA_SERVER_URL, + "document_intelligence_config": { + "endpoint": request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT, + "key": request.app.state.config.DOCUMENT_INTELLIGENCE_KEY, + }, }, "chunk": { "text_splitter": request.app.state.config.TEXT_SPLITTER, @@ -377,6 +382,7 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)): "search": { "enabled": request.app.state.config.ENABLE_RAG_WEB_SEARCH, "drive": request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION, + "onedrive": request.app.state.config.ENABLE_ONEDRIVE_INTEGRATION, "engine": request.app.state.config.RAG_WEB_SEARCH_ENGINE, "searxng_query_url": request.app.state.config.SEARXNG_QUERY_URL, "google_pse_api_key": request.app.state.config.GOOGLE_PSE_API_KEY, @@ -399,6 +405,7 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)): "bing_search_v7_subscription_key": request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY, "exa_api_key": request.app.state.config.EXA_API_KEY, "result_count": request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, + "trust_env": request.app.state.config.RAG_WEB_SEARCH_TRUST_ENV, "concurrent_requests": request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS, "domain_filter_list": request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, }, @@ -411,9 +418,15 @@ class FileConfig(BaseModel): max_count: Optional[int] = None +class DocumentIntelligenceConfigForm(BaseModel): + endpoint: str + key: str + + class ContentExtractionConfig(BaseModel): engine: str = "" tika_server_url: Optional[str] = None + document_intelligence_config: Optional[DocumentIntelligenceConfigForm] = None class ChunkParamUpdateForm(BaseModel): @@ -467,6 +480,7 @@ class ConfigUpdateForm(BaseModel): RAG_FULL_CONTEXT: Optional[bool] = None pdf_extract_images: Optional[bool] = None enable_google_drive_integration: Optional[bool] = None + enable_onedrive_integration: Optional[bool] = None file: Optional[FileConfig] = None content_extraction: Optional[ContentExtractionConfig] = None chunk: Optional[ChunkParamUpdateForm] = None @@ -496,18 +510,33 @@ async def update_rag_config( else request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION ) + request.app.state.config.ENABLE_ONEDRIVE_INTEGRATION = ( + form_data.enable_onedrive_integration + if form_data.enable_onedrive_integration is not None + else request.app.state.config.ENABLE_ONEDRIVE_INTEGRATION + ) + if form_data.file is not None: request.app.state.config.FILE_MAX_SIZE = form_data.file.max_size request.app.state.config.FILE_MAX_COUNT = form_data.file.max_count if form_data.content_extraction is not None: - log.info(f"Updating text settings: {form_data.content_extraction}") + log.info( + f"Updating content extraction: {request.app.state.config.CONTENT_EXTRACTION_ENGINE} to {form_data.content_extraction.engine}" + ) request.app.state.config.CONTENT_EXTRACTION_ENGINE = ( form_data.content_extraction.engine ) request.app.state.config.TIKA_SERVER_URL = ( form_data.content_extraction.tika_server_url ) + if form_data.content_extraction.document_intelligence_config is not None: + request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT = ( + form_data.content_extraction.document_intelligence_config.endpoint + ) + request.app.state.config.DOCUMENT_INTELLIGENCE_KEY = ( + form_data.content_extraction.document_intelligence_config.key + ) if form_data.chunk is not None: request.app.state.config.TEXT_SPLITTER = form_data.chunk.text_splitter @@ -604,6 +633,10 @@ async def update_rag_config( "content_extraction": { "engine": request.app.state.config.CONTENT_EXTRACTION_ENGINE, "tika_server_url": request.app.state.config.TIKA_SERVER_URL, + "document_intelligence_config": { + "endpoint": request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT, + "key": request.app.state.config.DOCUMENT_INTELLIGENCE_KEY, + }, }, "chunk": { "text_splitter": request.app.state.config.TEXT_SPLITTER, @@ -937,6 +970,8 @@ def process_file( engine=request.app.state.config.CONTENT_EXTRACTION_ENGINE, TIKA_SERVER_URL=request.app.state.config.TIKA_SERVER_URL, PDF_EXTRACT_IMAGES=request.app.state.config.PDF_EXTRACT_IMAGES, + DOCUMENT_INTELLIGENCE_ENDPOINT=request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT, + DOCUMENT_INTELLIGENCE_KEY=request.app.state.config.DOCUMENT_INTELLIGENCE_KEY, ) docs = loader.load( file.filename, file.meta.get("content_type"), file_path @@ -1553,11 +1588,11 @@ def reset_upload_dir(user=Depends(get_admin_user)) -> bool: elif os.path.isdir(file_path): shutil.rmtree(file_path) # Remove the directory except Exception as e: - print(f"Failed to delete {file_path}. Reason: {e}") + log.exception(f"Failed to delete {file_path}. Reason: {e}") else: - print(f"The directory {folder} does not exist") + log.warning(f"The directory {folder} does not exist") except Exception as e: - print(f"Failed to process the directory {folder}. Reason: {e}") + log.exception(f"Failed to process the directory {folder}. Reason: {e}") return True diff --git a/backend/open_webui/routers/tasks.py b/backend/open_webui/routers/tasks.py index 0328cefe0..b63c9732a 100644 --- a/backend/open_webui/routers/tasks.py +++ b/backend/open_webui/routers/tasks.py @@ -20,6 +20,10 @@ from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.constants import TASKS from open_webui.routers.pipelines import process_pipeline_inlet_filter +from open_webui.utils.filter import ( + get_sorted_filter_ids, + process_filter_functions, +) from open_webui.utils.task import get_task_model_id from open_webui.config import ( @@ -221,6 +225,12 @@ async def generate_title( }, } + # Process the payload through the pipeline + try: + payload = await process_pipeline_inlet_filter(request, payload, user, models) + except Exception as e: + raise e + try: return await generate_chat_completion(request, form_data=payload, user=user) except Exception as e: @@ -290,6 +300,12 @@ async def generate_chat_tags( }, } + # Process the payload through the pipeline + try: + payload = await process_pipeline_inlet_filter(request, payload, user, models) + except Exception as e: + raise e + try: return await generate_chat_completion(request, form_data=payload, user=user) except Exception as e: @@ -356,6 +372,12 @@ async def generate_image_prompt( }, } + # Process the payload through the pipeline + try: + payload = await process_pipeline_inlet_filter(request, payload, user, models) + except Exception as e: + raise e + try: return await generate_chat_completion(request, form_data=payload, user=user) except Exception as e: @@ -433,6 +455,12 @@ async def generate_queries( }, } + # Process the payload through the pipeline + try: + payload = await process_pipeline_inlet_filter(request, payload, user, models) + except Exception as e: + raise e + try: return await generate_chat_completion(request, form_data=payload, user=user) except Exception as e: @@ -514,6 +542,12 @@ async def generate_autocompletion( }, } + # Process the payload through the pipeline + try: + payload = await process_pipeline_inlet_filter(request, payload, user, models) + except Exception as e: + raise e + try: return await generate_chat_completion(request, form_data=payload, user=user) except Exception as e: @@ -584,6 +618,12 @@ async def generate_emoji( }, } + # Process the payload through the pipeline + try: + payload = await process_pipeline_inlet_filter(request, payload, user, models) + except Exception as e: + raise e + try: return await generate_chat_completion(request, form_data=payload, user=user) except Exception as e: @@ -644,6 +684,12 @@ async def generate_moa_response( }, } + # Process the payload through the pipeline + try: + payload = await process_pipeline_inlet_filter(request, payload, user, models) + except Exception as e: + raise e + try: return await generate_chat_completion(request, form_data=payload, user=user) except Exception as e: diff --git a/backend/open_webui/routers/tools.py b/backend/open_webui/routers/tools.py index d6a5c5532..5e4109037 100644 --- a/backend/open_webui/routers/tools.py +++ b/backend/open_webui/routers/tools.py @@ -1,3 +1,4 @@ +import logging from pathlib import Path from typing import Optional @@ -15,6 +16,10 @@ from fastapi import APIRouter, Depends, HTTPException, Request, status from open_webui.utils.tools import get_tools_specs from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.utils.access_control import has_access, has_permission +from open_webui.env import SRC_LOG_LEVELS + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["MAIN"]) router = APIRouter() @@ -111,7 +116,7 @@ async def create_new_tools( detail=ERROR_MESSAGES.DEFAULT("Error creating tools"), ) except Exception as e: - print(e) + log.exception(f"Failed to load the tool by id {form_data.id}: {e}") raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT(str(e)), @@ -193,7 +198,7 @@ async def update_tools_by_id( "specs": specs, } - print(updated) + log.debug(updated) tools = Tools.update_tool_by_id(id, updated) if tools: @@ -343,7 +348,7 @@ async def update_tools_valves_by_id( Tools.update_tool_valves_by_id(id, valves.model_dump()) return valves.model_dump() except Exception as e: - print(e) + log.exception(f"Failed to update tool valves by id {id}: {e}") raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT(str(e)), @@ -421,7 +426,7 @@ async def update_tools_user_valves_by_id( ) return user_valves.model_dump() except Exception as e: - print(e) + log.exception(f"Failed to update user valves by id {id}: {e}") raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT(str(e)), diff --git a/backend/open_webui/routers/utils.py b/backend/open_webui/routers/utils.py index fb1dc8272..237e732de 100644 --- a/backend/open_webui/routers/utils.py +++ b/backend/open_webui/routers/utils.py @@ -1,4 +1,5 @@ import black +import logging import markdown from open_webui.models.chats import ChatTitleMessagesForm @@ -13,11 +14,14 @@ from open_webui.utils.misc import get_gravatar_url from open_webui.utils.pdf_generator import PDFGenerator from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.utils.code_interpreter import execute_code_jupyter +from open_webui.env import SRC_LOG_LEVELS +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["MAIN"]) + router = APIRouter() - @router.get("/gravatar") async def get_gravatar(email: str, user=Depends(get_verified_user)): return get_gravatar_url(email) @@ -96,7 +100,7 @@ async def download_chat_as_pdf( headers={"Content-Disposition": "attachment;filename=chat.pdf"}, ) except Exception as e: - print(e) + log.exception(f"Error generating PDF: {e}") raise HTTPException(status_code=400, detail=str(e)) diff --git a/backend/open_webui/storage/provider.py b/backend/open_webui/storage/provider.py index 160a45153..2f31cbdaf 100644 --- a/backend/open_webui/storage/provider.py +++ b/backend/open_webui/storage/provider.py @@ -1,10 +1,12 @@ import os import shutil import json +import logging from abc import ABC, abstractmethod from typing import BinaryIO, Tuple import boto3 +from botocore.config import Config from botocore.exceptions import ClientError from open_webui.config import ( S3_ACCESS_KEY_ID, @@ -13,6 +15,8 @@ from open_webui.config import ( S3_KEY_PREFIX, S3_REGION_NAME, S3_SECRET_ACCESS_KEY, + S3_USE_ACCELERATE_ENDPOINT, + S3_ADDRESSING_STYLE, GCS_BUCKET_NAME, GOOGLE_APPLICATION_CREDENTIALS_JSON, AZURE_STORAGE_ENDPOINT, @@ -27,6 +31,11 @@ from open_webui.constants import ERROR_MESSAGES from azure.identity import DefaultAzureCredential from azure.storage.blob import BlobServiceClient from azure.core.exceptions import ResourceNotFoundError +from open_webui.env import SRC_LOG_LEVELS + + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["MAIN"]) class StorageProvider(ABC): @@ -71,7 +80,7 @@ class LocalStorageProvider(StorageProvider): if os.path.isfile(file_path): os.remove(file_path) else: - print(f"File {file_path} not found in local storage.") + log.warning(f"File {file_path} not found in local storage.") @staticmethod def delete_all_files() -> None: @@ -85,9 +94,9 @@ class LocalStorageProvider(StorageProvider): elif os.path.isdir(file_path): shutil.rmtree(file_path) # Remove the directory except Exception as e: - print(f"Failed to delete {file_path}. Reason: {e}") + log.exception(f"Failed to delete {file_path}. Reason: {e}") else: - print(f"Directory {UPLOAD_DIR} not found in local storage.") + log.warning(f"Directory {UPLOAD_DIR} not found in local storage.") class S3StorageProvider(StorageProvider): @@ -98,6 +107,12 @@ class S3StorageProvider(StorageProvider): endpoint_url=S3_ENDPOINT_URL, aws_access_key_id=S3_ACCESS_KEY_ID, aws_secret_access_key=S3_SECRET_ACCESS_KEY, + config=Config( + s3={ + "use_accelerate_endpoint": S3_USE_ACCELERATE_ENDPOINT, + "addressing_style": S3_ADDRESSING_STYLE, + }, + ), ) self.bucket_name = S3_BUCKET_NAME self.key_prefix = S3_KEY_PREFIX if S3_KEY_PREFIX else "" diff --git a/backend/open_webui/utils/auth.py b/backend/open_webui/utils/auth.py index e478284a6..529fcdf9d 100644 --- a/backend/open_webui/utils/auth.py +++ b/backend/open_webui/utils/auth.py @@ -14,14 +14,17 @@ from typing import Optional, Union, List, Dict from open_webui.models.users import Users from open_webui.constants import ERROR_MESSAGES -from open_webui.env import WEBUI_SECRET_KEY, TRUSTED_SIGNATURE_KEY, STATIC_DIR +from open_webui.env import WEBUI_SECRET_KEY, TRUSTED_SIGNATURE_KEY, STATIC_DIR, SRC_LOG_LEVELS from fastapi import Depends, HTTPException, Request, Response, status from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from passlib.context import CryptContext + logging.getLogger("passlib").setLevel(logging.ERROR) +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["OAUTH"]) SESSION_SECRET = WEBUI_SECRET_KEY ALGORITHM = "HS256" @@ -50,7 +53,7 @@ def verify_signature(payload: str, signature: str) -> bool: def override_static(path: str, content: str): # Ensure path is safe if "/" in path or ".." in path: - print(f"Invalid path: {path}") + log.error(f"Invalid path: {path}") return file_path = os.path.join(STATIC_DIR, path) @@ -82,11 +85,11 @@ def get_license_data(app, key): return True else: - print( + log.error( f"License: retrieval issue: {getattr(res, 'text', 'unknown error')}" ) except Exception as ex: - print(f"License: Uncaught Exception: {ex}") + log.exception(f"License: Uncaught Exception: {ex}") return False diff --git a/backend/open_webui/utils/chat.py b/backend/open_webui/utils/chat.py index 209fb02dc..74d0af4f7 100644 --- a/backend/open_webui/utils/chat.py +++ b/backend/open_webui/utils/chat.py @@ -66,7 +66,7 @@ async def generate_direct_chat_completion( user: Any, models: dict, ): - print("generate_direct_chat_completion") + log.info("generate_direct_chat_completion") metadata = form_data.pop("metadata", {}) @@ -103,7 +103,7 @@ async def generate_direct_chat_completion( } ) - print("res", res) + log.info(f"res: {res}") if res.get("status", False): # Define a generator to stream responses @@ -285,7 +285,7 @@ chat_completion = generate_chat_completion async def chat_completed(request: Request, form_data: dict, user: Any): if not request.app.state.MODELS: - await get_all_models(request) + await get_all_models(request, user=user) if getattr(request.state, "direct", False) and hasattr(request.state, "model"): models = { @@ -351,7 +351,7 @@ async def chat_action(request: Request, action_id: str, form_data: dict, user: A raise Exception(f"Action not found: {action_id}") if not request.app.state.MODELS: - await get_all_models(request) + await get_all_models(request, user=user) if getattr(request.state, "direct", False) and hasattr(request.state, "model"): models = { @@ -432,7 +432,7 @@ async def chat_action(request: Request, action_id: str, form_data: dict, user: A ) ) except Exception as e: - print(e) + log.exception(f"Failed to get user values: {e}") params = {**params, "__user__": __user__} diff --git a/backend/open_webui/utils/filter.py b/backend/open_webui/utils/filter.py index de51bd46e..0ca754ed8 100644 --- a/backend/open_webui/utils/filter.py +++ b/backend/open_webui/utils/filter.py @@ -1,6 +1,12 @@ import inspect +import logging + from open_webui.utils.plugin import load_function_module_by_id from open_webui.models.functions import Functions +from open_webui.env import SRC_LOG_LEVELS + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["MAIN"]) def get_sorted_filter_ids(model): @@ -61,7 +67,12 @@ async def process_filter_functions( try: # Prepare parameters sig = inspect.signature(handler) - params = {"body": form_data} | { + + params = {"body": form_data} + if filter_type == "stream": + params = {"event": form_data} + + params = params | { k: v for k, v in { **extra_params, @@ -80,7 +91,7 @@ async def process_filter_functions( ) ) except Exception as e: - print(e) + log.exception(f"Failed to get user values: {e}") # Execute handler if inspect.iscoroutinefunction(handler): @@ -89,7 +100,7 @@ async def process_filter_functions( form_data = handler(**params) except Exception as e: - print(f"Error in {filter_type} handler {filter_id}: {e}") + log.exception(f"Error in {filter_type} handler {filter_id}: {e}") raise e # Handle file cleanup for inlet diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index 7ec764fc0..f479da40c 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -1048,6 +1048,21 @@ async def process_chat_response( ): return response + extra_params = { + "__event_emitter__": event_emitter, + "__event_call__": event_caller, + "__user__": { + "id": user.id, + "email": user.email, + "name": user.name, + "role": user.role, + }, + "__metadata__": metadata, + "__request__": request, + "__model__": metadata.get("model"), + } + filter_ids = get_sorted_filter_ids(form_data.get("model")) + # Streaming response if event_emitter and event_caller: task_id = str(uuid4()) # Create a unique task ID. @@ -1127,12 +1142,12 @@ async def process_chat_response( if reasoning_duration is not None: if raw: - content = f'{content}\n<{block["tag"]}>{block["content"]}\n' + content = f'{content}\n<{block["start_tag"]}>{block["content"]}<{block["end_tag"]}>\n' else: content = f'{content}\n
\nThought for {reasoning_duration} seconds\n{reasoning_display_content}\n
\n' else: if raw: - content = f'{content}\n<{block["tag"]}>{block["content"]}\n' + content = f'{content}\n<{block["start_tag"]}>{block["content"]}<{block["end_tag"]}>\n' else: content = f'{content}\n
\nThinking…\n{reasoning_display_content}\n
\n' @@ -1228,9 +1243,9 @@ async def process_chat_response( return attributes if content_blocks[-1]["type"] == "text": - for tag in tags: + for start_tag, end_tag in tags: # Match start tag e.g., or - start_tag_pattern = rf"<{tag}(\s.*?)?>" + start_tag_pattern = rf"<{re.escape(start_tag)}(\s.*?)?>" match = re.search(start_tag_pattern, content) if match: attr_content = ( @@ -1263,7 +1278,8 @@ async def process_chat_response( content_blocks.append( { "type": content_type, - "tag": tag, + "start_tag": start_tag, + "end_tag": end_tag, "attributes": attributes, "content": "", "started_at": time.time(), @@ -1275,9 +1291,10 @@ async def process_chat_response( break elif content_blocks[-1]["type"] == content_type: - tag = content_blocks[-1]["tag"] + start_tag = content_blocks[-1]["start_tag"] + end_tag = content_blocks[-1]["end_tag"] # Match end tag e.g., - end_tag_pattern = rf"" + end_tag_pattern = rf"<{re.escape(end_tag)}>" # Check if the content has the end tag if re.search(end_tag_pattern, content): @@ -1285,7 +1302,7 @@ async def process_chat_response( block_content = content_blocks[-1]["content"] # Strip start and end tags from the content - start_tag_pattern = rf"<{tag}(.*?)>" + start_tag_pattern = rf"<{re.escape(start_tag)}(.*?)>" block_content = re.sub( start_tag_pattern, "", block_content ).strip() @@ -1350,7 +1367,7 @@ async def process_chat_response( # Clean processed content content = re.sub( - rf"<{tag}(.*?)>(.|\n)*?", + rf"<{re.escape(start_tag)}(.*?)>(.|\n)*?<{re.escape(end_tag)}>", "", content, flags=re.DOTALL, @@ -1388,19 +1405,24 @@ async def process_chat_response( # We might want to disable this by default DETECT_REASONING = True + DETECT_SOLUTION = True DETECT_CODE_INTERPRETER = metadata.get("features", {}).get( "code_interpreter", False ) reasoning_tags = [ - "think", - "thinking", - "reason", - "reasoning", - "thought", - "Thought", + ("think", "/think"), + ("thinking", "/thinking"), + ("reason", "/reason"), + ("reasoning", "/reasoning"), + ("thought", "/thought"), + ("Thought", "/Thought"), + ("|begin_of_thought|", "|end_of_thought|"), ] - code_interpreter_tags = ["code_interpreter"] + + code_interpreter_tags = [("code_interpreter", "/code_interpreter")] + + solution_tags = [("|begin_of_solution|", "|end_of_solution|")] try: for event in events: @@ -1444,119 +1466,154 @@ async def process_chat_response( try: data = json.loads(data) - if "selected_model_id" in data: - model_id = data["selected_model_id"] - Chats.upsert_message_to_chat_by_id_and_message_id( - metadata["chat_id"], - metadata["message_id"], - { - "selectedModelId": model_id, - }, - ) - else: - choices = data.get("choices", []) - if not choices: - continue + data, _ = await process_filter_functions( + request=request, + filter_ids=filter_ids, + filter_type="stream", + form_data=data, + extra_params=extra_params, + ) - delta = choices[0].get("delta", {}) - delta_tool_calls = delta.get("tool_calls", None) - - if delta_tool_calls: - for delta_tool_call in delta_tool_calls: - tool_call_index = delta_tool_call.get("index") - - if tool_call_index is not None: - if ( - len(response_tool_calls) - <= tool_call_index - ): - response_tool_calls.append( - delta_tool_call - ) - else: - delta_name = delta_tool_call.get( - "function", {} - ).get("name") - delta_arguments = delta_tool_call.get( - "function", {} - ).get("arguments") - - if delta_name: - response_tool_calls[ - tool_call_index - ]["function"]["name"] += delta_name - - if delta_arguments: - response_tool_calls[ - tool_call_index - ]["function"][ - "arguments" - ] += delta_arguments - - value = delta.get("content") - - if value: - content = f"{content}{value}" - - if not content_blocks: - content_blocks.append( - { - "type": "text", - "content": "", - } - ) - - content_blocks[-1]["content"] = ( - content_blocks[-1]["content"] + value + if data: + if "selected_model_id" in data: + model_id = data["selected_model_id"] + Chats.upsert_message_to_chat_by_id_and_message_id( + metadata["chat_id"], + metadata["message_id"], + { + "selectedModelId": model_id, + }, ) - - if DETECT_REASONING: - content, content_blocks, _ = ( - tag_content_handler( - "reasoning", - reasoning_tags, - content, - content_blocks, + else: + choices = data.get("choices", []) + if not choices: + usage = data.get("usage", {}) + if usage: + await event_emitter( + { + "type": "chat:completion", + "data": { + "usage": usage, + }, + } ) + continue + + delta = choices[0].get("delta", {}) + delta_tool_calls = delta.get("tool_calls", None) + + if delta_tool_calls: + for delta_tool_call in delta_tool_calls: + tool_call_index = delta_tool_call.get( + "index" + ) + + if tool_call_index is not None: + if ( + len(response_tool_calls) + <= tool_call_index + ): + response_tool_calls.append( + delta_tool_call + ) + else: + delta_name = delta_tool_call.get( + "function", {} + ).get("name") + delta_arguments = ( + delta_tool_call.get( + "function", {} + ).get("arguments") + ) + + if delta_name: + response_tool_calls[ + tool_call_index + ]["function"][ + "name" + ] += delta_name + + if delta_arguments: + response_tool_calls[ + tool_call_index + ]["function"][ + "arguments" + ] += delta_arguments + + value = delta.get("content") + + if value: + content = f"{content}{value}" + + if not content_blocks: + content_blocks.append( + { + "type": "text", + "content": "", + } + ) + + content_blocks[-1]["content"] = ( + content_blocks[-1]["content"] + value ) - if DETECT_CODE_INTERPRETER: - content, content_blocks, end = ( - tag_content_handler( - "code_interpreter", - code_interpreter_tags, - content, - content_blocks, + if DETECT_REASONING: + content, content_blocks, _ = ( + tag_content_handler( + "reasoning", + reasoning_tags, + content, + content_blocks, + ) ) - ) - if end: - break + if DETECT_CODE_INTERPRETER: + content, content_blocks, end = ( + tag_content_handler( + "code_interpreter", + code_interpreter_tags, + content, + content_blocks, + ) + ) - if ENABLE_REALTIME_CHAT_SAVE: - # Save message in the database - Chats.upsert_message_to_chat_by_id_and_message_id( - metadata["chat_id"], - metadata["message_id"], - { + if end: + break + + if DETECT_SOLUTION: + content, content_blocks, _ = ( + tag_content_handler( + "solution", + solution_tags, + content, + content_blocks, + ) + ) + + if ENABLE_REALTIME_CHAT_SAVE: + # Save message in the database + Chats.upsert_message_to_chat_by_id_and_message_id( + metadata["chat_id"], + metadata["message_id"], + { + "content": serialize_content_blocks( + content_blocks + ), + }, + ) + else: + data = { "content": serialize_content_blocks( content_blocks ), - }, - ) - else: - data = { - "content": serialize_content_blocks( - content_blocks - ), - } + } - await event_emitter( - { - "type": "chat:completion", - "data": data, - } - ) + await event_emitter( + { + "type": "chat:completion", + "data": data, + } + ) except Exception as e: done = "data: [DONE]" in line if done: @@ -1855,7 +1912,8 @@ async def process_chat_response( } ) - print(content_blocks, serialize_content_blocks(content_blocks)) + log.info(f"content_blocks={content_blocks}") + log.info(f"serialize_content_blocks={serialize_content_blocks(content_blocks)}") try: res = await generate_chat_completion( @@ -1926,7 +1984,7 @@ async def process_chat_response( await background_tasks_handler() except asyncio.CancelledError: - print("Task was cancelled!") + log.warning("Task was cancelled!") await event_emitter({"type": "task-cancelled"}) if not ENABLE_REALTIME_CHAT_SAVE: @@ -1947,17 +2005,34 @@ async def process_chat_response( return {"status": True, "task_id": task_id} else: - # Fallback to the original response async def stream_wrapper(original_generator, events): def wrap_item(item): return f"data: {item}\n\n" for event in events: - yield wrap_item(json.dumps(event)) + event, _ = await process_filter_functions( + request=request, + filter_ids=filter_ids, + filter_type="stream", + form_data=event, + extra_params=extra_params, + ) + + if event: + yield wrap_item(json.dumps(event)) async for data in original_generator: - yield data + data, _ = await process_filter_functions( + request=request, + filter_ids=filter_ids, + filter_type="stream", + form_data=data, + extra_params=extra_params, + ) + + if data: + yield data return StreamingResponse( stream_wrapper(response.body_iterator, events), diff --git a/backend/open_webui/utils/misc.py b/backend/open_webui/utils/misc.py index f79b62684..ce6cbc09c 100644 --- a/backend/open_webui/utils/misc.py +++ b/backend/open_webui/utils/misc.py @@ -2,13 +2,17 @@ import hashlib import re import time import uuid +import logging from datetime import timedelta from pathlib import Path from typing import Callable, Optional import collections.abc +from open_webui.env import SRC_LOG_LEVELS +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["MAIN"]) def deep_update(d, u): for k, v in u.items(): @@ -412,7 +416,7 @@ def parse_ollama_modelfile(model_text): elif param_type is bool: value = value.lower() == "true" except Exception as e: - print(e) + log.exception(f"Failed to parse parameter {param}: {e}") continue data["params"][param] = value diff --git a/backend/open_webui/utils/models.py b/backend/open_webui/utils/models.py index a2c0eadca..149e41a41 100644 --- a/backend/open_webui/utils/models.py +++ b/backend/open_webui/utils/models.py @@ -22,6 +22,7 @@ from open_webui.config import ( ) from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL +from open_webui.models.users import UserModel logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL) @@ -29,17 +30,17 @@ log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MAIN"]) -async def get_all_base_models(request: Request): +async def get_all_base_models(request: Request, user: UserModel = None): function_models = [] openai_models = [] ollama_models = [] if request.app.state.config.ENABLE_OPENAI_API: - openai_models = await openai.get_all_models(request) + openai_models = await openai.get_all_models(request, user=user) openai_models = openai_models["data"] if request.app.state.config.ENABLE_OLLAMA_API: - ollama_models = await ollama.get_all_models(request) + ollama_models = await ollama.get_all_models(request, user=user) ollama_models = [ { "id": model["model"], @@ -58,8 +59,8 @@ async def get_all_base_models(request: Request): return models -async def get_all_models(request): - models = await get_all_base_models(request) +async def get_all_models(request, user: UserModel = None): + models = await get_all_base_models(request, user=user) # If there are no models, return an empty list if len(models) == 0: diff --git a/backend/open_webui/utils/oauth.py b/backend/open_webui/utils/oauth.py index 13835e784..2af54c19d 100644 --- a/backend/open_webui/utils/oauth.py +++ b/backend/open_webui/utils/oauth.py @@ -146,7 +146,7 @@ class OAuthManager: nested_claims = oauth_claim.split(".") for nested_claim in nested_claims: claim_data = claim_data.get(nested_claim, {}) - user_oauth_groups = claim_data if isinstance(claim_data, list) else None + user_oauth_groups = claim_data if isinstance(claim_data, list) else [] user_current_groups: list[GroupModel] = Groups.get_groups_by_member_id(user.id) all_available_groups: list[GroupModel] = Groups.get_groups() @@ -315,15 +315,6 @@ class OAuthManager: if not user: user_count = Users.get_num_users() - if ( - request.app.state.USER_COUNT - and user_count >= request.app.state.USER_COUNT - ): - raise HTTPException( - 403, - detail=ERROR_MESSAGES.ACCESS_PROHIBITED, - ) - # If the user does not exist, check if signups are enabled if auth_manager_config.ENABLE_OAUTH_SIGNUP: # Check if an existing user with the same email already exists diff --git a/backend/open_webui/utils/payload.py b/backend/open_webui/utils/payload.py index 51e8d50cc..869e70895 100644 --- a/backend/open_webui/utils/payload.py +++ b/backend/open_webui/utils/payload.py @@ -124,7 +124,7 @@ def convert_messages_openai_to_ollama(messages: list[dict]) -> list[dict]: tool_call_id = message.get("tool_call_id", None) # Check if the content is a string (just a simple message) - if isinstance(content, str): + if isinstance(content, str) and not tool_calls: # If the content is a string, it's pure text new_message["content"] = content @@ -230,6 +230,12 @@ def convert_payload_openai_to_ollama(openai_payload: dict) -> dict: "system" ] # To prevent Ollama warning of invalid option provided + # If there is the "stop" parameter in the openai_payload, remap it to the ollama_payload.options + if "stop" in openai_payload: + ollama_options = ollama_payload.get("options", {}) + ollama_options["stop"] = openai_payload.get("stop") + ollama_payload["options"] = ollama_options + if "metadata" in openai_payload: ollama_payload["metadata"] = openai_payload["metadata"] diff --git a/backend/open_webui/utils/plugin.py b/backend/open_webui/utils/plugin.py index d6e24d6b9..e3fe9237f 100644 --- a/backend/open_webui/utils/plugin.py +++ b/backend/open_webui/utils/plugin.py @@ -45,7 +45,7 @@ def extract_frontmatter(content): frontmatter[key.strip()] = value.strip() except Exception as e: - print(f"An error occurred: {e}") + log.exception(f"Failed to extract frontmatter: {e}") return {} return frontmatter diff --git a/backend/open_webui/utils/response.py b/backend/open_webui/utils/response.py index bc47e1e13..8c3f1a58e 100644 --- a/backend/open_webui/utils/response.py +++ b/backend/open_webui/utils/response.py @@ -104,7 +104,7 @@ async def convert_streaming_response_ollama_to_openai(ollama_streaming_response) data = json.loads(data) model = data.get("model", "ollama") - message_content = data.get("message", {}).get("content", "") + message_content = data.get("message", {}).get("content", None) tool_calls = data.get("message", {}).get("tool_calls", None) openai_tool_calls = None @@ -118,7 +118,7 @@ async def convert_streaming_response_ollama_to_openai(ollama_streaming_response) usage = convert_ollama_usage_to_openai(data) data = openai_chat_chunk_message_template( - model, message_content if not done else None, openai_tool_calls, usage + model, message_content, openai_tool_calls, usage ) line = f"data: {json.dumps(data)}\n\n" diff --git a/backend/requirements.txt b/backend/requirements.txt index 965741f78..47d441ad5 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -71,6 +71,7 @@ validators==0.34.0 psutil sentencepiece soundfile==0.13.1 +azure-ai-documentintelligence==1.0.0 opencv-python-headless==4.11.0.86 rapidocr-onnxruntime==1.3.24 diff --git a/package-lock.json b/package-lock.json index c65870772..1ce7424f5 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "open-webui", - "version": "0.5.16", + "version": "0.5.17", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "open-webui", - "version": "0.5.16", + "version": "0.5.17", "dependencies": { "@codemirror/lang-javascript": "^6.2.2", "@codemirror/lang-python": "^6.1.6", diff --git a/package.json b/package.json index 86568869f..a5db14e19 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "open-webui", - "version": "0.5.16", + "version": "0.5.17", "private": true, "scripts": { "dev": "npm run pyodide:fetch && vite dev --host", diff --git a/pyproject.toml b/pyproject.toml index 5cd54da64..5282a9dba 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -78,6 +78,7 @@ dependencies = [ "psutil", "sentencepiece", "soundfile==0.13.1", + "azure-ai-documentintelligence==1.0.0", "opencv-python-headless==4.11.0.86", "rapidocr-onnxruntime==1.3.24", diff --git a/src/lib/apis/retrieval/index.ts b/src/lib/apis/retrieval/index.ts index c35c37847..31317fe0b 100644 --- a/src/lib/apis/retrieval/index.ts +++ b/src/lib/apis/retrieval/index.ts @@ -32,9 +32,15 @@ type ChunkConfigForm = { chunk_overlap: number; }; +type DocumentIntelligenceConfigForm = { + key: string; + endpoint: string; +}; + type ContentExtractConfigForm = { engine: string; tika_server_url: string | null; + document_intelligence_config: DocumentIntelligenceConfigForm | null; }; type YoutubeConfigForm = { @@ -46,6 +52,7 @@ type YoutubeConfigForm = { type RAGConfigForm = { pdf_extract_images?: boolean; enable_google_drive_integration?: boolean; + enable_onedrive_integration?: boolean; chunk?: ChunkConfigForm; content_extraction?: ContentExtractConfigForm; web_loader_ssl_verification?: boolean; diff --git a/src/lib/components/admin/Functions/FunctionEditor.svelte b/src/lib/components/admin/Functions/FunctionEditor.svelte index cbdec2425..6da2a83f4 100644 --- a/src/lib/components/admin/Functions/FunctionEditor.svelte +++ b/src/lib/components/admin/Functions/FunctionEditor.svelte @@ -1,8 +1,7 @@ @@ -585,10 +605,12 @@ bind:value={contentExtractionEngine} on:change={(e) => { showTikaServerUrl = e.target.value === 'tika'; + showDocumentIntelligenceConfig = e.target.value === 'document_intelligence'; }} > + @@ -604,6 +626,21 @@ {/if} + + {#if showDocumentIntelligenceConfig} +
+ + + +
+ {/if}
@@ -619,6 +656,17 @@ +
{$i18n.t('OneDrive')}
+ +
+
+
{$i18n.t('Enable OneDrive')}
+
+ +
+
+
+
diff --git a/src/lib/components/admin/Settings/WebSearch.svelte b/src/lib/components/admin/Settings/WebSearch.svelte index 84e9d0e5a..84729117b 100644 --- a/src/lib/components/admin/Settings/WebSearch.svelte +++ b/src/lib/components/admin/Settings/WebSearch.svelte @@ -130,6 +130,19 @@
+
+
{$i18n.t('Trust Proxy Environment')}
+
+ + + +
+
+ {#if webConfig.search.engine !== ''}
{#if webConfig.search.engine === 'searxng'} diff --git a/src/lib/components/chat/Chat.svelte b/src/lib/components/chat/Chat.svelte index fcd5177d7..2388af70b 100644 --- a/src/lib/components/chat/Chat.svelte +++ b/src/lib/components/chat/Chat.svelte @@ -187,15 +187,20 @@ setToolIds(); } + $: if (atSelectedModel || selectedModels) { + setToolIds(); + } + const setToolIds = async () => { if (!$tools) { tools.set(await getTools(localStorage.token)); } - if (selectedModels.length !== 1) { + if (selectedModels.length !== 1 && !atSelectedModel) { return; } - const model = $models.find((m) => m.id === selectedModels[0]); + + const model = atSelectedModel ?? $models.find((m) => m.id === selectedModels[0]); if (model) { selectedToolIds = (model?.info?.meta?.toolIds ?? []).filter((id) => $tools.find((t) => t.id === id) @@ -836,6 +841,7 @@ content: m.content, info: m.info ? m.info : undefined, timestamp: m.timestamp, + ...(m.usage ? { usage: m.usage } : {}), ...(m.sources ? { sources: m.sources } : {}) })), model_item: $models.find((m) => m.id === modelId), @@ -1273,7 +1279,9 @@ const chatInputElement = document.getElementById('chat-input'); if (chatInputElement) { + await tick(); chatInputElement.style.height = ''; + chatInputElement.style.height = Math.min(chatInputElement.scrollHeight, 320) + 'px'; } const _files = JSON.parse(JSON.stringify(files)); @@ -1488,7 +1496,10 @@ params?.system ?? $settings?.system ?? '', $user.name, $settings?.userLocation - ? await getAndUpdateUserLocation(localStorage.token) + ? await getAndUpdateUserLocation(localStorage.token).catch((err) => { + console.error(err); + return undefined; + }) : undefined )}${ (responseMessage?.userContext ?? null) @@ -1573,7 +1584,12 @@ variables: { ...getPromptVariables( $user.name, - $settings?.userLocation ? await getAndUpdateUserLocation(localStorage.token) : undefined + $settings?.userLocation + ? await getAndUpdateUserLocation(localStorage.token).catch((err) => { + console.error(err); + return undefined; + }) + : undefined ) }, model_item: $models.find((m) => m.id === model.id), @@ -1965,6 +1981,7 @@ bind:autoScroll bind:prompt {selectedModels} + {atSelectedModel} {sendPrompt} {showMessage} {submitMessage} diff --git a/src/lib/components/chat/ChatPlaceholder.svelte b/src/lib/components/chat/ChatPlaceholder.svelte index e30213ebb..0dfaec968 100644 --- a/src/lib/components/chat/ChatPlaceholder.svelte +++ b/src/lib/components/chat/ChatPlaceholder.svelte @@ -16,6 +16,7 @@ export let modelIds = []; export let models = []; + export let atSelectedModel; export let submitPrompt; @@ -126,7 +127,8 @@
{ diff --git a/src/lib/components/chat/MessageInput.svelte b/src/lib/components/chat/MessageInput.svelte index 5cde963ee..df07cb493 100644 --- a/src/lib/components/chat/MessageInput.svelte +++ b/src/lib/components/chat/MessageInput.svelte @@ -2,6 +2,7 @@ import { toast } from 'svelte-sonner'; import { v4 as uuidv4 } from 'uuid'; import { createPicker, getAuthToken } from '$lib/utils/google-drive-picker'; + import { pickAndDownloadFile } from '$lib/utils/onedrive-file-picker'; import { onMount, tick, getContext, createEventDispatcher, onDestroy } from 'svelte'; const dispatch = createEventDispatcher(); @@ -827,7 +828,11 @@ } // Submit the prompt when Enter key is pressed - if (prompt !== '' && e.keyCode === 13 && !e.shiftKey) { + if ( + (prompt !== '' || files.length > 0) && + e.keyCode === 13 && + !e.shiftKey + ) { dispatch('submit', prompt); } } @@ -906,7 +911,11 @@ } // Submit the prompt when Enter key is pressed - if (prompt !== '' && e.key === 'Enter' && !e.shiftKey) { + if ( + (prompt !== '' || files.length > 0) && + e.key === 'Enter' && + !e.shiftKey + ) { dispatch('submit', prompt); } } @@ -1108,6 +1117,21 @@ ); } }} + uploadOneDriveHandler={async () => { + try { + const fileData = await pickAndDownloadFile(); + if (fileData) { + const file = new File([fileData.blob], fileData.name, { + type: fileData.blob.type || 'application/octet-stream' + }); + await uploadFileHandler(file); + } else { + console.log('No file was selected from OneDrive'); + } + } catch (error) { + console.error('OneDrive Error:', error); + } + }} onClose={async () => { await tick(); @@ -1285,14 +1309,17 @@ stream = null; - if (!$TTSWorker) { - await TTSWorker.set( - new KokoroWorker({ - dtype: $settings.audio?.tts?.engineConfig?.dtype ?? 'fp32' - }) - ); + if ($settings.audio?.tts?.engine === 'browser-kokoro') { + // If the user has not initialized the TTS worker, initialize it + if (!$TTSWorker) { + await TTSWorker.set( + new KokoroWorker({ + dtype: $settings.audio?.tts?.engineConfig?.dtype ?? 'fp32' + }) + ); - await $TTSWorker.init(); + await $TTSWorker.init(); + } } showCallOverlay.set(true); diff --git a/src/lib/components/chat/MessageInput/Commands/Prompts.svelte b/src/lib/components/chat/MessageInput/Commands/Prompts.svelte index 76809e6d1..bb91e00a8 100644 --- a/src/lib/components/chat/MessageInput/Commands/Prompts.svelte +++ b/src/lib/components/chat/MessageInput/Commands/Prompts.svelte @@ -74,7 +74,13 @@ } if (command.content.includes('{{USER_LOCATION}}')) { - const location = await getUserPosition(); + let location; + try { + location = await getUserPosition(); + } catch (error) { + toast.error($i18n.t('Location access not allowed')); + location = 'LOCATION_UNKNOWN'; + } text = text.replaceAll('{{USER_LOCATION}}', String(location)); } diff --git a/src/lib/components/chat/MessageInput/InputMenu.svelte b/src/lib/components/chat/MessageInput/InputMenu.svelte index 801093d8f..ff97f0076 100644 --- a/src/lib/components/chat/MessageInput/InputMenu.svelte +++ b/src/lib/components/chat/MessageInput/InputMenu.svelte @@ -5,6 +5,7 @@ import { config, user, tools as _tools, mobile } from '$lib/stores'; import { createPicker } from '$lib/utils/google-drive-picker'; + import { getTools } from '$lib/apis/tools'; import Dropdown from '$lib/components/common/Dropdown.svelte'; @@ -24,6 +25,7 @@ export let inputFilesHandler: Function; export let uploadGoogleDriveHandler: Function; + export let uploadOneDriveHandler: Function; export let selectedToolIds: string[] = []; @@ -225,6 +227,97 @@
{$i18n.t('Google Drive')}
{/if} + + {#if $config?.features?.enable_onedrive_integration} + { + uploadOneDriveHandler(); + }} + > + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
{$i18n.t('OneDrive')}
+
+ {/if}
diff --git a/src/lib/components/chat/Messages.svelte b/src/lib/components/chat/Messages.svelte index 7c1c57a3f..72ed2a462 100644 --- a/src/lib/components/chat/Messages.svelte +++ b/src/lib/components/chat/Messages.svelte @@ -32,6 +32,7 @@ export let prompt; export let history = {}; export let selectedModels; + export let atSelectedModel; let messages = []; @@ -349,6 +350,7 @@ {#if Object.keys(history?.messages ?? {}).length == 0} { let text = p; diff --git a/src/lib/components/chat/Messages/CodeBlock.svelte b/src/lib/components/chat/Messages/CodeBlock.svelte index a5d08356f..f92c66210 100644 --- a/src/lib/components/chat/Messages/CodeBlock.svelte +++ b/src/lib/components/chat/Messages/CodeBlock.svelte @@ -1,18 +1,9 @@ @@ -802,10 +815,11 @@
{#if !edit} - {#if message.done || siblings.length > 1} -
+
+ {#if message.done || siblings.length > 1} {#if siblings.length > 1}
{/if}
diff --git a/src/lib/components/chat/Placeholder.svelte b/src/lib/components/chat/Placeholder.svelte index 4ae56f10e..74d6cd9fc 100644 --- a/src/lib/components/chat/Placeholder.svelte +++ b/src/lib/components/chat/Placeholder.svelte @@ -213,7 +213,8 @@
+ +
+
{$i18n.t('Generate prompt pair')}
+ +
+
+ Ctrl/⌘ +
+ +
+ Shift +
+ +
+ Enter +
+
+
@@ -219,7 +243,7 @@
- {$i18n.t('Attach file')} + {$i18n.t('Attach file from knowledge')}
@@ -247,7 +271,7 @@
- {$i18n.t('Select model')} + {$i18n.t('Talk to model')}
@@ -258,6 +282,20 @@
+ +
+
+ {$i18n.t('Accept autocomplete generation / Jump to prompt variable')} +
+ +
+
+ TAB +
+
+
diff --git a/src/lib/components/common/CodeEditor.svelte b/src/lib/components/common/CodeEditor.svelte index 7d9f3a55a..d545d7236 100644 --- a/src/lib/components/common/CodeEditor.svelte +++ b/src/lib/components/common/CodeEditor.svelte @@ -21,6 +21,10 @@ export let boilerplate = ''; export let value = ''; + + export let onSave = () => {}; + export let onChange = () => {}; + let _value = ''; $: if (value) { @@ -43,6 +47,10 @@ let codeEditor; + export const focus = () => { + codeEditor.focus(); + }; + let isDarkMode = false; let editorTheme = new Compartment(); let editorLanguage = new Compartment(); @@ -75,7 +83,7 @@ }); _value = formattedCode; - dispatch('change', { value: _value }); + onChange(_value); await tick(); toast.success($i18n.t('Code formatted successfully')); @@ -94,7 +102,7 @@ EditorView.updateListener.of((e) => { if (e.docChanged) { _value = e.state.doc.toString(); - dispatch('change', { value: _value }); + onChange(_value); } }), editorTheme.of([]), @@ -170,7 +178,8 @@ const keydownHandler = async (e) => { if ((e.ctrlKey || e.metaKey) && e.key === 's') { e.preventDefault(); - dispatch('save'); + + onSave(); } // Format code when Ctrl + Shift + F is pressed diff --git a/src/lib/components/workspace/Models/ModelEditor.svelte b/src/lib/components/workspace/Models/ModelEditor.svelte index 34b5a4b7b..170c37f22 100644 --- a/src/lib/components/workspace/Models/ModelEditor.svelte +++ b/src/lib/components/workspace/Models/ModelEditor.svelte @@ -180,7 +180,6 @@ } if (model) { - console.log(model); name = model.name; await tick(); diff --git a/src/lib/components/workspace/Tools/ToolkitEditor.svelte b/src/lib/components/workspace/Tools/ToolkitEditor.svelte index 60d231763..da2f555b4 100644 --- a/src/lib/components/workspace/Tools/ToolkitEditor.svelte +++ b/src/lib/components/workspace/Tools/ToolkitEditor.svelte @@ -1,5 +1,5 @@