diff --git a/CHANGELOG.md b/CHANGELOG.md index d19e82c39..795b35802 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,39 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.3.0] - 2024-06-09 + +### Added + +- **📚 Knowledge Support for Models**: Attach documents directly to models from the models workspace, enhancing the information available to each model. +- **🎙️ Hands-Free Voice Call Feature**: Initiate voice calls without needing to use your hands, making interactions more seamless. +- **📹 Video Call Feature**: Enable video calls with supported vision models like Llava and GPT-4o, adding a visual dimension to your communications. +- **🎛️ Enhanced UI for Voice Recording**: Improved user interface for the voice recording feature, making it more intuitive and user-friendly. +- **🌐 External STT Support**: Now support for external Speech-To-Text services, providing more flexibility in choosing your STT provider. +- **⚙️ Unified Settings**: Consolidated settings including document settings under a new admin settings section for easier management. +- **🌑 Dark Mode Splash Screen**: A new splash screen for dark mode, ensuring a consistent and visually appealing experience for dark mode users. +- **📥 Upload Pipeline**: Directly upload pipelines from the admin settings > pipelines section, streamlining the pipeline management process. +- **🌍 Improved Language Support**: Enhanced support for Chinese and Ukrainian languages, better catering to a global user base. + +### Fixed + +- **🛠️ Playground Issue**: Fixed the playground not functioning properly, ensuring a smoother user experience. +- **🔥 Temperature Parameter Issue**: Corrected the issue where the temperature value '0' was not being passed correctly. +- **📝 Prompt Input Clearing**: Resolved prompt input textarea not being cleared right away, ensuring a clean slate for new inputs. +- **✨ Various UI Styling Issues**: Fixed numerous user interface styling problems for a more cohesive look. +- **👥 Active Users Display**: Fixed active users showing active sessions instead of actual users, now reflecting accurate user activity. +- **🌐 Community Platform Compatibility**: The Community Platform is back online and fully compatible with Open WebUI. + +### Changed + +- **📝 RAG Implementation**: Updated the RAG (Retrieval-Augmented Generation) implementation to use a system prompt for context, instead of overriding the user's prompt. +- **🔄 Settings Relocation**: Moved Models, Connections, Audio, and Images settings to the admin settings for better organization. +- **✍️ Improved Title Generation**: Enhanced the default prompt for title generation, yielding better results. +- **🔧 Backend Task Management**: Tasks like title generation and search query generation are now managed on the backend side and controlled only by the admin. +- **🔍 Editable Search Query Prompt**: You can now edit the search query generation prompt, offering more control over how queries are generated. +- **📏 Prompt Length Threshold**: Set the prompt length threshold for search query generation from the admin settings, giving more customization options. +- **📣 Settings Consolidation**: Merged the Banners admin setting with the Interface admin setting for a more streamlined settings area. + ## [0.2.5] - 2024-06-05 ### Added diff --git a/README.md b/README.md index a8d79bd5c..7a6df2592 100644 --- a/README.md +++ b/README.md @@ -146,10 +146,19 @@ docker run --rm --volume /var/run/docker.sock:/var/run/docker.sock containrrr/wa In the last part of the command, replace `open-webui` with your container name if it is different. -### Moving from Ollama WebUI to Open WebUI - Check our Migration Guide available in our [Open WebUI Documentation](https://docs.openwebui.com/migration/). +### Using the Dev Branch 🌙 + +> [!WARNING] +> The `:dev` branch contains the latest unstable features and changes. Use it at your own risk as it may have bugs or incomplete features. + +If you want to try out the latest bleeding-edge features and are okay with occasional instability, you can use the `:dev` tag like this: + +```bash +docker run -d -p 3000:8080 -v open-webui:/app/backend/data --name open-webui --restart always ghcr.io/open-webui/open-webui:dev +``` + ## What's Next? 🌟 Discover upcoming features on our roadmap in the [Open WebUI Documentation](https://docs.openwebui.com/roadmap/). diff --git a/backend/apps/audio/main.py b/backend/apps/audio/main.py index 0f65a551e..663e20c97 100644 --- a/backend/apps/audio/main.py +++ b/backend/apps/audio/main.py @@ -17,13 +17,12 @@ from fastapi.middleware.cors import CORSMiddleware from faster_whisper import WhisperModel from pydantic import BaseModel - +import uuid import requests import hashlib from pathlib import Path import json - from constants import ERROR_MESSAGES from utils.utils import ( decode_token, @@ -41,10 +40,15 @@ from config import ( WHISPER_MODEL_DIR, WHISPER_MODEL_AUTO_UPDATE, DEVICE_TYPE, - AUDIO_OPENAI_API_BASE_URL, - AUDIO_OPENAI_API_KEY, - AUDIO_OPENAI_API_MODEL, - AUDIO_OPENAI_API_VOICE, + AUDIO_STT_OPENAI_API_BASE_URL, + AUDIO_STT_OPENAI_API_KEY, + AUDIO_TTS_OPENAI_API_BASE_URL, + AUDIO_TTS_OPENAI_API_KEY, + AUDIO_STT_ENGINE, + AUDIO_STT_MODEL, + AUDIO_TTS_ENGINE, + AUDIO_TTS_MODEL, + AUDIO_TTS_VOICE, AppConfig, ) @@ -61,10 +65,17 @@ app.add_middleware( ) app.state.config = AppConfig() -app.state.config.OPENAI_API_BASE_URL = AUDIO_OPENAI_API_BASE_URL -app.state.config.OPENAI_API_KEY = AUDIO_OPENAI_API_KEY -app.state.config.OPENAI_API_MODEL = AUDIO_OPENAI_API_MODEL -app.state.config.OPENAI_API_VOICE = AUDIO_OPENAI_API_VOICE + +app.state.config.STT_OPENAI_API_BASE_URL = AUDIO_STT_OPENAI_API_BASE_URL +app.state.config.STT_OPENAI_API_KEY = AUDIO_STT_OPENAI_API_KEY +app.state.config.STT_ENGINE = AUDIO_STT_ENGINE +app.state.config.STT_MODEL = AUDIO_STT_MODEL + +app.state.config.TTS_OPENAI_API_BASE_URL = AUDIO_TTS_OPENAI_API_BASE_URL +app.state.config.TTS_OPENAI_API_KEY = AUDIO_TTS_OPENAI_API_KEY +app.state.config.TTS_ENGINE = AUDIO_TTS_ENGINE +app.state.config.TTS_MODEL = AUDIO_TTS_MODEL +app.state.config.TTS_VOICE = AUDIO_TTS_VOICE # setting device type for whisper model whisper_device_type = DEVICE_TYPE if DEVICE_TYPE and DEVICE_TYPE == "cuda" else "cpu" @@ -74,41 +85,101 @@ SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/") SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True) -class OpenAIConfigUpdateForm(BaseModel): - url: str - key: str - model: str - speaker: str +class TTSConfigForm(BaseModel): + OPENAI_API_BASE_URL: str + OPENAI_API_KEY: str + ENGINE: str + MODEL: str + VOICE: str + + +class STTConfigForm(BaseModel): + OPENAI_API_BASE_URL: str + OPENAI_API_KEY: str + ENGINE: str + MODEL: str + + +class AudioConfigUpdateForm(BaseModel): + tts: TTSConfigForm + stt: STTConfigForm + + +from pydub import AudioSegment +from pydub.utils import mediainfo + + +def is_mp4_audio(file_path): + """Check if the given file is an MP4 audio file.""" + if not os.path.isfile(file_path): + print(f"File not found: {file_path}") + return False + + info = mediainfo(file_path) + if ( + info.get("codec_name") == "aac" + and info.get("codec_type") == "audio" + and info.get("codec_tag_string") == "mp4a" + ): + return True + return False + + +def convert_mp4_to_wav(file_path, output_path): + """Convert MP4 audio file to WAV format.""" + audio = AudioSegment.from_file(file_path, format="mp4") + audio.export(output_path, format="wav") + print(f"Converted {file_path} to {output_path}") @app.get("/config") -async def get_openai_config(user=Depends(get_admin_user)): +async def get_audio_config(user=Depends(get_admin_user)): return { - "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL, - "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY, - "OPENAI_API_MODEL": app.state.config.OPENAI_API_MODEL, - "OPENAI_API_VOICE": app.state.config.OPENAI_API_VOICE, + "tts": { + "OPENAI_API_BASE_URL": app.state.config.TTS_OPENAI_API_BASE_URL, + "OPENAI_API_KEY": app.state.config.TTS_OPENAI_API_KEY, + "ENGINE": app.state.config.TTS_ENGINE, + "MODEL": app.state.config.TTS_MODEL, + "VOICE": app.state.config.TTS_VOICE, + }, + "stt": { + "OPENAI_API_BASE_URL": app.state.config.STT_OPENAI_API_BASE_URL, + "OPENAI_API_KEY": app.state.config.STT_OPENAI_API_KEY, + "ENGINE": app.state.config.STT_ENGINE, + "MODEL": app.state.config.STT_MODEL, + }, } @app.post("/config/update") -async def update_openai_config( - form_data: OpenAIConfigUpdateForm, user=Depends(get_admin_user) +async def update_audio_config( + form_data: AudioConfigUpdateForm, user=Depends(get_admin_user) ): - if form_data.key == "": - raise HTTPException(status_code=400, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND) + app.state.config.TTS_OPENAI_API_BASE_URL = form_data.tts.OPENAI_API_BASE_URL + app.state.config.TTS_OPENAI_API_KEY = form_data.tts.OPENAI_API_KEY + app.state.config.TTS_ENGINE = form_data.tts.ENGINE + app.state.config.TTS_MODEL = form_data.tts.MODEL + app.state.config.TTS_VOICE = form_data.tts.VOICE - app.state.config.OPENAI_API_BASE_URL = form_data.url - app.state.config.OPENAI_API_KEY = form_data.key - app.state.config.OPENAI_API_MODEL = form_data.model - app.state.config.OPENAI_API_VOICE = form_data.speaker + app.state.config.STT_OPENAI_API_BASE_URL = form_data.stt.OPENAI_API_BASE_URL + app.state.config.STT_OPENAI_API_KEY = form_data.stt.OPENAI_API_KEY + app.state.config.STT_ENGINE = form_data.stt.ENGINE + app.state.config.STT_MODEL = form_data.stt.MODEL return { - "status": True, - "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL, - "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY, - "OPENAI_API_MODEL": app.state.config.OPENAI_API_MODEL, - "OPENAI_API_VOICE": app.state.config.OPENAI_API_VOICE, + "tts": { + "OPENAI_API_BASE_URL": app.state.config.TTS_OPENAI_API_BASE_URL, + "OPENAI_API_KEY": app.state.config.TTS_OPENAI_API_KEY, + "ENGINE": app.state.config.TTS_ENGINE, + "MODEL": app.state.config.TTS_MODEL, + "VOICE": app.state.config.TTS_VOICE, + }, + "stt": { + "OPENAI_API_BASE_URL": app.state.config.STT_OPENAI_API_BASE_URL, + "OPENAI_API_KEY": app.state.config.STT_OPENAI_API_KEY, + "ENGINE": app.state.config.STT_ENGINE, + "MODEL": app.state.config.STT_MODEL, + }, } @@ -125,13 +196,21 @@ async def speech(request: Request, user=Depends(get_verified_user)): return FileResponse(file_path) headers = {} - headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEY}" + headers["Authorization"] = f"Bearer {app.state.config.TTS_OPENAI_API_KEY}" headers["Content-Type"] = "application/json" + try: + body = body.decode("utf-8") + body = json.loads(body) + body["model"] = app.state.config.TTS_MODEL + body = json.dumps(body).encode("utf-8") + except Exception as e: + pass + r = None try: r = requests.post( - url=f"{app.state.config.OPENAI_API_BASE_URL}/audio/speech", + url=f"{app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech", data=body, headers=headers, stream=True, @@ -181,41 +260,110 @@ def transcribe( ) try: - filename = file.filename - file_path = f"{UPLOAD_DIR}/{filename}" + ext = file.filename.split(".")[-1] + + id = uuid.uuid4() + filename = f"{id}.{ext}" + + file_dir = f"{CACHE_DIR}/audio/transcriptions" + os.makedirs(file_dir, exist_ok=True) + file_path = f"{file_dir}/{filename}" + + print(filename) + contents = file.file.read() with open(file_path, "wb") as f: f.write(contents) f.close() - whisper_kwargs = { - "model_size_or_path": WHISPER_MODEL, - "device": whisper_device_type, - "compute_type": "int8", - "download_root": WHISPER_MODEL_DIR, - "local_files_only": not WHISPER_MODEL_AUTO_UPDATE, - } + if app.state.config.STT_ENGINE == "": + whisper_kwargs = { + "model_size_or_path": WHISPER_MODEL, + "device": whisper_device_type, + "compute_type": "int8", + "download_root": WHISPER_MODEL_DIR, + "local_files_only": not WHISPER_MODEL_AUTO_UPDATE, + } - log.debug(f"whisper_kwargs: {whisper_kwargs}") + log.debug(f"whisper_kwargs: {whisper_kwargs}") - try: - model = WhisperModel(**whisper_kwargs) - except: - log.warning( - "WhisperModel initialization failed, attempting download with local_files_only=False" + try: + model = WhisperModel(**whisper_kwargs) + except: + log.warning( + "WhisperModel initialization failed, attempting download with local_files_only=False" + ) + whisper_kwargs["local_files_only"] = False + model = WhisperModel(**whisper_kwargs) + + segments, info = model.transcribe(file_path, beam_size=5) + log.info( + "Detected language '%s' with probability %f" + % (info.language, info.language_probability) ) - whisper_kwargs["local_files_only"] = False - model = WhisperModel(**whisper_kwargs) - segments, info = model.transcribe(file_path, beam_size=5) - log.info( - "Detected language '%s' with probability %f" - % (info.language, info.language_probability) - ) + transcript = "".join([segment.text for segment in list(segments)]) - transcript = "".join([segment.text for segment in list(segments)]) + data = {"text": transcript.strip()} - return {"text": transcript.strip()} + # save the transcript to a json file + transcript_file = f"{file_dir}/{id}.json" + with open(transcript_file, "w") as f: + json.dump(data, f) + + print(data) + + return data + + elif app.state.config.STT_ENGINE == "openai": + if is_mp4_audio(file_path): + print("is_mp4_audio") + os.rename(file_path, file_path.replace(".wav", ".mp4")) + # Convert MP4 audio file to WAV format + convert_mp4_to_wav(file_path.replace(".wav", ".mp4"), file_path) + + headers = {"Authorization": f"Bearer {app.state.config.STT_OPENAI_API_KEY}"} + + files = {"file": (filename, open(file_path, "rb"))} + data = {"model": "whisper-1"} + + print(files, data) + + r = None + try: + r = requests.post( + url=f"{app.state.config.STT_OPENAI_API_BASE_URL}/audio/transcriptions", + headers=headers, + files=files, + data=data, + ) + + r.raise_for_status() + + data = r.json() + + # save the transcript to a json file + transcript_file = f"{file_dir}/{id}.json" + with open(transcript_file, "w") as f: + json.dump(data, f) + + print(data) + return data + except Exception as e: + log.exception(e) + error_detail = "Open WebUI: Server Connection Error" + if r is not None: + try: + res = r.json() + if "error" in res: + error_detail = f"External: {res['error']['message']}" + except: + error_detail = f"External: {e}" + + raise HTTPException( + status_code=r.status_code if r != None else 500, + detail=error_detail, + ) except Exception as e: log.exception(e) diff --git a/backend/apps/ollama/main.py b/backend/apps/ollama/main.py index 82cd8d383..144755418 100644 --- a/backend/apps/ollama/main.py +++ b/backend/apps/ollama/main.py @@ -41,8 +41,6 @@ from utils.utils import ( get_admin_user, ) -from utils.models import get_model_id_from_custom_model_id - from config import ( SRC_LOG_LEVELS, @@ -728,7 +726,6 @@ async def generate_chat_completion( model_info = Models.get_model_by_id(model_id) if model_info: - print(model_info) if model_info.base_model_id: payload["model"] = model_info.base_model_id @@ -764,7 +761,7 @@ async def generate_chat_completion( "frequency_penalty", None ) - if model_info.params.get("temperature", None): + if model_info.params.get("temperature", None) is not None: payload["options"]["temperature"] = model_info.params.get( "temperature", None ) @@ -849,9 +846,14 @@ async def generate_chat_completion( # TODO: we should update this part once Ollama supports other types +class OpenAIChatMessageContent(BaseModel): + type: str + model_config = ConfigDict(extra="allow") + + class OpenAIChatMessage(BaseModel): role: str - content: str + content: Union[str, OpenAIChatMessageContent] model_config = ConfigDict(extra="allow") @@ -879,7 +881,6 @@ async def generate_openai_chat_completion( model_info = Models.get_model_by_id(model_id) if model_info: - print(model_info) if model_info.base_model_id: payload["model"] = model_info.base_model_id diff --git a/backend/apps/openai/main.py b/backend/apps/openai/main.py index 472699f1d..93f913dea 100644 --- a/backend/apps/openai/main.py +++ b/backend/apps/openai/main.py @@ -345,113 +345,97 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_use ) -@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) -async def proxy(path: str, request: Request, user=Depends(get_verified_user)): +@app.post("/chat/completions") +@app.post("/chat/completions/{url_idx}") +async def generate_chat_completion( + form_data: dict, + url_idx: Optional[int] = None, + user=Depends(get_verified_user), +): idx = 0 + payload = {**form_data} - body = await request.body() - # TODO: Remove below after gpt-4-vision fix from Open AI - # Try to decode the body of the request from bytes to a UTF-8 string (Require add max_token to fix gpt-4-vision) + model_id = form_data.get("model") + model_info = Models.get_model_by_id(model_id) - payload = None + if model_info: + if model_info.base_model_id: + payload["model"] = model_info.base_model_id - try: - if "chat/completions" in path: - body = body.decode("utf-8") - body = json.loads(body) + model_info.params = model_info.params.model_dump() - payload = {**body} + if model_info.params: + if model_info.params.get("temperature", None) is not None: + payload["temperature"] = float(model_info.params.get("temperature")) - model_id = body.get("model") - model_info = Models.get_model_by_id(model_id) + if model_info.params.get("top_p", None): + payload["top_p"] = int(model_info.params.get("top_p", None)) - if model_info: - print(model_info) - if model_info.base_model_id: - payload["model"] = model_info.base_model_id + if model_info.params.get("max_tokens", None): + payload["max_tokens"] = int(model_info.params.get("max_tokens", None)) - model_info.params = model_info.params.model_dump() + if model_info.params.get("frequency_penalty", None): + payload["frequency_penalty"] = int( + model_info.params.get("frequency_penalty", None) + ) - if model_info.params: - if model_info.params.get("temperature", None): - payload["temperature"] = int( - model_info.params.get("temperature") + if model_info.params.get("seed", None): + payload["seed"] = model_info.params.get("seed", None) + + if model_info.params.get("stop", None): + payload["stop"] = ( + [ + bytes(stop, "utf-8").decode("unicode_escape") + for stop in model_info.params["stop"] + ] + if model_info.params.get("stop", None) + else None + ) + + if model_info.params.get("system", None): + # Check if the payload already has a system message + # If not, add a system message to the payload + if payload.get("messages"): + for message in payload["messages"]: + if message.get("role") == "system": + message["content"] = ( + model_info.params.get("system", None) + message["content"] ) + break + else: + payload["messages"].insert( + 0, + { + "role": "system", + "content": model_info.params.get("system", None), + }, + ) - if model_info.params.get("top_p", None): - payload["top_p"] = int(model_info.params.get("top_p", None)) + else: + pass - if model_info.params.get("max_tokens", None): - payload["max_tokens"] = int( - model_info.params.get("max_tokens", None) - ) + model = app.state.MODELS[payload.get("model")] + idx = model["urlIdx"] - if model_info.params.get("frequency_penalty", None): - payload["frequency_penalty"] = int( - model_info.params.get("frequency_penalty", None) - ) + if "pipeline" in model and model.get("pipeline"): + payload["user"] = {"name": user.name, "id": user.id} - if model_info.params.get("seed", None): - payload["seed"] = model_info.params.get("seed", None) + # Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000 + # This is a workaround until OpenAI fixes the issue with this model + if payload.get("model") == "gpt-4-vision-preview": + if "max_tokens" not in payload: + payload["max_tokens"] = 4000 + log.debug("Modified payload:", payload) - if model_info.params.get("stop", None): - payload["stop"] = ( - [ - bytes(stop, "utf-8").decode("unicode_escape") - for stop in model_info.params["stop"] - ] - if model_info.params.get("stop", None) - else None - ) - - if model_info.params.get("system", None): - # Check if the payload already has a system message - # If not, add a system message to the payload - if payload.get("messages"): - for message in payload["messages"]: - if message.get("role") == "system": - message["content"] = ( - model_info.params.get("system", None) - + message["content"] - ) - break - else: - payload["messages"].insert( - 0, - { - "role": "system", - "content": model_info.params.get("system", None), - }, - ) - else: - pass - - model = app.state.MODELS[payload.get("model")] - - idx = model["urlIdx"] - - if "pipeline" in model and model.get("pipeline"): - payload["user"] = {"name": user.name, "id": user.id} - - # Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000 - # This is a workaround until OpenAI fixes the issue with this model - if payload.get("model") == "gpt-4-vision-preview": - if "max_tokens" not in payload: - payload["max_tokens"] = 4000 - log.debug("Modified payload:", payload) - - # Convert the modified body back to JSON - payload = json.dumps(payload) - - except json.JSONDecodeError as e: - log.error("Error loading request body into a dictionary:", e) + # Convert the modified body back to JSON + payload = json.dumps(payload) print(payload) url = app.state.config.OPENAI_API_BASE_URLS[idx] key = app.state.config.OPENAI_API_KEYS[idx] - target_url = f"{url}/{path}" + print(payload) headers = {} headers["Authorization"] = f"Bearer {key}" @@ -464,9 +448,72 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)): try: session = aiohttp.ClientSession(trust_env=True) r = await session.request( - method=request.method, - url=target_url, - data=payload if payload else body, + method="POST", + url=f"{url}/chat/completions", + data=payload, + headers=headers, + ) + + r.raise_for_status() + + # Check if response is SSE + if "text/event-stream" in r.headers.get("Content-Type", ""): + streaming = True + return StreamingResponse( + r.content, + status_code=r.status, + headers=dict(r.headers), + background=BackgroundTask( + cleanup_response, response=r, session=session + ), + ) + else: + response_data = await r.json() + return response_data + except Exception as e: + log.exception(e) + error_detail = "Open WebUI: Server Connection Error" + if r is not None: + try: + res = await r.json() + print(res) + if "error" in res: + error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}" + except: + error_detail = f"External: {e}" + raise HTTPException(status_code=r.status if r else 500, detail=error_detail) + finally: + if not streaming and session: + if r: + r.close() + await session.close() + + +@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) +async def proxy(path: str, request: Request, user=Depends(get_verified_user)): + idx = 0 + + body = await request.body() + + url = app.state.config.OPENAI_API_BASE_URLS[idx] + key = app.state.config.OPENAI_API_KEYS[idx] + + target_url = f"{url}/{path}" + + headers = {} + headers["Authorization"] = f"Bearer {key}" + headers["Content-Type"] = "application/json" + + r = None + session = None + streaming = False + + try: + session = aiohttp.ClientSession(trust_env=True) + r = await session.request( + method=request.method, + url=target_url, + data=body, headers=headers, ) diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index d405ef0b4..8816321b3 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -9,6 +9,7 @@ from fastapi import ( ) from fastapi.middleware.cors import CORSMiddleware import os, shutil, logging, re +from datetime import datetime from pathlib import Path from typing import List, Union, Sequence @@ -30,6 +31,7 @@ from langchain_community.document_loaders import ( UnstructuredExcelLoader, UnstructuredPowerPointLoader, YoutubeLoader, + OutlookMessageLoader, ) from langchain.text_splitter import RecursiveCharacterTextSplitter @@ -879,6 +881,13 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b texts = [doc.page_content for doc in docs] metadatas = [doc.metadata for doc in docs] + # ChromaDB does not like datetime formats + # for meta-data so convert them to string. + for metadata in metadatas: + for key, value in metadata.items(): + if isinstance(value, datetime): + metadata[key] = str(value) + try: if overwrite: for collection in CHROMA_CLIENT.list_collections(): @@ -965,6 +974,7 @@ def get_loader(filename: str, file_content_type: str, file_path: str): "swift", "vue", "svelte", + "msg", ] if file_ext == "pdf": @@ -999,6 +1009,8 @@ def get_loader(filename: str, file_content_type: str, file_path: str): "application/vnd.openxmlformats-officedocument.presentationml.presentation", ] or file_ext in ["ppt", "pptx"]: loader = UnstructuredPowerPointLoader(file_path) + elif file_ext == "msg": + loader = OutlookMessageLoader(file_path) elif file_ext in known_source_ext or ( file_content_type and file_content_type.find("text/") >= 0 ): diff --git a/backend/apps/rag/utils.py b/backend/apps/rag/utils.py index ac52dc3d8..7d92dd10f 100644 --- a/backend/apps/rag/utils.py +++ b/backend/apps/rag/utils.py @@ -20,7 +20,7 @@ from langchain.retrievers import ( from typing import Optional - +from utils.misc import get_last_user_message, add_or_update_system_message from config import SRC_LOG_LEVELS, CHROMA_CLIENT log = logging.getLogger(__name__) @@ -247,31 +247,7 @@ def rag_messages( hybrid_search, ): log.debug(f"docs: {docs} {messages} {embedding_function} {reranking_function}") - - last_user_message_idx = None - for i in range(len(messages) - 1, -1, -1): - if messages[i]["role"] == "user": - last_user_message_idx = i - break - - user_message = messages[last_user_message_idx] - - if isinstance(user_message["content"], list): - # Handle list content input - content_type = "list" - query = "" - for content_item in user_message["content"]: - if content_item["type"] == "text": - query = content_item["text"] - break - elif isinstance(user_message["content"], str): - # Handle text content input - content_type = "text" - query = user_message["content"] - else: - # Fallback in case the input does not match expected types - content_type = None - query = "" + query = get_last_user_message(messages) extracted_collections = [] relevant_contexts = [] @@ -349,24 +325,7 @@ def rag_messages( ) log.debug(f"ra_content: {ra_content}") - - if content_type == "list": - new_content = [] - for content_item in user_message["content"]: - if content_item["type"] == "text": - # Update the text item's content with ra_content - new_content.append({"type": "text", "text": ra_content}) - else: - # Keep other types of content as they are - new_content.append(content_item) - new_user_message = {**user_message, "content": new_content} - else: - new_user_message = { - **user_message, - "content": ra_content, - } - - messages[last_user_message_idx] = new_user_message + messages = add_or_update_system_message(ra_content, messages) return messages, citations diff --git a/backend/apps/socket/main.py b/backend/apps/socket/main.py index 0bc45287a..e70812867 100644 --- a/backend/apps/socket/main.py +++ b/backend/apps/socket/main.py @@ -10,7 +10,7 @@ app = socketio.ASGIApp(sio, socketio_path="/ws/socket.io") # Dictionary to maintain the user pool - +SESSION_POOL = {} USER_POOL = {} USAGE_POOL = {} # Timeout duration in seconds @@ -29,7 +29,12 @@ async def connect(sid, environ, auth): user = Users.get_user_by_id(data["id"]) if user: - USER_POOL[sid] = user.id + SESSION_POOL[sid] = user.id + if user.id in USER_POOL: + USER_POOL[user.id].append(sid) + else: + USER_POOL[user.id] = [sid] + print(f"user {user.name}({user.id}) connected with session ID {sid}") print(len(set(USER_POOL))) @@ -50,7 +55,13 @@ async def user_join(sid, data): user = Users.get_user_by_id(data["id"]) if user: - USER_POOL[sid] = user.id + + SESSION_POOL[sid] = user.id + if user.id in USER_POOL: + USER_POOL[user.id].append(sid) + else: + USER_POOL[user.id] = [sid] + print(f"user {user.name}({user.id}) connected with session ID {sid}") print(len(set(USER_POOL))) @@ -123,9 +134,17 @@ async def remove_after_timeout(sid, model_id): @sio.event async def disconnect(sid): - if sid in USER_POOL: - disconnected_user = USER_POOL.pop(sid) - print(f"user {disconnected_user} disconnected with session ID {sid}") + if sid in SESSION_POOL: + user_id = SESSION_POOL[sid] + del SESSION_POOL[sid] + + USER_POOL[user_id].remove(sid) + + if len(USER_POOL[user_id]) == 0: + del USER_POOL[user_id] + + print(f"user {user_id} disconnected with session ID {sid}") + print(USER_POOL) await sio.emit("user-count", {"count": len(USER_POOL)}) else: diff --git a/backend/config.py b/backend/config.py index dd3bc9e4b..27c4c1277 100644 --- a/backend/config.py +++ b/backend/config.py @@ -306,7 +306,10 @@ STATIC_DIR = Path(os.getenv("STATIC_DIR", BACKEND_DIR / "static")).resolve() frontend_favicon = FRONTEND_BUILD_DIR / "favicon.png" if frontend_favicon.exists(): - shutil.copyfile(frontend_favicon, STATIC_DIR / "favicon.png") + try: + shutil.copyfile(frontend_favicon, STATIC_DIR / "favicon.png") + except PermissionError: + logging.error(f"No write permission to {STATIC_DIR / 'favicon.png'}") else: logging.warning(f"Frontend favicon not found at {frontend_favicon}") @@ -615,6 +618,66 @@ ADMIN_EMAIL = PersistentConfig( ) +#################################### +# TASKS +#################################### + + +TASK_MODEL = PersistentConfig( + "TASK_MODEL", + "task.model.default", + os.environ.get("TASK_MODEL", ""), +) + +TASK_MODEL_EXTERNAL = PersistentConfig( + "TASK_MODEL_EXTERNAL", + "task.model.external", + os.environ.get("TASK_MODEL_EXTERNAL", ""), +) + +TITLE_GENERATION_PROMPT_TEMPLATE = PersistentConfig( + "TITLE_GENERATION_PROMPT_TEMPLATE", + "task.title.prompt_template", + os.environ.get( + "TITLE_GENERATION_PROMPT_TEMPLATE", + """Here is the query: +{{prompt:middletruncate:8000}} + +Create a concise, 3-5 word phrase with an emoji as a title for the previous query. Suitable Emojis for the summary can be used to enhance understanding but avoid quotation marks or special formatting. RESPOND ONLY WITH THE TITLE TEXT. + +Examples of titles: +📉 Stock Market Trends +🍪 Perfect Chocolate Chip Recipe +Evolution of Music Streaming +Remote Work Productivity Tips +Artificial Intelligence in Healthcare +🎮 Video Game Development Insights""", + ), +) + + +SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = PersistentConfig( + "SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE", + "task.search.prompt_template", + os.environ.get( + "SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE", + """You are tasked with generating web search queries. Give me an appropriate query to answer my question for google search. Answer with only the query. Today is {{CURRENT_DATE}}. + +Question: +{{prompt:end:4000}}""", + ), +) + + +SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = PersistentConfig( + "SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD", + "task.search.prompt_length_threshold", + os.environ.get( + "SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD", + 100, + ), +) + #################################### # WEBUI_SECRET_KEY #################################### @@ -933,25 +996,59 @@ IMAGE_GENERATION_MODEL = PersistentConfig( # Audio #################################### -AUDIO_OPENAI_API_BASE_URL = PersistentConfig( - "AUDIO_OPENAI_API_BASE_URL", - "audio.openai.api_base_url", - os.getenv("AUDIO_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL), +AUDIO_STT_OPENAI_API_BASE_URL = PersistentConfig( + "AUDIO_STT_OPENAI_API_BASE_URL", + "audio.stt.openai.api_base_url", + os.getenv("AUDIO_STT_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL), ) -AUDIO_OPENAI_API_KEY = PersistentConfig( - "AUDIO_OPENAI_API_KEY", - "audio.openai.api_key", - os.getenv("AUDIO_OPENAI_API_KEY", OPENAI_API_KEY), + +AUDIO_STT_OPENAI_API_KEY = PersistentConfig( + "AUDIO_STT_OPENAI_API_KEY", + "audio.stt.openai.api_key", + os.getenv("AUDIO_STT_OPENAI_API_KEY", OPENAI_API_KEY), ) -AUDIO_OPENAI_API_MODEL = PersistentConfig( - "AUDIO_OPENAI_API_MODEL", - "audio.openai.api_model", - os.getenv("AUDIO_OPENAI_API_MODEL", "tts-1"), + +AUDIO_STT_ENGINE = PersistentConfig( + "AUDIO_STT_ENGINE", + "audio.stt.engine", + os.getenv("AUDIO_STT_ENGINE", ""), ) -AUDIO_OPENAI_API_VOICE = PersistentConfig( - "AUDIO_OPENAI_API_VOICE", - "audio.openai.api_voice", - os.getenv("AUDIO_OPENAI_API_VOICE", "alloy"), + +AUDIO_STT_MODEL = PersistentConfig( + "AUDIO_STT_MODEL", + "audio.stt.model", + os.getenv("AUDIO_STT_MODEL", "whisper-1"), +) + +AUDIO_TTS_OPENAI_API_BASE_URL = PersistentConfig( + "AUDIO_TTS_OPENAI_API_BASE_URL", + "audio.tts.openai.api_base_url", + os.getenv("AUDIO_TTS_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL), +) +AUDIO_TTS_OPENAI_API_KEY = PersistentConfig( + "AUDIO_TTS_OPENAI_API_KEY", + "audio.tts.openai.api_key", + os.getenv("AUDIO_TTS_OPENAI_API_KEY", OPENAI_API_KEY), +) + + +AUDIO_TTS_ENGINE = PersistentConfig( + "AUDIO_TTS_ENGINE", + "audio.tts.engine", + os.getenv("AUDIO_TTS_ENGINE", ""), +) + + +AUDIO_TTS_MODEL = PersistentConfig( + "AUDIO_TTS_MODEL", + "audio.tts.model", + os.getenv("AUDIO_TTS_MODEL", "tts-1"), +) + +AUDIO_TTS_VOICE = PersistentConfig( + "AUDIO_TTS_VOICE", + "audio.tts.voice", + os.getenv("AUDIO_TTS_VOICE", "alloy"), ) diff --git a/backend/main.py b/backend/main.py index 4ab13e98f..99b409983 100644 --- a/backend/main.py +++ b/backend/main.py @@ -9,8 +9,11 @@ import logging import aiohttp import requests import mimetypes +import shutil +import os +import asyncio -from fastapi import FastAPI, Request, Depends, status +from fastapi import FastAPI, Request, Depends, status, UploadFile, File, Form from fastapi.staticfiles import StaticFiles from fastapi.responses import JSONResponse from fastapi import HTTPException @@ -22,15 +25,24 @@ from starlette.responses import StreamingResponse, Response from apps.socket.main import app as socket_app -from apps.ollama.main import app as ollama_app, get_all_models as get_ollama_models -from apps.openai.main import app as openai_app, get_all_models as get_openai_models +from apps.ollama.main import ( + app as ollama_app, + OpenAIChatCompletionForm, + get_all_models as get_ollama_models, + generate_openai_chat_completion as generate_ollama_chat_completion, +) +from apps.openai.main import ( + app as openai_app, + get_all_models as get_openai_models, + generate_chat_completion as generate_openai_chat_completion, +) from apps.audio.main import app as audio_app from apps.images.main import app as images_app from apps.rag.main import app as rag_app from apps.webui.main import app as webui_app -import asyncio + from pydantic import BaseModel from typing import List, Optional @@ -41,6 +53,8 @@ from utils.utils import ( get_current_user, get_http_authorization_cred, ) +from utils.task import title_generation_template, search_query_generation_template + from apps.rag.utils import rag_messages from config import ( @@ -62,8 +76,13 @@ from config import ( SRC_LOG_LEVELS, WEBHOOK_URL, ENABLE_ADMIN_EXPORT, - AppConfig, WEBUI_BUILD_HASH, + TASK_MODEL, + TASK_MODEL_EXTERNAL, + TITLE_GENERATION_PROMPT_TEMPLATE, + SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE, + SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD, + AppConfig, ) from constants import ERROR_MESSAGES @@ -117,10 +136,19 @@ app.state.config.ENABLE_OLLAMA_API = ENABLE_OLLAMA_API app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST - app.state.config.WEBHOOK_URL = WEBHOOK_URL +app.state.config.TASK_MODEL = TASK_MODEL +app.state.config.TASK_MODEL_EXTERNAL = TASK_MODEL_EXTERNAL +app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMPLATE +app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = ( + SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE +) +app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = ( + SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD +) + app.state.MODELS = {} origins = ["*"] @@ -228,6 +256,78 @@ class RAGMiddleware(BaseHTTPMiddleware): app.add_middleware(RAGMiddleware) +def filter_pipeline(payload, user): + user = {"id": user.id, "name": user.name, "role": user.role} + model_id = payload["model"] + filters = [ + model + for model in app.state.MODELS.values() + if "pipeline" in model + and "type" in model["pipeline"] + and model["pipeline"]["type"] == "filter" + and ( + model["pipeline"]["pipelines"] == ["*"] + or any( + model_id == target_model_id + for target_model_id in model["pipeline"]["pipelines"] + ) + ) + ] + sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"]) + + model = app.state.MODELS[model_id] + + if "pipeline" in model: + sorted_filters.append(model) + + for filter in sorted_filters: + r = None + try: + urlIdx = filter["urlIdx"] + + url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] + key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] + + if key != "": + headers = {"Authorization": f"Bearer {key}"} + r = requests.post( + f"{url}/{filter['id']}/filter/inlet", + headers=headers, + json={ + "user": user, + "body": payload, + }, + ) + + r.raise_for_status() + payload = r.json() + except Exception as e: + # Handle connection error here + print(f"Connection error: {e}") + + if r is not None: + try: + res = r.json() + if "detail" in res: + return JSONResponse( + status_code=r.status_code, + content=res, + ) + except: + pass + + else: + pass + + if "pipeline" not in app.state.MODELS[model_id]: + if "chat_id" in payload: + del payload["chat_id"] + + if "title" in payload: + del payload["title"] + return payload + + class PipelineMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): if request.method == "POST" and ( @@ -243,85 +343,10 @@ class PipelineMiddleware(BaseHTTPMiddleware): # Parse string to JSON data = json.loads(body_str) if body_str else {} - model_id = data["model"] - filters = [ - model - for model in app.state.MODELS.values() - if "pipeline" in model - and "type" in model["pipeline"] - and model["pipeline"]["type"] == "filter" - and ( - model["pipeline"]["pipelines"] == ["*"] - or any( - model_id == target_model_id - for target_model_id in model["pipeline"]["pipelines"] - ) - ) - ] - sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"]) - - user = None - if len(sorted_filters) > 0: - try: - user = get_current_user( - get_http_authorization_cred( - request.headers.get("Authorization") - ) - ) - user = {"id": user.id, "name": user.name, "role": user.role} - except: - pass - - model = app.state.MODELS[model_id] - - if "pipeline" in model: - sorted_filters.append(model) - - for filter in sorted_filters: - r = None - try: - urlIdx = filter["urlIdx"] - - url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] - key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] - - if key != "": - headers = {"Authorization": f"Bearer {key}"} - r = requests.post( - f"{url}/{filter['id']}/filter/inlet", - headers=headers, - json={ - "user": user, - "body": data, - }, - ) - - r.raise_for_status() - data = r.json() - except Exception as e: - # Handle connection error here - print(f"Connection error: {e}") - - if r is not None: - try: - res = r.json() - if "detail" in res: - return JSONResponse( - status_code=r.status_code, - content=res, - ) - except: - pass - - else: - pass - - if "pipeline" not in app.state.MODELS[model_id]: - if "chat_id" in data: - del data["chat_id"] - - if "title" in data: - del data["title"] + user = get_current_user( + get_http_authorization_cred(request.headers.get("Authorization")) + ) + data = filter_pipeline(data, user) modified_body_bytes = json.dumps(data).encode("utf-8") # Replace the request body with the modified one @@ -482,6 +507,178 @@ async def get_models(user=Depends(get_verified_user)): return {"data": models} +@app.get("/api/task/config") +async def get_task_config(user=Depends(get_verified_user)): + return { + "TASK_MODEL": app.state.config.TASK_MODEL, + "TASK_MODEL_EXTERNAL": app.state.config.TASK_MODEL_EXTERNAL, + "TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE, + "SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE, + "SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD": app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD, + } + + +class TaskConfigForm(BaseModel): + TASK_MODEL: Optional[str] + TASK_MODEL_EXTERNAL: Optional[str] + TITLE_GENERATION_PROMPT_TEMPLATE: str + SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE: str + SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD: int + + +@app.post("/api/task/config/update") +async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_user)): + app.state.config.TASK_MODEL = form_data.TASK_MODEL + app.state.config.TASK_MODEL_EXTERNAL = form_data.TASK_MODEL_EXTERNAL + app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = ( + form_data.TITLE_GENERATION_PROMPT_TEMPLATE + ) + app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = ( + form_data.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE + ) + app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = ( + form_data.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD + ) + + return { + "TASK_MODEL": app.state.config.TASK_MODEL, + "TASK_MODEL_EXTERNAL": app.state.config.TASK_MODEL_EXTERNAL, + "TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE, + "SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE, + "SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD": app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD, + } + + +@app.post("/api/task/title/completions") +async def generate_title(form_data: dict, user=Depends(get_verified_user)): + print("generate_title") + + model_id = form_data["model"] + if model_id not in app.state.MODELS: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Model not found", + ) + + # Check if the user has a custom task model + # If the user has a custom task model, use that model + if app.state.MODELS[model_id]["owned_by"] == "ollama": + if app.state.config.TASK_MODEL: + task_model_id = app.state.config.TASK_MODEL + if task_model_id in app.state.MODELS: + model_id = task_model_id + else: + if app.state.config.TASK_MODEL_EXTERNAL: + task_model_id = app.state.config.TASK_MODEL_EXTERNAL + if task_model_id in app.state.MODELS: + model_id = task_model_id + + print(model_id) + model = app.state.MODELS[model_id] + + template = app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE + + content = title_generation_template( + template, form_data["prompt"], user.model_dump() + ) + + payload = { + "model": model_id, + "messages": [{"role": "user", "content": content}], + "stream": False, + "max_tokens": 50, + "chat_id": form_data.get("chat_id", None), + "title": True, + } + + print(payload) + payload = filter_pipeline(payload, user) + + if model["owned_by"] == "ollama": + return await generate_ollama_chat_completion( + OpenAIChatCompletionForm(**payload), user=user + ) + else: + return await generate_openai_chat_completion(payload, user=user) + + +@app.post("/api/task/query/completions") +async def generate_search_query(form_data: dict, user=Depends(get_verified_user)): + print("generate_search_query") + + if len(form_data["prompt"]) < app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Skip search query generation for short prompts (< {app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD} characters)", + ) + + model_id = form_data["model"] + if model_id not in app.state.MODELS: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Model not found", + ) + + # Check if the user has a custom task model + # If the user has a custom task model, use that model + if app.state.MODELS[model_id]["owned_by"] == "ollama": + if app.state.config.TASK_MODEL: + task_model_id = app.state.config.TASK_MODEL + if task_model_id in app.state.MODELS: + model_id = task_model_id + else: + if app.state.config.TASK_MODEL_EXTERNAL: + task_model_id = app.state.config.TASK_MODEL_EXTERNAL + if task_model_id in app.state.MODELS: + model_id = task_model_id + + print(model_id) + model = app.state.MODELS[model_id] + + template = app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE + + content = search_query_generation_template( + template, form_data["prompt"], user.model_dump() + ) + + payload = { + "model": model_id, + "messages": [{"role": "user", "content": content}], + "stream": False, + "max_tokens": 30, + } + + print(payload) + payload = filter_pipeline(payload, user) + + if model["owned_by"] == "ollama": + return await generate_ollama_chat_completion( + OpenAIChatCompletionForm(**payload), user=user + ) + else: + return await generate_openai_chat_completion(payload, user=user) + + +@app.post("/api/chat/completions") +async def generate_chat_completions(form_data: dict, user=Depends(get_verified_user)): + model_id = form_data["model"] + if model_id not in app.state.MODELS: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Model not found", + ) + + model = app.state.MODELS[model_id] + print(model) + + if model["owned_by"] == "ollama": + return await generate_ollama_chat_completion( + OpenAIChatCompletionForm(**form_data), user=user + ) + else: + return await generate_openai_chat_completion(form_data, user=user) + + @app.post("/api/chat/completed") async def chat_completed(form_data: dict, user=Depends(get_verified_user)): data = form_data @@ -574,6 +771,63 @@ async def get_pipelines_list(user=Depends(get_admin_user)): } +@app.post("/api/pipelines/upload") +async def upload_pipeline( + urlIdx: int = Form(...), file: UploadFile = File(...), user=Depends(get_admin_user) +): + print("upload_pipeline", urlIdx, file.filename) + # Check if the uploaded file is a python file + if not file.filename.endswith(".py"): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Only Python (.py) files are allowed.", + ) + + upload_folder = f"{CACHE_DIR}/pipelines" + os.makedirs(upload_folder, exist_ok=True) + file_path = os.path.join(upload_folder, file.filename) + + try: + # Save the uploaded file + with open(file_path, "wb") as buffer: + shutil.copyfileobj(file.file, buffer) + + url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] + key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] + + headers = {"Authorization": f"Bearer {key}"} + + with open(file_path, "rb") as f: + files = {"file": f} + r = requests.post(f"{url}/pipelines/upload", headers=headers, files=files) + + r.raise_for_status() + data = r.json() + + return {**data} + except Exception as e: + # Handle connection error here + print(f"Connection error: {e}") + + detail = "Pipeline not found" + if r is not None: + try: + res = r.json() + if "detail" in res: + detail = res["detail"] + except: + pass + + raise HTTPException( + status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND), + detail=detail, + ) + finally: + # Ensure the file is deleted after the upload is completed or on failure + if os.path.exists(file_path): + os.remove(file_path) + + class AddPipelineForm(BaseModel): url: str urlIdx: int @@ -840,6 +1094,15 @@ async def get_app_config(): "enable_community_sharing": webui_app.state.config.ENABLE_COMMUNITY_SHARING, "enable_admin_export": ENABLE_ADMIN_EXPORT, }, + "audio": { + "tts": { + "engine": audio_app.state.config.TTS_ENGINE, + "voice": audio_app.state.config.TTS_VOICE, + }, + "stt": { + "engine": audio_app.state.config.STT_ENGINE, + }, + }, } @@ -902,7 +1165,7 @@ async def get_app_changelog(): @app.get("/api/version/updates") async def get_app_latest_release_version(): try: - async with aiohttp.ClientSession() as session: + async with aiohttp.ClientSession(trust_env=True) as session: async with session.get( "https://api.github.com/repos/open-webui/open-webui/releases/latest" ) as response: diff --git a/backend/requirements.txt b/backend/requirements.txt index 7a3668428..38c0c6915 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -56,4 +56,7 @@ PyJWT[crypto]==2.8.0 black==24.4.2 langfuse==2.33.0 youtube-transcript-api==0.6.2 -pytube==15.0.0 \ No newline at end of file +pytube==15.0.0 + +extract_msg +pydub \ No newline at end of file diff --git a/backend/start.sh b/backend/start.sh index 15fc568d3..16a004e45 100755 --- a/backend/start.sh +++ b/backend/start.sh @@ -20,12 +20,12 @@ if test "$WEBUI_SECRET_KEY $WEBUI_JWT_SECRET_KEY" = " "; then WEBUI_SECRET_KEY=$(cat "$KEY_FILE") fi -if [ "$USE_OLLAMA_DOCKER" = "true" ]; then +if [[ "${USE_OLLAMA_DOCKER,,}" == "true" ]]; then echo "USE_OLLAMA is set to true, starting ollama serve." ollama serve & fi -if [ "$USE_CUDA_DOCKER" = "true" ]; then +if [[ "${USE_CUDA_DOCKER,,}" == "true" ]]; then echo "CUDA is enabled, appending LD_LIBRARY_PATH to include torch/cudnn & cublas libraries." export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/usr/local/lib/python3.11/site-packages/torch/lib:/usr/local/lib/python3.11/site-packages/nvidia/cudnn/lib" fi diff --git a/backend/utils/misc.py b/backend/utils/misc.py index 9069857b7..c3c65d3f5 100644 --- a/backend/utils/misc.py +++ b/backend/utils/misc.py @@ -3,7 +3,48 @@ import hashlib import json import re from datetime import timedelta -from typing import Optional +from typing import Optional, List + + +def get_last_user_message(messages: List[dict]) -> str: + for message in reversed(messages): + if message["role"] == "user": + if isinstance(message["content"], list): + for item in message["content"]: + if item["type"] == "text": + return item["text"] + return message["content"] + return None + + +def get_last_assistant_message(messages: List[dict]) -> str: + for message in reversed(messages): + if message["role"] == "assistant": + if isinstance(message["content"], list): + for item in message["content"]: + if item["type"] == "text": + return item["text"] + return message["content"] + return None + + +def add_or_update_system_message(content: str, messages: List[dict]): + """ + Adds a new system message at the beginning of the messages list + or updates the existing system message at the beginning. + + :param msg: The message to be added or appended. + :param messages: The list of message dictionaries. + :return: The updated list of message dictionaries. + """ + + if messages and messages[0].get("role") == "system": + messages[0]["content"] += f"{content}\n{messages[0]['content']}" + else: + # Insert at the beginning + messages.insert(0, {"role": "system", "content": content}) + + return messages def get_gravatar_url(email): @@ -193,8 +234,14 @@ def parse_ollama_modelfile(model_text): system_desc_match = re.search( r'SYSTEM\s+"""(.+?)"""', model_text, re.DOTALL | re.IGNORECASE ) + system_desc_match_single = re.search( + r"SYSTEM\s+([^\n]+)", model_text, re.IGNORECASE + ) + if system_desc_match: data["params"]["system"] = system_desc_match.group(1).strip() + elif system_desc_match_single: + data["params"]["system"] = system_desc_match_single.group(1).strip() # Parse messages messages = [] diff --git a/backend/utils/models.py b/backend/utils/models.py deleted file mode 100644 index c4d675d29..000000000 --- a/backend/utils/models.py +++ /dev/null @@ -1,10 +0,0 @@ -from apps.webui.models.models import Models, ModelModel, ModelForm, ModelResponse - - -def get_model_id_from_custom_model_id(id: str): - model = Models.get_model_by_id(id) - - if model: - return model.id - else: - return id diff --git a/backend/utils/task.py b/backend/utils/task.py new file mode 100644 index 000000000..2239de7df --- /dev/null +++ b/backend/utils/task.py @@ -0,0 +1,112 @@ +import re +import math + +from datetime import datetime +from typing import Optional + + +def prompt_template( + template: str, user_name: str = None, current_location: str = None +) -> str: + # Get the current date + current_date = datetime.now() + + # Format the date to YYYY-MM-DD + formatted_date = current_date.strftime("%Y-%m-%d") + + # Replace {{CURRENT_DATE}} in the template with the formatted date + template = template.replace("{{CURRENT_DATE}}", formatted_date) + + if user_name: + # Replace {{USER_NAME}} in the template with the user's name + template = template.replace("{{USER_NAME}}", user_name) + + if current_location: + # Replace {{CURRENT_LOCATION}} in the template with the current location + template = template.replace("{{CURRENT_LOCATION}}", current_location) + + return template + + +def title_generation_template( + template: str, prompt: str, user: Optional[dict] = None +) -> str: + def replacement_function(match): + full_match = match.group(0) + start_length = match.group(1) + end_length = match.group(2) + middle_length = match.group(3) + + if full_match == "{{prompt}}": + return prompt + elif start_length is not None: + return prompt[: int(start_length)] + elif end_length is not None: + return prompt[-int(end_length) :] + elif middle_length is not None: + middle_length = int(middle_length) + if len(prompt) <= middle_length: + return prompt + start = prompt[: math.ceil(middle_length / 2)] + end = prompt[-math.floor(middle_length / 2) :] + return f"{start}...{end}" + return "" + + template = re.sub( + r"{{prompt}}|{{prompt:start:(\d+)}}|{{prompt:end:(\d+)}}|{{prompt:middletruncate:(\d+)}}", + replacement_function, + template, + ) + + template = prompt_template( + template, + **( + {"user_name": user.get("name"), "current_location": user.get("location")} + if user + else {} + ), + ) + + return template + + +def search_query_generation_template( + template: str, prompt: str, user: Optional[dict] = None +) -> str: + + def replacement_function(match): + full_match = match.group(0) + start_length = match.group(1) + end_length = match.group(2) + middle_length = match.group(3) + + if full_match == "{{prompt}}": + return prompt + elif start_length is not None: + return prompt[: int(start_length)] + elif end_length is not None: + return prompt[-int(end_length) :] + elif middle_length is not None: + middle_length = int(middle_length) + if len(prompt) <= middle_length: + return prompt + start = prompt[: math.ceil(middle_length / 2)] + end = prompt[-math.floor(middle_length / 2) :] + return f"{start}...{end}" + return "" + + template = re.sub( + r"{{prompt}}|{{prompt:start:(\d+)}}|{{prompt:end:(\d+)}}|{{prompt:middletruncate:(\d+)}}", + replacement_function, + template, + ) + + template = prompt_template( + template, + **( + {"user_name": user.get("name"), "current_location": user.get("location")} + if user + else {} + ), + ) + return template diff --git a/docs/CONTRIBUTING.md b/docs/CONTRIBUTING.md index 92238d307..325964b1a 100644 --- a/docs/CONTRIBUTING.md +++ b/docs/CONTRIBUTING.md @@ -41,7 +41,7 @@ Looking to contribute? Great! Here's how you can help: We welcome pull requests. Before submitting one, please: -1. Discuss your idea or issue in the [issues section](https://github.com/open-webui/open-webui/issues). +1. Open a discussion regarding your ideas [here](https://github.com/open-webui/open-webui/discussions/new/choose). 2. Follow the project's coding standards and include tests for new features. 3. Update documentation as necessary. 4. Write clear, descriptive commit messages. diff --git a/package-lock.json b/package-lock.json index 7d3b385e2..cf42022e8 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "open-webui", - "version": "0.2.5", + "version": "0.3.0", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "open-webui", - "version": "0.2.5", + "version": "0.3.0", "dependencies": { "@pyscript/core": "^0.4.32", "@sveltejs/adapter-node": "^1.3.1", diff --git a/package.json b/package.json index 7ea3bf3c7..6827728be 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "open-webui", - "version": "0.2.5", + "version": "0.3.0", "private": true, "scripts": { "dev": "npm run pyodide:fetch && vite dev --host", diff --git a/src/app.html b/src/app.html index 138fb2829..a71d5ff4c 100644 --- a/src/app.html +++ b/src/app.html @@ -59,15 +59,7 @@
diff --git a/src/lib/apis/audio/index.ts b/src/lib/apis/audio/index.ts index 7bd8981fe..9716c552a 100644 --- a/src/lib/apis/audio/index.ts +++ b/src/lib/apis/audio/index.ts @@ -98,7 +98,7 @@ export const synthesizeOpenAISpeech = async ( token: string = '', speaker: string = 'alloy', text: string = '', - model: string = 'tts-1' + model?: string ) => { let error = null; @@ -109,9 +109,9 @@ export const synthesizeOpenAISpeech = async ( 'Content-Type': 'application/json' }, body: JSON.stringify({ - model: model, input: text, - voice: speaker + voice: speaker, + ...(model && { model }) }) }) .then(async (res) => { diff --git a/src/lib/apis/index.ts b/src/lib/apis/index.ts index f6b2de4d0..c40815611 100644 --- a/src/lib/apis/index.ts +++ b/src/lib/apis/index.ts @@ -104,6 +104,147 @@ export const chatCompleted = async (token: string, body: ChatCompletedForm) => { return res; }; +export const getTaskConfig = async (token: string = '') => { + let error = null; + + const res = await fetch(`${WEBUI_BASE_URL}/api/task/config`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + error = err; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const updateTaskConfig = async (token: string, config: object) => { + let error = null; + + const res = await fetch(`${WEBUI_BASE_URL}/api/task/config/update`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + }, + body: JSON.stringify(config) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + if ('detail' in err) { + error = err.detail; + } else { + error = err; + } + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const generateTitle = async ( + token: string = '', + model: string, + prompt: string, + chat_id?: string +) => { + let error = null; + + const res = await fetch(`${WEBUI_BASE_URL}/api/task/title/completions`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + model: model, + prompt: prompt, + ...(chat_id && { chat_id: chat_id }) + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + if ('detail' in err) { + error = err.detail; + } + return null; + }); + + if (error) { + throw error; + } + + return res?.choices[0]?.message?.content.replace(/["']/g, '') ?? 'New Chat'; +}; + +export const generateSearchQuery = async ( + token: string = '', + model: string, + messages: object[], + prompt: string +) => { + let error = null; + + const res = await fetch(`${WEBUI_BASE_URL}/api/task/query/completions`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + model: model, + messages: messages, + prompt: prompt + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + if ('detail' in err) { + error = err.detail; + } + return null; + }); + + if (error) { + throw error; + } + + return res?.choices[0]?.message?.content.replace(/["']/g, '') ?? prompt; +}; + export const getPipelinesList = async (token: string = '') => { let error = null; @@ -133,6 +274,43 @@ export const getPipelinesList = async (token: string = '') => { return pipelines; }; +export const uploadPipeline = async (token: string, file: File, urlIdx: string) => { + let error = null; + + // Create a new FormData object to handle the file upload + const formData = new FormData(); + formData.append('file', file); + formData.append('urlIdx', urlIdx); + + const res = await fetch(`${WEBUI_BASE_URL}/api/pipelines/upload`, { + method: 'POST', + headers: { + ...(token && { authorization: `Bearer ${token}` }) + // 'Content-Type': 'multipart/form-data' is not needed as Fetch API will set it automatically + }, + body: formData + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + if ('detail' in err) { + error = err.detail; + } else { + error = err; + } + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + export const downloadPipeline = async (token: string, url: string, urlIdx: string) => { let error = null; diff --git a/src/lib/components/admin/Settings.svelte b/src/lib/components/admin/Settings.svelte new file mode 100644 index 000000000..5538a11cf --- /dev/null +++ b/src/lib/components/admin/Settings.svelte @@ -0,0 +1,390 @@ + + +
+
+ + + + + + + + + + + + + + + + + + + + + +
+ +
+ {#if selectedTab === 'general'} + { + toast.success($i18n.t('Settings saved successfully!')); + }} + /> + {:else if selectedTab === 'users'} + { + toast.success($i18n.t('Settings saved successfully!')); + }} + /> + {:else if selectedTab === 'connections'} + { + toast.success($i18n.t('Settings saved successfully!')); + }} + /> + {:else if selectedTab === 'models'} + + {:else if selectedTab === 'documents'} + { + toast.success($i18n.t('Settings saved successfully!')); + }} + /> + {:else if selectedTab === 'web'} + { + toast.success($i18n.t('Settings saved successfully!')); + + await tick(); + await config.set(await getBackendConfig()); + }} + /> + {:else if selectedTab === 'interface'} + { + toast.success($i18n.t('Settings saved successfully!')); + }} + /> + {:else if selectedTab === 'audio'} +
+
diff --git a/src/lib/components/admin/Settings/Audio.svelte b/src/lib/components/admin/Settings/Audio.svelte new file mode 100644 index 000000000..d38402aa0 --- /dev/null +++ b/src/lib/components/admin/Settings/Audio.svelte @@ -0,0 +1,302 @@ + + +
{ + await updateConfigHandler(); + dispatch('save'); + }} +> +
+
+
+
{$i18n.t('STT Settings')}
+ +
+
{$i18n.t('Speech-to-Text Engine')}
+
+ +
+
+ + {#if STT_ENGINE === 'openai'} +
+
+ + + +
+
+ +
+ +
+
{$i18n.t('STT Model')}
+
+
+ + + + +
+
+
+ {/if} +
+ +
+ +
+
{$i18n.t('TTS Settings')}
+ +
+
{$i18n.t('Text-to-Speech Engine')}
+
+ +
+
+ + {#if TTS_ENGINE === 'openai'} +
+
+ + + +
+
+ {/if} + +
+ + {#if TTS_ENGINE === ''} +
+
{$i18n.t('TTS Voice')}
+
+
+ +
+
+
+ {:else if TTS_ENGINE === 'openai'} +
+
+
{$i18n.t('TTS Voice')}
+
+
+ + + + {#each voices as voice} + +
+
+
+
+
{$i18n.t('TTS Model')}
+
+
+ + + + {#each models as model} + +
+
+
+
+ {/if} +
+
+
+
+ +
+
diff --git a/src/lib/components/admin/Settings/Banners.svelte b/src/lib/components/admin/Settings/Banners.svelte deleted file mode 100644 index e69a8ebb1..000000000 --- a/src/lib/components/admin/Settings/Banners.svelte +++ /dev/null @@ -1,137 +0,0 @@ - - -
{ - updateBanners(); - saveHandler(); - }} -> -
-
-
-
- {$i18n.t('Banners')} -
- - -
-
- {#each banners as banner, bannerIdx} -
-
- - - - -
- - - -
-
- - -
- {/each} -
-
-
-
- -
-
diff --git a/src/lib/components/chat/Settings/Connections.svelte b/src/lib/components/admin/Settings/Connections.svelte similarity index 98% rename from src/lib/components/chat/Settings/Connections.svelte rename to src/lib/components/admin/Settings/Connections.svelte index 80fdcf45f..669fe8aae 100644 --- a/src/lib/components/chat/Settings/Connections.svelte +++ b/src/lib/components/admin/Settings/Connections.svelte @@ -23,10 +23,14 @@ import Switch from '$lib/components/common/Switch.svelte'; import Spinner from '$lib/components/common/Spinner.svelte'; import Tooltip from '$lib/components/common/Tooltip.svelte'; + import { getModels as _getModels } from '$lib/apis'; const i18n = getContext('i18n'); - export let getModels: Function; + const getModels = async () => { + const models = await _getModels(localStorage.token); + return models; + }; // External let OLLAMA_BASE_URLS = ['']; @@ -158,7 +162,7 @@ dispatch('save'); }} > -
+
{#if ENABLE_OPENAI_API !== null && ENABLE_OLLAMA_API !== null}
@@ -300,7 +304,7 @@
-
+
diff --git a/src/lib/components/admin/Settings/Database.svelte b/src/lib/components/admin/Settings/Database.svelte index 9ce8e9d8d..ae3077be4 100644 --- a/src/lib/components/admin/Settings/Database.svelte +++ b/src/lib/components/admin/Settings/Database.svelte @@ -30,7 +30,7 @@ saveHandler(); }} > -
+
{$i18n.t('Database')}
diff --git a/src/lib/components/documents/Settings/General.svelte b/src/lib/components/admin/Settings/Documents.svelte similarity index 83% rename from src/lib/components/documents/Settings/General.svelte rename to src/lib/components/admin/Settings/Documents.svelte index 34d8f5787..0e6527813 100644 --- a/src/lib/components/documents/Settings/General.svelte +++ b/src/lib/components/admin/Settings/Documents.svelte @@ -9,7 +9,9 @@ updateEmbeddingConfig, getRerankingConfig, updateRerankingConfig, - resetUploadDir + resetUploadDir, + getRAGConfig, + updateRAGConfig } from '$lib/apis/rag'; import { documents, models } from '$lib/stores'; @@ -31,6 +33,10 @@ let embeddingModel = ''; let rerankingModel = ''; + let chunkSize = 0; + let chunkOverlap = 0; + let pdfExtractImages = true; + let OpenAIKey = ''; let OpenAIUrl = ''; let OpenAIBatchSize = 1; @@ -152,6 +158,14 @@ if (querySettings.hybrid) { rerankingModelUpdateHandler(); } + + const res = await updateRAGConfig(localStorage.token, { + pdf_extract_images: pdfExtractImages, + chunk: { + chunk_overlap: chunkOverlap, + chunk_size: chunkSize + } + }); }; const setEmbeddingConfig = async () => { @@ -185,6 +199,15 @@ await setRerankingConfig(); querySettings = await getQuerySettings(localStorage.token); + + const res = await getRAGConfig(localStorage.token); + + if (res) { + pdfExtractImages = res.pdf_extract_images; + + chunkSize = res.chunk.chunk_size; + chunkOverlap = res.chunk.chunk_overlap; + } }); @@ -195,7 +218,7 @@ saveHandler(); }} > -
+
{$i18n.t('General Settings')}
@@ -332,7 +355,7 @@
-
+
@@ -350,10 +373,8 @@ {#if !embeddingModel} {/if} - {#each $models.filter((m) => m.id && !m.external) as model} - + {#each $models.filter((m) => m.id && m.ollama && !(m?.preset ?? false)) as model} + {/each}
@@ -500,6 +521,122 @@
+
+
{$i18n.t('Query Params')}
+ +
+
+
{$i18n.t('Top K')}
+ +
+ +
+
+ + {#if querySettings.hybrid === true} +
+
+ {$i18n.t('Minimum Score')} +
+ +
+ +
+
+ {/if} +
+ + {#if querySettings.hybrid === true} +
+ {$i18n.t( + 'Note: If you set a minimum score, the search will only return documents with a score greater than or equal to the minimum score.' + )} +
+ +
+ {/if} + +
+
{$i18n.t('RAG Template')}
+