diff --git a/README.md b/README.md index cd06bc384..d4d31ba64 100644 --- a/README.md +++ b/README.md @@ -25,22 +25,28 @@ Open WebUI is an extensible, feature-rich, and user-friendly self-hosted WebUI d - 🚀 **Effortless Setup**: Install seamlessly using Docker or Kubernetes (kubectl, kustomize or helm) for a hassle-free experience. +- 🌈 **Theme Customization**: Choose from a variety of themes to personalize your Open WebUI experience. + - 💻 **Code Syntax Highlighting**: Enjoy enhanced code readability with our syntax highlighting feature. - ✒️🔢 **Full Markdown and LaTeX Support**: Elevate your LLM experience with comprehensive Markdown and LaTeX capabilities for enriched interaction. - 📚 **Local RAG Integration**: Dive into the future of chat interactions with the groundbreaking Retrieval Augmented Generation (RAG) support. This feature seamlessly integrates document interactions into your chat experience. You can load documents directly into the chat or add files to your document library, effortlessly accessing them using `#` command in the prompt. In its alpha phase, occasional issues may arise as we actively refine and enhance this feature to ensure optimal performance and reliability. +- 🔍 **RAG Embedding Support**: Change the RAG embedding model directly in document settings, enhancing document processing. This feature supports Ollama and OpenAI models. + - 🌐 **Web Browsing Capability**: Seamlessly integrate websites into your chat experience using the `#` command followed by the URL. This feature allows you to incorporate web content directly into your conversations, enhancing the richness and depth of your interactions. - 📜 **Prompt Preset Support**: Instantly access preset prompts using the `/` command in the chat input. Load predefined conversation starters effortlessly and expedite your interactions. Effortlessly import prompts through [Open WebUI Community](https://openwebui.com/) integration. -- 👍👎 **RLHF Annotation**: Empower your messages by rating them with thumbs up and thumbs down, facilitating the creation of datasets for Reinforcement Learning from Human Feedback (RLHF). Utilize your messages to train or fine-tune models, all while ensuring the confidentiality of locally saved data. +- 👍👎 **RLHF Annotation**: Empower your messages by rating them with thumbs up and thumbs down, followed by the option to provide textual feedback, facilitating the creation of datasets for Reinforcement Learning from Human Feedback (RLHF). Utilize your messages to train or fine-tune models, all while ensuring the confidentiality of locally saved data. - 🏷️ **Conversation Tagging**: Effortlessly categorize and locate specific chats for quick reference and streamlined data collection. - 📥🗑️ **Download/Delete Models**: Easily download or remove models directly from the web UI. +- 🔄 **Update All Ollama Models**: Easily update locally installed models all at once with a convenient button, streamlining model management. + - ⬆️ **GGUF File Model Creation**: Effortlessly create Ollama models by uploading GGUF files directly from the web UI. Streamlined process with options to upload from your machine or download GGUF files from Hugging Face. - 🤖 **Multiple Model Support**: Seamlessly switch between different chat models for diverse interactions. @@ -53,28 +59,42 @@ Open WebUI is an extensible, feature-rich, and user-friendly self-hosted WebUI d - 💬 **Collaborative Chat**: Harness the collective intelligence of multiple models by seamlessly orchestrating group conversations. Use the `@` command to specify the model, enabling dynamic and diverse dialogues within your chat interface. Immerse yourself in the collective intelligence woven into your chat environment. +- 🗨️ **Local Chat Sharing**: Generate and share chat links seamlessly between users, enhancing collaboration and communication. + - 🔄 **Regeneration History Access**: Easily revisit and explore your entire regeneration history. - 📜 **Chat History**: Effortlessly access and manage your conversation history. +- 📬 **Archive Chats**: Effortlessly store away completed conversations with LLMs for future reference, maintaining a tidy and clutter-free chat interface while allowing for easy retrieval and reference. + - 📤📥 **Import/Export Chat History**: Seamlessly move your chat data in and out of the platform. - 🗣️ **Voice Input Support**: Engage with your model through voice interactions; enjoy the convenience of talking to your model directly. Additionally, explore the option for sending voice input automatically after 3 seconds of silence for a streamlined experience. +- 🔊 **Configurable Text-to-Speech Endpoint**: Customize your Text-to-Speech experience with configurable OpenAI endpoints. + - ⚙️ **Fine-Tuned Control with Advanced Parameters**: Gain a deeper level of control by adjusting parameters such as temperature and defining your system prompts to tailor the conversation to your specific preferences and needs. -- 🎨🤖 **Image Generation Integration**: Seamlessly incorporate image generation capabilities using AUTOMATIC1111 API (local) and DALL-E, enriching your chat experience with dynamic visual content. +- 🎨🤖 **Image Generation Integration**: Seamlessly incorporate image generation capabilities using options such as AUTOMATIC1111 API (local), ComfyUI (local), and DALL-E, enriching your chat experience with dynamic visual content. - 🤝 **OpenAI API Integration**: Effortlessly integrate OpenAI-compatible API for versatile conversations alongside Ollama models. Customize the API Base URL to link with **LMStudio, Mistral, OpenRouter, and more**. - ✨ **Multiple OpenAI-Compatible API Support**: Seamlessly integrate and customize various OpenAI-compatible APIs, enhancing the versatility of your chat interactions. +- 🔑 **API Key Generation Support**: Generate secret keys to leverage Open WebUI with OpenAI libraries, simplifying integration and development. + - 🔗 **External Ollama Server Connection**: Seamlessly link to an external Ollama server hosted on a different address by configuring the environment variable. - 🔀 **Multiple Ollama Instance Load Balancing**: Effortlessly distribute chat requests across multiple Ollama instances for enhanced performance and reliability. - 👥 **Multi-User Management**: Easily oversee and administer users via our intuitive admin panel, streamlining user management processes. +- 🔗 **Webhook Integration**: Subscribe to new user sign-up events via webhook (compatible with Google Chat and Microsoft Teams), providing real-time notifications and automation capabilities. + +- 🛡️ **Model Whitelisting**: Admins can whitelist models for users with the 'user' role, enhancing security and access control. + +- 📧 **Trusted Email Authentication**: Authenticate using a trusted email header, adding an additional layer of security and authentication. + - 🔐 **Role-Based Access Control (RBAC)**: Ensure secure access with restricted permissions; only authorized individuals can access your Ollama, and exclusive model creation/pulling rights are reserved for administrators. - 🔒 **Backend Reverse Proxy Support**: Bolster security through direct communication between Open WebUI backend and Ollama. This key feature eliminates the need to expose Ollama over LAN. Requests made to the '/ollama/api' route from the web UI are seamlessly redirected to Ollama from the backend, enhancing overall system security. diff --git a/backend/apps/images/utils/comfyui.py b/backend/apps/images/utils/comfyui.py index 393333255..05df1c166 100644 --- a/backend/apps/images/utils/comfyui.py +++ b/backend/apps/images/utils/comfyui.py @@ -195,7 +195,7 @@ class ImageGenerationPayload(BaseModel): def comfyui_generate_image( model: str, payload: ImageGenerationPayload, client_id, base_url ): - host = base_url.replace("http://", "").replace("https://", "") + ws_url = base_url.replace("http://", "ws://").replace("https://", "wss://") comfyui_prompt = json.loads(COMFYUI_DEFAULT_PROMPT) @@ -217,7 +217,7 @@ def comfyui_generate_image( try: ws = websocket.WebSocket() - ws.connect(f"ws://{host}/ws?clientId={client_id}") + ws.connect(f"{ws_url}/ws?clientId={client_id}") log.info("WebSocket connection established.") except Exception as e: log.exception(f"Failed to connect to WebSocket server: {e}") diff --git a/backend/apps/litellm/main.py b/backend/apps/litellm/main.py index a9922aad7..52e0c7002 100644 --- a/backend/apps/litellm/main.py +++ b/backend/apps/litellm/main.py @@ -1,100 +1,336 @@ +from fastapi import FastAPI, Depends, HTTPException +from fastapi.routing import APIRoute +from fastapi.middleware.cors import CORSMiddleware + import logging - -from litellm.proxy.proxy_server import ProxyConfig, initialize -from litellm.proxy.proxy_server import app - from fastapi import FastAPI, Request, Depends, status, Response from fastapi.responses import JSONResponse from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint from starlette.responses import StreamingResponse import json +import time +import requests -from utils.utils import get_http_authorization_cred, get_current_user +from pydantic import BaseModel, ConfigDict +from typing import Optional, List + +from utils.utils import get_verified_user, get_current_user, get_admin_user from config import SRC_LOG_LEVELS, ENV +from constants import MESSAGES log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["LITELLM"]) -from config import ( - MODEL_FILTER_ENABLED, - MODEL_FILTER_LIST, +from config import MODEL_FILTER_ENABLED, MODEL_FILTER_LIST, DATA_DIR + +from litellm.utils import get_llm_provider + +import asyncio +import subprocess +import yaml + +app = FastAPI() + +origins = ["*"] + +app.add_middleware( + CORSMiddleware, + allow_origins=origins, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], ) -proxy_config = ProxyConfig() +LITELLM_CONFIG_DIR = f"{DATA_DIR}/litellm/config.yaml" + +with open(LITELLM_CONFIG_DIR, "r") as file: + litellm_config = yaml.safe_load(file) + +app.state.CONFIG = litellm_config + +# Global variable to store the subprocess reference +background_process = None -async def config(): - router, model_list, general_settings = await proxy_config.load_config( - router=None, config_file_path="./data/litellm/config.yaml" +async def run_background_process(command): + global background_process + log.info("run_background_process") + + try: + # Log the command to be executed + log.info(f"Executing command: {command}") + # Execute the command and create a subprocess + process = await asyncio.create_subprocess_exec( + *command.split(), stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) + background_process = process + log.info("Subprocess started successfully.") + + # Capture STDERR for debugging purposes + stderr_output = await process.stderr.read() + stderr_text = stderr_output.decode().strip() + if stderr_text: + log.info(f"Subprocess STDERR: {stderr_text}") + + # log.info output line by line + async for line in process.stdout: + log.info(line.decode().strip()) + + # Wait for the process to finish + returncode = await process.wait() + log.info(f"Subprocess exited with return code {returncode}") + except Exception as e: + log.error(f"Failed to start subprocess: {e}") + raise # Optionally re-raise the exception if you want it to propagate + + +async def start_litellm_background(): + log.info("start_litellm_background") + # Command to run in the background + command = ( + "litellm --port 14365 --telemetry False --config ./data/litellm/config.yaml" ) - await initialize(config="./data/litellm/config.yaml", telemetry=False) + await run_background_process(command) -async def startup(): - await config() +async def shutdown_litellm_background(): + log.info("shutdown_litellm_background") + global background_process + if background_process: + background_process.terminate() + await background_process.wait() # Ensure the process has terminated + log.info("Subprocess terminated") + background_process = None @app.on_event("startup") -async def on_startup(): - await startup() +async def startup_event(): + + log.info("startup_event") + # TODO: Check config.yaml file and create one + asyncio.create_task(start_litellm_background()) app.state.MODEL_FILTER_ENABLED = MODEL_FILTER_ENABLED app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST -@app.middleware("http") -async def auth_middleware(request: Request, call_next): - auth_header = request.headers.get("Authorization", "") - request.state.user = None +@app.get("/") +async def get_status(): + return {"status": True} + + +async def restart_litellm(): + """ + Endpoint to restart the litellm background service. + """ + log.info("Requested restart of litellm service.") + try: + # Shut down the existing process if it is running + await shutdown_litellm_background() + log.info("litellm service shutdown complete.") + + # Restart the background service + + asyncio.create_task(start_litellm_background()) + log.info("litellm service restart complete.") + + return { + "status": "success", + "message": "litellm service restarted successfully.", + } + except Exception as e: + log.info(f"Error restarting litellm service: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e) + ) + + +@app.get("/restart") +async def restart_litellm_handler(user=Depends(get_admin_user)): + return await restart_litellm() + + +@app.get("/config") +async def get_config(user=Depends(get_admin_user)): + return app.state.CONFIG + + +class LiteLLMConfigForm(BaseModel): + general_settings: Optional[dict] = None + litellm_settings: Optional[dict] = None + model_list: Optional[List[dict]] = None + router_settings: Optional[dict] = None + + model_config = ConfigDict(protected_namespaces=()) + + +@app.post("/config/update") +async def update_config(form_data: LiteLLMConfigForm, user=Depends(get_admin_user)): + app.state.CONFIG = form_data.model_dump(exclude_none=True) + + with open(LITELLM_CONFIG_DIR, "w") as file: + yaml.dump(app.state.CONFIG, file) + + await restart_litellm() + return app.state.CONFIG + + +@app.get("/models") +@app.get("/v1/models") +async def get_models(user=Depends(get_current_user)): + while not background_process: + await asyncio.sleep(0.1) + + url = "http://localhost:14365/v1" + r = None + try: + r = requests.request(method="GET", url=f"{url}/models") + r.raise_for_status() + + data = r.json() + + if app.state.MODEL_FILTER_ENABLED: + if user and user.role == "user": + data["data"] = list( + filter( + lambda model: model["id"] in app.state.MODEL_FILTER_LIST, + data["data"], + ) + ) + + return data + except Exception as e: + + log.exception(e) + error_detail = "Open WebUI: Server Connection Error" + if r is not None: + try: + res = r.json() + if "error" in res: + error_detail = f"External: {res['error']}" + except: + error_detail = f"External: {e}" + + return { + "data": [ + { + "id": model["model_name"], + "object": "model", + "created": int(time.time()), + "owned_by": "openai", + } + for model in app.state.CONFIG["model_list"] + ], + "object": "list", + } + + +@app.get("/model/info") +async def get_model_list(user=Depends(get_admin_user)): + return {"data": app.state.CONFIG["model_list"]} + + +class AddLiteLLMModelForm(BaseModel): + model_name: str + litellm_params: dict + + model_config = ConfigDict(protected_namespaces=()) + + +@app.post("/model/new") +async def add_model_to_config( + form_data: AddLiteLLMModelForm, user=Depends(get_admin_user) +): + try: + get_llm_provider(model=form_data.model_name) + app.state.CONFIG["model_list"].append(form_data.model_dump()) + + with open(LITELLM_CONFIG_DIR, "w") as file: + yaml.dump(app.state.CONFIG, file) + + await restart_litellm() + + return {"message": MESSAGES.MODEL_ADDED(form_data.model_name)} + except Exception as e: + print(e) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e) + ) + + +class DeleteLiteLLMModelForm(BaseModel): + id: str + + +@app.post("/model/delete") +async def delete_model_from_config( + form_data: DeleteLiteLLMModelForm, user=Depends(get_admin_user) +): + app.state.CONFIG["model_list"] = [ + model + for model in app.state.CONFIG["model_list"] + if model["model_name"] != form_data.id + ] + + with open(LITELLM_CONFIG_DIR, "w") as file: + yaml.dump(app.state.CONFIG, file) + + await restart_litellm() + + return {"message": MESSAGES.MODEL_DELETED(form_data.id)} + + +@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) +async def proxy(path: str, request: Request, user=Depends(get_verified_user)): + body = await request.body() + + url = "http://localhost:14365" + + target_url = f"{url}/{path}" + + headers = {} + # headers["Authorization"] = f"Bearer {key}" + headers["Content-Type"] = "application/json" + + r = None try: - user = get_current_user(get_http_authorization_cred(auth_header)) - log.debug(f"user: {user}") - request.state.user = user + r = requests.request( + method=request.method, + url=target_url, + data=body, + headers=headers, + stream=True, + ) + + r.raise_for_status() + + # Check if response is SSE + if "text/event-stream" in r.headers.get("Content-Type", ""): + return StreamingResponse( + r.iter_content(chunk_size=8192), + status_code=r.status_code, + headers=dict(r.headers), + ) + else: + response_data = r.json() + return response_data except Exception as e: - return JSONResponse(status_code=400, content={"detail": str(e)}) + log.exception(e) + error_detail = "Open WebUI: Server Connection Error" + if r is not None: + try: + res = r.json() + if "error" in res: + error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}" + except: + error_detail = f"External: {e}" - response = await call_next(request) - return response - - -class ModifyModelsResponseMiddleware(BaseHTTPMiddleware): - async def dispatch( - self, request: Request, call_next: RequestResponseEndpoint - ) -> Response: - - response = await call_next(request) - user = request.state.user - - if "/models" in request.url.path: - if isinstance(response, StreamingResponse): - # Read the content of the streaming response - body = b"" - async for chunk in response.body_iterator: - body += chunk - - data = json.loads(body.decode("utf-8")) - - if app.state.MODEL_FILTER_ENABLED: - if user and user.role == "user": - data["data"] = list( - filter( - lambda model: model["id"] - in app.state.MODEL_FILTER_LIST, - data["data"], - ) - ) - - # Modified Flag - data["modified"] = True - return JSONResponse(content=data) - - return response - - -app.add_middleware(ModifyModelsResponseMiddleware) + raise HTTPException( + status_code=r.status_code if r else 500, detail=error_detail + ) diff --git a/backend/apps/openai/main.py b/backend/apps/openai/main.py index 4647d7489..0fbbd365e 100644 --- a/backend/apps/openai/main.py +++ b/backend/apps/openai/main.py @@ -80,6 +80,7 @@ async def get_openai_urls(user=Depends(get_admin_user)): @app.post("/urls/update") async def update_openai_urls(form_data: UrlsUpdateForm, user=Depends(get_admin_user)): + await get_all_models() app.state.OPENAI_API_BASE_URLS = form_data.urls return {"OPENAI_API_BASE_URLS": app.state.OPENAI_API_BASE_URLS} diff --git a/backend/apps/web/models/tags.py b/backend/apps/web/models/tags.py index 196551b7b..02de5b9d7 100644 --- a/backend/apps/web/models/tags.py +++ b/backend/apps/web/models/tags.py @@ -136,7 +136,9 @@ class TagTable: return [ TagModel(**model_to_dict(tag)) - for tag in Tag.select().where(Tag.name.in_(tag_names)) + for tag in Tag.select() + .where(Tag.user_id == user_id) + .where(Tag.name.in_(tag_names)) ] def get_tags_by_chat_id_and_user_id( @@ -151,7 +153,9 @@ class TagTable: return [ TagModel(**model_to_dict(tag)) - for tag in Tag.select().where(Tag.name.in_(tag_names)) + for tag in Tag.select() + .where(Tag.user_id == user_id) + .where(Tag.name.in_(tag_names)) ] def get_chat_ids_by_tag_name_and_user_id( diff --git a/backend/apps/web/routers/chats.py b/backend/apps/web/routers/chats.py index 678c9aea7..bbe3d84b9 100644 --- a/backend/apps/web/routers/chats.py +++ b/backend/apps/web/routers/chats.py @@ -28,7 +28,7 @@ from apps.web.models.tags import ( from constants import ERROR_MESSAGES -from config import SRC_LOG_LEVELS +from config import SRC_LOG_LEVELS, ENABLE_ADMIN_EXPORT log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MODELS"]) @@ -79,6 +79,11 @@ async def get_all_user_chats(user=Depends(get_current_user)): @router.get("/all/db", response_model=List[ChatResponse]) async def get_all_user_chats_in_db(user=Depends(get_admin_user)): + if not ENABLE_ADMIN_EXPORT: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) return [ ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) for chat in Chats.get_all_chats() diff --git a/backend/apps/web/routers/utils.py b/backend/apps/web/routers/utils.py index 0ee75cfe6..284f350a0 100644 --- a/backend/apps/web/routers/utils.py +++ b/backend/apps/web/routers/utils.py @@ -91,7 +91,11 @@ async def download_chat_as_pdf( @router.get("/db/download") async def download_db(user=Depends(get_admin_user)): - + if not ENABLE_ADMIN_EXPORT: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) return FileResponse( f"{DATA_DIR}/webui.db", media_type="application/octet-stream", diff --git a/backend/config.py b/backend/config.py index 29284667b..269f0eedb 100644 --- a/backend/config.py +++ b/backend/config.py @@ -322,9 +322,14 @@ OPENAI_API_BASE_URLS = [ ] OPENAI_API_KEY = "" -OPENAI_API_KEY = OPENAI_API_KEYS[ - OPENAI_API_BASE_URLS.index("https://api.openai.com/v1") -] + +try: + OPENAI_API_KEY = OPENAI_API_KEYS[ + OPENAI_API_BASE_URLS.index("https://api.openai.com/v1") + ] +except: + pass + OPENAI_API_BASE_URL = "https://api.openai.com/v1" @@ -377,6 +382,8 @@ MODEL_FILTER_LIST = [model.strip() for model in MODEL_FILTER_LIST.split(";")] WEBHOOK_URL = os.environ.get("WEBHOOK_URL", "") +ENABLE_ADMIN_EXPORT = os.environ.get("ENABLE_ADMIN_EXPORT", "True").lower() == "true" + #################################### # WEBUI_VERSION #################################### diff --git a/backend/constants.py b/backend/constants.py index da1ee0b3f..310c13311 100644 --- a/backend/constants.py +++ b/backend/constants.py @@ -3,6 +3,10 @@ from enum import Enum class MESSAGES(str, Enum): DEFAULT = lambda msg="": f"{msg if msg else ''}" + MODEL_ADDED = lambda model="": f"The model '{model}' has been added successfully." + MODEL_DELETED = ( + lambda model="": f"The model '{model}' has been deleted successfully." + ) class WEBHOOK_MESSAGES(str, Enum): diff --git a/backend/main.py b/backend/main.py index 655bdb6da..47a9ce310 100644 --- a/backend/main.py +++ b/backend/main.py @@ -20,12 +20,17 @@ from starlette.middleware.base import BaseHTTPMiddleware from apps.ollama.main import app as ollama_app from apps.openai.main import app as openai_app -from apps.litellm.main import app as litellm_app, startup as litellm_app_startup +from apps.litellm.main import ( + app as litellm_app, + start_litellm_background, + shutdown_litellm_background, +) from apps.audio.main import app as audio_app from apps.images.main import app as images_app from apps.rag.main import app as rag_app from apps.web.main import app as webui_app +import asyncio from pydantic import BaseModel from typing import List @@ -47,6 +52,7 @@ from config import ( GLOBAL_LOG_LEVEL, SRC_LOG_LEVELS, WEBHOOK_URL, + ENABLE_ADMIN_EXPORT, ) from constants import ERROR_MESSAGES @@ -171,7 +177,7 @@ async def check_url(request: Request, call_next): @app.on_event("startup") async def on_startup(): - await litellm_app_startup() + asyncio.create_task(start_litellm_background()) app.mount("/api/v1", webui_app) @@ -203,6 +209,7 @@ async def get_app_config(): "default_models": webui_app.state.DEFAULT_MODELS, "default_prompt_suggestions": webui_app.state.DEFAULT_PROMPT_SUGGESTIONS, "trusted_header_auth": bool(webui_app.state.AUTH_TRUSTED_EMAIL_HEADER), + "admin_export_enabled": ENABLE_ADMIN_EXPORT, } @@ -316,3 +323,8 @@ app.mount( SPAStaticFiles(directory=FRONTEND_BUILD_DIR, html=True), name="spa-static-files", ) + + +@app.on_event("shutdown") +async def shutdown_event(): + await shutdown_litellm_background() diff --git a/backend/requirements.txt b/backend/requirements.txt index d5c179d86..10bcc3b69 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -17,7 +17,9 @@ peewee peewee-migrate bcrypt -litellm==1.30.7 +litellm==1.35.17 +litellm[proxy]==1.35.17 + boto3 argon2-cffi diff --git a/src/lib/apis/streaming/index.ts b/src/lib/apis/streaming/index.ts new file mode 100644 index 000000000..5b89a4668 --- /dev/null +++ b/src/lib/apis/streaming/index.ts @@ -0,0 +1,70 @@ +type TextStreamUpdate = { + done: boolean; + value: string; +}; + +// createOpenAITextStream takes a ReadableStreamDefaultReader from an SSE response, +// and returns an async generator that emits delta updates with large deltas chunked into random sized chunks +export async function createOpenAITextStream( + messageStream: ReadableStreamDefaultReader, + splitLargeDeltas: boolean +): Promise> { + let iterator = openAIStreamToIterator(messageStream); + if (splitLargeDeltas) { + iterator = streamLargeDeltasAsRandomChunks(iterator); + } + return iterator; +} + +async function* openAIStreamToIterator( + reader: ReadableStreamDefaultReader +): AsyncGenerator { + while (true) { + const { value, done } = await reader.read(); + if (done) { + yield { done: true, value: '' }; + break; + } + const lines = value.split('\n'); + for (const line of lines) { + if (line !== '') { + console.log(line); + if (line === 'data: [DONE]') { + yield { done: true, value: '' }; + } else { + const data = JSON.parse(line.replace(/^data: /, '')); + console.log(data); + + yield { done: false, value: data.choices[0].delta.content ?? '' }; + } + } + } + } +} + +// streamLargeDeltasAsRandomChunks will chunk large deltas (length > 5) into random sized chunks between 1-3 characters +// This is to simulate a more fluid streaming, even though some providers may send large chunks of text at once +async function* streamLargeDeltasAsRandomChunks( + iterator: AsyncGenerator +): AsyncGenerator { + for await (const textStreamUpdate of iterator) { + if (textStreamUpdate.done) { + yield textStreamUpdate; + return; + } + let content = textStreamUpdate.value; + if (content.length < 5) { + yield { done: false, value: content }; + continue; + } + while (content != '') { + const chunkSize = Math.min(Math.floor(Math.random() * 3) + 1, content.length); + const chunk = content.slice(0, chunkSize); + yield { done: false, value: chunk }; + await sleep(5); + content = content.slice(chunkSize); + } + } +} + +const sleep = (ms: number) => new Promise((resolve) => setTimeout(resolve, ms)); diff --git a/src/lib/components/admin/Settings/Database.svelte b/src/lib/components/admin/Settings/Database.svelte index 7d3a34444..06a0d595c 100644 --- a/src/lib/components/admin/Settings/Database.svelte +++ b/src/lib/components/admin/Settings/Database.svelte @@ -1,6 +1,7 @@ diff --git a/src/lib/components/chat/ShareChatModal.svelte b/src/lib/components/chat/ShareChatModal.svelte index 96ff12cdf..447274ceb 100644 --- a/src/lib/components/chat/ShareChatModal.svelte +++ b/src/lib/components/chat/ShareChatModal.svelte @@ -134,11 +134,36 @@