diff --git a/.env.example b/.env.example index 2d782fce1..c38bf88bf 100644 --- a/.env.example +++ b/.env.example @@ -10,8 +10,4 @@ OPENAI_API_KEY='' # DO NOT TRACK SCARF_NO_ANALYTICS=true DO_NOT_TRACK=true -ANONYMIZED_TELEMETRY=false - -# Use locally bundled version of the LiteLLM cost map json -# to avoid repetitive startup connections -LITELLM_LOCAL_MODEL_COST_MAP="True" \ No newline at end of file +ANONYMIZED_TELEMETRY=false \ No newline at end of file diff --git a/Dockerfile b/Dockerfile index e3b2e64d2..be5c1da41 100644 --- a/Dockerfile +++ b/Dockerfile @@ -63,11 +63,6 @@ ENV OPENAI_API_KEY="" \ DO_NOT_TRACK=true \ ANONYMIZED_TELEMETRY=false -# Use locally bundled version of the LiteLLM cost map json -# to avoid repetitive startup connections -ENV LITELLM_LOCAL_MODEL_COST_MAP="True" - - #### Other models ######################################################### ## whisper TTS model settings ## ENV WHISPER_MODEL="base" \ @@ -87,10 +82,10 @@ WORKDIR /app/backend ENV HOME /root # Create user and group if not root RUN if [ $UID -ne 0 ]; then \ - if [ $GID -ne 0 ]; then \ - addgroup --gid $GID app; \ - fi; \ - adduser --uid $UID --gid $GID --home $HOME --disabled-password --no-create-home app; \ + if [ $GID -ne 0 ]; then \ + addgroup --gid $GID app; \ + fi; \ + adduser --uid $UID --gid $GID --home $HOME --disabled-password --no-create-home app; \ fi RUN mkdir -p $HOME/.cache/chroma diff --git a/backend/apps/litellm/main.py b/backend/apps/litellm/main.py deleted file mode 100644 index 6a355038b..000000000 --- a/backend/apps/litellm/main.py +++ /dev/null @@ -1,379 +0,0 @@ -import sys -from contextlib import asynccontextmanager - -from fastapi import FastAPI, Depends, HTTPException -from fastapi.routing import APIRoute -from fastapi.middleware.cors import CORSMiddleware - -import logging -from fastapi import FastAPI, Request, Depends, status, Response -from fastapi.responses import JSONResponse - -from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint -from starlette.responses import StreamingResponse -import json -import time -import requests - -from pydantic import BaseModel, ConfigDict -from typing import Optional, List - -from utils.utils import get_verified_user, get_current_user, get_admin_user -from config import SRC_LOG_LEVELS, ENV -from constants import MESSAGES - -import os - -log = logging.getLogger(__name__) -log.setLevel(SRC_LOG_LEVELS["LITELLM"]) - - -from config import ( - ENABLE_LITELLM, - ENABLE_MODEL_FILTER, - MODEL_FILTER_LIST, - DATA_DIR, - LITELLM_PROXY_PORT, - LITELLM_PROXY_HOST, -) - -import warnings - -warnings.simplefilter("ignore") - -from litellm.utils import get_llm_provider - -import asyncio -import subprocess -import yaml - - -@asynccontextmanager -async def lifespan(app: FastAPI): - log.info("startup_event") - # TODO: Check config.yaml file and create one - asyncio.create_task(start_litellm_background()) - yield - - -app = FastAPI(lifespan=lifespan) - -origins = ["*"] - -app.add_middleware( - CORSMiddleware, - allow_origins=origins, - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) - - -LITELLM_CONFIG_DIR = f"{DATA_DIR}/litellm/config.yaml" - -with open(LITELLM_CONFIG_DIR, "r") as file: - litellm_config = yaml.safe_load(file) - - -app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER.value -app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST.value - - -app.state.ENABLE = ENABLE_LITELLM -app.state.CONFIG = litellm_config - -# Global variable to store the subprocess reference -background_process = None - -CONFLICT_ENV_VARS = [ - # Uvicorn uses PORT, so LiteLLM might use it as well - "PORT", - # LiteLLM uses DATABASE_URL for Prisma connections - "DATABASE_URL", -] - - -async def run_background_process(command): - global background_process - log.info("run_background_process") - - try: - # Log the command to be executed - log.info(f"Executing command: {command}") - # Filter environment variables known to conflict with litellm - env = {k: v for k, v in os.environ.items() if k not in CONFLICT_ENV_VARS} - # Execute the command and create a subprocess - process = await asyncio.create_subprocess_exec( - *command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env - ) - background_process = process - log.info("Subprocess started successfully.") - - # Capture STDERR for debugging purposes - stderr_output = await process.stderr.read() - stderr_text = stderr_output.decode().strip() - if stderr_text: - log.info(f"Subprocess STDERR: {stderr_text}") - - # log.info output line by line - async for line in process.stdout: - log.info(line.decode().strip()) - - # Wait for the process to finish - returncode = await process.wait() - log.info(f"Subprocess exited with return code {returncode}") - except Exception as e: - log.error(f"Failed to start subprocess: {e}") - raise # Optionally re-raise the exception if you want it to propagate - - -async def start_litellm_background(): - log.info("start_litellm_background") - # Command to run in the background - command = [ - "litellm", - "--port", - str(LITELLM_PROXY_PORT), - "--host", - LITELLM_PROXY_HOST, - "--telemetry", - "False", - "--config", - LITELLM_CONFIG_DIR, - ] - - await run_background_process(command) - - -async def shutdown_litellm_background(): - log.info("shutdown_litellm_background") - global background_process - if background_process: - background_process.terminate() - await background_process.wait() # Ensure the process has terminated - log.info("Subprocess terminated") - background_process = None - - -@app.get("/") -async def get_status(): - return {"status": True} - - -async def restart_litellm(): - """ - Endpoint to restart the litellm background service. - """ - log.info("Requested restart of litellm service.") - try: - # Shut down the existing process if it is running - await shutdown_litellm_background() - log.info("litellm service shutdown complete.") - - # Restart the background service - - asyncio.create_task(start_litellm_background()) - log.info("litellm service restart complete.") - - return { - "status": "success", - "message": "litellm service restarted successfully.", - } - except Exception as e: - log.info(f"Error restarting litellm service: {e}") - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e) - ) - - -@app.get("/restart") -async def restart_litellm_handler(user=Depends(get_admin_user)): - return await restart_litellm() - - -@app.get("/config") -async def get_config(user=Depends(get_admin_user)): - return app.state.CONFIG - - -class LiteLLMConfigForm(BaseModel): - general_settings: Optional[dict] = None - litellm_settings: Optional[dict] = None - model_list: Optional[List[dict]] = None - router_settings: Optional[dict] = None - - model_config = ConfigDict(protected_namespaces=()) - - -@app.post("/config/update") -async def update_config(form_data: LiteLLMConfigForm, user=Depends(get_admin_user)): - app.state.CONFIG = form_data.model_dump(exclude_none=True) - - with open(LITELLM_CONFIG_DIR, "w") as file: - yaml.dump(app.state.CONFIG, file) - - await restart_litellm() - return app.state.CONFIG - - -@app.get("/models") -@app.get("/v1/models") -async def get_models(user=Depends(get_current_user)): - - if app.state.ENABLE: - while not background_process: - await asyncio.sleep(0.1) - - url = f"http://localhost:{LITELLM_PROXY_PORT}/v1" - r = None - try: - r = requests.request(method="GET", url=f"{url}/models") - r.raise_for_status() - - data = r.json() - - if app.state.ENABLE_MODEL_FILTER: - if user and user.role == "user": - data["data"] = list( - filter( - lambda model: model["id"] in app.state.MODEL_FILTER_LIST, - data["data"], - ) - ) - - return data - except Exception as e: - - log.exception(e) - error_detail = "Open WebUI: Server Connection Error" - if r is not None: - try: - res = r.json() - if "error" in res: - error_detail = f"External: {res['error']}" - except: - error_detail = f"External: {e}" - - return { - "data": [ - { - "id": model["model_name"], - "object": "model", - "created": int(time.time()), - "owned_by": "openai", - } - for model in app.state.CONFIG["model_list"] - ], - "object": "list", - } - else: - return { - "data": [], - "object": "list", - } - - -@app.get("/model/info") -async def get_model_list(user=Depends(get_admin_user)): - return {"data": app.state.CONFIG["model_list"]} - - -class AddLiteLLMModelForm(BaseModel): - model_name: str - litellm_params: dict - - model_config = ConfigDict(protected_namespaces=()) - - -@app.post("/model/new") -async def add_model_to_config( - form_data: AddLiteLLMModelForm, user=Depends(get_admin_user) -): - try: - get_llm_provider(model=form_data.model_name) - app.state.CONFIG["model_list"].append(form_data.model_dump()) - - with open(LITELLM_CONFIG_DIR, "w") as file: - yaml.dump(app.state.CONFIG, file) - - await restart_litellm() - - return {"message": MESSAGES.MODEL_ADDED(form_data.model_name)} - except Exception as e: - print(e) - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e) - ) - - -class DeleteLiteLLMModelForm(BaseModel): - id: str - - -@app.post("/model/delete") -async def delete_model_from_config( - form_data: DeleteLiteLLMModelForm, user=Depends(get_admin_user) -): - app.state.CONFIG["model_list"] = [ - model - for model in app.state.CONFIG["model_list"] - if model["model_name"] != form_data.id - ] - - with open(LITELLM_CONFIG_DIR, "w") as file: - yaml.dump(app.state.CONFIG, file) - - await restart_litellm() - - return {"message": MESSAGES.MODEL_DELETED(form_data.id)} - - -@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) -async def proxy(path: str, request: Request, user=Depends(get_verified_user)): - body = await request.body() - - url = f"http://localhost:{LITELLM_PROXY_PORT}" - - target_url = f"{url}/{path}" - - headers = {} - # headers["Authorization"] = f"Bearer {key}" - headers["Content-Type"] = "application/json" - - r = None - - try: - r = requests.request( - method=request.method, - url=target_url, - data=body, - headers=headers, - stream=True, - ) - - r.raise_for_status() - - # Check if response is SSE - if "text/event-stream" in r.headers.get("Content-Type", ""): - return StreamingResponse( - r.iter_content(chunk_size=8192), - status_code=r.status_code, - headers=dict(r.headers), - ) - else: - response_data = r.json() - return response_data - except Exception as e: - log.exception(e) - error_detail = "Open WebUI: Server Connection Error" - if r is not None: - try: - res = r.json() - if "error" in res: - error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}" - except: - error_detail = f"External: {e}" - - raise HTTPException( - status_code=r.status_code if r else 500, detail=error_detail - ) diff --git a/backend/apps/ollama/main.py b/backend/apps/ollama/main.py index fb8a35a17..01e127074 100644 --- a/backend/apps/ollama/main.py +++ b/backend/apps/ollama/main.py @@ -29,8 +29,8 @@ import time from urllib.parse import urlparse from typing import Optional, List, Union - -from apps.web.models.users import Users +from apps.webui.models.models import Models +from apps.webui.models.users import Users from constants import ERROR_MESSAGES from utils.utils import ( decode_token, @@ -39,6 +39,8 @@ from utils.utils import ( get_admin_user, ) +from utils.models import get_model_id_from_custom_model_id + from config import ( SRC_LOG_LEVELS, @@ -68,7 +70,6 @@ app.state.config = AppConfig() app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST - app.state.config.ENABLE_OLLAMA_API = ENABLE_OLLAMA_API app.state.config.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS app.state.MODELS = {} @@ -305,6 +306,9 @@ async def pull_model( r = None + # Admin should be able to pull models from any source + payload = {**form_data.model_dump(exclude_none=True), "insecure": True} + def get_request(): nonlocal url nonlocal r @@ -332,7 +336,7 @@ async def pull_model( r = requests.request( method="POST", url=f"{url}/api/pull", - data=form_data.model_dump_json(exclude_none=True).encode(), + data=json.dumps(payload), stream=True, ) @@ -875,14 +879,93 @@ async def generate_chat_completion( user=Depends(get_verified_user), ): + log.debug( + "form_data.model_dump_json(exclude_none=True).encode(): {0} ".format( + form_data.model_dump_json(exclude_none=True).encode() + ) + ) + + payload = { + **form_data.model_dump(exclude_none=True), + } + + model_id = form_data.model + model_info = Models.get_model_by_id(model_id) + + if model_info: + print(model_info) + if model_info.base_model_id: + payload["model"] = model_info.base_model_id + + model_info.params = model_info.params.model_dump() + + if model_info.params: + payload["options"] = {} + + payload["options"]["mirostat"] = model_info.params.get("mirostat", None) + payload["options"]["mirostat_eta"] = model_info.params.get( + "mirostat_eta", None + ) + payload["options"]["mirostat_tau"] = model_info.params.get( + "mirostat_tau", None + ) + payload["options"]["num_ctx"] = model_info.params.get("num_ctx", None) + + payload["options"]["repeat_last_n"] = model_info.params.get( + "repeat_last_n", None + ) + payload["options"]["repeat_penalty"] = model_info.params.get( + "frequency_penalty", None + ) + + payload["options"]["temperature"] = model_info.params.get( + "temperature", None + ) + payload["options"]["seed"] = model_info.params.get("seed", None) + + payload["options"]["stop"] = ( + [ + bytes(stop, "utf-8").decode("unicode_escape") + for stop in model_info.params["stop"] + ] + if model_info.params.get("stop", None) + else None + ) + + payload["options"]["tfs_z"] = model_info.params.get("tfs_z", None) + + payload["options"]["num_predict"] = model_info.params.get( + "max_tokens", None + ) + payload["options"]["top_k"] = model_info.params.get("top_k", None) + + payload["options"]["top_p"] = model_info.params.get("top_p", None) + + if model_info.params.get("system", None): + # Check if the payload already has a system message + # If not, add a system message to the payload + if payload.get("messages"): + for message in payload["messages"]: + if message.get("role") == "system": + message["content"] = ( + model_info.params.get("system", None) + message["content"] + ) + break + else: + payload["messages"].insert( + 0, + { + "role": "system", + "content": model_info.params.get("system", None), + }, + ) + if url_idx == None: - model = form_data.model + if ":" not in payload["model"]: + payload["model"] = f"{payload['model']}:latest" - if ":" not in model: - model = f"{model}:latest" - - if model in app.state.MODELS: - url_idx = random.choice(app.state.MODELS[model]["urls"]) + if payload["model"] in app.state.MODELS: + url_idx = random.choice(app.state.MODELS[payload["model"]]["urls"]) else: raise HTTPException( status_code=400, @@ -892,16 +975,12 @@ async def generate_chat_completion( url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") + print(payload) + r = None - log.debug( - "form_data.model_dump_json(exclude_none=True).encode(): {0} ".format( - form_data.model_dump_json(exclude_none=True).encode() - ) - ) - def get_request(): - nonlocal form_data + nonlocal payload nonlocal r request_id = str(uuid.uuid4()) @@ -910,7 +989,7 @@ async def generate_chat_completion( def stream_content(): try: - if form_data.stream: + if payload.get("stream", None): yield json.dumps({"id": request_id, "done": False}) + "\n" for chunk in r.iter_content(chunk_size=8192): @@ -928,7 +1007,7 @@ async def generate_chat_completion( r = requests.request( method="POST", url=f"{url}/api/chat", - data=form_data.model_dump_json(exclude_none=True).encode(), + data=json.dumps(payload), stream=True, ) @@ -984,14 +1063,62 @@ async def generate_openai_chat_completion( user=Depends(get_verified_user), ): + payload = { + **form_data.model_dump(exclude_none=True), + } + + model_id = form_data.model + model_info = Models.get_model_by_id(model_id) + + if model_info: + print(model_info) + if model_info.base_model_id: + payload["model"] = model_info.base_model_id + + model_info.params = model_info.params.model_dump() + + if model_info.params: + payload["temperature"] = model_info.params.get("temperature", None) + payload["top_p"] = model_info.params.get("top_p", None) + payload["max_tokens"] = model_info.params.get("max_tokens", None) + payload["frequency_penalty"] = model_info.params.get( + "frequency_penalty", None + ) + payload["seed"] = model_info.params.get("seed", None) + payload["stop"] = ( + [ + bytes(stop, "utf-8").decode("unicode_escape") + for stop in model_info.params["stop"] + ] + if model_info.params.get("stop", None) + else None + ) + + if model_info.params.get("system", None): + # Check if the payload already has a system message + # If not, add a system message to the payload + if payload.get("messages"): + for message in payload["messages"]: + if message.get("role") == "system": + message["content"] = ( + model_info.params.get("system", None) + message["content"] + ) + break + else: + payload["messages"].insert( + 0, + { + "role": "system", + "content": model_info.params.get("system", None), + }, + ) + if url_idx == None: - model = form_data.model + if ":" not in payload["model"]: + payload["model"] = f"{payload['model']}:latest" - if ":" not in model: - model = f"{model}:latest" - - if model in app.state.MODELS: - url_idx = random.choice(app.state.MODELS[model]["urls"]) + if payload["model"] in app.state.MODELS: + url_idx = random.choice(app.state.MODELS[payload["model"]]["urls"]) else: raise HTTPException( status_code=400, @@ -1004,7 +1131,7 @@ async def generate_openai_chat_completion( r = None def get_request(): - nonlocal form_data + nonlocal payload nonlocal r request_id = str(uuid.uuid4()) @@ -1013,7 +1140,7 @@ async def generate_openai_chat_completion( def stream_content(): try: - if form_data.stream: + if payload.get("stream"): yield json.dumps( {"request_id": request_id, "done": False} ) + "\n" @@ -1033,7 +1160,7 @@ async def generate_openai_chat_completion( r = requests.request( method="POST", url=f"{url}/v1/chat/completions", - data=form_data.model_dump_json(exclude_none=True).encode(), + data=json.dumps(payload), stream=True, ) diff --git a/backend/apps/openai/main.py b/backend/apps/openai/main.py index 6659ebfcf..74ac18a12 100644 --- a/backend/apps/openai/main.py +++ b/backend/apps/openai/main.py @@ -10,8 +10,8 @@ import logging from pydantic import BaseModel - -from apps.web.models.users import Users +from apps.webui.models.models import Models +from apps.webui.models.users import Users from constants import ERROR_MESSAGES from utils.utils import ( decode_token, @@ -53,7 +53,6 @@ app.state.config = AppConfig() app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST - app.state.config.ENABLE_OPENAI_API = ENABLE_OPENAI_API app.state.config.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS app.state.config.OPENAI_API_KEYS = OPENAI_API_KEYS @@ -206,7 +205,13 @@ def merge_models_lists(model_lists): if models is not None and "error" not in models: merged_list.extend( [ - {**model, "urlIdx": idx} + { + **model, + "name": model.get("name", model["id"]), + "owned_by": "openai", + "openai": model, + "urlIdx": idx, + } for model in models if "api.openai.com" not in app.state.config.OPENAI_API_BASE_URLS[idx] @@ -252,7 +257,7 @@ async def get_all_models(): log.info(f"models: {models}") app.state.MODELS = {model["id"]: model for model in models["data"]} - return models + return models @app.get("/models") @@ -310,39 +315,93 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)): body = await request.body() # TODO: Remove below after gpt-4-vision fix from Open AI # Try to decode the body of the request from bytes to a UTF-8 string (Require add max_token to fix gpt-4-vision) + + payload = None + try: - body = body.decode("utf-8") - body = json.loads(body) + if "chat/completions" in path: + body = body.decode("utf-8") + body = json.loads(body) - model = app.state.MODELS[body.get("model")] + payload = {**body} - idx = model["urlIdx"] + model_id = body.get("model") + model_info = Models.get_model_by_id(model_id) - if "pipeline" in model and model.get("pipeline"): - body["user"] = {"name": user.name, "id": user.id} - body["title"] = ( - True if body["stream"] == False and body["max_tokens"] == 50 else False - ) + if model_info: + print(model_info) + if model_info.base_model_id: + payload["model"] = model_info.base_model_id - # Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000 - # This is a workaround until OpenAI fixes the issue with this model - if body.get("model") == "gpt-4-vision-preview": - if "max_tokens" not in body: - body["max_tokens"] = 4000 - log.debug("Modified body_dict:", body) + model_info.params = model_info.params.model_dump() - # Fix for ChatGPT calls failing because the num_ctx key is in body - if "num_ctx" in body: - # If 'num_ctx' is in the dictionary, delete it - # Leaving it there generates an error with the - # OpenAI API (Feb 2024) - del body["num_ctx"] + if model_info.params: + payload["temperature"] = model_info.params.get("temperature", None) + payload["top_p"] = model_info.params.get("top_p", None) + payload["max_tokens"] = model_info.params.get("max_tokens", None) + payload["frequency_penalty"] = model_info.params.get( + "frequency_penalty", None + ) + payload["seed"] = model_info.params.get("seed", None) + payload["stop"] = ( + [ + bytes(stop, "utf-8").decode("unicode_escape") + for stop in model_info.params["stop"] + ] + if model_info.params.get("stop", None) + else None + ) + + if model_info.params.get("system", None): + # Check if the payload already has a system message + # If not, add a system message to the payload + if payload.get("messages"): + for message in payload["messages"]: + if message.get("role") == "system": + message["content"] = ( + model_info.params.get("system", None) + + message["content"] + ) + break + else: + payload["messages"].insert( + 0, + { + "role": "system", + "content": model_info.params.get("system", None), + }, + ) + else: + pass + + print(app.state.MODELS) + model = app.state.MODELS[payload.get("model")] + + idx = model["urlIdx"] + + if "pipeline" in model and model.get("pipeline"): + payload["user"] = {"name": user.name, "id": user.id} + payload["title"] = ( + True + if payload["stream"] == False and payload["max_tokens"] == 50 + else False + ) + + # Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000 + # This is a workaround until OpenAI fixes the issue with this model + if payload.get("model") == "gpt-4-vision-preview": + if "max_tokens" not in payload: + payload["max_tokens"] = 4000 + log.debug("Modified payload:", payload) + + # Convert the modified body back to JSON + payload = json.dumps(payload) - # Convert the modified body back to JSON - body = json.dumps(body) except json.JSONDecodeError as e: log.error("Error loading request body into a dictionary:", e) + print(payload) + url = app.state.config.OPENAI_API_BASE_URLS[idx] key = app.state.config.OPENAI_API_KEYS[idx] @@ -361,7 +420,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)): r = requests.request( method=request.method, url=target_url, - data=body, + data=payload if payload else body, headers=headers, stream=True, ) diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index f08d81a3b..d04c256d7 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -46,7 +46,7 @@ import json import sentence_transformers -from apps.web.models.documents import ( +from apps.webui.models.documents import ( Documents, DocumentForm, DocumentResponse, diff --git a/backend/apps/web/models/modelfiles.py b/backend/apps/web/models/modelfiles.py deleted file mode 100644 index 1d60d7c55..000000000 --- a/backend/apps/web/models/modelfiles.py +++ /dev/null @@ -1,136 +0,0 @@ -from pydantic import BaseModel -from peewee import * -from playhouse.shortcuts import model_to_dict -from typing import List, Union, Optional -import time - -from utils.utils import decode_token -from utils.misc import get_gravatar_url - -from apps.web.internal.db import DB - -import json - -#################### -# Modelfile DB Schema -#################### - - -class Modelfile(Model): - tag_name = CharField(unique=True) - user_id = CharField() - modelfile = TextField() - timestamp = BigIntegerField() - - class Meta: - database = DB - - -class ModelfileModel(BaseModel): - tag_name: str - user_id: str - modelfile: str - timestamp: int # timestamp in epoch - - -#################### -# Forms -#################### - - -class ModelfileForm(BaseModel): - modelfile: dict - - -class ModelfileTagNameForm(BaseModel): - tag_name: str - - -class ModelfileUpdateForm(ModelfileForm, ModelfileTagNameForm): - pass - - -class ModelfileResponse(BaseModel): - tag_name: str - user_id: str - modelfile: dict - timestamp: int # timestamp in epoch - - -class ModelfilesTable: - - def __init__(self, db): - self.db = db - self.db.create_tables([Modelfile]) - - def insert_new_modelfile( - self, user_id: str, form_data: ModelfileForm - ) -> Optional[ModelfileModel]: - if "tagName" in form_data.modelfile: - modelfile = ModelfileModel( - **{ - "user_id": user_id, - "tag_name": form_data.modelfile["tagName"], - "modelfile": json.dumps(form_data.modelfile), - "timestamp": int(time.time()), - } - ) - - try: - result = Modelfile.create(**modelfile.model_dump()) - if result: - return modelfile - else: - return None - except: - return None - - else: - return None - - def get_modelfile_by_tag_name(self, tag_name: str) -> Optional[ModelfileModel]: - try: - modelfile = Modelfile.get(Modelfile.tag_name == tag_name) - return ModelfileModel(**model_to_dict(modelfile)) - except: - return None - - def get_modelfiles(self, skip: int = 0, limit: int = 50) -> List[ModelfileResponse]: - return [ - ModelfileResponse( - **{ - **model_to_dict(modelfile), - "modelfile": json.loads(modelfile.modelfile), - } - ) - for modelfile in Modelfile.select() - # .limit(limit).offset(skip) - ] - - def update_modelfile_by_tag_name( - self, tag_name: str, modelfile: dict - ) -> Optional[ModelfileModel]: - try: - query = Modelfile.update( - modelfile=json.dumps(modelfile), - timestamp=int(time.time()), - ).where(Modelfile.tag_name == tag_name) - - query.execute() - - modelfile = Modelfile.get(Modelfile.tag_name == tag_name) - return ModelfileModel(**model_to_dict(modelfile)) - except: - return None - - def delete_modelfile_by_tag_name(self, tag_name: str) -> bool: - try: - query = Modelfile.delete().where((Modelfile.tag_name == tag_name)) - query.execute() # Remove the rows, return number of rows removed. - - return True - except: - return False - - -Modelfiles = ModelfilesTable(DB) diff --git a/backend/apps/web/routers/modelfiles.py b/backend/apps/web/routers/modelfiles.py deleted file mode 100644 index 3cdbf8a74..000000000 --- a/backend/apps/web/routers/modelfiles.py +++ /dev/null @@ -1,124 +0,0 @@ -from fastapi import Depends, FastAPI, HTTPException, status -from datetime import datetime, timedelta -from typing import List, Union, Optional - -from fastapi import APIRouter -from pydantic import BaseModel -import json -from apps.web.models.modelfiles import ( - Modelfiles, - ModelfileForm, - ModelfileTagNameForm, - ModelfileUpdateForm, - ModelfileResponse, -) - -from utils.utils import get_current_user, get_admin_user -from constants import ERROR_MESSAGES - -router = APIRouter() - -############################ -# GetModelfiles -############################ - - -@router.get("/", response_model=List[ModelfileResponse]) -async def get_modelfiles( - skip: int = 0, limit: int = 50, user=Depends(get_current_user) -): - return Modelfiles.get_modelfiles(skip, limit) - - -############################ -# CreateNewModelfile -############################ - - -@router.post("/create", response_model=Optional[ModelfileResponse]) -async def create_new_modelfile(form_data: ModelfileForm, user=Depends(get_admin_user)): - modelfile = Modelfiles.insert_new_modelfile(user.id, form_data) - - if modelfile: - return ModelfileResponse( - **{ - **modelfile.model_dump(), - "modelfile": json.loads(modelfile.modelfile), - } - ) - else: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail=ERROR_MESSAGES.DEFAULT(), - ) - - -############################ -# GetModelfileByTagName -############################ - - -@router.post("/", response_model=Optional[ModelfileResponse]) -async def get_modelfile_by_tag_name( - form_data: ModelfileTagNameForm, user=Depends(get_current_user) -): - modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name) - - if modelfile: - return ModelfileResponse( - **{ - **modelfile.model_dump(), - "modelfile": json.loads(modelfile.modelfile), - } - ) - else: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail=ERROR_MESSAGES.NOT_FOUND, - ) - - -############################ -# UpdateModelfileByTagName -############################ - - -@router.post("/update", response_model=Optional[ModelfileResponse]) -async def update_modelfile_by_tag_name( - form_data: ModelfileUpdateForm, user=Depends(get_admin_user) -): - modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name) - if modelfile: - updated_modelfile = { - **json.loads(modelfile.modelfile), - **form_data.modelfile, - } - - modelfile = Modelfiles.update_modelfile_by_tag_name( - form_data.tag_name, updated_modelfile - ) - - return ModelfileResponse( - **{ - **modelfile.model_dump(), - "modelfile": json.loads(modelfile.modelfile), - } - ) - else: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail=ERROR_MESSAGES.ACCESS_PROHIBITED, - ) - - -############################ -# DeleteModelfileByTagName -############################ - - -@router.delete("/delete", response_model=bool) -async def delete_modelfile_by_tag_name( - form_data: ModelfileTagNameForm, user=Depends(get_admin_user) -): - result = Modelfiles.delete_modelfile_by_tag_name(form_data.tag_name) - return result diff --git a/backend/apps/web/internal/db.py b/backend/apps/webui/internal/db.py similarity index 67% rename from backend/apps/web/internal/db.py rename to backend/apps/webui/internal/db.py index a6051de50..0e7b1f95d 100644 --- a/backend/apps/web/internal/db.py +++ b/backend/apps/webui/internal/db.py @@ -1,3 +1,5 @@ +import json + from peewee import * from peewee_migrate import Router from playhouse.db_url import connect @@ -8,6 +10,16 @@ import logging log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["DB"]) + +class JSONField(TextField): + def db_value(self, value): + return json.dumps(value) + + def python_value(self, value): + if value is not None: + return json.loads(value) + + # Check if the file exists if os.path.exists(f"{DATA_DIR}/ollama.db"): # Rename the file @@ -19,7 +31,9 @@ else: DB = connect(DATABASE_URL) log.info(f"Connected to a {DB.__class__.__name__} database.") router = Router( - DB, migrate_dir=BACKEND_DIR / "apps" / "web" / "internal" / "migrations", logger=log + DB, + migrate_dir=BACKEND_DIR / "apps" / "webui" / "internal" / "migrations", + logger=log, ) router.run() DB.connect(reuse_if_open=True) diff --git a/backend/apps/web/internal/migrations/001_initial_schema.py b/backend/apps/webui/internal/migrations/001_initial_schema.py similarity index 100% rename from backend/apps/web/internal/migrations/001_initial_schema.py rename to backend/apps/webui/internal/migrations/001_initial_schema.py diff --git a/backend/apps/web/internal/migrations/002_add_local_sharing.py b/backend/apps/webui/internal/migrations/002_add_local_sharing.py similarity index 100% rename from backend/apps/web/internal/migrations/002_add_local_sharing.py rename to backend/apps/webui/internal/migrations/002_add_local_sharing.py diff --git a/backend/apps/web/internal/migrations/003_add_auth_api_key.py b/backend/apps/webui/internal/migrations/003_add_auth_api_key.py similarity index 100% rename from backend/apps/web/internal/migrations/003_add_auth_api_key.py rename to backend/apps/webui/internal/migrations/003_add_auth_api_key.py diff --git a/backend/apps/web/internal/migrations/004_add_archived.py b/backend/apps/webui/internal/migrations/004_add_archived.py similarity index 100% rename from backend/apps/web/internal/migrations/004_add_archived.py rename to backend/apps/webui/internal/migrations/004_add_archived.py diff --git a/backend/apps/web/internal/migrations/005_add_updated_at.py b/backend/apps/webui/internal/migrations/005_add_updated_at.py similarity index 100% rename from backend/apps/web/internal/migrations/005_add_updated_at.py rename to backend/apps/webui/internal/migrations/005_add_updated_at.py diff --git a/backend/apps/web/internal/migrations/006_migrate_timestamps_and_charfields.py b/backend/apps/webui/internal/migrations/006_migrate_timestamps_and_charfields.py similarity index 100% rename from backend/apps/web/internal/migrations/006_migrate_timestamps_and_charfields.py rename to backend/apps/webui/internal/migrations/006_migrate_timestamps_and_charfields.py diff --git a/backend/apps/web/internal/migrations/007_add_user_last_active_at.py b/backend/apps/webui/internal/migrations/007_add_user_last_active_at.py similarity index 100% rename from backend/apps/web/internal/migrations/007_add_user_last_active_at.py rename to backend/apps/webui/internal/migrations/007_add_user_last_active_at.py diff --git a/backend/apps/web/internal/migrations/008_add_memory.py b/backend/apps/webui/internal/migrations/008_add_memory.py similarity index 100% rename from backend/apps/web/internal/migrations/008_add_memory.py rename to backend/apps/webui/internal/migrations/008_add_memory.py diff --git a/backend/apps/webui/internal/migrations/009_add_models.py b/backend/apps/webui/internal/migrations/009_add_models.py new file mode 100644 index 000000000..548ec7cdc --- /dev/null +++ b/backend/apps/webui/internal/migrations/009_add_models.py @@ -0,0 +1,61 @@ +"""Peewee migrations -- 009_add_models.py. + +Some examples (model - class or model name):: + + > Model = migrator.orm['table_name'] # Return model in current state by name + > Model = migrator.ModelClass # Return model in current state by name + + > migrator.sql(sql) # Run custom SQL + > migrator.run(func, *args, **kwargs) # Run python function with the given args + > migrator.create_model(Model) # Create a model (could be used as decorator) + > migrator.remove_model(model, cascade=True) # Remove a model + > migrator.add_fields(model, **fields) # Add fields to a model + > migrator.change_fields(model, **fields) # Change fields + > migrator.remove_fields(model, *field_names, cascade=True) + > migrator.rename_field(model, old_field_name, new_field_name) + > migrator.rename_table(model, new_table_name) + > migrator.add_index(model, *col_names, unique=False) + > migrator.add_not_null(model, *field_names) + > migrator.add_default(model, field_name, default) + > migrator.add_constraint(model, name, sql) + > migrator.drop_index(model, *col_names) + > migrator.drop_not_null(model, *field_names) + > migrator.drop_constraints(model, *constraints) + +""" + +from contextlib import suppress + +import peewee as pw +from peewee_migrate import Migrator + + +with suppress(ImportError): + import playhouse.postgres_ext as pw_pext + + +def migrate(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your migrations here.""" + + @migrator.create_model + class Model(pw.Model): + id = pw.TextField(unique=True) + user_id = pw.TextField() + base_model_id = pw.TextField(null=True) + + name = pw.TextField() + + meta = pw.TextField() + params = pw.TextField() + + created_at = pw.BigIntegerField(null=False) + updated_at = pw.BigIntegerField(null=False) + + class Meta: + table_name = "model" + + +def rollback(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your rollback migrations here.""" + + migrator.remove_model("model") diff --git a/backend/apps/webui/internal/migrations/010_migrate_modelfiles_to_models.py b/backend/apps/webui/internal/migrations/010_migrate_modelfiles_to_models.py new file mode 100644 index 000000000..2ef814c06 --- /dev/null +++ b/backend/apps/webui/internal/migrations/010_migrate_modelfiles_to_models.py @@ -0,0 +1,130 @@ +"""Peewee migrations -- 009_add_models.py. + +Some examples (model - class or model name):: + + > Model = migrator.orm['table_name'] # Return model in current state by name + > Model = migrator.ModelClass # Return model in current state by name + + > migrator.sql(sql) # Run custom SQL + > migrator.run(func, *args, **kwargs) # Run python function with the given args + > migrator.create_model(Model) # Create a model (could be used as decorator) + > migrator.remove_model(model, cascade=True) # Remove a model + > migrator.add_fields(model, **fields) # Add fields to a model + > migrator.change_fields(model, **fields) # Change fields + > migrator.remove_fields(model, *field_names, cascade=True) + > migrator.rename_field(model, old_field_name, new_field_name) + > migrator.rename_table(model, new_table_name) + > migrator.add_index(model, *col_names, unique=False) + > migrator.add_not_null(model, *field_names) + > migrator.add_default(model, field_name, default) + > migrator.add_constraint(model, name, sql) + > migrator.drop_index(model, *col_names) + > migrator.drop_not_null(model, *field_names) + > migrator.drop_constraints(model, *constraints) + +""" + +from contextlib import suppress + +import peewee as pw +from peewee_migrate import Migrator +import json + +from utils.misc import parse_ollama_modelfile + +with suppress(ImportError): + import playhouse.postgres_ext as pw_pext + + +def migrate(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your migrations here.""" + + # Fetch data from 'modelfile' table and insert into 'model' table + migrate_modelfile_to_model(migrator, database) + # Drop the 'modelfile' table + migrator.remove_model("modelfile") + + +def migrate_modelfile_to_model(migrator: Migrator, database: pw.Database): + ModelFile = migrator.orm["modelfile"] + Model = migrator.orm["model"] + + modelfiles = ModelFile.select() + + for modelfile in modelfiles: + # Extract and transform data in Python + + modelfile.modelfile = json.loads(modelfile.modelfile) + meta = json.dumps( + { + "description": modelfile.modelfile.get("desc"), + "profile_image_url": modelfile.modelfile.get("imageUrl"), + "ollama": {"modelfile": modelfile.modelfile.get("content")}, + "suggestion_prompts": modelfile.modelfile.get("suggestionPrompts"), + "categories": modelfile.modelfile.get("categories"), + "user": {**modelfile.modelfile.get("user", {}), "community": True}, + } + ) + + info = parse_ollama_modelfile(modelfile.modelfile.get("content")) + + # Insert the processed data into the 'model' table + Model.create( + id=f"ollama-{modelfile.tag_name}", + user_id=modelfile.user_id, + base_model_id=info.get("base_model_id"), + name=modelfile.modelfile.get("title"), + meta=meta, + params=json.dumps(info.get("params", {})), + created_at=modelfile.timestamp, + updated_at=modelfile.timestamp, + ) + + +def rollback(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your rollback migrations here.""" + + recreate_modelfile_table(migrator, database) + move_data_back_to_modelfile(migrator, database) + migrator.remove_model("model") + + +def recreate_modelfile_table(migrator: Migrator, database: pw.Database): + query = """ + CREATE TABLE IF NOT EXISTS modelfile ( + user_id TEXT, + tag_name TEXT, + modelfile JSON, + timestamp BIGINT + ) + """ + migrator.sql(query) + + +def move_data_back_to_modelfile(migrator: Migrator, database: pw.Database): + Model = migrator.orm["model"] + Modelfile = migrator.orm["modelfile"] + + models = Model.select() + + for model in models: + # Extract and transform data in Python + meta = json.loads(model.meta) + + modelfile_data = { + "title": model.name, + "desc": meta.get("description"), + "imageUrl": meta.get("profile_image_url"), + "content": meta.get("ollama", {}).get("modelfile"), + "suggestionPrompts": meta.get("suggestion_prompts"), + "categories": meta.get("categories"), + "user": {k: v for k, v in meta.get("user", {}).items() if k != "community"}, + } + + # Insert the processed data back into the 'modelfile' table + Modelfile.create( + user_id=model.user_id, + tag_name=model.id, + modelfile=modelfile_data, + timestamp=model.created_at, + ) diff --git a/backend/apps/web/internal/migrations/README.md b/backend/apps/webui/internal/migrations/README.md similarity index 84% rename from backend/apps/web/internal/migrations/README.md rename to backend/apps/webui/internal/migrations/README.md index 63d92e802..260214113 100644 --- a/backend/apps/web/internal/migrations/README.md +++ b/backend/apps/webui/internal/migrations/README.md @@ -14,7 +14,7 @@ You will need to create a migration file to ensure that existing databases are u 2. Make your changes to the models. 3. From the `backend` directory, run the following command: ```bash - pw_migrate create --auto --auto-source apps.web.models --database sqlite:///${SQLITE_DB} --directory apps/web/internal/migrations ${MIGRATION_NAME} + pw_migrate create --auto --auto-source apps.webui.models --database sqlite:///${SQLITE_DB} --directory apps/web/internal/migrations ${MIGRATION_NAME} ``` - `$SQLITE_DB` should be the path to the database file. - `$MIGRATION_NAME` should be a descriptive name for the migration. diff --git a/backend/apps/web/main.py b/backend/apps/webui/main.py similarity index 93% rename from backend/apps/web/main.py rename to backend/apps/webui/main.py index 24f6b8b51..e19382481 100644 --- a/backend/apps/web/main.py +++ b/backend/apps/webui/main.py @@ -1,12 +1,12 @@ from fastapi import FastAPI, Depends from fastapi.routing import APIRoute from fastapi.middleware.cors import CORSMiddleware -from apps.web.routers import ( +from apps.webui.routers import ( auths, users, chats, documents, - modelfiles, + models, prompts, configs, memories, @@ -40,6 +40,9 @@ app.state.config.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS app.state.config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE app.state.config.USER_PERMISSIONS = USER_PERMISSIONS app.state.config.WEBHOOK_URL = WEBHOOK_URL + + +app.state.MODELS = {} app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER @@ -56,11 +59,10 @@ app.include_router(users.router, prefix="/users", tags=["users"]) app.include_router(chats.router, prefix="/chats", tags=["chats"]) app.include_router(documents.router, prefix="/documents", tags=["documents"]) -app.include_router(modelfiles.router, prefix="/modelfiles", tags=["modelfiles"]) +app.include_router(models.router, prefix="/models", tags=["models"]) app.include_router(prompts.router, prefix="/prompts", tags=["prompts"]) app.include_router(memories.router, prefix="/memories", tags=["memories"]) - app.include_router(configs.router, prefix="/configs", tags=["configs"]) app.include_router(utils.router, prefix="/utils", tags=["utils"]) diff --git a/backend/apps/web/models/auths.py b/backend/apps/webui/models/auths.py similarity index 98% rename from backend/apps/web/models/auths.py rename to backend/apps/webui/models/auths.py index dfa0c4395..e3b659e43 100644 --- a/backend/apps/web/models/auths.py +++ b/backend/apps/webui/models/auths.py @@ -5,10 +5,10 @@ import uuid import logging from peewee import * -from apps.web.models.users import UserModel, Users +from apps.webui.models.users import UserModel, Users from utils.utils import verify_password -from apps.web.internal.db import DB +from apps.webui.internal.db import DB from config import SRC_LOG_LEVELS diff --git a/backend/apps/web/models/chats.py b/backend/apps/webui/models/chats.py similarity index 88% rename from backend/apps/web/models/chats.py rename to backend/apps/webui/models/chats.py index 891151b94..d4597f16d 100644 --- a/backend/apps/web/models/chats.py +++ b/backend/apps/webui/models/chats.py @@ -7,7 +7,7 @@ import json import uuid import time -from apps.web.internal.db import DB +from apps.webui.internal.db import DB #################### # Chat DB Schema @@ -191,6 +191,20 @@ class ChatTable: except: return None + def archive_all_chats_by_user_id(self, user_id: str) -> bool: + try: + chats = self.get_chats_by_user_id(user_id) + for chat in chats: + query = Chat.update( + archived=True, + ).where(Chat.id == chat.id) + + query.execute() + + return True + except: + return False + def get_archived_chat_list_by_user_id( self, user_id: str, skip: int = 0, limit: int = 50 ) -> List[ChatModel]: @@ -205,17 +219,31 @@ class ChatTable: ] def get_chat_list_by_user_id( - self, user_id: str, skip: int = 0, limit: int = 50 + self, + user_id: str, + include_archived: bool = False, + skip: int = 0, + limit: int = 50, ) -> List[ChatModel]: - return [ - ChatModel(**model_to_dict(chat)) - for chat in Chat.select() - .where(Chat.archived == False) - .where(Chat.user_id == user_id) - .order_by(Chat.updated_at.desc()) - # .limit(limit) - # .offset(skip) - ] + if include_archived: + return [ + ChatModel(**model_to_dict(chat)) + for chat in Chat.select() + .where(Chat.user_id == user_id) + .order_by(Chat.updated_at.desc()) + # .limit(limit) + # .offset(skip) + ] + else: + return [ + ChatModel(**model_to_dict(chat)) + for chat in Chat.select() + .where(Chat.archived == False) + .where(Chat.user_id == user_id) + .order_by(Chat.updated_at.desc()) + # .limit(limit) + # .offset(skip) + ] def get_chat_list_by_chat_ids( self, chat_ids: List[str], skip: int = 0, limit: int = 50 diff --git a/backend/apps/web/models/documents.py b/backend/apps/webui/models/documents.py similarity index 99% rename from backend/apps/web/models/documents.py rename to backend/apps/webui/models/documents.py index 42b99596c..3b730535f 100644 --- a/backend/apps/web/models/documents.py +++ b/backend/apps/webui/models/documents.py @@ -8,7 +8,7 @@ import logging from utils.utils import decode_token from utils.misc import get_gravatar_url -from apps.web.internal.db import DB +from apps.webui.internal.db import DB import json diff --git a/backend/apps/web/models/memories.py b/backend/apps/webui/models/memories.py similarity index 97% rename from backend/apps/web/models/memories.py rename to backend/apps/webui/models/memories.py index 8382b3e52..70e5577e9 100644 --- a/backend/apps/web/models/memories.py +++ b/backend/apps/webui/models/memories.py @@ -3,8 +3,8 @@ from peewee import * from playhouse.shortcuts import model_to_dict from typing import List, Union, Optional -from apps.web.internal.db import DB -from apps.web.models.chats import Chats +from apps.webui.internal.db import DB +from apps.webui.models.chats import Chats import time import uuid diff --git a/backend/apps/webui/models/models.py b/backend/apps/webui/models/models.py new file mode 100644 index 000000000..851352398 --- /dev/null +++ b/backend/apps/webui/models/models.py @@ -0,0 +1,179 @@ +import json +import logging +from typing import Optional + +import peewee as pw +from peewee import * + +from playhouse.shortcuts import model_to_dict +from pydantic import BaseModel, ConfigDict + +from apps.webui.internal.db import DB, JSONField + +from typing import List, Union, Optional +from config import SRC_LOG_LEVELS + +import time + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["MODELS"]) + + +#################### +# Models DB Schema +#################### + + +# ModelParams is a model for the data stored in the params field of the Model table +class ModelParams(BaseModel): + model_config = ConfigDict(extra="allow") + pass + + +# ModelMeta is a model for the data stored in the meta field of the Model table +class ModelMeta(BaseModel): + profile_image_url: Optional[str] = "/favicon.png" + + description: Optional[str] = None + """ + User-facing description of the model. + """ + + capabilities: Optional[dict] = None + + model_config = ConfigDict(extra="allow") + + pass + + +class Model(pw.Model): + id = pw.TextField(unique=True) + """ + The model's id as used in the API. If set to an existing model, it will override the model. + """ + user_id = pw.TextField() + + base_model_id = pw.TextField(null=True) + """ + An optional pointer to the actual model that should be used when proxying requests. + """ + + name = pw.TextField() + """ + The human-readable display name of the model. + """ + + params = JSONField() + """ + Holds a JSON encoded blob of parameters, see `ModelParams`. + """ + + meta = JSONField() + """ + Holds a JSON encoded blob of metadata, see `ModelMeta`. + """ + + updated_at = BigIntegerField() + created_at = BigIntegerField() + + class Meta: + database = DB + + +class ModelModel(BaseModel): + id: str + user_id: str + base_model_id: Optional[str] = None + + name: str + params: ModelParams + meta: ModelMeta + + updated_at: int # timestamp in epoch + created_at: int # timestamp in epoch + + +#################### +# Forms +#################### + + +class ModelResponse(BaseModel): + id: str + name: str + meta: ModelMeta + updated_at: int # timestamp in epoch + created_at: int # timestamp in epoch + + +class ModelForm(BaseModel): + id: str + base_model_id: Optional[str] = None + name: str + meta: ModelMeta + params: ModelParams + + +class ModelsTable: + def __init__( + self, + db: pw.SqliteDatabase | pw.PostgresqlDatabase, + ): + self.db = db + self.db.create_tables([Model]) + + def insert_new_model( + self, form_data: ModelForm, user_id: str + ) -> Optional[ModelModel]: + model = ModelModel( + **{ + **form_data.model_dump(), + "user_id": user_id, + "created_at": int(time.time()), + "updated_at": int(time.time()), + } + ) + try: + result = Model.create(**model.model_dump()) + + if result: + return model + else: + return None + except Exception as e: + print(e) + return None + + def get_all_models(self) -> List[ModelModel]: + return [ModelModel(**model_to_dict(model)) for model in Model.select()] + + def get_model_by_id(self, id: str) -> Optional[ModelModel]: + try: + model = Model.get(Model.id == id) + return ModelModel(**model_to_dict(model)) + except: + return None + + def update_model_by_id(self, id: str, model: ModelForm) -> Optional[ModelModel]: + try: + # update only the fields that are present in the model + query = Model.update(**model.model_dump()).where(Model.id == id) + query.execute() + + model = Model.get(Model.id == id) + return ModelModel(**model_to_dict(model)) + except Exception as e: + print(e) + + return None + + def delete_model_by_id(self, id: str) -> bool: + try: + query = Model.delete().where(Model.id == id) + query.execute() + return True + except: + return False + + +Models = ModelsTable(DB) diff --git a/backend/apps/web/models/prompts.py b/backend/apps/webui/models/prompts.py similarity index 98% rename from backend/apps/web/models/prompts.py rename to backend/apps/webui/models/prompts.py index bc4e3e58b..c4ac6be14 100644 --- a/backend/apps/web/models/prompts.py +++ b/backend/apps/webui/models/prompts.py @@ -7,7 +7,7 @@ import time from utils.utils import decode_token from utils.misc import get_gravatar_url -from apps.web.internal.db import DB +from apps.webui.internal.db import DB import json diff --git a/backend/apps/web/models/tags.py b/backend/apps/webui/models/tags.py similarity index 99% rename from backend/apps/web/models/tags.py rename to backend/apps/webui/models/tags.py index d9a967ff7..4c4fa82e6 100644 --- a/backend/apps/web/models/tags.py +++ b/backend/apps/webui/models/tags.py @@ -8,7 +8,7 @@ import uuid import time import logging -from apps.web.internal.db import DB +from apps.webui.internal.db import DB from config import SRC_LOG_LEVELS diff --git a/backend/apps/web/models/users.py b/backend/apps/webui/models/users.py similarity index 98% rename from backend/apps/web/models/users.py rename to backend/apps/webui/models/users.py index 450dd9187..8f600c6d5 100644 --- a/backend/apps/web/models/users.py +++ b/backend/apps/webui/models/users.py @@ -5,8 +5,8 @@ from typing import List, Union, Optional import time from utils.misc import get_gravatar_url -from apps.web.internal.db import DB -from apps.web.models.chats import Chats +from apps.webui.internal.db import DB +from apps.webui.models.chats import Chats #################### # User DB Schema diff --git a/backend/apps/web/routers/auths.py b/backend/apps/webui/routers/auths.py similarity index 99% rename from backend/apps/web/routers/auths.py rename to backend/apps/webui/routers/auths.py index 998e74659..ce9b92061 100644 --- a/backend/apps/web/routers/auths.py +++ b/backend/apps/webui/routers/auths.py @@ -10,7 +10,7 @@ import uuid import csv -from apps.web.models.auths import ( +from apps.webui.models.auths import ( SigninForm, SignupForm, AddUserForm, @@ -21,7 +21,7 @@ from apps.web.models.auths import ( Auths, ApiKey, ) -from apps.web.models.users import Users +from apps.webui.models.users import Users from utils.utils import ( get_password_hash, diff --git a/backend/apps/web/routers/chats.py b/backend/apps/webui/routers/chats.py similarity index 95% rename from backend/apps/web/routers/chats.py rename to backend/apps/webui/routers/chats.py index aaf173521..5d52f40c9 100644 --- a/backend/apps/web/routers/chats.py +++ b/backend/apps/webui/routers/chats.py @@ -7,8 +7,8 @@ from pydantic import BaseModel import json import logging -from apps.web.models.users import Users -from apps.web.models.chats import ( +from apps.webui.models.users import Users +from apps.webui.models.chats import ( ChatModel, ChatResponse, ChatTitleForm, @@ -18,7 +18,7 @@ from apps.web.models.chats import ( ) -from apps.web.models.tags import ( +from apps.webui.models.tags import ( TagModel, ChatIdTagModel, ChatIdTagForm, @@ -78,43 +78,25 @@ async def delete_all_user_chats(request: Request, user=Depends(get_current_user) async def get_user_chat_list_by_user_id( user_id: str, user=Depends(get_admin_user), skip: int = 0, limit: int = 50 ): - return Chats.get_chat_list_by_user_id(user_id, skip, limit) + return Chats.get_chat_list_by_user_id( + user_id, include_archived=True, skip=skip, limit=limit + ) ############################ -# GetArchivedChats +# CreateNewChat ############################ -@router.get("/archived", response_model=List[ChatTitleIdResponse]) -async def get_archived_session_user_chat_list( - user=Depends(get_current_user), skip: int = 0, limit: int = 50 -): - return Chats.get_archived_chat_list_by_user_id(user.id, skip, limit) - - -############################ -# GetSharedChatById -############################ - - -@router.get("/share/{share_id}", response_model=Optional[ChatResponse]) -async def get_shared_chat_by_id(share_id: str, user=Depends(get_current_user)): - if user.role == "pending": - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND - ) - - if user.role == "user": - chat = Chats.get_chat_by_share_id(share_id) - elif user.role == "admin": - chat = Chats.get_chat_by_id(share_id) - - if chat: +@router.post("/new", response_model=Optional[ChatResponse]) +async def create_new_chat(form_data: ChatForm, user=Depends(get_current_user)): + try: + chat = Chats.insert_new_chat(user.id, form_data) return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) - else: + except Exception as e: + log.exception(e) raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND + status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() ) @@ -150,19 +132,49 @@ async def get_all_user_chats_in_db(user=Depends(get_admin_user)): ############################ -# CreateNewChat +# GetArchivedChats ############################ -@router.post("/new", response_model=Optional[ChatResponse]) -async def create_new_chat(form_data: ChatForm, user=Depends(get_current_user)): - try: - chat = Chats.insert_new_chat(user.id, form_data) - return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) - except Exception as e: - log.exception(e) +@router.get("/archived", response_model=List[ChatTitleIdResponse]) +async def get_archived_session_user_chat_list( + user=Depends(get_current_user), skip: int = 0, limit: int = 50 +): + return Chats.get_archived_chat_list_by_user_id(user.id, skip, limit) + + +############################ +# ArchiveAllChats +############################ + + +@router.post("/archive/all", response_model=List[ChatTitleIdResponse]) +async def archive_all_chats(user=Depends(get_current_user)): + return Chats.archive_all_chats_by_user_id(user.id) + + +############################ +# GetSharedChatById +############################ + + +@router.get("/share/{share_id}", response_model=Optional[ChatResponse]) +async def get_shared_chat_by_id(share_id: str, user=Depends(get_current_user)): + if user.role == "pending": raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() + status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND + ) + + if user.role == "user": + chat = Chats.get_chat_by_share_id(share_id) + elif user.role == "admin": + chat = Chats.get_chat_by_id(share_id) + + if chat: + return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND ) diff --git a/backend/apps/web/routers/configs.py b/backend/apps/webui/routers/configs.py similarity index 97% rename from backend/apps/web/routers/configs.py rename to backend/apps/webui/routers/configs.py index 143ed5e0a..00feafb18 100644 --- a/backend/apps/web/routers/configs.py +++ b/backend/apps/webui/routers/configs.py @@ -8,7 +8,7 @@ from pydantic import BaseModel import time import uuid -from apps.web.models.users import Users +from apps.webui.models.users import Users from utils.utils import ( get_password_hash, diff --git a/backend/apps/web/routers/documents.py b/backend/apps/webui/routers/documents.py similarity index 98% rename from backend/apps/web/routers/documents.py rename to backend/apps/webui/routers/documents.py index 7c69514fe..c5447a3fe 100644 --- a/backend/apps/web/routers/documents.py +++ b/backend/apps/webui/routers/documents.py @@ -6,7 +6,7 @@ from fastapi import APIRouter from pydantic import BaseModel import json -from apps.web.models.documents import ( +from apps.webui.models.documents import ( Documents, DocumentForm, DocumentUpdateForm, diff --git a/backend/apps/web/routers/memories.py b/backend/apps/webui/routers/memories.py similarity index 98% rename from backend/apps/web/routers/memories.py rename to backend/apps/webui/routers/memories.py index f20e02601..6448ebe1e 100644 --- a/backend/apps/web/routers/memories.py +++ b/backend/apps/webui/routers/memories.py @@ -7,7 +7,7 @@ from fastapi import APIRouter from pydantic import BaseModel import logging -from apps.web.models.memories import Memories, MemoryModel +from apps.webui.models.memories import Memories, MemoryModel from utils.utils import get_verified_user from constants import ERROR_MESSAGES diff --git a/backend/apps/webui/routers/models.py b/backend/apps/webui/routers/models.py new file mode 100644 index 000000000..363737e25 --- /dev/null +++ b/backend/apps/webui/routers/models.py @@ -0,0 +1,108 @@ +from fastapi import Depends, FastAPI, HTTPException, status, Request +from datetime import datetime, timedelta +from typing import List, Union, Optional + +from fastapi import APIRouter +from pydantic import BaseModel +import json +from apps.webui.models.models import Models, ModelModel, ModelForm, ModelResponse + +from utils.utils import get_verified_user, get_admin_user +from constants import ERROR_MESSAGES + +router = APIRouter() + +########################### +# getModels +########################### + + +@router.get("/", response_model=List[ModelResponse]) +async def get_models(user=Depends(get_verified_user)): + return Models.get_all_models() + + +############################ +# AddNewModel +############################ + + +@router.post("/add", response_model=Optional[ModelModel]) +async def add_new_model( + request: Request, form_data: ModelForm, user=Depends(get_admin_user) +): + if form_data.id in request.app.state.MODELS: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.MODEL_ID_TAKEN, + ) + else: + model = Models.insert_new_model(form_data, user.id) + + if model: + return model + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.DEFAULT(), + ) + + +############################ +# GetModelById +############################ + + +@router.get("/", response_model=Optional[ModelModel]) +async def get_model_by_id(id: str, user=Depends(get_verified_user)): + model = Models.get_model_by_id(id) + + if model: + return model + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + +############################ +# UpdateModelById +############################ + + +@router.post("/update", response_model=Optional[ModelModel]) +async def update_model_by_id( + request: Request, id: str, form_data: ModelForm, user=Depends(get_admin_user) +): + model = Models.get_model_by_id(id) + if model: + model = Models.update_model_by_id(id, form_data) + return model + else: + if form_data.id in request.app.state.MODELS: + model = Models.insert_new_model(form_data, user.id) + print(model) + if model: + return model + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.DEFAULT(), + ) + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.DEFAULT(), + ) + + +############################ +# DeleteModelById +############################ + + +@router.delete("/delete", response_model=bool) +async def delete_model_by_id(id: str, user=Depends(get_admin_user)): + result = Models.delete_model_by_id(id) + return result diff --git a/backend/apps/web/routers/prompts.py b/backend/apps/webui/routers/prompts.py similarity index 97% rename from backend/apps/web/routers/prompts.py rename to backend/apps/webui/routers/prompts.py index db7619676..47d8c7012 100644 --- a/backend/apps/web/routers/prompts.py +++ b/backend/apps/webui/routers/prompts.py @@ -6,7 +6,7 @@ from fastapi import APIRouter from pydantic import BaseModel import json -from apps.web.models.prompts import Prompts, PromptForm, PromptModel +from apps.webui.models.prompts import Prompts, PromptForm, PromptModel from utils.utils import get_current_user, get_admin_user from constants import ERROR_MESSAGES diff --git a/backend/apps/web/routers/users.py b/backend/apps/webui/routers/users.py similarity index 96% rename from backend/apps/web/routers/users.py rename to backend/apps/webui/routers/users.py index d77475d8d..bb9c557db 100644 --- a/backend/apps/web/routers/users.py +++ b/backend/apps/webui/routers/users.py @@ -9,9 +9,9 @@ import time import uuid import logging -from apps.web.models.users import UserModel, UserUpdateForm, UserRoleUpdateForm, Users -from apps.web.models.auths import Auths -from apps.web.models.chats import Chats +from apps.webui.models.users import UserModel, UserUpdateForm, UserRoleUpdateForm, Users +from apps.webui.models.auths import Auths +from apps.webui.models.chats import Chats from utils.utils import get_verified_user, get_password_hash, get_admin_user from constants import ERROR_MESSAGES diff --git a/backend/apps/web/routers/utils.py b/backend/apps/webui/routers/utils.py similarity index 98% rename from backend/apps/web/routers/utils.py rename to backend/apps/webui/routers/utils.py index 12805873d..b95fe8834 100644 --- a/backend/apps/web/routers/utils.py +++ b/backend/apps/webui/routers/utils.py @@ -8,7 +8,7 @@ from pydantic import BaseModel from fpdf import FPDF import markdown -from apps.web.internal.db import DB +from apps.webui.internal.db import DB from utils.utils import get_admin_user from utils.misc import calculate_sha256, get_gravatar_url diff --git a/backend/config.py b/backend/config.py index 8f259e1dc..5d074e250 100644 --- a/backend/config.py +++ b/backend/config.py @@ -27,6 +27,8 @@ from constants import ERROR_MESSAGES BACKEND_DIR = Path(__file__).parent # the path containing this file BASE_DIR = BACKEND_DIR.parent # the path containing the backend/ +print(BASE_DIR) + try: from dotenv import load_dotenv, find_dotenv @@ -56,7 +58,6 @@ log_sources = [ "CONFIG", "DB", "IMAGES", - "LITELLM", "MAIN", "MODELS", "OLLAMA", @@ -122,7 +123,10 @@ def parse_section(section): try: - changelog_content = (BASE_DIR / "CHANGELOG.md").read_text() + changelog_path = BASE_DIR / "CHANGELOG.md" + with open(str(changelog_path.absolute()), "r", encoding="utf8") as file: + changelog_content = file.read() + except: changelog_content = (pkgutil.get_data("open_webui", "CHANGELOG.md") or b"").decode() @@ -374,10 +378,10 @@ def create_config_file(file_path): LITELLM_CONFIG_PATH = f"{DATA_DIR}/litellm/config.yaml" -if not os.path.exists(LITELLM_CONFIG_PATH): - log.info("Config file doesn't exist. Creating...") - create_config_file(LITELLM_CONFIG_PATH) - log.info("Config file created successfully.") +# if not os.path.exists(LITELLM_CONFIG_PATH): +# log.info("Config file doesn't exist. Creating...") +# create_config_file(LITELLM_CONFIG_PATH) +# log.info("Config file created successfully.") #################################### @@ -826,18 +830,6 @@ AUDIO_OPENAI_API_VOICE = PersistentConfig( os.getenv("AUDIO_OPENAI_API_VOICE", "alloy"), ) -#################################### -# LiteLLM -#################################### - - -ENABLE_LITELLM = os.environ.get("ENABLE_LITELLM", "True").lower() == "true" - -LITELLM_PROXY_PORT = int(os.getenv("LITELLM_PROXY_PORT", "14365")) -if LITELLM_PROXY_PORT < 0 or LITELLM_PROXY_PORT > 65535: - raise ValueError("Invalid port number for LITELLM_PROXY_PORT") -LITELLM_PROXY_HOST = os.getenv("LITELLM_PROXY_HOST", "127.0.0.1") - #################################### # Database diff --git a/backend/constants.py b/backend/constants.py index be4d135b2..86875d2df 100644 --- a/backend/constants.py +++ b/backend/constants.py @@ -32,6 +32,8 @@ class ERROR_MESSAGES(str, Enum): COMMAND_TAKEN = "Uh-oh! This command is already registered. Please choose another command string." FILE_EXISTS = "Uh-oh! This file is already registered. Please choose another file." + MODEL_ID_TAKEN = "Uh-oh! This model id is already registered. Please choose another model id string." + NAME_TAG_TAKEN = "Uh-oh! This name tag is already registered. Please choose another name tag string." INVALID_TOKEN = ( "Your session has expired or the token is invalid. Please sign in again." diff --git a/backend/main.py b/backend/main.py index bbc0110ae..66fae2b60 100644 --- a/backend/main.py +++ b/backend/main.py @@ -19,27 +19,20 @@ from starlette.exceptions import HTTPException as StarletteHTTPException from starlette.middleware.base import BaseHTTPMiddleware from starlette.responses import StreamingResponse, Response -from apps.ollama.main import app as ollama_app -from apps.openai.main import app as openai_app - -from apps.litellm.main import ( - app as litellm_app, - start_litellm_background, - shutdown_litellm_background, -) - +from apps.ollama.main import app as ollama_app, get_all_models as get_ollama_models +from apps.openai.main import app as openai_app, get_all_models as get_openai_models from apps.audio.main import app as audio_app from apps.images.main import app as images_app from apps.rag.main import app as rag_app -from apps.web.main import app as webui_app +from apps.webui.main import app as webui_app import asyncio from pydantic import BaseModel -from typing import List +from typing import List, Optional - -from utils.utils import get_admin_user +from apps.webui.models.models import Models, ModelModel +from utils.utils import get_admin_user, get_verified_user from apps.rag.utils import rag_messages from config import ( @@ -53,7 +46,8 @@ from config import ( FRONTEND_BUILD_DIR, CACHE_DIR, STATIC_DIR, - ENABLE_LITELLM, + ENABLE_OPENAI_API, + ENABLE_OLLAMA_API, ENABLE_MODEL_FILTER, MODEL_FILTER_LIST, GLOBAL_LOG_LEVEL, @@ -100,11 +94,7 @@ https://github.com/open-webui/open-webui @asynccontextmanager async def lifespan(app: FastAPI): - if ENABLE_LITELLM: - asyncio.create_task(start_litellm_background()) yield - if ENABLE_LITELLM: - await shutdown_litellm_background() app = FastAPI( @@ -112,11 +102,19 @@ app = FastAPI( ) app.state.config = AppConfig() + +app.state.config.ENABLE_OPENAI_API = ENABLE_OPENAI_API +app.state.config.ENABLE_OLLAMA_API = ENABLE_OLLAMA_API + app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST + app.state.config.WEBHOOK_URL = WEBHOOK_URL + +app.state.MODELS = {} + origins = ["*"] @@ -233,6 +231,11 @@ app.add_middleware( @app.middleware("http") async def check_url(request: Request, call_next): + if len(app.state.MODELS) == 0: + await get_all_models() + else: + pass + start_time = int(time.time()) response = await call_next(request) process_time = int(time.time()) - start_time @@ -249,9 +252,8 @@ async def update_embedding_function(request: Request, call_next): return response -app.mount("/litellm/api", litellm_app) app.mount("/ollama", ollama_app) -app.mount("/openai/api", openai_app) +app.mount("/openai", openai_app) app.mount("/images/api/v1", images_app) app.mount("/audio/api/v1", audio_app) @@ -262,6 +264,87 @@ app.mount("/api/v1", webui_app) webui_app.state.EMBEDDING_FUNCTION = rag_app.state.EMBEDDING_FUNCTION +async def get_all_models(): + openai_models = [] + ollama_models = [] + + if app.state.config.ENABLE_OPENAI_API: + openai_models = await get_openai_models() + + openai_models = openai_models["data"] + + if app.state.config.ENABLE_OLLAMA_API: + ollama_models = await get_ollama_models() + + ollama_models = [ + { + "id": model["model"], + "name": model["name"], + "object": "model", + "created": int(time.time()), + "owned_by": "ollama", + "ollama": model, + } + for model in ollama_models["models"] + ] + + models = openai_models + ollama_models + custom_models = Models.get_all_models() + + for custom_model in custom_models: + if custom_model.base_model_id == None: + for model in models: + if ( + custom_model.id == model["id"] + or custom_model.id == model["id"].split(":")[0] + ): + model["name"] = custom_model.name + model["info"] = custom_model.model_dump() + else: + owned_by = "openai" + for model in models: + if ( + custom_model.base_model_id == model["id"] + or custom_model.base_model_id == model["id"].split(":")[0] + ): + owned_by = model["owned_by"] + break + + models.append( + { + "id": custom_model.id, + "name": custom_model.name, + "object": "model", + "created": custom_model.created_at, + "owned_by": owned_by, + "info": custom_model.model_dump(), + "preset": True, + } + ) + + app.state.MODELS = {model["id"]: model for model in models} + + webui_app.state.MODELS = app.state.MODELS + + return models + + +@app.get("/api/models") +async def get_models(user=Depends(get_verified_user)): + models = await get_all_models() + if app.state.config.ENABLE_MODEL_FILTER: + if user.role == "user": + models = list( + filter( + lambda model: model["id"] in app.state.config.MODEL_FILTER_LIST, + models, + ) + ) + return {"data": models} + + return {"data": models} + + @app.get("/api/config") async def get_app_config(): # Checking and Handling the Absence of 'ui' in CONFIG_DATA @@ -276,12 +359,13 @@ async def get_app_config(): "name": WEBUI_NAME, "version": VERSION, "auth": WEBUI_AUTH, + "auth_trusted_header": bool(webui_app.state.AUTH_TRUSTED_EMAIL_HEADER), + "enable_signup": webui_app.state.config.ENABLE_SIGNUP, + "enable_image_generation": images_app.state.config.ENABLED, + "enable_admin_export": ENABLE_ADMIN_EXPORT, "default_locale": default_locale, - "images": images_app.state.config.ENABLED, "default_models": webui_app.state.config.DEFAULT_MODELS, "default_prompt_suggestions": webui_app.state.config.DEFAULT_PROMPT_SUGGESTIONS, - "trusted_header_auth": bool(webui_app.state.AUTH_TRUSTED_EMAIL_HEADER), - "admin_export_enabled": ENABLE_ADMIN_EXPORT, } @@ -305,15 +389,6 @@ async def update_model_filter_config( app.state.config.ENABLE_MODEL_FILTER = form_data.enabled app.state.config.MODEL_FILTER_LIST = form_data.models - ollama_app.state.config.ENABLE_MODEL_FILTER = app.state.config.ENABLE_MODEL_FILTER - ollama_app.state.config.MODEL_FILTER_LIST = app.state.config.MODEL_FILTER_LIST - - openai_app.state.config.ENABLE_MODEL_FILTER = app.state.config.ENABLE_MODEL_FILTER - openai_app.state.config.MODEL_FILTER_LIST = app.state.config.MODEL_FILTER_LIST - - litellm_app.state.ENABLE_MODEL_FILTER = app.state.config.ENABLE_MODEL_FILTER - litellm_app.state.MODEL_FILTER_LIST = app.state.config.MODEL_FILTER_LIST - return { "enabled": app.state.config.ENABLE_MODEL_FILTER, "models": app.state.config.MODEL_FILTER_LIST, @@ -334,7 +409,6 @@ class UrlForm(BaseModel): @app.post("/api/webhook") async def update_webhook_url(form_data: UrlForm, user=Depends(get_admin_user)): app.state.config.WEBHOOK_URL = form_data.url - webui_app.state.WEBHOOK_URL = app.state.config.WEBHOOK_URL return { diff --git a/backend/requirements.txt b/backend/requirements.txt index 29e37f8b8..7a3668428 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -18,8 +18,6 @@ psycopg2-binary==2.9.9 PyMySQL==1.1.1 bcrypt==4.1.3 -litellm[proxy]==1.37.20 - boto3==1.34.110 argon2-cffi==23.1.0 diff --git a/backend/space/litellm_config.yaml b/backend/space/litellm_config.yaml deleted file mode 100644 index af4f880b9..000000000 --- a/backend/space/litellm_config.yaml +++ /dev/null @@ -1,43 +0,0 @@ -litellm_settings: - drop_params: true -model_list: - - model_name: 'HuggingFace: Mistral: Mistral 7B Instruct v0.1' - litellm_params: - model: huggingface/mistralai/Mistral-7B-Instruct-v0.1 - api_key: os.environ/HF_TOKEN - max_tokens: 1024 - - model_name: 'HuggingFace: Mistral: Mistral 7B Instruct v0.2' - litellm_params: - model: huggingface/mistralai/Mistral-7B-Instruct-v0.2 - api_key: os.environ/HF_TOKEN - max_tokens: 1024 - - model_name: 'HuggingFace: Meta: Llama 3 8B Instruct' - litellm_params: - model: huggingface/meta-llama/Meta-Llama-3-8B-Instruct - api_key: os.environ/HF_TOKEN - max_tokens: 2047 - - model_name: 'HuggingFace: Mistral: Mixtral 8x7B Instruct v0.1' - litellm_params: - model: huggingface/mistralai/Mixtral-8x7B-Instruct-v0.1 - api_key: os.environ/HF_TOKEN - max_tokens: 8192 - - model_name: 'HuggingFace: Microsoft: Phi-3 Mini-4K-Instruct' - litellm_params: - model: huggingface/microsoft/Phi-3-mini-4k-instruct - api_key: os.environ/HF_TOKEN - max_tokens: 1024 - - model_name: 'HuggingFace: Google: Gemma 7B 1.1' - litellm_params: - model: huggingface/google/gemma-1.1-7b-it - api_key: os.environ/HF_TOKEN - max_tokens: 1024 - - model_name: 'HuggingFace: Yi-1.5 34B Chat' - litellm_params: - model: huggingface/01-ai/Yi-1.5-34B-Chat - api_key: os.environ/HF_TOKEN - max_tokens: 1024 - - model_name: 'HuggingFace: Nous Research: Nous Hermes 2 Mixtral 8x7B DPO' - litellm_params: - model: huggingface/NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO - api_key: os.environ/HF_TOKEN - max_tokens: 2048 diff --git a/backend/start.sh b/backend/start.sh index ba7741e1d..15fc568d3 100755 --- a/backend/start.sh +++ b/backend/start.sh @@ -34,11 +34,6 @@ fi # Check if SPACE_ID is set, if so, configure for space if [ -n "$SPACE_ID" ]; then echo "Configuring for HuggingFace Space deployment" - - # Copy litellm_config.yaml with specified ownership - echo "Copying litellm_config.yaml to the desired location with specified ownership..." - cp -f ./space/litellm_config.yaml ./data/litellm/config.yaml - if [ -n "$ADMIN_USER_EMAIL" ] && [ -n "$ADMIN_USER_PASSWORD" ]; then echo "Admin user configured, creating" WEBUI_SECRET_KEY="$WEBUI_SECRET_KEY" uvicorn main:app --host "$HOST" --port "$PORT" --forwarded-allow-ips '*' & diff --git a/backend/utils/misc.py b/backend/utils/misc.py index 5efff4a35..fca941263 100644 --- a/backend/utils/misc.py +++ b/backend/utils/misc.py @@ -1,5 +1,6 @@ from pathlib import Path import hashlib +import json import re from datetime import timedelta from typing import Optional @@ -110,3 +111,76 @@ def parse_duration(duration: str) -> Optional[timedelta]: total_duration += timedelta(weeks=number) return total_duration + + +def parse_ollama_modelfile(model_text): + parameters_meta = { + "mirostat": int, + "mirostat_eta": float, + "mirostat_tau": float, + "num_ctx": int, + "repeat_last_n": int, + "repeat_penalty": float, + "temperature": float, + "seed": int, + "stop": str, + "tfs_z": float, + "num_predict": int, + "top_k": int, + "top_p": float, + } + + data = {"base_model_id": None, "params": {}} + + # Parse base model + base_model_match = re.search( + r"^FROM\s+(\w+)", model_text, re.MULTILINE | re.IGNORECASE + ) + if base_model_match: + data["base_model_id"] = base_model_match.group(1) + + # Parse template + template_match = re.search( + r'TEMPLATE\s+"""(.+?)"""', model_text, re.DOTALL | re.IGNORECASE + ) + if template_match: + data["params"] = {"template": template_match.group(1).strip()} + + # Parse stops + stops = re.findall(r'PARAMETER stop "(.*?)"', model_text, re.IGNORECASE) + if stops: + data["params"]["stop"] = stops + + # Parse other parameters from the provided list + for param, param_type in parameters_meta.items(): + param_match = re.search(rf"PARAMETER {param} (.+)", model_text, re.IGNORECASE) + if param_match: + value = param_match.group(1) + if param_type == int: + value = int(value) + elif param_type == float: + value = float(value) + data["params"][param] = value + + # Parse adapter + adapter_match = re.search(r"ADAPTER (.+)", model_text, re.IGNORECASE) + if adapter_match: + data["params"]["adapter"] = adapter_match.group(1) + + # Parse system description + system_desc_match = re.search( + r'SYSTEM\s+"""(.+?)"""', model_text, re.DOTALL | re.IGNORECASE + ) + if system_desc_match: + data["params"]["system"] = system_desc_match.group(1).strip() + + # Parse messages + messages = [] + message_matches = re.findall(r"MESSAGE (\w+) (.+)", model_text, re.IGNORECASE) + for role, content in message_matches: + messages.append({"role": role, "content": content}) + + if messages: + data["params"]["messages"] = messages + + return data diff --git a/backend/utils/models.py b/backend/utils/models.py new file mode 100644 index 000000000..c4d675d29 --- /dev/null +++ b/backend/utils/models.py @@ -0,0 +1,10 @@ +from apps.webui.models.models import Models, ModelModel, ModelForm, ModelResponse + + +def get_model_id_from_custom_model_id(id: str): + model = Models.get_model_by_id(id) + + if model: + return model.id + else: + return id diff --git a/backend/utils/utils.py b/backend/utils/utils.py index af4fd85c0..cc6bb06b8 100644 --- a/backend/utils/utils.py +++ b/backend/utils/utils.py @@ -1,7 +1,7 @@ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from fastapi import HTTPException, status, Depends -from apps.web.models.users import Users +from apps.webui.models.users import Users from pydantic import BaseModel from typing import Union, Optional diff --git a/requirements-dev.lock b/requirements-dev.lock index 93c126eb4..39b1d0ef0 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -273,7 +273,6 @@ langsmith==0.1.57 # via langchain-community # via langchain-core litellm==1.37.20 - # via litellm # via open-webui lxml==5.2.2 # via unstructured @@ -396,7 +395,6 @@ pandas==2.2.2 # via open-webui passlib==1.7.4 # via open-webui - # via passlib pathspec==0.12.1 # via black peewee==3.17.5 @@ -454,7 +452,6 @@ pygments==2.18.0 pyjwt==2.8.0 # via litellm # via open-webui - # via pyjwt pymysql==1.1.0 # via open-webui pypandoc==1.13 @@ -559,6 +556,9 @@ scipy==1.13.0 # via sentence-transformers sentence-transformers==2.7.0 # via open-webui +setuptools==69.5.1 + # via ctranslate2 + # via opentelemetry-instrumentation shapely==2.0.4 # via rapidocr-onnxruntime shellingham==1.5.4 @@ -659,7 +659,6 @@ uvicorn==0.22.0 # via fastapi # via litellm # via open-webui - # via uvicorn uvloop==0.19.0 # via uvicorn validators==0.28.1 @@ -687,6 +686,3 @@ youtube-transcript-api==0.6.2 # via open-webui zipp==3.18.1 # via importlib-metadata -setuptools==69.5.1 - # via ctranslate2 - # via opentelemetry-instrumentation diff --git a/requirements.lock b/requirements.lock index 93c126eb4..39b1d0ef0 100644 --- a/requirements.lock +++ b/requirements.lock @@ -273,7 +273,6 @@ langsmith==0.1.57 # via langchain-community # via langchain-core litellm==1.37.20 - # via litellm # via open-webui lxml==5.2.2 # via unstructured @@ -396,7 +395,6 @@ pandas==2.2.2 # via open-webui passlib==1.7.4 # via open-webui - # via passlib pathspec==0.12.1 # via black peewee==3.17.5 @@ -454,7 +452,6 @@ pygments==2.18.0 pyjwt==2.8.0 # via litellm # via open-webui - # via pyjwt pymysql==1.1.0 # via open-webui pypandoc==1.13 @@ -559,6 +556,9 @@ scipy==1.13.0 # via sentence-transformers sentence-transformers==2.7.0 # via open-webui +setuptools==69.5.1 + # via ctranslate2 + # via opentelemetry-instrumentation shapely==2.0.4 # via rapidocr-onnxruntime shellingham==1.5.4 @@ -659,7 +659,6 @@ uvicorn==0.22.0 # via fastapi # via litellm # via open-webui - # via uvicorn uvloop==0.19.0 # via uvicorn validators==0.28.1 @@ -687,6 +686,3 @@ youtube-transcript-api==0.6.2 # via open-webui zipp==3.18.1 # via importlib-metadata -setuptools==69.5.1 - # via ctranslate2 - # via opentelemetry-instrumentation diff --git a/src/lib/apis/chats/index.ts b/src/lib/apis/chats/index.ts index a72b51939..834e29d29 100644 --- a/src/lib/apis/chats/index.ts +++ b/src/lib/apis/chats/index.ts @@ -654,3 +654,35 @@ export const deleteAllChats = async (token: string) => { return res; }; + +export const archiveAllChats = async (token: string) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/chats/archive/all`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; diff --git a/src/lib/apis/index.ts b/src/lib/apis/index.ts index a610f7210..5d94e7678 100644 --- a/src/lib/apis/index.ts +++ b/src/lib/apis/index.ts @@ -1,5 +1,54 @@ import { WEBUI_BASE_URL } from '$lib/constants'; +export const getModels = async (token: string = '') => { + let error = null; + + const res = await fetch(`${WEBUI_BASE_URL}/api/models`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + error = err; + return null; + }); + + if (error) { + throw error; + } + + let models = res?.data ?? []; + + models = models + .filter((models) => models) + .sort((a, b) => { + // Compare case-insensitively + const lowerA = a.name.toLowerCase(); + const lowerB = b.name.toLowerCase(); + + if (lowerA < lowerB) return -1; + if (lowerA > lowerB) return 1; + + // If same case-insensitively, sort by original strings, + // lowercase will come before uppercase due to ASCII values + if (a < b) return -1; + if (a > b) return 1; + + return 0; // They are equal + }); + + console.log(models); + return models; +}; + export const getBackendConfig = async () => { let error = null; @@ -196,3 +245,77 @@ export const updateWebhookUrl = async (token: string, url: string) => { return res.url; }; + +export const getModelConfig = async (token: string): Promise => { + let error = null; + + const res = await fetch(`${WEBUI_BASE_URL}/api/config/models`, { + method: 'GET', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + error = err; + return null; + }); + + if (error) { + throw error; + } + + return res.models; +}; + +export interface ModelConfig { + id: string; + name: string; + meta: ModelMeta; + base_model_id?: string; + params: ModelParams; +} + +export interface ModelMeta { + description?: string; + capabilities?: object; +} + +export interface ModelParams {} + +export type GlobalModelConfig = ModelConfig[]; + +export const updateModelConfig = async (token: string, config: GlobalModelConfig) => { + let error = null; + + const res = await fetch(`${WEBUI_BASE_URL}/api/config/models`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + models: config + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + error = err; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; diff --git a/src/lib/apis/litellm/index.ts b/src/lib/apis/litellm/index.ts deleted file mode 100644 index 643146b73..000000000 --- a/src/lib/apis/litellm/index.ts +++ /dev/null @@ -1,150 +0,0 @@ -import { LITELLM_API_BASE_URL } from '$lib/constants'; - -export const getLiteLLMModels = async (token: string = '') => { - let error = null; - - const res = await fetch(`${LITELLM_API_BASE_URL}/v1/models`, { - method: 'GET', - headers: { - Accept: 'application/json', - 'Content-Type': 'application/json', - ...(token && { authorization: `Bearer ${token}` }) - } - }) - .then(async (res) => { - if (!res.ok) throw await res.json(); - return res.json(); - }) - .catch((err) => { - console.log(err); - error = `LiteLLM: ${err?.error?.message ?? 'Network Problem'}`; - return []; - }); - - if (error) { - throw error; - } - - const models = Array.isArray(res) ? res : res?.data ?? null; - - return models - ? models - .map((model) => ({ - id: model.id, - name: model.name ?? model.id, - external: true, - source: 'LiteLLM' - })) - .sort((a, b) => { - return a.name.localeCompare(b.name); - }) - : models; -}; - -export const getLiteLLMModelInfo = async (token: string = '') => { - let error = null; - - const res = await fetch(`${LITELLM_API_BASE_URL}/model/info`, { - method: 'GET', - headers: { - Accept: 'application/json', - 'Content-Type': 'application/json', - ...(token && { authorization: `Bearer ${token}` }) - } - }) - .then(async (res) => { - if (!res.ok) throw await res.json(); - return res.json(); - }) - .catch((err) => { - console.log(err); - error = `LiteLLM: ${err?.error?.message ?? 'Network Problem'}`; - return []; - }); - - if (error) { - throw error; - } - - const models = Array.isArray(res) ? res : res?.data ?? null; - - return models; -}; - -type AddLiteLLMModelForm = { - name: string; - model: string; - api_base: string; - api_key: string; - rpm: string; - max_tokens: string; -}; - -export const addLiteLLMModel = async (token: string = '', payload: AddLiteLLMModelForm) => { - let error = null; - - const res = await fetch(`${LITELLM_API_BASE_URL}/model/new`, { - method: 'POST', - headers: { - Accept: 'application/json', - 'Content-Type': 'application/json', - ...(token && { authorization: `Bearer ${token}` }) - }, - body: JSON.stringify({ - model_name: payload.name, - litellm_params: { - model: payload.model, - ...(payload.api_base === '' ? {} : { api_base: payload.api_base }), - ...(payload.api_key === '' ? {} : { api_key: payload.api_key }), - ...(isNaN(parseInt(payload.rpm)) ? {} : { rpm: parseInt(payload.rpm) }), - ...(payload.max_tokens === '' ? {} : { max_tokens: payload.max_tokens }) - } - }) - }) - .then(async (res) => { - if (!res.ok) throw await res.json(); - return res.json(); - }) - .catch((err) => { - console.log(err); - error = `LiteLLM: ${err?.error?.message ?? 'Network Problem'}`; - return []; - }); - - if (error) { - throw error; - } - - return res; -}; - -export const deleteLiteLLMModel = async (token: string = '', id: string) => { - let error = null; - - const res = await fetch(`${LITELLM_API_BASE_URL}/model/delete`, { - method: 'POST', - headers: { - Accept: 'application/json', - 'Content-Type': 'application/json', - ...(token && { authorization: `Bearer ${token}` }) - }, - body: JSON.stringify({ - id: id - }) - }) - .then(async (res) => { - if (!res.ok) throw await res.json(); - return res.json(); - }) - .catch((err) => { - console.log(err); - error = `LiteLLM: ${err?.error?.message ?? 'Network Problem'}`; - return []; - }); - - if (error) { - throw error; - } - - return res; -}; diff --git a/src/lib/apis/modelfiles/index.ts b/src/lib/apis/models/index.ts similarity index 65% rename from src/lib/apis/modelfiles/index.ts rename to src/lib/apis/models/index.ts index 91af5e381..9faa358d3 100644 --- a/src/lib/apis/modelfiles/index.ts +++ b/src/lib/apis/models/index.ts @@ -1,18 +1,16 @@ import { WEBUI_API_BASE_URL } from '$lib/constants'; -export const createNewModelfile = async (token: string, modelfile: object) => { +export const addNewModel = async (token: string, model: object) => { let error = null; - const res = await fetch(`${WEBUI_API_BASE_URL}/modelfiles/create`, { + const res = await fetch(`${WEBUI_API_BASE_URL}/models/add`, { method: 'POST', headers: { Accept: 'application/json', 'Content-Type': 'application/json', authorization: `Bearer ${token}` }, - body: JSON.stringify({ - modelfile: modelfile - }) + body: JSON.stringify(model) }) .then(async (res) => { if (!res.ok) throw await res.json(); @@ -31,10 +29,10 @@ export const createNewModelfile = async (token: string, modelfile: object) => { return res; }; -export const getModelfiles = async (token: string = '') => { +export const getModelInfos = async (token: string = '') => { let error = null; - const res = await fetch(`${WEBUI_API_BASE_URL}/modelfiles/`, { + const res = await fetch(`${WEBUI_API_BASE_URL}/models`, { method: 'GET', headers: { Accept: 'application/json', @@ -59,62 +57,22 @@ export const getModelfiles = async (token: string = '') => { throw error; } - return res.map((modelfile) => modelfile.modelfile); + return res; }; -export const getModelfileByTagName = async (token: string, tagName: string) => { +export const getModelById = async (token: string, id: string) => { let error = null; - const res = await fetch(`${WEBUI_API_BASE_URL}/modelfiles/`, { - method: 'POST', + const searchParams = new URLSearchParams(); + searchParams.append('id', id); + + const res = await fetch(`${WEBUI_API_BASE_URL}/models?${searchParams.toString()}`, { + method: 'GET', headers: { Accept: 'application/json', 'Content-Type': 'application/json', authorization: `Bearer ${token}` - }, - body: JSON.stringify({ - tag_name: tagName - }) - }) - .then(async (res) => { - if (!res.ok) throw await res.json(); - return res.json(); - }) - .then((json) => { - return json; - }) - .catch((err) => { - error = err; - - console.log(err); - return null; - }); - - if (error) { - throw error; - } - - return res.modelfile; -}; - -export const updateModelfileByTagName = async ( - token: string, - tagName: string, - modelfile: object -) => { - let error = null; - - const res = await fetch(`${WEBUI_API_BASE_URL}/modelfiles/update`, { - method: 'POST', - headers: { - Accept: 'application/json', - 'Content-Type': 'application/json', - authorization: `Bearer ${token}` - }, - body: JSON.stringify({ - tag_name: tagName, - modelfile: modelfile - }) + } }) .then(async (res) => { if (!res.ok) throw await res.json(); @@ -137,19 +95,55 @@ export const updateModelfileByTagName = async ( return res; }; -export const deleteModelfileByTagName = async (token: string, tagName: string) => { +export const updateModelById = async (token: string, id: string, model: object) => { let error = null; - const res = await fetch(`${WEBUI_API_BASE_URL}/modelfiles/delete`, { + const searchParams = new URLSearchParams(); + searchParams.append('id', id); + + const res = await fetch(`${WEBUI_API_BASE_URL}/models/update?${searchParams.toString()}`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + }, + body: JSON.stringify(model) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err; + + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const deleteModelById = async (token: string, id: string) => { + let error = null; + + const searchParams = new URLSearchParams(); + searchParams.append('id', id); + + const res = await fetch(`${WEBUI_API_BASE_URL}/models/delete?${searchParams.toString()}`, { method: 'DELETE', headers: { Accept: 'application/json', 'Content-Type': 'application/json', authorization: `Bearer ${token}` - }, - body: JSON.stringify({ - tag_name: tagName - }) + } }) .then(async (res) => { if (!res.ok) throw await res.json(); diff --git a/src/lib/apis/ollama/index.ts b/src/lib/apis/ollama/index.ts index b7f842177..efc3f0d0f 100644 --- a/src/lib/apis/ollama/index.ts +++ b/src/lib/apis/ollama/index.ts @@ -164,7 +164,7 @@ export const getOllamaVersion = async (token: string = '') => { throw error; } - return res?.version ?? ''; + return res?.version ?? false; }; export const getOllamaModels = async (token: string = '') => { diff --git a/src/lib/apis/openai/index.ts b/src/lib/apis/openai/index.ts index 02281eff0..8afcec018 100644 --- a/src/lib/apis/openai/index.ts +++ b/src/lib/apis/openai/index.ts @@ -230,7 +230,12 @@ export const getOpenAIModels = async (token: string = '') => { return models ? models - .map((model) => ({ id: model.id, name: model.name ?? model.id, external: true })) + .map((model) => ({ + id: model.id, + name: model.name ?? model.id, + external: true, + custom_info: model.custom_info + })) .sort((a, b) => { return a.name.localeCompare(b.name); }) diff --git a/src/lib/components/admin/Settings/Database.svelte b/src/lib/components/admin/Settings/Database.svelte index cde6bcaa4..711c1254f 100644 --- a/src/lib/components/admin/Settings/Database.svelte +++ b/src/lib/components/admin/Settings/Database.svelte @@ -1,13 +1,24 @@ -{#if code} -
-
-
{@html lang}
+
+
+
{@html lang}
-
- {#if lang === 'python' || (lang === '' && checkPythonCode(code))} - {#if executing} -
Running
- {:else} - - {/if} +
+ {#if lang === 'python' || (lang === '' && checkPythonCode(code))} + {#if executing} +
Running
+ {:else} + {/if} - -
+ {/if} +
- -
{@html highlightedCode || code}
- -
- - {#if executing} -
-
STDOUT/STDERR
-
Running...
-
- {:else if stdout || stderr || result} -
-
STDOUT/STDERR
-
{stdout || stderr || result}
-
- {/if}
-{/if} + +
{@html highlightedCode || code}
+ +
+ + {#if executing} +
+
STDOUT/STDERR
+
Running...
+
+ {:else if stdout || stderr || result} +
+
STDOUT/STDERR
+
{stdout || stderr || result}
+
+ {/if} +
diff --git a/src/lib/components/chat/Messages/CompareMessages.svelte b/src/lib/components/chat/Messages/CompareMessages.svelte index 60efdb2ab..f904a57ab 100644 --- a/src/lib/components/chat/Messages/CompareMessages.svelte +++ b/src/lib/components/chat/Messages/CompareMessages.svelte @@ -13,8 +13,6 @@ export let parentMessage; - export let selectedModelfiles; - export let updateChatMessages: Function; export let confirmEditResponseMessage: Function; export let rateMessage: Function; @@ -130,7 +128,6 @@ > m.id)} isLastMessage={true} {updateChatMessages} diff --git a/src/lib/components/chat/Messages/Placeholder.svelte b/src/lib/components/chat/Messages/Placeholder.svelte index dfb6cfb36..ed121dbe6 100644 --- a/src/lib/components/chat/Messages/Placeholder.svelte +++ b/src/lib/components/chat/Messages/Placeholder.svelte @@ -1,6 +1,6 @@ - -
-
-
{$i18n.t('Parameters')}
- - -
- -
-
-
{$i18n.t('Keep Alive')}
- - -
- - {#if keepAlive !== null} -
- -
- {/if} -
- -
-
-
{$i18n.t('Request Mode')}
- - -
-
-
- -
- -
-
diff --git a/src/lib/components/chat/Settings/Advanced/AdvancedParams.svelte b/src/lib/components/chat/Settings/Advanced/AdvancedParams.svelte index 6eaf82da8..93c482711 100644 --- a/src/lib/components/chat/Settings/Advanced/AdvancedParams.svelte +++ b/src/lib/components/chat/Settings/Advanced/AdvancedParams.svelte @@ -1,14 +1,16 @@ -
-
-
-
{$i18n.t('Seed')}
-
- -
+
+
+
+
{$i18n.t('Seed')}
+ +
+ + {#if (params?.seed ?? null) !== null} +
+
+ +
+
+ {/if}
-
-
-
{$i18n.t('Stop Sequence')}
-
- -
+
+
+
{$i18n.t('Stop Sequence')}
+ +
+ + {#if (params?.stop ?? null) !== null} +
+
+ +
+
+ {/if}
@@ -61,10 +109,10 @@ class="p-1 px-3 text-xs flex rounded transition" type="button" on:click={() => { - options.temperature = options.temperature === '' ? 0.8 : ''; + params.temperature = (params?.temperature ?? '') === '' ? 0.8 : ''; }} > - {#if options.temperature === ''} + {#if (params?.temperature ?? '') === ''} {$i18n.t('Default')} {:else} {$i18n.t('Custom')} @@ -72,7 +120,7 @@
- {#if options.temperature !== ''} + {#if (params?.temperature ?? '') !== ''}
{ - options.mirostat = options.mirostat === '' ? 0 : ''; + params.mirostat = (params?.mirostat ?? '') === '' ? 0 : ''; }} > - {#if options.mirostat === ''} + {#if (params?.mirostat ?? '') === ''} {$i18n.t('Default')} {:else} - {$i18n.t('Default')} + {$i18n.t('Custom')} {/if}
- {#if options.mirostat !== ''} + {#if (params?.mirostat ?? '') !== ''}
{ - options.mirostat_eta = options.mirostat_eta === '' ? 0.1 : ''; + params.mirostat_eta = (params?.mirostat_eta ?? '') === '' ? 0.1 : ''; }} > - {#if options.mirostat_eta === ''} + {#if (params?.mirostat_eta ?? '') === ''} {$i18n.t('Default')} {:else} - {$i18n.t('Default')} + {$i18n.t('Custom')} {/if}
- {#if options.mirostat_eta !== ''} + {#if (params?.mirostat_eta ?? '') !== ''}
{ - options.mirostat_tau = options.mirostat_tau === '' ? 5.0 : ''; + params.mirostat_tau = (params?.mirostat_tau ?? '') === '' ? 5.0 : ''; }} > - {#if options.mirostat_tau === ''} + {#if (params?.mirostat_tau ?? '') === ''} {$i18n.t('Default')} {:else} {$i18n.t('Custom')} @@ -210,7 +258,7 @@
- {#if options.mirostat_tau !== ''} + {#if (params?.mirostat_tau ?? '') !== ''}
{ - options.top_k = options.top_k === '' ? 40 : ''; + params.top_k = (params?.top_k ?? '') === '' ? 40 : ''; }} > - {#if options.top_k === ''} + {#if (params?.top_k ?? '') === ''} {$i18n.t('Default')} {:else} - {$i18n.t('Default')} + {$i18n.t('Custom')} {/if}
- {#if options.top_k !== ''} + {#if (params?.top_k ?? '') !== ''}
{ - options.top_p = options.top_p === '' ? 0.9 : ''; + params.top_p = (params?.top_p ?? '') === '' ? 0.9 : ''; }} > - {#if options.top_p === ''} + {#if (params?.top_p ?? '') === ''} {$i18n.t('Default')} {:else} - {$i18n.t('Default')} + {$i18n.t('Custom')} {/if}
- {#if options.top_p !== ''} + {#if (params?.top_p ?? '') !== ''}
-
{$i18n.t('Repeat Penalty')}
+
{$i18n.t('Frequencey Penalty')}
- {#if options.repeat_penalty !== ''} + {#if (params?.frequency_penalty ?? '') !== ''}
{ - options.repeat_last_n = options.repeat_last_n === '' ? 64 : ''; + params.repeat_last_n = (params?.repeat_last_n ?? '') === '' ? 64 : ''; }} > - {#if options.repeat_last_n === ''} + {#if (params?.repeat_last_n ?? '') === ''} {$i18n.t('Default')} {:else} - {$i18n.t('Default')} + {$i18n.t('Custom')} {/if}
- {#if options.repeat_last_n !== ''} + {#if (params?.repeat_last_n ?? '') !== ''}
{ - options.tfs_z = options.tfs_z === '' ? 1 : ''; + params.tfs_z = (params?.tfs_z ?? '') === '' ? 1 : ''; }} > - {#if options.tfs_z === ''} + {#if (params?.tfs_z ?? '') === ''} {$i18n.t('Default')} {:else} - {$i18n.t('Default')} + {$i18n.t('Custom')} {/if}
- {#if options.tfs_z !== ''} + {#if (params?.tfs_z ?? '') !== ''}
{ - options.num_ctx = options.num_ctx === '' ? 2048 : ''; + params.num_ctx = (params?.num_ctx ?? '') === '' ? 2048 : ''; }} > - {#if options.num_ctx === ''} + {#if (params?.num_ctx ?? '') === ''} {$i18n.t('Default')} {:else} - {$i18n.t('Default')} + {$i18n.t('Custom')} {/if}
- {#if options.num_ctx !== ''} + {#if (params?.num_ctx ?? '') !== ''}
-
{$i18n.t('Max Tokens')}
+
{$i18n.t('Max Tokens (num_predict)')}
- {#if options.num_predict !== ''} + {#if (params?.max_tokens ?? '') !== ''}
{/if}
+
+
+
{$i18n.t('Template')}
+ + +
+ + {#if (params?.template ?? null) !== null} +
+
+