diff --git a/.env.example b/.env.example index 2d782fce1..c38bf88bf 100644 --- a/.env.example +++ b/.env.example @@ -10,8 +10,4 @@ OPENAI_API_KEY='' # DO NOT TRACK SCARF_NO_ANALYTICS=true DO_NOT_TRACK=true -ANONYMIZED_TELEMETRY=false - -# Use locally bundled version of the LiteLLM cost map json -# to avoid repetitive startup connections -LITELLM_LOCAL_MODEL_COST_MAP="True" \ No newline at end of file +ANONYMIZED_TELEMETRY=false \ No newline at end of file diff --git a/.github/dependabot.yml b/.github/dependabot.disabled similarity index 100% rename from .github/dependabot.yml rename to .github/dependabot.disabled diff --git a/.github/workflows/build-release.yml b/.github/workflows/build-release.yml index 036bb97ae..cae363f42 100644 --- a/.github/workflows/build-release.yml +++ b/.github/workflows/build-release.yml @@ -11,7 +11,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v2 + uses: actions/checkout@v4 - name: Check for changes in package.json run: | @@ -36,7 +36,7 @@ jobs: echo "::set-output name=content::$CHANGELOG_ESCAPED" - name: Create GitHub release - uses: actions/github-script@v5 + uses: actions/github-script@v7 with: github-token: ${{ secrets.GITHUB_TOKEN }} script: | @@ -51,7 +51,7 @@ jobs: console.log(`Created release ${release.data.html_url}`) - name: Upload package to GitHub release - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: package path: . diff --git a/.github/workflows/deploy-to-hf-spaces.yml b/.github/workflows/deploy-to-hf-spaces.yml new file mode 100644 index 000000000..aa8bbcfce --- /dev/null +++ b/.github/workflows/deploy-to-hf-spaces.yml @@ -0,0 +1,59 @@ +name: Deploy to HuggingFace Spaces + +on: + push: + branches: + - dev + - main + workflow_dispatch: + +jobs: + check-secret: + runs-on: ubuntu-latest + outputs: + token-set: ${{ steps.check-key.outputs.defined }} + steps: + - id: check-key + env: + HF_TOKEN: ${{ secrets.HF_TOKEN }} + if: "${{ env.HF_TOKEN != '' }}" + run: echo "defined=true" >> $GITHUB_OUTPUT + + deploy: + runs-on: ubuntu-latest + needs: [check-secret] + if: needs.check-secret.outputs.token-set == 'true' + env: + HF_TOKEN: ${{ secrets.HF_TOKEN }} + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Remove git history + run: rm -rf .git + + - name: Prepend YAML front matter to README.md + run: | + echo "---" > temp_readme.md + echo "title: Open WebUI" >> temp_readme.md + echo "emoji: 🐳" >> temp_readme.md + echo "colorFrom: purple" >> temp_readme.md + echo "colorTo: gray" >> temp_readme.md + echo "sdk: docker" >> temp_readme.md + echo "app_port: 8080" >> temp_readme.md + echo "---" >> temp_readme.md + cat README.md >> temp_readme.md + mv temp_readme.md README.md + + - name: Configure git + run: | + git config --global user.email "41898282+github-actions[bot]@users.noreply.github.com" + git config --global user.name "github-actions[bot]" + - name: Set up Git and push to Space + run: | + git init --initial-branch=main + git lfs track "*.ttf" + rm demo.gif + git add . + git commit -m "GitHub deploy: ${{ github.sha }}" + git push --force https://open-webui:${HF_TOKEN}@huggingface.co/spaces/open-webui/open-webui main diff --git a/.github/workflows/docker-build.yaml b/.github/workflows/docker-build.yaml index b5dd72192..86d27f4dc 100644 --- a/.github/workflows/docker-build.yaml +++ b/.github/workflows/docker-build.yaml @@ -84,6 +84,8 @@ jobs: outputs: type=image,name=${{ env.FULL_IMAGE_NAME }},push-by-digest=true,name-canonical=true,push=true cache-from: type=registry,ref=${{ steps.cache-meta.outputs.tags }} cache-to: type=registry,ref=${{ steps.cache-meta.outputs.tags }},mode=max + build-args: | + BUILD_HASH=${{ github.sha }} - name: Export digest run: | @@ -170,7 +172,9 @@ jobs: outputs: type=image,name=${{ env.FULL_IMAGE_NAME }},push-by-digest=true,name-canonical=true,push=true cache-from: type=registry,ref=${{ steps.cache-meta.outputs.tags }} cache-to: type=registry,ref=${{ steps.cache-meta.outputs.tags }},mode=max - build-args: USE_CUDA=true + build-args: | + BUILD_HASH=${{ github.sha }} + USE_CUDA=true - name: Export digest run: | @@ -257,7 +261,9 @@ jobs: outputs: type=image,name=${{ env.FULL_IMAGE_NAME }},push-by-digest=true,name-canonical=true,push=true cache-from: type=registry,ref=${{ steps.cache-meta.outputs.tags }} cache-to: type=registry,ref=${{ steps.cache-meta.outputs.tags }},mode=max - build-args: USE_OLLAMA=true + build-args: | + BUILD_HASH=${{ github.sha }} + USE_OLLAMA=true - name: Export digest run: | diff --git a/.github/workflows/format-backend.yaml b/.github/workflows/format-backend.yaml index dd0e94868..2e980de41 100644 --- a/.github/workflows/format-backend.yaml +++ b/.github/workflows/format-backend.yaml @@ -23,7 +23,7 @@ jobs: - uses: actions/checkout@v4 - name: Set up Python - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} diff --git a/.github/workflows/format-build-frontend.yaml b/.github/workflows/format-build-frontend.yaml index 6f89f14a9..9ee57f475 100644 --- a/.github/workflows/format-build-frontend.yaml +++ b/.github/workflows/format-build-frontend.yaml @@ -19,7 +19,7 @@ jobs: uses: actions/checkout@v4 - name: Setup Node.js - uses: actions/setup-node@v3 + uses: actions/setup-node@v4 with: node-version: '20' # Or specify any other version you want to use diff --git a/.github/workflows/integration-test.yml b/.github/workflows/integration-test.yml index 32346d3b9..2426aff27 100644 --- a/.github/workflows/integration-test.yml +++ b/.github/workflows/integration-test.yml @@ -20,7 +20,11 @@ jobs: - name: Build and run Compose Stack run: | - docker compose --file docker-compose.yaml --file docker-compose.api.yaml up --detach --build + docker compose \ + --file docker-compose.yaml \ + --file docker-compose.api.yaml \ + --file docker-compose.a1111-test.yaml \ + up --detach --build - name: Wait for Ollama to be up timeout-minutes: 5 @@ -95,7 +99,7 @@ jobs: uses: actions/checkout@v4 - name: Set up Python - uses: actions/setup-python@v2 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} diff --git a/.github/workflows/release-pypi.yml b/.github/workflows/release-pypi.yml new file mode 100644 index 000000000..b786329c2 --- /dev/null +++ b/.github/workflows/release-pypi.yml @@ -0,0 +1,32 @@ +name: Release to PyPI + +on: + push: + branches: + - main # or whatever branch you want to use + - dev + +jobs: + release: + runs-on: ubuntu-latest + environment: + name: pypi + url: https://pypi.org/p/open-webui + permissions: + id-token: write + steps: + - name: Checkout repository + uses: actions/checkout@v4 + - uses: actions/setup-node@v4 + with: + node-version: 18 + - uses: actions/setup-python@v5 + with: + python-version: 3.11 + - name: Build + run: | + python -m pip install --upgrade pip + pip install build + python -m build . + - name: Publish package distributions to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 diff --git a/Dockerfile b/Dockerfile index dee049fb4..be5c1da41 100644 --- a/Dockerfile +++ b/Dockerfile @@ -11,12 +11,14 @@ ARG USE_CUDA_VER=cu121 # IMPORTANT: If you change the embedding model (sentence-transformers/all-MiniLM-L6-v2) and vice versa, you aren't able to use RAG Chat with your previous documents loaded in the WebUI! You need to re-embed them. ARG USE_EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2 ARG USE_RERANKING_MODEL="" +ARG BUILD_HASH=dev-build # Override at your own risk - non-root configurations are untested ARG UID=0 ARG GID=0 ######## WebUI frontend ######## FROM --platform=$BUILDPLATFORM node:21-alpine3.19 as build +ARG BUILD_HASH WORKDIR /app @@ -24,6 +26,7 @@ COPY package.json package-lock.json ./ RUN npm ci COPY . . +ENV APP_BUILD_HASH=${BUILD_HASH} RUN npm run build ######## WebUI backend ######## @@ -35,6 +38,7 @@ ARG USE_OLLAMA ARG USE_CUDA_VER ARG USE_EMBEDDING_MODEL ARG USE_RERANKING_MODEL +ARG BUILD_HASH ARG UID ARG GID @@ -59,11 +63,6 @@ ENV OPENAI_API_KEY="" \ DO_NOT_TRACK=true \ ANONYMIZED_TELEMETRY=false -# Use locally bundled version of the LiteLLM cost map json -# to avoid repetitive startup connections -ENV LITELLM_LOCAL_MODEL_COST_MAP="True" - - #### Other models ######################################################### ## whisper TTS model settings ## ENV WHISPER_MODEL="base" \ @@ -83,10 +82,10 @@ WORKDIR /app/backend ENV HOME /root # Create user and group if not root RUN if [ $UID -ne 0 ]; then \ - if [ $GID -ne 0 ]; then \ - addgroup --gid $GID app; \ - fi; \ - adduser --uid $UID --gid $GID --home $HOME --disabled-password --no-create-home app; \ + if [ $GID -ne 0 ]; then \ + addgroup --gid $GID app; \ + fi; \ + adduser --uid $UID --gid $GID --home $HOME --disabled-password --no-create-home app; \ fi RUN mkdir -p $HOME/.cache/chroma @@ -132,7 +131,8 @@ RUN pip3 install uv && \ uv pip install --system -r requirements.txt --no-cache-dir && \ python -c "import os; from sentence_transformers import SentenceTransformer; SentenceTransformer(os.environ['RAG_EMBEDDING_MODEL'], device='cpu')" && \ python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"; \ - fi + fi; \ + chown -R $UID:$GID /app/backend/data/ @@ -154,4 +154,6 @@ HEALTHCHECK CMD curl --silent --fail http://localhost:8080/health | jq -e '.stat USER $UID:$GID +ENV WEBUI_BUILD_VERSION=${BUILD_HASH} + CMD [ "bash", "start.sh"] diff --git a/backend/apps/litellm/main.py b/backend/apps/litellm/main.py deleted file mode 100644 index 6a355038b..000000000 --- a/backend/apps/litellm/main.py +++ /dev/null @@ -1,379 +0,0 @@ -import sys -from contextlib import asynccontextmanager - -from fastapi import FastAPI, Depends, HTTPException -from fastapi.routing import APIRoute -from fastapi.middleware.cors import CORSMiddleware - -import logging -from fastapi import FastAPI, Request, Depends, status, Response -from fastapi.responses import JSONResponse - -from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint -from starlette.responses import StreamingResponse -import json -import time -import requests - -from pydantic import BaseModel, ConfigDict -from typing import Optional, List - -from utils.utils import get_verified_user, get_current_user, get_admin_user -from config import SRC_LOG_LEVELS, ENV -from constants import MESSAGES - -import os - -log = logging.getLogger(__name__) -log.setLevel(SRC_LOG_LEVELS["LITELLM"]) - - -from config import ( - ENABLE_LITELLM, - ENABLE_MODEL_FILTER, - MODEL_FILTER_LIST, - DATA_DIR, - LITELLM_PROXY_PORT, - LITELLM_PROXY_HOST, -) - -import warnings - -warnings.simplefilter("ignore") - -from litellm.utils import get_llm_provider - -import asyncio -import subprocess -import yaml - - -@asynccontextmanager -async def lifespan(app: FastAPI): - log.info("startup_event") - # TODO: Check config.yaml file and create one - asyncio.create_task(start_litellm_background()) - yield - - -app = FastAPI(lifespan=lifespan) - -origins = ["*"] - -app.add_middleware( - CORSMiddleware, - allow_origins=origins, - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) - - -LITELLM_CONFIG_DIR = f"{DATA_DIR}/litellm/config.yaml" - -with open(LITELLM_CONFIG_DIR, "r") as file: - litellm_config = yaml.safe_load(file) - - -app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER.value -app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST.value - - -app.state.ENABLE = ENABLE_LITELLM -app.state.CONFIG = litellm_config - -# Global variable to store the subprocess reference -background_process = None - -CONFLICT_ENV_VARS = [ - # Uvicorn uses PORT, so LiteLLM might use it as well - "PORT", - # LiteLLM uses DATABASE_URL for Prisma connections - "DATABASE_URL", -] - - -async def run_background_process(command): - global background_process - log.info("run_background_process") - - try: - # Log the command to be executed - log.info(f"Executing command: {command}") - # Filter environment variables known to conflict with litellm - env = {k: v for k, v in os.environ.items() if k not in CONFLICT_ENV_VARS} - # Execute the command and create a subprocess - process = await asyncio.create_subprocess_exec( - *command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env - ) - background_process = process - log.info("Subprocess started successfully.") - - # Capture STDERR for debugging purposes - stderr_output = await process.stderr.read() - stderr_text = stderr_output.decode().strip() - if stderr_text: - log.info(f"Subprocess STDERR: {stderr_text}") - - # log.info output line by line - async for line in process.stdout: - log.info(line.decode().strip()) - - # Wait for the process to finish - returncode = await process.wait() - log.info(f"Subprocess exited with return code {returncode}") - except Exception as e: - log.error(f"Failed to start subprocess: {e}") - raise # Optionally re-raise the exception if you want it to propagate - - -async def start_litellm_background(): - log.info("start_litellm_background") - # Command to run in the background - command = [ - "litellm", - "--port", - str(LITELLM_PROXY_PORT), - "--host", - LITELLM_PROXY_HOST, - "--telemetry", - "False", - "--config", - LITELLM_CONFIG_DIR, - ] - - await run_background_process(command) - - -async def shutdown_litellm_background(): - log.info("shutdown_litellm_background") - global background_process - if background_process: - background_process.terminate() - await background_process.wait() # Ensure the process has terminated - log.info("Subprocess terminated") - background_process = None - - -@app.get("/") -async def get_status(): - return {"status": True} - - -async def restart_litellm(): - """ - Endpoint to restart the litellm background service. - """ - log.info("Requested restart of litellm service.") - try: - # Shut down the existing process if it is running - await shutdown_litellm_background() - log.info("litellm service shutdown complete.") - - # Restart the background service - - asyncio.create_task(start_litellm_background()) - log.info("litellm service restart complete.") - - return { - "status": "success", - "message": "litellm service restarted successfully.", - } - except Exception as e: - log.info(f"Error restarting litellm service: {e}") - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e) - ) - - -@app.get("/restart") -async def restart_litellm_handler(user=Depends(get_admin_user)): - return await restart_litellm() - - -@app.get("/config") -async def get_config(user=Depends(get_admin_user)): - return app.state.CONFIG - - -class LiteLLMConfigForm(BaseModel): - general_settings: Optional[dict] = None - litellm_settings: Optional[dict] = None - model_list: Optional[List[dict]] = None - router_settings: Optional[dict] = None - - model_config = ConfigDict(protected_namespaces=()) - - -@app.post("/config/update") -async def update_config(form_data: LiteLLMConfigForm, user=Depends(get_admin_user)): - app.state.CONFIG = form_data.model_dump(exclude_none=True) - - with open(LITELLM_CONFIG_DIR, "w") as file: - yaml.dump(app.state.CONFIG, file) - - await restart_litellm() - return app.state.CONFIG - - -@app.get("/models") -@app.get("/v1/models") -async def get_models(user=Depends(get_current_user)): - - if app.state.ENABLE: - while not background_process: - await asyncio.sleep(0.1) - - url = f"http://localhost:{LITELLM_PROXY_PORT}/v1" - r = None - try: - r = requests.request(method="GET", url=f"{url}/models") - r.raise_for_status() - - data = r.json() - - if app.state.ENABLE_MODEL_FILTER: - if user and user.role == "user": - data["data"] = list( - filter( - lambda model: model["id"] in app.state.MODEL_FILTER_LIST, - data["data"], - ) - ) - - return data - except Exception as e: - - log.exception(e) - error_detail = "Open WebUI: Server Connection Error" - if r is not None: - try: - res = r.json() - if "error" in res: - error_detail = f"External: {res['error']}" - except: - error_detail = f"External: {e}" - - return { - "data": [ - { - "id": model["model_name"], - "object": "model", - "created": int(time.time()), - "owned_by": "openai", - } - for model in app.state.CONFIG["model_list"] - ], - "object": "list", - } - else: - return { - "data": [], - "object": "list", - } - - -@app.get("/model/info") -async def get_model_list(user=Depends(get_admin_user)): - return {"data": app.state.CONFIG["model_list"]} - - -class AddLiteLLMModelForm(BaseModel): - model_name: str - litellm_params: dict - - model_config = ConfigDict(protected_namespaces=()) - - -@app.post("/model/new") -async def add_model_to_config( - form_data: AddLiteLLMModelForm, user=Depends(get_admin_user) -): - try: - get_llm_provider(model=form_data.model_name) - app.state.CONFIG["model_list"].append(form_data.model_dump()) - - with open(LITELLM_CONFIG_DIR, "w") as file: - yaml.dump(app.state.CONFIG, file) - - await restart_litellm() - - return {"message": MESSAGES.MODEL_ADDED(form_data.model_name)} - except Exception as e: - print(e) - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e) - ) - - -class DeleteLiteLLMModelForm(BaseModel): - id: str - - -@app.post("/model/delete") -async def delete_model_from_config( - form_data: DeleteLiteLLMModelForm, user=Depends(get_admin_user) -): - app.state.CONFIG["model_list"] = [ - model - for model in app.state.CONFIG["model_list"] - if model["model_name"] != form_data.id - ] - - with open(LITELLM_CONFIG_DIR, "w") as file: - yaml.dump(app.state.CONFIG, file) - - await restart_litellm() - - return {"message": MESSAGES.MODEL_DELETED(form_data.id)} - - -@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) -async def proxy(path: str, request: Request, user=Depends(get_verified_user)): - body = await request.body() - - url = f"http://localhost:{LITELLM_PROXY_PORT}" - - target_url = f"{url}/{path}" - - headers = {} - # headers["Authorization"] = f"Bearer {key}" - headers["Content-Type"] = "application/json" - - r = None - - try: - r = requests.request( - method=request.method, - url=target_url, - data=body, - headers=headers, - stream=True, - ) - - r.raise_for_status() - - # Check if response is SSE - if "text/event-stream" in r.headers.get("Content-Type", ""): - return StreamingResponse( - r.iter_content(chunk_size=8192), - status_code=r.status_code, - headers=dict(r.headers), - ) - else: - response_data = r.json() - return response_data - except Exception as e: - log.exception(e) - error_detail = "Open WebUI: Server Connection Error" - if r is not None: - try: - res = r.json() - if "error" in res: - error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}" - except: - error_detail = f"External: {e}" - - raise HTTPException( - status_code=r.status_code if r else 500, detail=error_detail - ) diff --git a/backend/apps/ollama/main.py b/backend/apps/ollama/main.py index df268067f..01e127074 100644 --- a/backend/apps/ollama/main.py +++ b/backend/apps/ollama/main.py @@ -29,8 +29,8 @@ import time from urllib.parse import urlparse from typing import Optional, List, Union - -from apps.web.models.users import Users +from apps.webui.models.models import Models +from apps.webui.models.users import Users from constants import ERROR_MESSAGES from utils.utils import ( decode_token, @@ -39,10 +39,13 @@ from utils.utils import ( get_admin_user, ) +from utils.models import get_model_id_from_custom_model_id + from config import ( SRC_LOG_LEVELS, OLLAMA_BASE_URLS, + ENABLE_OLLAMA_API, ENABLE_MODEL_FILTER, MODEL_FILTER_LIST, UPLOAD_DIR, @@ -67,6 +70,7 @@ app.state.config = AppConfig() app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST +app.state.config.ENABLE_OLLAMA_API = ENABLE_OLLAMA_API app.state.config.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS app.state.MODELS = {} @@ -96,6 +100,21 @@ async def get_status(): return {"status": True} +@app.get("/config") +async def get_config(user=Depends(get_admin_user)): + return {"ENABLE_OLLAMA_API": app.state.config.ENABLE_OLLAMA_API} + + +class OllamaConfigForm(BaseModel): + enable_ollama_api: Optional[bool] = None + + +@app.post("/config/update") +async def update_config(form_data: OllamaConfigForm, user=Depends(get_admin_user)): + app.state.config.ENABLE_OLLAMA_API = form_data.enable_ollama_api + return {"ENABLE_OLLAMA_API": app.state.config.ENABLE_OLLAMA_API} + + @app.get("/urls") async def get_ollama_api_urls(user=Depends(get_admin_user)): return {"OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS} @@ -156,14 +175,23 @@ def merge_models_lists(model_lists): async def get_all_models(): log.info("get_all_models()") - tasks = [fetch_url(f"{url}/api/tags") for url in app.state.config.OLLAMA_BASE_URLS] - responses = await asyncio.gather(*tasks) - models = { - "models": merge_models_lists( - map(lambda response: response["models"] if response else None, responses) - ) - } + if app.state.config.ENABLE_OLLAMA_API: + tasks = [ + fetch_url(f"{url}/api/tags") for url in app.state.config.OLLAMA_BASE_URLS + ] + responses = await asyncio.gather(*tasks) + + models = { + "models": merge_models_lists( + map( + lambda response: response["models"] if response else None, responses + ) + ) + } + + else: + models = {"models": []} app.state.MODELS = {model["model"]: model for model in models["models"]} @@ -278,6 +306,9 @@ async def pull_model( r = None + # Admin should be able to pull models from any source + payload = {**form_data.model_dump(exclude_none=True), "insecure": True} + def get_request(): nonlocal url nonlocal r @@ -305,7 +336,7 @@ async def pull_model( r = requests.request( method="POST", url=f"{url}/api/pull", - data=form_data.model_dump_json(exclude_none=True).encode(), + data=json.dumps(payload), stream=True, ) @@ -848,14 +879,93 @@ async def generate_chat_completion( user=Depends(get_verified_user), ): + log.debug( + "form_data.model_dump_json(exclude_none=True).encode(): {0} ".format( + form_data.model_dump_json(exclude_none=True).encode() + ) + ) + + payload = { + **form_data.model_dump(exclude_none=True), + } + + model_id = form_data.model + model_info = Models.get_model_by_id(model_id) + + if model_info: + print(model_info) + if model_info.base_model_id: + payload["model"] = model_info.base_model_id + + model_info.params = model_info.params.model_dump() + + if model_info.params: + payload["options"] = {} + + payload["options"]["mirostat"] = model_info.params.get("mirostat", None) + payload["options"]["mirostat_eta"] = model_info.params.get( + "mirostat_eta", None + ) + payload["options"]["mirostat_tau"] = model_info.params.get( + "mirostat_tau", None + ) + payload["options"]["num_ctx"] = model_info.params.get("num_ctx", None) + + payload["options"]["repeat_last_n"] = model_info.params.get( + "repeat_last_n", None + ) + payload["options"]["repeat_penalty"] = model_info.params.get( + "frequency_penalty", None + ) + + payload["options"]["temperature"] = model_info.params.get( + "temperature", None + ) + payload["options"]["seed"] = model_info.params.get("seed", None) + + payload["options"]["stop"] = ( + [ + bytes(stop, "utf-8").decode("unicode_escape") + for stop in model_info.params["stop"] + ] + if model_info.params.get("stop", None) + else None + ) + + payload["options"]["tfs_z"] = model_info.params.get("tfs_z", None) + + payload["options"]["num_predict"] = model_info.params.get( + "max_tokens", None + ) + payload["options"]["top_k"] = model_info.params.get("top_k", None) + + payload["options"]["top_p"] = model_info.params.get("top_p", None) + + if model_info.params.get("system", None): + # Check if the payload already has a system message + # If not, add a system message to the payload + if payload.get("messages"): + for message in payload["messages"]: + if message.get("role") == "system": + message["content"] = ( + model_info.params.get("system", None) + message["content"] + ) + break + else: + payload["messages"].insert( + 0, + { + "role": "system", + "content": model_info.params.get("system", None), + }, + ) + if url_idx == None: - model = form_data.model + if ":" not in payload["model"]: + payload["model"] = f"{payload['model']}:latest" - if ":" not in model: - model = f"{model}:latest" - - if model in app.state.MODELS: - url_idx = random.choice(app.state.MODELS[model]["urls"]) + if payload["model"] in app.state.MODELS: + url_idx = random.choice(app.state.MODELS[payload["model"]]["urls"]) else: raise HTTPException( status_code=400, @@ -865,16 +975,12 @@ async def generate_chat_completion( url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") + print(payload) + r = None - log.debug( - "form_data.model_dump_json(exclude_none=True).encode(): {0} ".format( - form_data.model_dump_json(exclude_none=True).encode() - ) - ) - def get_request(): - nonlocal form_data + nonlocal payload nonlocal r request_id = str(uuid.uuid4()) @@ -883,7 +989,7 @@ async def generate_chat_completion( def stream_content(): try: - if form_data.stream: + if payload.get("stream", None): yield json.dumps({"id": request_id, "done": False}) + "\n" for chunk in r.iter_content(chunk_size=8192): @@ -901,7 +1007,7 @@ async def generate_chat_completion( r = requests.request( method="POST", url=f"{url}/api/chat", - data=form_data.model_dump_json(exclude_none=True).encode(), + data=json.dumps(payload), stream=True, ) @@ -957,14 +1063,62 @@ async def generate_openai_chat_completion( user=Depends(get_verified_user), ): + payload = { + **form_data.model_dump(exclude_none=True), + } + + model_id = form_data.model + model_info = Models.get_model_by_id(model_id) + + if model_info: + print(model_info) + if model_info.base_model_id: + payload["model"] = model_info.base_model_id + + model_info.params = model_info.params.model_dump() + + if model_info.params: + payload["temperature"] = model_info.params.get("temperature", None) + payload["top_p"] = model_info.params.get("top_p", None) + payload["max_tokens"] = model_info.params.get("max_tokens", None) + payload["frequency_penalty"] = model_info.params.get( + "frequency_penalty", None + ) + payload["seed"] = model_info.params.get("seed", None) + payload["stop"] = ( + [ + bytes(stop, "utf-8").decode("unicode_escape") + for stop in model_info.params["stop"] + ] + if model_info.params.get("stop", None) + else None + ) + + if model_info.params.get("system", None): + # Check if the payload already has a system message + # If not, add a system message to the payload + if payload.get("messages"): + for message in payload["messages"]: + if message.get("role") == "system": + message["content"] = ( + model_info.params.get("system", None) + message["content"] + ) + break + else: + payload["messages"].insert( + 0, + { + "role": "system", + "content": model_info.params.get("system", None), + }, + ) + if url_idx == None: - model = form_data.model + if ":" not in payload["model"]: + payload["model"] = f"{payload['model']}:latest" - if ":" not in model: - model = f"{model}:latest" - - if model in app.state.MODELS: - url_idx = random.choice(app.state.MODELS[model]["urls"]) + if payload["model"] in app.state.MODELS: + url_idx = random.choice(app.state.MODELS[payload["model"]]["urls"]) else: raise HTTPException( status_code=400, @@ -977,7 +1131,7 @@ async def generate_openai_chat_completion( r = None def get_request(): - nonlocal form_data + nonlocal payload nonlocal r request_id = str(uuid.uuid4()) @@ -986,7 +1140,7 @@ async def generate_openai_chat_completion( def stream_content(): try: - if form_data.stream: + if payload.get("stream"): yield json.dumps( {"request_id": request_id, "done": False} ) + "\n" @@ -1006,7 +1160,7 @@ async def generate_openai_chat_completion( r = requests.request( method="POST", url=f"{url}/v1/chat/completions", - data=form_data.model_dump_json(exclude_none=True).encode(), + data=json.dumps(payload), stream=True, ) diff --git a/backend/apps/openai/main.py b/backend/apps/openai/main.py index 85ee531f1..e19e0bbcf 100644 --- a/backend/apps/openai/main.py +++ b/backend/apps/openai/main.py @@ -10,8 +10,8 @@ import logging from pydantic import BaseModel - -from apps.web.models.users import Users +from apps.webui.models.models import Models +from apps.webui.models.users import Users from constants import ERROR_MESSAGES from utils.utils import ( decode_token, @@ -53,7 +53,6 @@ app.state.config = AppConfig() app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST - app.state.config.ENABLE_OPENAI_API = ENABLE_OPENAI_API app.state.config.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS app.state.config.OPENAI_API_KEYS = OPENAI_API_KEYS @@ -199,14 +198,20 @@ async def fetch_url(url, key): def merge_models_lists(model_lists): - log.info(f"merge_models_lists {model_lists}") + log.debug(f"merge_models_lists {model_lists}") merged_list = [] for idx, models in enumerate(model_lists): if models is not None and "error" not in models: merged_list.extend( [ - {**model, "urlIdx": idx} + { + **model, + "name": model.get("name", model["id"]), + "owned_by": "openai", + "openai": model, + "urlIdx": idx, + } for model in models if "api.openai.com" not in app.state.config.OPENAI_API_BASE_URLS[idx] @@ -232,7 +237,7 @@ async def get_all_models(): ] responses = await asyncio.gather(*tasks) - log.info(f"get_all_models:responses() {responses}") + log.debug(f"get_all_models:responses() {responses}") models = { "data": merge_models_lists( @@ -249,10 +254,10 @@ async def get_all_models(): ) } - log.info(f"models: {models}") + log.debug(f"models: {models}") app.state.MODELS = {model["id"]: model for model in models["data"]} - return models + return models @app.get("/models") @@ -310,31 +315,93 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)): body = await request.body() # TODO: Remove below after gpt-4-vision fix from Open AI # Try to decode the body of the request from bytes to a UTF-8 string (Require add max_token to fix gpt-4-vision) + + payload = None + try: - body = body.decode("utf-8") - body = json.loads(body) + if "chat/completions" in path: + body = body.decode("utf-8") + body = json.loads(body) - idx = app.state.MODELS[body.get("model")]["urlIdx"] + payload = {**body} - # Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000 - # This is a workaround until OpenAI fixes the issue with this model - if body.get("model") == "gpt-4-vision-preview": - if "max_tokens" not in body: - body["max_tokens"] = 4000 - log.debug("Modified body_dict:", body) + model_id = body.get("model") + model_info = Models.get_model_by_id(model_id) - # Fix for ChatGPT calls failing because the num_ctx key is in body - if "num_ctx" in body: - # If 'num_ctx' is in the dictionary, delete it - # Leaving it there generates an error with the - # OpenAI API (Feb 2024) - del body["num_ctx"] + if model_info: + print(model_info) + if model_info.base_model_id: + payload["model"] = model_info.base_model_id + + model_info.params = model_info.params.model_dump() + + if model_info.params: + payload["temperature"] = model_info.params.get("temperature", None) + payload["top_p"] = model_info.params.get("top_p", None) + payload["max_tokens"] = model_info.params.get("max_tokens", None) + payload["frequency_penalty"] = model_info.params.get( + "frequency_penalty", None + ) + payload["seed"] = model_info.params.get("seed", None) + payload["stop"] = ( + [ + bytes(stop, "utf-8").decode("unicode_escape") + for stop in model_info.params["stop"] + ] + if model_info.params.get("stop", None) + else None + ) + + if model_info.params.get("system", None): + # Check if the payload already has a system message + # If not, add a system message to the payload + if payload.get("messages"): + for message in payload["messages"]: + if message.get("role") == "system": + message["content"] = ( + model_info.params.get("system", None) + + message["content"] + ) + break + else: + payload["messages"].insert( + 0, + { + "role": "system", + "content": model_info.params.get("system", None), + }, + ) + else: + pass + + print(app.state.MODELS) + model = app.state.MODELS[payload.get("model")] + + idx = model["urlIdx"] + + if "pipeline" in model and model.get("pipeline"): + payload["user"] = {"name": user.name, "id": user.id} + payload["title"] = ( + True + if payload["stream"] == False and payload["max_tokens"] == 50 + else False + ) + + # Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000 + # This is a workaround until OpenAI fixes the issue with this model + if payload.get("model") == "gpt-4-vision-preview": + if "max_tokens" not in payload: + payload["max_tokens"] = 4000 + log.debug("Modified payload:", payload) + + # Convert the modified body back to JSON + payload = json.dumps(payload) - # Convert the modified body back to JSON - body = json.dumps(body) except json.JSONDecodeError as e: log.error("Error loading request body into a dictionary:", e) + print(payload) + url = app.state.config.OPENAI_API_BASE_URLS[idx] key = app.state.config.OPENAI_API_KEYS[idx] @@ -353,7 +420,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)): r = requests.request( method=request.method, url=target_url, - data=body, + data=payload if payload else body, headers=headers, stream=True, ) diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index f08d81a3b..d04c256d7 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -46,7 +46,7 @@ import json import sentence_transformers -from apps.web.models.documents import ( +from apps.webui.models.documents import ( Documents, DocumentForm, DocumentResponse, diff --git a/backend/apps/web/models/modelfiles.py b/backend/apps/web/models/modelfiles.py deleted file mode 100644 index 1d60d7c55..000000000 --- a/backend/apps/web/models/modelfiles.py +++ /dev/null @@ -1,136 +0,0 @@ -from pydantic import BaseModel -from peewee import * -from playhouse.shortcuts import model_to_dict -from typing import List, Union, Optional -import time - -from utils.utils import decode_token -from utils.misc import get_gravatar_url - -from apps.web.internal.db import DB - -import json - -#################### -# Modelfile DB Schema -#################### - - -class Modelfile(Model): - tag_name = CharField(unique=True) - user_id = CharField() - modelfile = TextField() - timestamp = BigIntegerField() - - class Meta: - database = DB - - -class ModelfileModel(BaseModel): - tag_name: str - user_id: str - modelfile: str - timestamp: int # timestamp in epoch - - -#################### -# Forms -#################### - - -class ModelfileForm(BaseModel): - modelfile: dict - - -class ModelfileTagNameForm(BaseModel): - tag_name: str - - -class ModelfileUpdateForm(ModelfileForm, ModelfileTagNameForm): - pass - - -class ModelfileResponse(BaseModel): - tag_name: str - user_id: str - modelfile: dict - timestamp: int # timestamp in epoch - - -class ModelfilesTable: - - def __init__(self, db): - self.db = db - self.db.create_tables([Modelfile]) - - def insert_new_modelfile( - self, user_id: str, form_data: ModelfileForm - ) -> Optional[ModelfileModel]: - if "tagName" in form_data.modelfile: - modelfile = ModelfileModel( - **{ - "user_id": user_id, - "tag_name": form_data.modelfile["tagName"], - "modelfile": json.dumps(form_data.modelfile), - "timestamp": int(time.time()), - } - ) - - try: - result = Modelfile.create(**modelfile.model_dump()) - if result: - return modelfile - else: - return None - except: - return None - - else: - return None - - def get_modelfile_by_tag_name(self, tag_name: str) -> Optional[ModelfileModel]: - try: - modelfile = Modelfile.get(Modelfile.tag_name == tag_name) - return ModelfileModel(**model_to_dict(modelfile)) - except: - return None - - def get_modelfiles(self, skip: int = 0, limit: int = 50) -> List[ModelfileResponse]: - return [ - ModelfileResponse( - **{ - **model_to_dict(modelfile), - "modelfile": json.loads(modelfile.modelfile), - } - ) - for modelfile in Modelfile.select() - # .limit(limit).offset(skip) - ] - - def update_modelfile_by_tag_name( - self, tag_name: str, modelfile: dict - ) -> Optional[ModelfileModel]: - try: - query = Modelfile.update( - modelfile=json.dumps(modelfile), - timestamp=int(time.time()), - ).where(Modelfile.tag_name == tag_name) - - query.execute() - - modelfile = Modelfile.get(Modelfile.tag_name == tag_name) - return ModelfileModel(**model_to_dict(modelfile)) - except: - return None - - def delete_modelfile_by_tag_name(self, tag_name: str) -> bool: - try: - query = Modelfile.delete().where((Modelfile.tag_name == tag_name)) - query.execute() # Remove the rows, return number of rows removed. - - return True - except: - return False - - -Modelfiles = ModelfilesTable(DB) diff --git a/backend/apps/web/routers/modelfiles.py b/backend/apps/web/routers/modelfiles.py deleted file mode 100644 index 3cdbf8a74..000000000 --- a/backend/apps/web/routers/modelfiles.py +++ /dev/null @@ -1,124 +0,0 @@ -from fastapi import Depends, FastAPI, HTTPException, status -from datetime import datetime, timedelta -from typing import List, Union, Optional - -from fastapi import APIRouter -from pydantic import BaseModel -import json -from apps.web.models.modelfiles import ( - Modelfiles, - ModelfileForm, - ModelfileTagNameForm, - ModelfileUpdateForm, - ModelfileResponse, -) - -from utils.utils import get_current_user, get_admin_user -from constants import ERROR_MESSAGES - -router = APIRouter() - -############################ -# GetModelfiles -############################ - - -@router.get("/", response_model=List[ModelfileResponse]) -async def get_modelfiles( - skip: int = 0, limit: int = 50, user=Depends(get_current_user) -): - return Modelfiles.get_modelfiles(skip, limit) - - -############################ -# CreateNewModelfile -############################ - - -@router.post("/create", response_model=Optional[ModelfileResponse]) -async def create_new_modelfile(form_data: ModelfileForm, user=Depends(get_admin_user)): - modelfile = Modelfiles.insert_new_modelfile(user.id, form_data) - - if modelfile: - return ModelfileResponse( - **{ - **modelfile.model_dump(), - "modelfile": json.loads(modelfile.modelfile), - } - ) - else: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail=ERROR_MESSAGES.DEFAULT(), - ) - - -############################ -# GetModelfileByTagName -############################ - - -@router.post("/", response_model=Optional[ModelfileResponse]) -async def get_modelfile_by_tag_name( - form_data: ModelfileTagNameForm, user=Depends(get_current_user) -): - modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name) - - if modelfile: - return ModelfileResponse( - **{ - **modelfile.model_dump(), - "modelfile": json.loads(modelfile.modelfile), - } - ) - else: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail=ERROR_MESSAGES.NOT_FOUND, - ) - - -############################ -# UpdateModelfileByTagName -############################ - - -@router.post("/update", response_model=Optional[ModelfileResponse]) -async def update_modelfile_by_tag_name( - form_data: ModelfileUpdateForm, user=Depends(get_admin_user) -): - modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name) - if modelfile: - updated_modelfile = { - **json.loads(modelfile.modelfile), - **form_data.modelfile, - } - - modelfile = Modelfiles.update_modelfile_by_tag_name( - form_data.tag_name, updated_modelfile - ) - - return ModelfileResponse( - **{ - **modelfile.model_dump(), - "modelfile": json.loads(modelfile.modelfile), - } - ) - else: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail=ERROR_MESSAGES.ACCESS_PROHIBITED, - ) - - -############################ -# DeleteModelfileByTagName -############################ - - -@router.delete("/delete", response_model=bool) -async def delete_modelfile_by_tag_name( - form_data: ModelfileTagNameForm, user=Depends(get_admin_user) -): - result = Modelfiles.delete_modelfile_by_tag_name(form_data.tag_name) - return result diff --git a/backend/apps/web/internal/db.py b/backend/apps/webui/internal/db.py similarity index 58% rename from backend/apps/web/internal/db.py rename to backend/apps/webui/internal/db.py index 136e3fafc..0e7b1f95d 100644 --- a/backend/apps/web/internal/db.py +++ b/backend/apps/webui/internal/db.py @@ -1,13 +1,25 @@ +import json + from peewee import * from peewee_migrate import Router from playhouse.db_url import connect -from config import SRC_LOG_LEVELS, DATA_DIR, DATABASE_URL +from config import SRC_LOG_LEVELS, DATA_DIR, DATABASE_URL, BACKEND_DIR import os import logging log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["DB"]) + +class JSONField(TextField): + def db_value(self, value): + return json.dumps(value) + + def python_value(self, value): + if value is not None: + return json.loads(value) + + # Check if the file exists if os.path.exists(f"{DATA_DIR}/ollama.db"): # Rename the file @@ -18,6 +30,10 @@ else: DB = connect(DATABASE_URL) log.info(f"Connected to a {DB.__class__.__name__} database.") -router = Router(DB, migrate_dir="apps/web/internal/migrations", logger=log) +router = Router( + DB, + migrate_dir=BACKEND_DIR / "apps" / "webui" / "internal" / "migrations", + logger=log, +) router.run() DB.connect(reuse_if_open=True) diff --git a/backend/apps/web/internal/migrations/001_initial_schema.py b/backend/apps/webui/internal/migrations/001_initial_schema.py similarity index 100% rename from backend/apps/web/internal/migrations/001_initial_schema.py rename to backend/apps/webui/internal/migrations/001_initial_schema.py diff --git a/backend/apps/web/internal/migrations/002_add_local_sharing.py b/backend/apps/webui/internal/migrations/002_add_local_sharing.py similarity index 100% rename from backend/apps/web/internal/migrations/002_add_local_sharing.py rename to backend/apps/webui/internal/migrations/002_add_local_sharing.py diff --git a/backend/apps/web/internal/migrations/003_add_auth_api_key.py b/backend/apps/webui/internal/migrations/003_add_auth_api_key.py similarity index 100% rename from backend/apps/web/internal/migrations/003_add_auth_api_key.py rename to backend/apps/webui/internal/migrations/003_add_auth_api_key.py diff --git a/backend/apps/web/internal/migrations/004_add_archived.py b/backend/apps/webui/internal/migrations/004_add_archived.py similarity index 100% rename from backend/apps/web/internal/migrations/004_add_archived.py rename to backend/apps/webui/internal/migrations/004_add_archived.py diff --git a/backend/apps/web/internal/migrations/005_add_updated_at.py b/backend/apps/webui/internal/migrations/005_add_updated_at.py similarity index 100% rename from backend/apps/web/internal/migrations/005_add_updated_at.py rename to backend/apps/webui/internal/migrations/005_add_updated_at.py diff --git a/backend/apps/web/internal/migrations/006_migrate_timestamps_and_charfields.py b/backend/apps/webui/internal/migrations/006_migrate_timestamps_and_charfields.py similarity index 100% rename from backend/apps/web/internal/migrations/006_migrate_timestamps_and_charfields.py rename to backend/apps/webui/internal/migrations/006_migrate_timestamps_and_charfields.py diff --git a/backend/apps/web/internal/migrations/007_add_user_last_active_at.py b/backend/apps/webui/internal/migrations/007_add_user_last_active_at.py similarity index 100% rename from backend/apps/web/internal/migrations/007_add_user_last_active_at.py rename to backend/apps/webui/internal/migrations/007_add_user_last_active_at.py diff --git a/backend/apps/web/internal/migrations/008_add_memory.py b/backend/apps/webui/internal/migrations/008_add_memory.py similarity index 100% rename from backend/apps/web/internal/migrations/008_add_memory.py rename to backend/apps/webui/internal/migrations/008_add_memory.py diff --git a/backend/apps/webui/internal/migrations/009_add_models.py b/backend/apps/webui/internal/migrations/009_add_models.py new file mode 100644 index 000000000..548ec7cdc --- /dev/null +++ b/backend/apps/webui/internal/migrations/009_add_models.py @@ -0,0 +1,61 @@ +"""Peewee migrations -- 009_add_models.py. + +Some examples (model - class or model name):: + + > Model = migrator.orm['table_name'] # Return model in current state by name + > Model = migrator.ModelClass # Return model in current state by name + + > migrator.sql(sql) # Run custom SQL + > migrator.run(func, *args, **kwargs) # Run python function with the given args + > migrator.create_model(Model) # Create a model (could be used as decorator) + > migrator.remove_model(model, cascade=True) # Remove a model + > migrator.add_fields(model, **fields) # Add fields to a model + > migrator.change_fields(model, **fields) # Change fields + > migrator.remove_fields(model, *field_names, cascade=True) + > migrator.rename_field(model, old_field_name, new_field_name) + > migrator.rename_table(model, new_table_name) + > migrator.add_index(model, *col_names, unique=False) + > migrator.add_not_null(model, *field_names) + > migrator.add_default(model, field_name, default) + > migrator.add_constraint(model, name, sql) + > migrator.drop_index(model, *col_names) + > migrator.drop_not_null(model, *field_names) + > migrator.drop_constraints(model, *constraints) + +""" + +from contextlib import suppress + +import peewee as pw +from peewee_migrate import Migrator + + +with suppress(ImportError): + import playhouse.postgres_ext as pw_pext + + +def migrate(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your migrations here.""" + + @migrator.create_model + class Model(pw.Model): + id = pw.TextField(unique=True) + user_id = pw.TextField() + base_model_id = pw.TextField(null=True) + + name = pw.TextField() + + meta = pw.TextField() + params = pw.TextField() + + created_at = pw.BigIntegerField(null=False) + updated_at = pw.BigIntegerField(null=False) + + class Meta: + table_name = "model" + + +def rollback(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your rollback migrations here.""" + + migrator.remove_model("model") diff --git a/backend/apps/webui/internal/migrations/010_migrate_modelfiles_to_models.py b/backend/apps/webui/internal/migrations/010_migrate_modelfiles_to_models.py new file mode 100644 index 000000000..2ef814c06 --- /dev/null +++ b/backend/apps/webui/internal/migrations/010_migrate_modelfiles_to_models.py @@ -0,0 +1,130 @@ +"""Peewee migrations -- 009_add_models.py. + +Some examples (model - class or model name):: + + > Model = migrator.orm['table_name'] # Return model in current state by name + > Model = migrator.ModelClass # Return model in current state by name + + > migrator.sql(sql) # Run custom SQL + > migrator.run(func, *args, **kwargs) # Run python function with the given args + > migrator.create_model(Model) # Create a model (could be used as decorator) + > migrator.remove_model(model, cascade=True) # Remove a model + > migrator.add_fields(model, **fields) # Add fields to a model + > migrator.change_fields(model, **fields) # Change fields + > migrator.remove_fields(model, *field_names, cascade=True) + > migrator.rename_field(model, old_field_name, new_field_name) + > migrator.rename_table(model, new_table_name) + > migrator.add_index(model, *col_names, unique=False) + > migrator.add_not_null(model, *field_names) + > migrator.add_default(model, field_name, default) + > migrator.add_constraint(model, name, sql) + > migrator.drop_index(model, *col_names) + > migrator.drop_not_null(model, *field_names) + > migrator.drop_constraints(model, *constraints) + +""" + +from contextlib import suppress + +import peewee as pw +from peewee_migrate import Migrator +import json + +from utils.misc import parse_ollama_modelfile + +with suppress(ImportError): + import playhouse.postgres_ext as pw_pext + + +def migrate(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your migrations here.""" + + # Fetch data from 'modelfile' table and insert into 'model' table + migrate_modelfile_to_model(migrator, database) + # Drop the 'modelfile' table + migrator.remove_model("modelfile") + + +def migrate_modelfile_to_model(migrator: Migrator, database: pw.Database): + ModelFile = migrator.orm["modelfile"] + Model = migrator.orm["model"] + + modelfiles = ModelFile.select() + + for modelfile in modelfiles: + # Extract and transform data in Python + + modelfile.modelfile = json.loads(modelfile.modelfile) + meta = json.dumps( + { + "description": modelfile.modelfile.get("desc"), + "profile_image_url": modelfile.modelfile.get("imageUrl"), + "ollama": {"modelfile": modelfile.modelfile.get("content")}, + "suggestion_prompts": modelfile.modelfile.get("suggestionPrompts"), + "categories": modelfile.modelfile.get("categories"), + "user": {**modelfile.modelfile.get("user", {}), "community": True}, + } + ) + + info = parse_ollama_modelfile(modelfile.modelfile.get("content")) + + # Insert the processed data into the 'model' table + Model.create( + id=f"ollama-{modelfile.tag_name}", + user_id=modelfile.user_id, + base_model_id=info.get("base_model_id"), + name=modelfile.modelfile.get("title"), + meta=meta, + params=json.dumps(info.get("params", {})), + created_at=modelfile.timestamp, + updated_at=modelfile.timestamp, + ) + + +def rollback(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your rollback migrations here.""" + + recreate_modelfile_table(migrator, database) + move_data_back_to_modelfile(migrator, database) + migrator.remove_model("model") + + +def recreate_modelfile_table(migrator: Migrator, database: pw.Database): + query = """ + CREATE TABLE IF NOT EXISTS modelfile ( + user_id TEXT, + tag_name TEXT, + modelfile JSON, + timestamp BIGINT + ) + """ + migrator.sql(query) + + +def move_data_back_to_modelfile(migrator: Migrator, database: pw.Database): + Model = migrator.orm["model"] + Modelfile = migrator.orm["modelfile"] + + models = Model.select() + + for model in models: + # Extract and transform data in Python + meta = json.loads(model.meta) + + modelfile_data = { + "title": model.name, + "desc": meta.get("description"), + "imageUrl": meta.get("profile_image_url"), + "content": meta.get("ollama", {}).get("modelfile"), + "suggestionPrompts": meta.get("suggestion_prompts"), + "categories": meta.get("categories"), + "user": {k: v for k, v in meta.get("user", {}).items() if k != "community"}, + } + + # Insert the processed data back into the 'modelfile' table + Modelfile.create( + user_id=model.user_id, + tag_name=model.id, + modelfile=modelfile_data, + timestamp=model.created_at, + ) diff --git a/backend/apps/webui/internal/migrations/011_add_user_settings.py b/backend/apps/webui/internal/migrations/011_add_user_settings.py new file mode 100644 index 000000000..a1620dcad --- /dev/null +++ b/backend/apps/webui/internal/migrations/011_add_user_settings.py @@ -0,0 +1,48 @@ +"""Peewee migrations -- 002_add_local_sharing.py. + +Some examples (model - class or model name):: + + > Model = migrator.orm['table_name'] # Return model in current state by name + > Model = migrator.ModelClass # Return model in current state by name + + > migrator.sql(sql) # Run custom SQL + > migrator.run(func, *args, **kwargs) # Run python function with the given args + > migrator.create_model(Model) # Create a model (could be used as decorator) + > migrator.remove_model(model, cascade=True) # Remove a model + > migrator.add_fields(model, **fields) # Add fields to a model + > migrator.change_fields(model, **fields) # Change fields + > migrator.remove_fields(model, *field_names, cascade=True) + > migrator.rename_field(model, old_field_name, new_field_name) + > migrator.rename_table(model, new_table_name) + > migrator.add_index(model, *col_names, unique=False) + > migrator.add_not_null(model, *field_names) + > migrator.add_default(model, field_name, default) + > migrator.add_constraint(model, name, sql) + > migrator.drop_index(model, *col_names) + > migrator.drop_not_null(model, *field_names) + > migrator.drop_constraints(model, *constraints) + +""" + +from contextlib import suppress + +import peewee as pw +from peewee_migrate import Migrator + + +with suppress(ImportError): + import playhouse.postgres_ext as pw_pext + + +def migrate(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your migrations here.""" + + # Adding fields settings to the 'user' table + migrator.add_fields("user", settings=pw.TextField(null=True)) + + +def rollback(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your rollback migrations here.""" + + # Remove the settings field + migrator.remove_fields("user", "settings") diff --git a/backend/apps/web/internal/migrations/README.md b/backend/apps/webui/internal/migrations/README.md similarity index 84% rename from backend/apps/web/internal/migrations/README.md rename to backend/apps/webui/internal/migrations/README.md index 63d92e802..260214113 100644 --- a/backend/apps/web/internal/migrations/README.md +++ b/backend/apps/webui/internal/migrations/README.md @@ -14,7 +14,7 @@ You will need to create a migration file to ensure that existing databases are u 2. Make your changes to the models. 3. From the `backend` directory, run the following command: ```bash - pw_migrate create --auto --auto-source apps.web.models --database sqlite:///${SQLITE_DB} --directory apps/web/internal/migrations ${MIGRATION_NAME} + pw_migrate create --auto --auto-source apps.webui.models --database sqlite:///${SQLITE_DB} --directory apps/web/internal/migrations ${MIGRATION_NAME} ``` - `$SQLITE_DB` should be the path to the database file. - `$MIGRATION_NAME` should be a descriptive name for the migration. diff --git a/backend/apps/web/main.py b/backend/apps/webui/main.py similarity index 85% rename from backend/apps/web/main.py rename to backend/apps/webui/main.py index 2b6966381..b823859a6 100644 --- a/backend/apps/web/main.py +++ b/backend/apps/webui/main.py @@ -1,19 +1,19 @@ from fastapi import FastAPI, Depends from fastapi.routing import APIRoute from fastapi.middleware.cors import CORSMiddleware -from apps.web.routers import ( +from apps.webui.routers import ( auths, users, chats, documents, - modelfiles, + models, prompts, configs, memories, utils, ) from config import ( - WEBUI_VERSION, + WEBUI_BUILD_HASH, WEBUI_AUTH, DEFAULT_MODELS, DEFAULT_PROMPT_SUGGESTIONS, @@ -23,7 +23,9 @@ from config import ( WEBHOOK_URL, WEBUI_AUTH_TRUSTED_EMAIL_HEADER, JWT_EXPIRES_IN, + WEBUI_BANNERS, AppConfig, + ENABLE_COMMUNITY_SHARING, ) app = FastAPI() @@ -40,6 +42,11 @@ app.state.config.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS app.state.config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE app.state.config.USER_PERMISSIONS = USER_PERMISSIONS app.state.config.WEBHOOK_URL = WEBHOOK_URL +app.state.config.BANNERS = WEBUI_BANNERS + +app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING + +app.state.MODELS = {} app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER @@ -56,11 +63,10 @@ app.include_router(users.router, prefix="/users", tags=["users"]) app.include_router(chats.router, prefix="/chats", tags=["chats"]) app.include_router(documents.router, prefix="/documents", tags=["documents"]) -app.include_router(modelfiles.router, prefix="/modelfiles", tags=["modelfiles"]) +app.include_router(models.router, prefix="/models", tags=["models"]) app.include_router(prompts.router, prefix="/prompts", tags=["prompts"]) app.include_router(memories.router, prefix="/memories", tags=["memories"]) - app.include_router(configs.router, prefix="/configs", tags=["configs"]) app.include_router(utils.router, prefix="/utils", tags=["utils"]) diff --git a/backend/apps/web/models/auths.py b/backend/apps/webui/models/auths.py similarity index 98% rename from backend/apps/web/models/auths.py rename to backend/apps/webui/models/auths.py index dfa0c4395..e3b659e43 100644 --- a/backend/apps/web/models/auths.py +++ b/backend/apps/webui/models/auths.py @@ -5,10 +5,10 @@ import uuid import logging from peewee import * -from apps.web.models.users import UserModel, Users +from apps.webui.models.users import UserModel, Users from utils.utils import verify_password -from apps.web.internal.db import DB +from apps.webui.internal.db import DB from config import SRC_LOG_LEVELS diff --git a/backend/apps/web/models/chats.py b/backend/apps/webui/models/chats.py similarity index 88% rename from backend/apps/web/models/chats.py rename to backend/apps/webui/models/chats.py index 891151b94..d4597f16d 100644 --- a/backend/apps/web/models/chats.py +++ b/backend/apps/webui/models/chats.py @@ -7,7 +7,7 @@ import json import uuid import time -from apps.web.internal.db import DB +from apps.webui.internal.db import DB #################### # Chat DB Schema @@ -191,6 +191,20 @@ class ChatTable: except: return None + def archive_all_chats_by_user_id(self, user_id: str) -> bool: + try: + chats = self.get_chats_by_user_id(user_id) + for chat in chats: + query = Chat.update( + archived=True, + ).where(Chat.id == chat.id) + + query.execute() + + return True + except: + return False + def get_archived_chat_list_by_user_id( self, user_id: str, skip: int = 0, limit: int = 50 ) -> List[ChatModel]: @@ -205,17 +219,31 @@ class ChatTable: ] def get_chat_list_by_user_id( - self, user_id: str, skip: int = 0, limit: int = 50 + self, + user_id: str, + include_archived: bool = False, + skip: int = 0, + limit: int = 50, ) -> List[ChatModel]: - return [ - ChatModel(**model_to_dict(chat)) - for chat in Chat.select() - .where(Chat.archived == False) - .where(Chat.user_id == user_id) - .order_by(Chat.updated_at.desc()) - # .limit(limit) - # .offset(skip) - ] + if include_archived: + return [ + ChatModel(**model_to_dict(chat)) + for chat in Chat.select() + .where(Chat.user_id == user_id) + .order_by(Chat.updated_at.desc()) + # .limit(limit) + # .offset(skip) + ] + else: + return [ + ChatModel(**model_to_dict(chat)) + for chat in Chat.select() + .where(Chat.archived == False) + .where(Chat.user_id == user_id) + .order_by(Chat.updated_at.desc()) + # .limit(limit) + # .offset(skip) + ] def get_chat_list_by_chat_ids( self, chat_ids: List[str], skip: int = 0, limit: int = 50 diff --git a/backend/apps/web/models/documents.py b/backend/apps/webui/models/documents.py similarity index 99% rename from backend/apps/web/models/documents.py rename to backend/apps/webui/models/documents.py index 42b99596c..3b730535f 100644 --- a/backend/apps/web/models/documents.py +++ b/backend/apps/webui/models/documents.py @@ -8,7 +8,7 @@ import logging from utils.utils import decode_token from utils.misc import get_gravatar_url -from apps.web.internal.db import DB +from apps.webui.internal.db import DB import json diff --git a/backend/apps/web/models/memories.py b/backend/apps/webui/models/memories.py similarity index 97% rename from backend/apps/web/models/memories.py rename to backend/apps/webui/models/memories.py index 8382b3e52..70e5577e9 100644 --- a/backend/apps/web/models/memories.py +++ b/backend/apps/webui/models/memories.py @@ -3,8 +3,8 @@ from peewee import * from playhouse.shortcuts import model_to_dict from typing import List, Union, Optional -from apps.web.internal.db import DB -from apps.web.models.chats import Chats +from apps.webui.internal.db import DB +from apps.webui.models.chats import Chats import time import uuid diff --git a/backend/apps/webui/models/models.py b/backend/apps/webui/models/models.py new file mode 100644 index 000000000..851352398 --- /dev/null +++ b/backend/apps/webui/models/models.py @@ -0,0 +1,179 @@ +import json +import logging +from typing import Optional + +import peewee as pw +from peewee import * + +from playhouse.shortcuts import model_to_dict +from pydantic import BaseModel, ConfigDict + +from apps.webui.internal.db import DB, JSONField + +from typing import List, Union, Optional +from config import SRC_LOG_LEVELS + +import time + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["MODELS"]) + + +#################### +# Models DB Schema +#################### + + +# ModelParams is a model for the data stored in the params field of the Model table +class ModelParams(BaseModel): + model_config = ConfigDict(extra="allow") + pass + + +# ModelMeta is a model for the data stored in the meta field of the Model table +class ModelMeta(BaseModel): + profile_image_url: Optional[str] = "/favicon.png" + + description: Optional[str] = None + """ + User-facing description of the model. + """ + + capabilities: Optional[dict] = None + + model_config = ConfigDict(extra="allow") + + pass + + +class Model(pw.Model): + id = pw.TextField(unique=True) + """ + The model's id as used in the API. If set to an existing model, it will override the model. + """ + user_id = pw.TextField() + + base_model_id = pw.TextField(null=True) + """ + An optional pointer to the actual model that should be used when proxying requests. + """ + + name = pw.TextField() + """ + The human-readable display name of the model. + """ + + params = JSONField() + """ + Holds a JSON encoded blob of parameters, see `ModelParams`. + """ + + meta = JSONField() + """ + Holds a JSON encoded blob of metadata, see `ModelMeta`. + """ + + updated_at = BigIntegerField() + created_at = BigIntegerField() + + class Meta: + database = DB + + +class ModelModel(BaseModel): + id: str + user_id: str + base_model_id: Optional[str] = None + + name: str + params: ModelParams + meta: ModelMeta + + updated_at: int # timestamp in epoch + created_at: int # timestamp in epoch + + +#################### +# Forms +#################### + + +class ModelResponse(BaseModel): + id: str + name: str + meta: ModelMeta + updated_at: int # timestamp in epoch + created_at: int # timestamp in epoch + + +class ModelForm(BaseModel): + id: str + base_model_id: Optional[str] = None + name: str + meta: ModelMeta + params: ModelParams + + +class ModelsTable: + def __init__( + self, + db: pw.SqliteDatabase | pw.PostgresqlDatabase, + ): + self.db = db + self.db.create_tables([Model]) + + def insert_new_model( + self, form_data: ModelForm, user_id: str + ) -> Optional[ModelModel]: + model = ModelModel( + **{ + **form_data.model_dump(), + "user_id": user_id, + "created_at": int(time.time()), + "updated_at": int(time.time()), + } + ) + try: + result = Model.create(**model.model_dump()) + + if result: + return model + else: + return None + except Exception as e: + print(e) + return None + + def get_all_models(self) -> List[ModelModel]: + return [ModelModel(**model_to_dict(model)) for model in Model.select()] + + def get_model_by_id(self, id: str) -> Optional[ModelModel]: + try: + model = Model.get(Model.id == id) + return ModelModel(**model_to_dict(model)) + except: + return None + + def update_model_by_id(self, id: str, model: ModelForm) -> Optional[ModelModel]: + try: + # update only the fields that are present in the model + query = Model.update(**model.model_dump()).where(Model.id == id) + query.execute() + + model = Model.get(Model.id == id) + return ModelModel(**model_to_dict(model)) + except Exception as e: + print(e) + + return None + + def delete_model_by_id(self, id: str) -> bool: + try: + query = Model.delete().where(Model.id == id) + query.execute() + return True + except: + return False + + +Models = ModelsTable(DB) diff --git a/backend/apps/web/models/prompts.py b/backend/apps/webui/models/prompts.py similarity index 98% rename from backend/apps/web/models/prompts.py rename to backend/apps/webui/models/prompts.py index bc4e3e58b..c4ac6be14 100644 --- a/backend/apps/web/models/prompts.py +++ b/backend/apps/webui/models/prompts.py @@ -7,7 +7,7 @@ import time from utils.utils import decode_token from utils.misc import get_gravatar_url -from apps.web.internal.db import DB +from apps.webui.internal.db import DB import json diff --git a/backend/apps/web/models/tags.py b/backend/apps/webui/models/tags.py similarity index 99% rename from backend/apps/web/models/tags.py rename to backend/apps/webui/models/tags.py index d9a967ff7..4c4fa82e6 100644 --- a/backend/apps/web/models/tags.py +++ b/backend/apps/webui/models/tags.py @@ -8,7 +8,7 @@ import uuid import time import logging -from apps.web.internal.db import DB +from apps.webui.internal.db import DB from config import SRC_LOG_LEVELS diff --git a/backend/apps/web/models/users.py b/backend/apps/webui/models/users.py similarity index 94% rename from backend/apps/web/models/users.py rename to backend/apps/webui/models/users.py index 450dd9187..48811e8af 100644 --- a/backend/apps/web/models/users.py +++ b/backend/apps/webui/models/users.py @@ -1,12 +1,12 @@ -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict from peewee import * from playhouse.shortcuts import model_to_dict from typing import List, Union, Optional import time from utils.misc import get_gravatar_url -from apps.web.internal.db import DB -from apps.web.models.chats import Chats +from apps.webui.internal.db import DB, JSONField +from apps.webui.models.chats import Chats #################### # User DB Schema @@ -25,11 +25,18 @@ class User(Model): created_at = BigIntegerField() api_key = CharField(null=True, unique=True) + settings = JSONField(null=True) class Meta: database = DB +class UserSettings(BaseModel): + ui: Optional[dict] = {} + model_config = ConfigDict(extra="allow") + pass + + class UserModel(BaseModel): id: str name: str @@ -42,6 +49,7 @@ class UserModel(BaseModel): created_at: int # timestamp in epoch api_key: Optional[str] = None + settings: Optional[UserSettings] = None #################### diff --git a/backend/apps/web/routers/auths.py b/backend/apps/webui/routers/auths.py similarity index 99% rename from backend/apps/web/routers/auths.py rename to backend/apps/webui/routers/auths.py index 998e74659..ce9b92061 100644 --- a/backend/apps/web/routers/auths.py +++ b/backend/apps/webui/routers/auths.py @@ -10,7 +10,7 @@ import uuid import csv -from apps.web.models.auths import ( +from apps.webui.models.auths import ( SigninForm, SignupForm, AddUserForm, @@ -21,7 +21,7 @@ from apps.web.models.auths import ( Auths, ApiKey, ) -from apps.web.models.users import Users +from apps.webui.models.users import Users from utils.utils import ( get_password_hash, diff --git a/backend/apps/web/routers/chats.py b/backend/apps/webui/routers/chats.py similarity index 95% rename from backend/apps/web/routers/chats.py rename to backend/apps/webui/routers/chats.py index aaf173521..5d52f40c9 100644 --- a/backend/apps/web/routers/chats.py +++ b/backend/apps/webui/routers/chats.py @@ -7,8 +7,8 @@ from pydantic import BaseModel import json import logging -from apps.web.models.users import Users -from apps.web.models.chats import ( +from apps.webui.models.users import Users +from apps.webui.models.chats import ( ChatModel, ChatResponse, ChatTitleForm, @@ -18,7 +18,7 @@ from apps.web.models.chats import ( ) -from apps.web.models.tags import ( +from apps.webui.models.tags import ( TagModel, ChatIdTagModel, ChatIdTagForm, @@ -78,43 +78,25 @@ async def delete_all_user_chats(request: Request, user=Depends(get_current_user) async def get_user_chat_list_by_user_id( user_id: str, user=Depends(get_admin_user), skip: int = 0, limit: int = 50 ): - return Chats.get_chat_list_by_user_id(user_id, skip, limit) + return Chats.get_chat_list_by_user_id( + user_id, include_archived=True, skip=skip, limit=limit + ) ############################ -# GetArchivedChats +# CreateNewChat ############################ -@router.get("/archived", response_model=List[ChatTitleIdResponse]) -async def get_archived_session_user_chat_list( - user=Depends(get_current_user), skip: int = 0, limit: int = 50 -): - return Chats.get_archived_chat_list_by_user_id(user.id, skip, limit) - - -############################ -# GetSharedChatById -############################ - - -@router.get("/share/{share_id}", response_model=Optional[ChatResponse]) -async def get_shared_chat_by_id(share_id: str, user=Depends(get_current_user)): - if user.role == "pending": - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND - ) - - if user.role == "user": - chat = Chats.get_chat_by_share_id(share_id) - elif user.role == "admin": - chat = Chats.get_chat_by_id(share_id) - - if chat: +@router.post("/new", response_model=Optional[ChatResponse]) +async def create_new_chat(form_data: ChatForm, user=Depends(get_current_user)): + try: + chat = Chats.insert_new_chat(user.id, form_data) return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) - else: + except Exception as e: + log.exception(e) raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND + status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() ) @@ -150,19 +132,49 @@ async def get_all_user_chats_in_db(user=Depends(get_admin_user)): ############################ -# CreateNewChat +# GetArchivedChats ############################ -@router.post("/new", response_model=Optional[ChatResponse]) -async def create_new_chat(form_data: ChatForm, user=Depends(get_current_user)): - try: - chat = Chats.insert_new_chat(user.id, form_data) - return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) - except Exception as e: - log.exception(e) +@router.get("/archived", response_model=List[ChatTitleIdResponse]) +async def get_archived_session_user_chat_list( + user=Depends(get_current_user), skip: int = 0, limit: int = 50 +): + return Chats.get_archived_chat_list_by_user_id(user.id, skip, limit) + + +############################ +# ArchiveAllChats +############################ + + +@router.post("/archive/all", response_model=List[ChatTitleIdResponse]) +async def archive_all_chats(user=Depends(get_current_user)): + return Chats.archive_all_chats_by_user_id(user.id) + + +############################ +# GetSharedChatById +############################ + + +@router.get("/share/{share_id}", response_model=Optional[ChatResponse]) +async def get_shared_chat_by_id(share_id: str, user=Depends(get_current_user)): + if user.role == "pending": raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() + status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND + ) + + if user.role == "user": + chat = Chats.get_chat_by_share_id(share_id) + elif user.role == "admin": + chat = Chats.get_chat_by_id(share_id) + + if chat: + return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND ) diff --git a/backend/apps/web/routers/configs.py b/backend/apps/webui/routers/configs.py similarity index 67% rename from backend/apps/web/routers/configs.py rename to backend/apps/webui/routers/configs.py index 143ed5e0a..c127e721b 100644 --- a/backend/apps/web/routers/configs.py +++ b/backend/apps/webui/routers/configs.py @@ -8,7 +8,9 @@ from pydantic import BaseModel import time import uuid -from apps.web.models.users import Users +from config import BannerModel + +from apps.webui.models.users import Users from utils.utils import ( get_password_hash, @@ -57,3 +59,31 @@ async def set_global_default_suggestions( data = form_data.model_dump() request.app.state.config.DEFAULT_PROMPT_SUGGESTIONS = data["suggestions"] return request.app.state.config.DEFAULT_PROMPT_SUGGESTIONS + + +############################ +# SetBanners +############################ + + +class SetBannersForm(BaseModel): + banners: List[BannerModel] + + +@router.post("/banners", response_model=List[BannerModel]) +async def set_banners( + request: Request, + form_data: SetBannersForm, + user=Depends(get_admin_user), +): + data = form_data.model_dump() + request.app.state.config.BANNERS = data["banners"] + return request.app.state.config.BANNERS + + +@router.get("/banners", response_model=List[BannerModel]) +async def get_banners( + request: Request, + user=Depends(get_current_user), +): + return request.app.state.config.BANNERS diff --git a/backend/apps/web/routers/documents.py b/backend/apps/webui/routers/documents.py similarity index 98% rename from backend/apps/web/routers/documents.py rename to backend/apps/webui/routers/documents.py index 7c69514fe..c5447a3fe 100644 --- a/backend/apps/web/routers/documents.py +++ b/backend/apps/webui/routers/documents.py @@ -6,7 +6,7 @@ from fastapi import APIRouter from pydantic import BaseModel import json -from apps.web.models.documents import ( +from apps.webui.models.documents import ( Documents, DocumentForm, DocumentUpdateForm, diff --git a/backend/apps/web/routers/memories.py b/backend/apps/webui/routers/memories.py similarity index 98% rename from backend/apps/web/routers/memories.py rename to backend/apps/webui/routers/memories.py index f20e02601..6448ebe1e 100644 --- a/backend/apps/web/routers/memories.py +++ b/backend/apps/webui/routers/memories.py @@ -7,7 +7,7 @@ from fastapi import APIRouter from pydantic import BaseModel import logging -from apps.web.models.memories import Memories, MemoryModel +from apps.webui.models.memories import Memories, MemoryModel from utils.utils import get_verified_user from constants import ERROR_MESSAGES diff --git a/backend/apps/webui/routers/models.py b/backend/apps/webui/routers/models.py new file mode 100644 index 000000000..363737e25 --- /dev/null +++ b/backend/apps/webui/routers/models.py @@ -0,0 +1,108 @@ +from fastapi import Depends, FastAPI, HTTPException, status, Request +from datetime import datetime, timedelta +from typing import List, Union, Optional + +from fastapi import APIRouter +from pydantic import BaseModel +import json +from apps.webui.models.models import Models, ModelModel, ModelForm, ModelResponse + +from utils.utils import get_verified_user, get_admin_user +from constants import ERROR_MESSAGES + +router = APIRouter() + +########################### +# getModels +########################### + + +@router.get("/", response_model=List[ModelResponse]) +async def get_models(user=Depends(get_verified_user)): + return Models.get_all_models() + + +############################ +# AddNewModel +############################ + + +@router.post("/add", response_model=Optional[ModelModel]) +async def add_new_model( + request: Request, form_data: ModelForm, user=Depends(get_admin_user) +): + if form_data.id in request.app.state.MODELS: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.MODEL_ID_TAKEN, + ) + else: + model = Models.insert_new_model(form_data, user.id) + + if model: + return model + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.DEFAULT(), + ) + + +############################ +# GetModelById +############################ + + +@router.get("/", response_model=Optional[ModelModel]) +async def get_model_by_id(id: str, user=Depends(get_verified_user)): + model = Models.get_model_by_id(id) + + if model: + return model + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + +############################ +# UpdateModelById +############################ + + +@router.post("/update", response_model=Optional[ModelModel]) +async def update_model_by_id( + request: Request, id: str, form_data: ModelForm, user=Depends(get_admin_user) +): + model = Models.get_model_by_id(id) + if model: + model = Models.update_model_by_id(id, form_data) + return model + else: + if form_data.id in request.app.state.MODELS: + model = Models.insert_new_model(form_data, user.id) + print(model) + if model: + return model + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.DEFAULT(), + ) + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.DEFAULT(), + ) + + +############################ +# DeleteModelById +############################ + + +@router.delete("/delete", response_model=bool) +async def delete_model_by_id(id: str, user=Depends(get_admin_user)): + result = Models.delete_model_by_id(id) + return result diff --git a/backend/apps/web/routers/prompts.py b/backend/apps/webui/routers/prompts.py similarity index 97% rename from backend/apps/web/routers/prompts.py rename to backend/apps/webui/routers/prompts.py index db7619676..47d8c7012 100644 --- a/backend/apps/web/routers/prompts.py +++ b/backend/apps/webui/routers/prompts.py @@ -6,7 +6,7 @@ from fastapi import APIRouter from pydantic import BaseModel import json -from apps.web.models.prompts import Prompts, PromptForm, PromptModel +from apps.webui.models.prompts import Prompts, PromptForm, PromptModel from utils.utils import get_current_user, get_admin_user from constants import ERROR_MESSAGES diff --git a/backend/apps/web/routers/users.py b/backend/apps/webui/routers/users.py similarity index 78% rename from backend/apps/web/routers/users.py rename to backend/apps/webui/routers/users.py index d77475d8d..cd17e3a7c 100644 --- a/backend/apps/web/routers/users.py +++ b/backend/apps/webui/routers/users.py @@ -9,9 +9,15 @@ import time import uuid import logging -from apps.web.models.users import UserModel, UserUpdateForm, UserRoleUpdateForm, Users -from apps.web.models.auths import Auths -from apps.web.models.chats import Chats +from apps.webui.models.users import ( + UserModel, + UserUpdateForm, + UserRoleUpdateForm, + UserSettings, + Users, +) +from apps.webui.models.auths import Auths +from apps.webui.models.chats import Chats from utils.utils import get_verified_user, get_password_hash, get_admin_user from constants import ERROR_MESSAGES @@ -68,6 +74,42 @@ async def update_user_role(form_data: UserRoleUpdateForm, user=Depends(get_admin ) +############################ +# GetUserSettingsBySessionUser +############################ + + +@router.get("/user/settings", response_model=Optional[UserSettings]) +async def get_user_settings_by_session_user(user=Depends(get_verified_user)): + user = Users.get_user_by_id(user.id) + if user: + return user.settings + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.USER_NOT_FOUND, + ) + + +############################ +# UpdateUserSettingsBySessionUser +############################ + + +@router.post("/user/settings/update", response_model=UserSettings) +async def update_user_settings_by_session_user( + form_data: UserSettings, user=Depends(get_verified_user) +): + user = Users.update_user_by_id(user.id, {"settings": form_data.model_dump()}) + if user: + return user.settings + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.USER_NOT_FOUND, + ) + + ############################ # GetUserById ############################ @@ -81,6 +123,8 @@ class UserResponse(BaseModel): @router.get("/{user_id}", response_model=UserResponse) async def get_user_by_id(user_id: str, user=Depends(get_verified_user)): + # Check if user_id is a shared chat + # If it is, get the user_id from the chat if user_id.startswith("shared-"): chat_id = user_id.replace("shared-", "") chat = Chats.get_chat_by_id(chat_id) diff --git a/backend/apps/web/routers/utils.py b/backend/apps/webui/routers/utils.py similarity index 98% rename from backend/apps/web/routers/utils.py rename to backend/apps/webui/routers/utils.py index 12805873d..b95fe8834 100644 --- a/backend/apps/web/routers/utils.py +++ b/backend/apps/webui/routers/utils.py @@ -8,7 +8,7 @@ from pydantic import BaseModel from fpdf import FPDF import markdown -from apps.web.internal.db import DB +from apps.webui.internal.db import DB from utils.utils import get_admin_user from utils.misc import calculate_sha256, get_gravatar_url diff --git a/backend/config.py b/backend/config.py index 1a62e98bf..28ace5d5d 100644 --- a/backend/config.py +++ b/backend/config.py @@ -1,11 +1,15 @@ import os import sys import logging +import importlib.metadata +import pkgutil import chromadb from chromadb import Settings from base64 import b64encode from bs4 import BeautifulSoup from typing import TypeVar, Generic, Union +from pydantic import BaseModel +from typing import Optional from pathlib import Path import json @@ -22,10 +26,15 @@ from constants import ERROR_MESSAGES # Load .env file #################################### +BACKEND_DIR = Path(__file__).parent # the path containing this file +BASE_DIR = BACKEND_DIR.parent # the path containing the backend/ + +print(BASE_DIR) + try: from dotenv import load_dotenv, find_dotenv - load_dotenv(find_dotenv("../.env")) + load_dotenv(find_dotenv(str(BASE_DIR / ".env"))) except ImportError: print("dotenv not installed, skipping...") @@ -51,7 +60,6 @@ log_sources = [ "CONFIG", "DB", "IMAGES", - "LITELLM", "MAIN", "MODELS", "OLLAMA", @@ -87,10 +95,12 @@ WEBUI_FAVICON_URL = "https://openwebui.com/favicon.png" ENV = os.environ.get("ENV", "dev") try: - with open(f"../package.json", "r") as f: - PACKAGE_DATA = json.load(f) + PACKAGE_DATA = json.loads((BASE_DIR / "package.json").read_text()) except: - PACKAGE_DATA = {"version": "0.0.0"} + try: + PACKAGE_DATA = {"version": importlib.metadata.version("open-webui")} + except importlib.metadata.PackageNotFoundError: + PACKAGE_DATA = {"version": "0.0.0"} VERSION = PACKAGE_DATA["version"] @@ -115,10 +125,13 @@ def parse_section(section): try: - with open("../CHANGELOG.md", "r") as file: + changelog_path = BASE_DIR / "CHANGELOG.md" + with open(str(changelog_path.absolute()), "r", encoding="utf8") as file: changelog_content = file.read() + except: - changelog_content = "" + changelog_content = (pkgutil.get_data("open_webui", "CHANGELOG.md") or b"").decode() + # Convert markdown content to HTML html_content = markdown.markdown(changelog_content) @@ -155,21 +168,20 @@ CHANGELOG = changelog_json #################################### -# WEBUI_VERSION +# WEBUI_BUILD_HASH #################################### -WEBUI_VERSION = os.environ.get("WEBUI_VERSION", "v1.0.0-alpha.100") +WEBUI_BUILD_HASH = os.environ.get("WEBUI_BUILD_HASH", "dev-build") #################################### # DATA/FRONTEND BUILD DIR #################################### -DATA_DIR = str(Path(os.getenv("DATA_DIR", "./data")).resolve()) -FRONTEND_BUILD_DIR = str(Path(os.getenv("FRONTEND_BUILD_DIR", "../build"))) +DATA_DIR = Path(os.getenv("DATA_DIR", BACKEND_DIR / "data")).resolve() +FRONTEND_BUILD_DIR = Path(os.getenv("FRONTEND_BUILD_DIR", BASE_DIR / "build")).resolve() try: - with open(f"{DATA_DIR}/config.json", "r") as f: - CONFIG_DATA = json.load(f) + CONFIG_DATA = json.loads((DATA_DIR / "config.json").read_text()) except: CONFIG_DATA = {} @@ -279,11 +291,11 @@ JWT_EXPIRES_IN = PersistentConfig( # Static DIR #################################### -STATIC_DIR = str(Path(os.getenv("STATIC_DIR", "./static")).resolve()) +STATIC_DIR = Path(os.getenv("STATIC_DIR", BACKEND_DIR / "static")).resolve() -frontend_favicon = f"{FRONTEND_BUILD_DIR}/favicon.png" -if os.path.exists(frontend_favicon): - shutil.copyfile(frontend_favicon, f"{STATIC_DIR}/favicon.png") +frontend_favicon = FRONTEND_BUILD_DIR / "favicon.png" +if frontend_favicon.exists(): + shutil.copyfile(frontend_favicon, STATIC_DIR / "favicon.png") else: logging.warning(f"Frontend favicon not found at {frontend_favicon}") @@ -368,16 +380,23 @@ def create_config_file(file_path): LITELLM_CONFIG_PATH = f"{DATA_DIR}/litellm/config.yaml" -if not os.path.exists(LITELLM_CONFIG_PATH): - log.info("Config file doesn't exist. Creating...") - create_config_file(LITELLM_CONFIG_PATH) - log.info("Config file created successfully.") +# if not os.path.exists(LITELLM_CONFIG_PATH): +# log.info("Config file doesn't exist. Creating...") +# create_config_file(LITELLM_CONFIG_PATH) +# log.info("Config file created successfully.") #################################### # OLLAMA_BASE_URL #################################### + +ENABLE_OLLAMA_API = PersistentConfig( + "ENABLE_OLLAMA_API", + "ollama.enable", + os.environ.get("ENABLE_OLLAMA_API", "True").lower() == "true", +) + OLLAMA_API_BASE_URL = os.environ.get( "OLLAMA_API_BASE_URL", "http://localhost:11434/api" ) @@ -549,6 +568,27 @@ WEBHOOK_URL = PersistentConfig( ENABLE_ADMIN_EXPORT = os.environ.get("ENABLE_ADMIN_EXPORT", "True").lower() == "true" +ENABLE_COMMUNITY_SHARING = PersistentConfig( + "ENABLE_COMMUNITY_SHARING", + "ui.enable_community_sharing", + os.environ.get("ENABLE_COMMUNITY_SHARING", "True").lower() == "true", +) + +class BannerModel(BaseModel): + id: str + type: str + title: Optional[str] = None + content: str + dismissible: bool + timestamp: int + + +WEBUI_BANNERS = PersistentConfig( + "WEBUI_BANNERS", + "ui.banners", + [BannerModel(**banner) for banner in json.loads("[]")], +) + #################################### # WEBUI_SECRET_KEY #################################### @@ -813,18 +853,6 @@ AUDIO_OPENAI_API_VOICE = PersistentConfig( os.getenv("AUDIO_OPENAI_API_VOICE", "alloy"), ) -#################################### -# LiteLLM -#################################### - - -ENABLE_LITELLM = os.environ.get("ENABLE_LITELLM", "True").lower() == "true" - -LITELLM_PROXY_PORT = int(os.getenv("LITELLM_PROXY_PORT", "14365")) -if LITELLM_PROXY_PORT < 0 or LITELLM_PROXY_PORT > 65535: - raise ValueError("Invalid port number for LITELLM_PROXY_PORT") -LITELLM_PROXY_HOST = os.getenv("LITELLM_PROXY_HOST", "127.0.0.1") - #################################### # Database diff --git a/backend/constants.py b/backend/constants.py index be4d135b2..86875d2df 100644 --- a/backend/constants.py +++ b/backend/constants.py @@ -32,6 +32,8 @@ class ERROR_MESSAGES(str, Enum): COMMAND_TAKEN = "Uh-oh! This command is already registered. Please choose another command string." FILE_EXISTS = "Uh-oh! This file is already registered. Please choose another file." + MODEL_ID_TAKEN = "Uh-oh! This model id is already registered. Please choose another model id string." + NAME_TAG_TAKEN = "Uh-oh! This name tag is already registered. Please choose another name tag string." INVALID_TOKEN = ( "Your session has expired or the token is invalid. Please sign in again." diff --git a/backend/main.py b/backend/main.py index 4cf3243f7..d1d267ce0 100644 --- a/backend/main.py +++ b/backend/main.py @@ -8,6 +8,7 @@ import sys import logging import aiohttp import requests +import mimetypes from fastapi import FastAPI, Request, Depends, status from fastapi.staticfiles import StaticFiles @@ -18,27 +19,20 @@ from starlette.exceptions import HTTPException as StarletteHTTPException from starlette.middleware.base import BaseHTTPMiddleware from starlette.responses import StreamingResponse, Response -from apps.ollama.main import app as ollama_app -from apps.openai.main import app as openai_app - -from apps.litellm.main import ( - app as litellm_app, - start_litellm_background, - shutdown_litellm_background, -) - +from apps.ollama.main import app as ollama_app, get_all_models as get_ollama_models +from apps.openai.main import app as openai_app, get_all_models as get_openai_models from apps.audio.main import app as audio_app from apps.images.main import app as images_app from apps.rag.main import app as rag_app -from apps.web.main import app as webui_app +from apps.webui.main import app as webui_app import asyncio from pydantic import BaseModel -from typing import List +from typing import List, Optional - -from utils.utils import get_admin_user +from apps.webui.models.models import Models, ModelModel +from utils.utils import get_admin_user, get_verified_user from apps.rag.utils import rag_messages from config import ( @@ -52,7 +46,8 @@ from config import ( FRONTEND_BUILD_DIR, CACHE_DIR, STATIC_DIR, - ENABLE_LITELLM, + ENABLE_OPENAI_API, + ENABLE_OLLAMA_API, ENABLE_MODEL_FILTER, MODEL_FILTER_LIST, GLOBAL_LOG_LEVEL, @@ -60,6 +55,7 @@ from config import ( WEBHOOK_URL, ENABLE_ADMIN_EXPORT, AppConfig, + WEBUI_BUILD_HASH, ) from constants import ERROR_MESSAGES @@ -89,7 +85,8 @@ print( |_| -v{VERSION} - building the best open-source AI user interface. +v{VERSION} - building the best open-source AI user interface. +{f"Commit: {WEBUI_BUILD_HASH}" if WEBUI_BUILD_HASH != "dev-build" else ""} https://github.com/open-webui/open-webui """ ) @@ -97,11 +94,7 @@ https://github.com/open-webui/open-webui @asynccontextmanager async def lifespan(app: FastAPI): - if ENABLE_LITELLM: - asyncio.create_task(start_litellm_background()) yield - if ENABLE_LITELLM: - await shutdown_litellm_background() app = FastAPI( @@ -109,11 +102,19 @@ app = FastAPI( ) app.state.config = AppConfig() + +app.state.config.ENABLE_OPENAI_API = ENABLE_OPENAI_API +app.state.config.ENABLE_OLLAMA_API = ENABLE_OLLAMA_API + app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST + app.state.config.WEBHOOK_URL = WEBHOOK_URL + +app.state.MODELS = {} + origins = ["*"] @@ -230,6 +231,11 @@ app.add_middleware( @app.middleware("http") async def check_url(request: Request, call_next): + if len(app.state.MODELS) == 0: + await get_all_models() + else: + pass + start_time = int(time.time()) response = await call_next(request) process_time = int(time.time()) - start_time @@ -246,9 +252,8 @@ async def update_embedding_function(request: Request, call_next): return response -app.mount("/litellm/api", litellm_app) app.mount("/ollama", ollama_app) -app.mount("/openai/api", openai_app) +app.mount("/openai", openai_app) app.mount("/images/api/v1", images_app) app.mount("/audio/api/v1", audio_app) @@ -259,6 +264,87 @@ app.mount("/api/v1", webui_app) webui_app.state.EMBEDDING_FUNCTION = rag_app.state.EMBEDDING_FUNCTION +async def get_all_models(): + openai_models = [] + ollama_models = [] + + if app.state.config.ENABLE_OPENAI_API: + openai_models = await get_openai_models() + + openai_models = openai_models["data"] + + if app.state.config.ENABLE_OLLAMA_API: + ollama_models = await get_ollama_models() + + ollama_models = [ + { + "id": model["model"], + "name": model["name"], + "object": "model", + "created": int(time.time()), + "owned_by": "ollama", + "ollama": model, + } + for model in ollama_models["models"] + ] + + models = openai_models + ollama_models + custom_models = Models.get_all_models() + + for custom_model in custom_models: + if custom_model.base_model_id == None: + for model in models: + if ( + custom_model.id == model["id"] + or custom_model.id == model["id"].split(":")[0] + ): + model["name"] = custom_model.name + model["info"] = custom_model.model_dump() + else: + owned_by = "openai" + for model in models: + if ( + custom_model.base_model_id == model["id"] + or custom_model.base_model_id == model["id"].split(":")[0] + ): + owned_by = model["owned_by"] + break + + models.append( + { + "id": custom_model.id, + "name": custom_model.name, + "object": "model", + "created": custom_model.created_at, + "owned_by": owned_by, + "info": custom_model.model_dump(), + "preset": True, + } + ) + + app.state.MODELS = {model["id"]: model for model in models} + + webui_app.state.MODELS = app.state.MODELS + + return models + + +@app.get("/api/models") +async def get_models(user=Depends(get_verified_user)): + models = await get_all_models() + if app.state.config.ENABLE_MODEL_FILTER: + if user.role == "user": + models = list( + filter( + lambda model: model["id"] in app.state.config.MODEL_FILTER_LIST, + models, + ) + ) + return {"data": models} + + return {"data": models} + + @app.get("/api/config") async def get_app_config(): # Checking and Handling the Absence of 'ui' in CONFIG_DATA @@ -272,13 +358,17 @@ async def get_app_config(): "status": True, "name": WEBUI_NAME, "version": VERSION, - "auth": WEBUI_AUTH, "default_locale": default_locale, - "images": images_app.state.config.ENABLED, "default_models": webui_app.state.config.DEFAULT_MODELS, "default_prompt_suggestions": webui_app.state.config.DEFAULT_PROMPT_SUGGESTIONS, - "trusted_header_auth": bool(webui_app.state.AUTH_TRUSTED_EMAIL_HEADER), - "admin_export_enabled": ENABLE_ADMIN_EXPORT, + "features": { + "auth": WEBUI_AUTH, + "auth_trusted_header": bool(webui_app.state.AUTH_TRUSTED_EMAIL_HEADER), + "enable_signup": webui_app.state.config.ENABLE_SIGNUP, + "enable_image_generation": images_app.state.config.ENABLED, + "enable_admin_export": ENABLE_ADMIN_EXPORT, + "enable_community_sharing": webui_app.state.config.ENABLE_COMMUNITY_SHARING, + }, } @@ -302,15 +392,6 @@ async def update_model_filter_config( app.state.config.ENABLE_MODEL_FILTER = form_data.enabled app.state.config.MODEL_FILTER_LIST = form_data.models - ollama_app.state.config.ENABLE_MODEL_FILTER = app.state.config.ENABLE_MODEL_FILTER - ollama_app.state.config.MODEL_FILTER_LIST = app.state.config.MODEL_FILTER_LIST - - openai_app.state.config.ENABLE_MODEL_FILTER = app.state.config.ENABLE_MODEL_FILTER - openai_app.state.config.MODEL_FILTER_LIST = app.state.config.MODEL_FILTER_LIST - - litellm_app.state.ENABLE_MODEL_FILTER = app.state.config.ENABLE_MODEL_FILTER - litellm_app.state.MODEL_FILTER_LIST = app.state.config.MODEL_FILTER_LIST - return { "enabled": app.state.config.ENABLE_MODEL_FILTER, "models": app.state.config.MODEL_FILTER_LIST, @@ -331,7 +412,6 @@ class UrlForm(BaseModel): @app.post("/api/webhook") async def update_webhook_url(form_data: UrlForm, user=Depends(get_admin_user)): app.state.config.WEBHOOK_URL = form_data.url - webui_app.state.WEBHOOK_URL = app.state.config.WEBHOOK_URL return { @@ -339,6 +419,19 @@ async def update_webhook_url(form_data: UrlForm, user=Depends(get_admin_user)): } +@app.get("/api/community_sharing", response_model=bool) +async def get_community_sharing_status(request: Request, user=Depends(get_admin_user)): + return webui_app.state.config.ENABLE_COMMUNITY_SHARING + + +@app.get("/api/community_sharing/toggle", response_model=bool) +async def toggle_community_sharing(request: Request, user=Depends(get_admin_user)): + webui_app.state.config.ENABLE_COMMUNITY_SHARING = ( + not webui_app.state.config.ENABLE_COMMUNITY_SHARING + ) + return webui_app.state.config.ENABLE_COMMUNITY_SHARING + + @app.get("/api/version") async def get_app_config(): return { @@ -408,6 +501,7 @@ app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static") app.mount("/cache", StaticFiles(directory=CACHE_DIR), name="cache") if os.path.exists(FRONTEND_BUILD_DIR): + mimetypes.add_type("text/javascript", ".js") app.mount( "/", SPAStaticFiles(directory=FRONTEND_BUILD_DIR, html=True), diff --git a/backend/open_webui/__init__.py b/backend/open_webui/__init__.py new file mode 100644 index 000000000..1defac824 --- /dev/null +++ b/backend/open_webui/__init__.py @@ -0,0 +1,60 @@ +import base64 +import os +import random +from pathlib import Path + +import typer +import uvicorn + +app = typer.Typer() + +KEY_FILE = Path.cwd() / ".webui_secret_key" +if (frontend_build_dir := Path(__file__).parent / "frontend").exists(): + os.environ["FRONTEND_BUILD_DIR"] = str(frontend_build_dir) + + +@app.command() +def serve( + host: str = "0.0.0.0", + port: int = 8080, +): + if os.getenv("WEBUI_SECRET_KEY") is None: + typer.echo( + "Loading WEBUI_SECRET_KEY from file, not provided as an environment variable." + ) + if not KEY_FILE.exists(): + typer.echo(f"Generating a new secret key and saving it to {KEY_FILE}") + KEY_FILE.write_bytes(base64.b64encode(random.randbytes(12))) + typer.echo(f"Loading WEBUI_SECRET_KEY from {KEY_FILE}") + os.environ["WEBUI_SECRET_KEY"] = KEY_FILE.read_text() + + if os.getenv("USE_CUDA_DOCKER", "false") == "true": + typer.echo( + "CUDA is enabled, appending LD_LIBRARY_PATH to include torch/cudnn & cublas libraries." + ) + LD_LIBRARY_PATH = os.getenv("LD_LIBRARY_PATH", "").split(":") + os.environ["LD_LIBRARY_PATH"] = ":".join( + LD_LIBRARY_PATH + + [ + "/usr/local/lib/python3.11/site-packages/torch/lib", + "/usr/local/lib/python3.11/site-packages/nvidia/cudnn/lib", + ] + ) + import main # we need set environment variables before importing main + + uvicorn.run(main.app, host=host, port=port, forwarded_allow_ips="*") + + +@app.command() +def dev( + host: str = "0.0.0.0", + port: int = 8080, + reload: bool = True, +): + uvicorn.run( + "main:app", host=host, port=port, reload=reload, forwarded_allow_ips="*" + ) + + +if __name__ == "__main__": + app() diff --git a/backend/requirements.txt b/backend/requirements.txt index a82da1966..7a3668428 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -1,42 +1,40 @@ -fastapi==0.109.2 +fastapi==0.111.0 uvicorn[standard]==0.22.0 pydantic==2.7.1 python-multipart==0.0.9 Flask==3.0.3 -Flask-Cors==4.0.0 +Flask-Cors==4.0.1 python-socketio==5.11.2 python-jose==3.3.0 passlib[bcrypt]==1.7.4 -requests==2.31.0 +requests==2.32.2 aiohttp==3.9.5 -peewee==3.17.3 +peewee==3.17.5 peewee-migrate==1.12.2 psycopg2-binary==2.9.9 -PyMySQL==1.1.0 -bcrypt==4.1.2 +PyMySQL==1.1.1 +bcrypt==4.1.3 -litellm[proxy]==1.35.28 - -boto3==1.34.95 +boto3==1.34.110 argon2-cffi==23.1.0 APScheduler==3.10.4 -google-generativeai==0.5.2 +google-generativeai==0.5.4 -langchain==0.1.16 -langchain-community==0.0.34 -langchain-chroma==0.1.0 +langchain==0.2.0 +langchain-community==0.2.0 +langchain-chroma==0.1.1 fake-useragent==1.5.1 -chromadb==0.4.24 +chromadb==0.5.0 sentence-transformers==2.7.0 pypdf==4.2.0 docx2txt==0.8 python-pptx==0.6.23 -unstructured==0.11.8 +unstructured==0.14.0 Markdown==3.6 pypandoc==1.13 pandas==2.2.2 @@ -46,16 +44,16 @@ xlrd==2.0.1 validators==0.28.1 opencv-python-headless==4.9.0.80 -rapidocr-onnxruntime==1.2.3 +rapidocr-onnxruntime==1.3.22 -fpdf2==2.7.8 +fpdf2==2.7.9 rank-bm25==0.2.2 -faster-whisper==1.0.1 +faster-whisper==1.0.2 PyJWT[crypto]==2.8.0 black==24.4.2 -langfuse==2.27.3 +langfuse==2.33.0 youtube-transcript-api==0.6.2 -pytube \ No newline at end of file +pytube==15.0.0 \ No newline at end of file diff --git a/backend/start.sh b/backend/start.sh index 9b3411f01..15fc568d3 100755 --- a/backend/start.sh +++ b/backend/start.sh @@ -30,4 +30,29 @@ if [ "$USE_CUDA_DOCKER" = "true" ]; then export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/usr/local/lib/python3.11/site-packages/torch/lib:/usr/local/lib/python3.11/site-packages/nvidia/cudnn/lib" fi + +# Check if SPACE_ID is set, if so, configure for space +if [ -n "$SPACE_ID" ]; then + echo "Configuring for HuggingFace Space deployment" + if [ -n "$ADMIN_USER_EMAIL" ] && [ -n "$ADMIN_USER_PASSWORD" ]; then + echo "Admin user configured, creating" + WEBUI_SECRET_KEY="$WEBUI_SECRET_KEY" uvicorn main:app --host "$HOST" --port "$PORT" --forwarded-allow-ips '*' & + webui_pid=$! + echo "Waiting for webui to start..." + while ! curl -s http://localhost:8080/health > /dev/null; do + sleep 1 + done + echo "Creating admin user..." + curl \ + -X POST "http://localhost:8080/api/v1/auths/signup" \ + -H "accept: application/json" \ + -H "Content-Type: application/json" \ + -d "{ \"email\": \"${ADMIN_USER_EMAIL}\", \"password\": \"${ADMIN_USER_PASSWORD}\", \"name\": \"Admin\" }" + echo "Shutting down webui..." + kill $webui_pid + fi + + export WEBUI_URL=${SPACE_HOST} +fi + WEBUI_SECRET_KEY="$WEBUI_SECRET_KEY" exec uvicorn main:app --host "$HOST" --port "$PORT" --forwarded-allow-ips '*' diff --git a/backend/utils/misc.py b/backend/utils/misc.py index 5efff4a35..fca941263 100644 --- a/backend/utils/misc.py +++ b/backend/utils/misc.py @@ -1,5 +1,6 @@ from pathlib import Path import hashlib +import json import re from datetime import timedelta from typing import Optional @@ -110,3 +111,76 @@ def parse_duration(duration: str) -> Optional[timedelta]: total_duration += timedelta(weeks=number) return total_duration + + +def parse_ollama_modelfile(model_text): + parameters_meta = { + "mirostat": int, + "mirostat_eta": float, + "mirostat_tau": float, + "num_ctx": int, + "repeat_last_n": int, + "repeat_penalty": float, + "temperature": float, + "seed": int, + "stop": str, + "tfs_z": float, + "num_predict": int, + "top_k": int, + "top_p": float, + } + + data = {"base_model_id": None, "params": {}} + + # Parse base model + base_model_match = re.search( + r"^FROM\s+(\w+)", model_text, re.MULTILINE | re.IGNORECASE + ) + if base_model_match: + data["base_model_id"] = base_model_match.group(1) + + # Parse template + template_match = re.search( + r'TEMPLATE\s+"""(.+?)"""', model_text, re.DOTALL | re.IGNORECASE + ) + if template_match: + data["params"] = {"template": template_match.group(1).strip()} + + # Parse stops + stops = re.findall(r'PARAMETER stop "(.*?)"', model_text, re.IGNORECASE) + if stops: + data["params"]["stop"] = stops + + # Parse other parameters from the provided list + for param, param_type in parameters_meta.items(): + param_match = re.search(rf"PARAMETER {param} (.+)", model_text, re.IGNORECASE) + if param_match: + value = param_match.group(1) + if param_type == int: + value = int(value) + elif param_type == float: + value = float(value) + data["params"][param] = value + + # Parse adapter + adapter_match = re.search(r"ADAPTER (.+)", model_text, re.IGNORECASE) + if adapter_match: + data["params"]["adapter"] = adapter_match.group(1) + + # Parse system description + system_desc_match = re.search( + r'SYSTEM\s+"""(.+?)"""', model_text, re.DOTALL | re.IGNORECASE + ) + if system_desc_match: + data["params"]["system"] = system_desc_match.group(1).strip() + + # Parse messages + messages = [] + message_matches = re.findall(r"MESSAGE (\w+) (.+)", model_text, re.IGNORECASE) + for role, content in message_matches: + messages.append({"role": role, "content": content}) + + if messages: + data["params"]["messages"] = messages + + return data diff --git a/backend/utils/models.py b/backend/utils/models.py new file mode 100644 index 000000000..c4d675d29 --- /dev/null +++ b/backend/utils/models.py @@ -0,0 +1,10 @@ +from apps.webui.models.models import Models, ModelModel, ModelForm, ModelResponse + + +def get_model_id_from_custom_model_id(id: str): + model = Models.get_model_by_id(id) + + if model: + return model.id + else: + return id diff --git a/backend/utils/utils.py b/backend/utils/utils.py index af4fd85c0..cc6bb06b8 100644 --- a/backend/utils/utils.py +++ b/backend/utils/utils.py @@ -1,7 +1,7 @@ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from fastapi import HTTPException, status, Depends -from apps.web.models.users import Users +from apps.webui.models.users import Users from pydantic import BaseModel from typing import Union, Optional diff --git a/cypress/e2e/chat.cy.ts b/cypress/e2e/chat.cy.ts index ced998104..ddb33d6c0 100644 --- a/cypress/e2e/chat.cy.ts +++ b/cypress/e2e/chat.cy.ts @@ -74,5 +74,28 @@ describe('Settings', () => { expect(spy).to.be.callCount(2); }); }); + + it('user can generate image', () => { + // Click on the model selector + cy.get('button[aria-label="Select a model"]').click(); + // Select the first model + cy.get('button[aria-label="model-item"]').first().click(); + // Type a message + cy.get('#chat-textarea').type('Hi, what can you do? A single sentence only please.', { + force: true + }); + // Send the message + cy.get('button[type="submit"]').click(); + // User's message should be visible + cy.get('.chat-user').should('exist'); + // Wait for the response + cy.get('.chat-assistant', { timeout: 120_000 }) // .chat-assistant is created after the first token is received + .find('div[aria-label="Generation Info"]', { timeout: 120_000 }) // Generation Info is created after the stop token is received + .should('exist'); + // Click on the generate image button + cy.get('[aria-label="Generate Image"]').click(); + // Wait for image to be visible + cy.get('img[data-cy="image"]', { timeout: 60_000 }).should('be.visible'); + }); }); }); diff --git a/docker-compose.a1111-test.yaml b/docker-compose.a1111-test.yaml new file mode 100644 index 000000000..e6ab12c07 --- /dev/null +++ b/docker-compose.a1111-test.yaml @@ -0,0 +1,31 @@ +# This is an overlay that spins up stable-diffusion-webui for integration testing +# This is not designed to be used in production +services: + stable-diffusion-webui: + # Not built for ARM64 + platform: linux/amd64 + image: ghcr.io/neggles/sd-webui-docker:latest + restart: unless-stopped + environment: + CLI_ARGS: "--api --use-cpu all --precision full --no-half --skip-torch-cuda-test --ckpt /empty.pt --do-not-download-clip --disable-nan-check --disable-opt-split-attention" + PYTHONUNBUFFERED: "1" + TERM: "vt100" + SD_WEBUI_VARIANT: "default" + # Hack to get container working on Apple Silicon + # Rosetta creates a conflict ${HOME}/.cache folder + entrypoint: /bin/bash + command: + - -c + - | + export HOME=/root-home + rm -rf $${HOME}/.cache + /docker/entrypoint.sh python -u webui.py --listen --port $${WEBUI_PORT} --skip-version-check $${CLI_ARGS} + volumes: + - ./test/test_files/image_gen/sd-empty.pt:/empty.pt + + open-webui: + environment: + ENABLE_IMAGE_GENERATION: "true" + AUTOMATIC1111_BASE_URL: http://stable-diffusion-webui:7860 + IMAGE_SIZE: "64x64" + IMAGE_STEPS: "3" diff --git a/hatch_build.py b/hatch_build.py new file mode 100644 index 000000000..8ddaf0749 --- /dev/null +++ b/hatch_build.py @@ -0,0 +1,23 @@ +# noqa: INP001 +import os +import shutil +import subprocess +from sys import stderr + +from hatchling.builders.hooks.plugin.interface import BuildHookInterface + + +class CustomBuildHook(BuildHookInterface): + def initialize(self, version, build_data): + super().initialize(version, build_data) + stderr.write(">>> Building Open Webui frontend\n") + npm = shutil.which("npm") + if npm is None: + raise RuntimeError( + "NodeJS `npm` is required for building Open Webui but it was not found" + ) + stderr.write("### npm install\n") + subprocess.run([npm, "install"], check=True) # noqa: S603 + stderr.write("\n### npm run build\n") + os.environ["APP_BUILD_HASH"] = version + subprocess.run([npm, "run", "build"], check=True) # noqa: S603 diff --git a/package-lock.json b/package-lock.json index 5f98d38f6..5b1461939 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "open-webui", - "version": "0.1.125", + "version": "0.2.0.dev2", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "open-webui", - "version": "0.1.125", + "version": "0.2.0.dev2", "dependencies": { "@pyscript/core": "^0.4.32", "@sveltejs/adapter-node": "^1.3.1", diff --git a/package.json b/package.json index 2b412e310..8522cffe5 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "open-webui", - "version": "0.1.125", + "version": "0.2.0.dev2", "private": true, "scripts": { "dev": "npm run pyodide:fetch && vite dev --host", @@ -13,7 +13,7 @@ "lint:types": "npm run check", "lint:backend": "pylint backend/", "format": "prettier --plugin-search-dir --write \"**/*.{js,ts,svelte,css,md,html,json}\"", - "format:backend": "black . --exclude \"/venv/\"", + "format:backend": "black . --exclude \".venv/|/venv/\"", "i18n:parse": "i18next --config i18next-parser.config.ts && prettier --write \"src/lib/i18n/**/*.{js,json}\"", "cy:open": "cypress open", "test:frontend": "vitest", diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 000000000..004ce374b --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,115 @@ +[project] +name = "open-webui" +description = "Open WebUI (Formerly Ollama WebUI)" +authors = [ + { name = "Timothy Jaeryang Baek", email = "tim@openwebui.com" } +] +license = { file = "LICENSE" } +dependencies = [ + "fastapi==0.111.0", + "uvicorn[standard]==0.22.0", + "pydantic==2.7.1", + "python-multipart==0.0.9", + + "Flask==3.0.3", + "Flask-Cors==4.0.1", + + "python-socketio==5.11.2", + "python-jose==3.3.0", + "passlib[bcrypt]==1.7.4", + + "requests==2.32.2", + "aiohttp==3.9.5", + "peewee==3.17.5", + "peewee-migrate==1.12.2", + "psycopg2-binary==2.9.9", + "PyMySQL==1.1.0", + "bcrypt==4.1.3", + + "litellm[proxy]==1.37.20", + + "boto3==1.34.110", + + "argon2-cffi==23.1.0", + "APScheduler==3.10.4", + "google-generativeai==0.5.4", + + "langchain==0.2.0", + "langchain-community==0.2.0", + "langchain-chroma==0.1.1", + + "fake-useragent==1.5.1", + "chromadb==0.5.0", + "sentence-transformers==2.7.0", + "pypdf==4.2.0", + "docx2txt==0.8", + "unstructured==0.14.0", + "Markdown==3.6", + "pypandoc==1.13", + "pandas==2.2.2", + "openpyxl==3.1.2", + "pyxlsb==1.0.10", + "xlrd==2.0.1", + "validators==0.28.1", + + "opencv-python-headless==4.9.0.80", + "rapidocr-onnxruntime==1.3.22", + + "fpdf2==2.7.9", + "rank-bm25==0.2.2", + + "faster-whisper==1.0.2", + + "PyJWT[crypto]==2.8.0", + + "black==24.4.2", + "langfuse==2.33.0", + "youtube-transcript-api==0.6.2", + "pytube==15.0.0", +] +readme = "README.md" +requires-python = ">= 3.11, < 3.12.0a1" +dynamic = ["version"] +classifiers = [ + "Development Status :: 4 - Beta", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.11", + "Topic :: Communications :: Chat", + "Topic :: Multimedia", +] + +[project.scripts] +open-webui = "open_webui:app" + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.rye] +managed = true +dev-dependencies = [] + +[tool.hatch.metadata] +allow-direct-references = true + +[tool.hatch.version] +path = "package.json" +pattern = '"version":\s*"(?P[^"]+)"' + +[tool.hatch.build.hooks.custom] # keep this for reading hooks from `hatch_build.py` + +[tool.hatch.build.targets.wheel] +sources = ["backend"] +exclude = [ + ".dockerignore", + ".gitignore", + ".webui_secret_key", + "dev.sh", + "requirements.txt", + "start.sh", + "start_windows.bat", + "webui.db", + "chroma.sqlite3", +] +force-include = { "CHANGELOG.md" = "open_webui/CHANGELOG.md", build = "open_webui/frontend" } diff --git a/requirements-dev.lock b/requirements-dev.lock new file mode 100644 index 000000000..39b1d0ef0 --- /dev/null +++ b/requirements-dev.lock @@ -0,0 +1,688 @@ +# generated by rye +# use `rye lock` or `rye sync` to update this lockfile +# +# last locked with the following flags: +# pre: false +# features: [] +# all-features: false +# with-sources: false +# generate-hashes: false + +-e file:. +aiohttp==3.9.5 + # via langchain + # via langchain-community + # via litellm + # via open-webui +aiosignal==1.3.1 + # via aiohttp +annotated-types==0.6.0 + # via pydantic +anyio==4.3.0 + # via httpx + # via openai + # via starlette + # via watchfiles +apscheduler==3.10.4 + # via litellm + # via open-webui +argon2-cffi==23.1.0 + # via open-webui +argon2-cffi-bindings==21.2.0 + # via argon2-cffi +asgiref==3.8.1 + # via opentelemetry-instrumentation-asgi +attrs==23.2.0 + # via aiohttp +av==11.0.0 + # via faster-whisper +backoff==2.2.1 + # via langfuse + # via litellm + # via posthog + # via unstructured +bcrypt==4.1.3 + # via chromadb + # via open-webui + # via passlib +beautifulsoup4==4.12.3 + # via unstructured +bidict==0.23.1 + # via python-socketio +black==24.4.2 + # via open-webui +blinker==1.8.2 + # via flask +boto3==1.34.110 + # via open-webui +botocore==1.34.110 + # via boto3 + # via s3transfer +build==1.2.1 + # via chromadb +cachetools==5.3.3 + # via google-auth +certifi==2024.2.2 + # via httpcore + # via httpx + # via kubernetes + # via requests + # via unstructured-client +cffi==1.16.0 + # via argon2-cffi-bindings + # via cryptography +chardet==5.2.0 + # via unstructured +charset-normalizer==3.3.2 + # via requests + # via unstructured-client +chroma-hnswlib==0.7.3 + # via chromadb +chromadb==0.5.0 + # via langchain-chroma + # via open-webui +click==8.1.7 + # via black + # via flask + # via litellm + # via nltk + # via peewee-migrate + # via rq + # via typer + # via uvicorn +coloredlogs==15.0.1 + # via onnxruntime +cryptography==42.0.7 + # via litellm + # via pyjwt +ctranslate2==4.2.1 + # via faster-whisper +dataclasses-json==0.6.6 + # via langchain + # via langchain-community + # via unstructured + # via unstructured-client +deepdiff==7.0.1 + # via unstructured-client +defusedxml==0.7.1 + # via fpdf2 +deprecated==1.2.14 + # via opentelemetry-api + # via opentelemetry-exporter-otlp-proto-grpc +distro==1.9.0 + # via openai +dnspython==2.6.1 + # via email-validator +docx2txt==0.8 + # via open-webui +ecdsa==0.19.0 + # via python-jose +email-validator==2.1.1 + # via fastapi + # via pydantic +emoji==2.11.1 + # via unstructured +et-xmlfile==1.1.0 + # via openpyxl +fake-useragent==1.5.1 + # via open-webui +fastapi==0.111.0 + # via chromadb + # via fastapi-sso + # via langchain-chroma + # via litellm + # via open-webui +fastapi-cli==0.0.4 + # via fastapi +fastapi-sso==0.10.0 + # via litellm +faster-whisper==1.0.2 + # via open-webui +filelock==3.14.0 + # via huggingface-hub + # via torch + # via transformers +filetype==1.2.0 + # via unstructured +flask==3.0.3 + # via flask-cors + # via open-webui +flask-cors==4.0.1 + # via open-webui +flatbuffers==24.3.25 + # via onnxruntime +fonttools==4.51.0 + # via fpdf2 +fpdf2==2.7.9 + # via open-webui +frozenlist==1.4.1 + # via aiohttp + # via aiosignal +fsspec==2024.3.1 + # via huggingface-hub + # via torch +google-ai-generativelanguage==0.6.4 + # via google-generativeai +google-api-core==2.19.0 + # via google-ai-generativelanguage + # via google-api-python-client + # via google-generativeai +google-api-python-client==2.129.0 + # via google-generativeai +google-auth==2.29.0 + # via google-ai-generativelanguage + # via google-api-core + # via google-api-python-client + # via google-auth-httplib2 + # via google-generativeai + # via kubernetes +google-auth-httplib2==0.2.0 + # via google-api-python-client +google-generativeai==0.5.4 + # via open-webui +googleapis-common-protos==1.63.0 + # via google-api-core + # via grpcio-status + # via opentelemetry-exporter-otlp-proto-grpc +grpcio==1.63.0 + # via chromadb + # via google-api-core + # via grpcio-status + # via opentelemetry-exporter-otlp-proto-grpc +grpcio-status==1.62.2 + # via google-api-core +gunicorn==22.0.0 + # via litellm +h11==0.14.0 + # via httpcore + # via uvicorn + # via wsproto +httpcore==1.0.5 + # via httpx +httplib2==0.22.0 + # via google-api-python-client + # via google-auth-httplib2 +httptools==0.6.1 + # via uvicorn +httpx==0.27.0 + # via fastapi + # via fastapi-sso + # via langfuse + # via openai +huggingface-hub==0.23.0 + # via faster-whisper + # via sentence-transformers + # via tokenizers + # via transformers +humanfriendly==10.0 + # via coloredlogs +idna==3.7 + # via anyio + # via email-validator + # via httpx + # via langfuse + # via requests + # via unstructured-client + # via yarl +importlib-metadata==7.0.0 + # via litellm + # via opentelemetry-api +importlib-resources==6.4.0 + # via chromadb +itsdangerous==2.2.0 + # via flask +jinja2==3.1.4 + # via fastapi + # via flask + # via litellm + # via torch +jmespath==1.0.1 + # via boto3 + # via botocore +joblib==1.4.2 + # via nltk + # via scikit-learn +jsonpatch==1.33 + # via langchain-core +jsonpath-python==1.0.6 + # via unstructured-client +jsonpointer==2.4 + # via jsonpatch +kubernetes==29.0.0 + # via chromadb +langchain==0.2.0 + # via langchain-community + # via open-webui +langchain-chroma==0.1.1 + # via open-webui +langchain-community==0.2.0 + # via open-webui +langchain-core==0.2.1 + # via langchain + # via langchain-chroma + # via langchain-community + # via langchain-text-splitters +langchain-text-splitters==0.2.0 + # via langchain +langdetect==1.0.9 + # via unstructured +langfuse==2.33.0 + # via open-webui +langsmith==0.1.57 + # via langchain + # via langchain-community + # via langchain-core +litellm==1.37.20 + # via open-webui +lxml==5.2.2 + # via unstructured +markdown==3.6 + # via open-webui +markdown-it-py==3.0.0 + # via rich +markupsafe==2.1.5 + # via jinja2 + # via werkzeug +marshmallow==3.21.2 + # via dataclasses-json + # via unstructured-client +mdurl==0.1.2 + # via markdown-it-py +mmh3==4.1.0 + # via chromadb +monotonic==1.6 + # via posthog +mpmath==1.3.0 + # via sympy +multidict==6.0.5 + # via aiohttp + # via yarl +mypy-extensions==1.0.0 + # via black + # via typing-inspect + # via unstructured-client +networkx==3.3 + # via torch +nltk==3.8.1 + # via unstructured +numpy==1.26.4 + # via chroma-hnswlib + # via chromadb + # via ctranslate2 + # via langchain + # via langchain-chroma + # via langchain-community + # via onnxruntime + # via opencv-python + # via opencv-python-headless + # via pandas + # via rank-bm25 + # via rapidocr-onnxruntime + # via scikit-learn + # via scipy + # via sentence-transformers + # via shapely + # via transformers + # via unstructured +oauthlib==3.2.2 + # via fastapi-sso + # via kubernetes + # via requests-oauthlib +onnxruntime==1.17.3 + # via chromadb + # via faster-whisper + # via rapidocr-onnxruntime +openai==1.28.1 + # via litellm +opencv-python==4.9.0.80 + # via rapidocr-onnxruntime +opencv-python-headless==4.9.0.80 + # via open-webui +openpyxl==3.1.2 + # via open-webui +opentelemetry-api==1.24.0 + # via chromadb + # via opentelemetry-exporter-otlp-proto-grpc + # via opentelemetry-instrumentation + # via opentelemetry-instrumentation-asgi + # via opentelemetry-instrumentation-fastapi + # via opentelemetry-sdk +opentelemetry-exporter-otlp-proto-common==1.24.0 + # via opentelemetry-exporter-otlp-proto-grpc +opentelemetry-exporter-otlp-proto-grpc==1.24.0 + # via chromadb +opentelemetry-instrumentation==0.45b0 + # via opentelemetry-instrumentation-asgi + # via opentelemetry-instrumentation-fastapi +opentelemetry-instrumentation-asgi==0.45b0 + # via opentelemetry-instrumentation-fastapi +opentelemetry-instrumentation-fastapi==0.45b0 + # via chromadb +opentelemetry-proto==1.24.0 + # via opentelemetry-exporter-otlp-proto-common + # via opentelemetry-exporter-otlp-proto-grpc +opentelemetry-sdk==1.24.0 + # via chromadb + # via opentelemetry-exporter-otlp-proto-grpc +opentelemetry-semantic-conventions==0.45b0 + # via opentelemetry-instrumentation-asgi + # via opentelemetry-instrumentation-fastapi + # via opentelemetry-sdk +opentelemetry-util-http==0.45b0 + # via opentelemetry-instrumentation-asgi + # via opentelemetry-instrumentation-fastapi +ordered-set==4.1.0 + # via deepdiff +orjson==3.10.3 + # via chromadb + # via fastapi + # via langsmith + # via litellm +overrides==7.7.0 + # via chromadb +packaging==23.2 + # via black + # via build + # via gunicorn + # via huggingface-hub + # via langchain-core + # via langfuse + # via marshmallow + # via onnxruntime + # via transformers + # via unstructured-client +pandas==2.2.2 + # via open-webui +passlib==1.7.4 + # via open-webui +pathspec==0.12.1 + # via black +peewee==3.17.5 + # via open-webui + # via peewee-migrate +peewee-migrate==1.12.2 + # via open-webui +pillow==10.3.0 + # via fpdf2 + # via rapidocr-onnxruntime + # via sentence-transformers +platformdirs==4.2.1 + # via black +posthog==3.5.0 + # via chromadb +proto-plus==1.23.0 + # via google-ai-generativelanguage + # via google-api-core +protobuf==4.25.3 + # via google-ai-generativelanguage + # via google-api-core + # via google-generativeai + # via googleapis-common-protos + # via grpcio-status + # via onnxruntime + # via opentelemetry-proto + # via proto-plus +psycopg2-binary==2.9.9 + # via open-webui +pyasn1==0.6.0 + # via pyasn1-modules + # via python-jose + # via rsa +pyasn1-modules==0.4.0 + # via google-auth +pyclipper==1.3.0.post5 + # via rapidocr-onnxruntime +pycparser==2.22 + # via cffi +pydantic==2.7.1 + # via chromadb + # via fastapi + # via fastapi-sso + # via google-generativeai + # via langchain + # via langchain-core + # via langfuse + # via langsmith + # via open-webui + # via openai +pydantic-core==2.18.2 + # via pydantic +pygments==2.18.0 + # via rich +pyjwt==2.8.0 + # via litellm + # via open-webui +pymysql==1.1.0 + # via open-webui +pypandoc==1.13 + # via open-webui +pyparsing==3.1.2 + # via httplib2 +pypdf==4.2.0 + # via open-webui + # via unstructured-client +pypika==0.48.9 + # via chromadb +pyproject-hooks==1.1.0 + # via build +python-dateutil==2.9.0.post0 + # via botocore + # via kubernetes + # via pandas + # via posthog + # via unstructured-client +python-dotenv==1.0.1 + # via litellm + # via uvicorn +python-engineio==4.9.0 + # via python-socketio +python-iso639==2024.4.27 + # via unstructured +python-jose==3.3.0 + # via open-webui +python-magic==0.4.27 + # via unstructured +python-multipart==0.0.9 + # via fastapi + # via litellm + # via open-webui +python-socketio==5.11.2 + # via open-webui +pytube==15.0.0 + # via open-webui +pytz==2024.1 + # via apscheduler + # via pandas +pyxlsb==1.0.10 + # via open-webui +pyyaml==6.0.1 + # via chromadb + # via ctranslate2 + # via huggingface-hub + # via kubernetes + # via langchain + # via langchain-community + # via langchain-core + # via litellm + # via rapidocr-onnxruntime + # via transformers + # via uvicorn +rank-bm25==0.2.2 + # via open-webui +rapidfuzz==3.9.0 + # via unstructured +rapidocr-onnxruntime==1.3.22 + # via open-webui +redis==5.0.4 + # via rq +regex==2024.5.10 + # via nltk + # via tiktoken + # via transformers +requests==2.32.2 + # via chromadb + # via google-api-core + # via huggingface-hub + # via kubernetes + # via langchain + # via langchain-community + # via langsmith + # via litellm + # via open-webui + # via posthog + # via requests-oauthlib + # via tiktoken + # via transformers + # via unstructured + # via unstructured-client + # via youtube-transcript-api +requests-oauthlib==2.0.0 + # via kubernetes +rich==13.7.1 + # via typer +rq==1.16.2 + # via litellm +rsa==4.9 + # via google-auth + # via python-jose +s3transfer==0.10.1 + # via boto3 +safetensors==0.4.3 + # via transformers +scikit-learn==1.4.2 + # via sentence-transformers +scipy==1.13.0 + # via scikit-learn + # via sentence-transformers +sentence-transformers==2.7.0 + # via open-webui +setuptools==69.5.1 + # via ctranslate2 + # via opentelemetry-instrumentation +shapely==2.0.4 + # via rapidocr-onnxruntime +shellingham==1.5.4 + # via typer +simple-websocket==1.0.0 + # via python-engineio +six==1.16.0 + # via apscheduler + # via ecdsa + # via kubernetes + # via langdetect + # via posthog + # via python-dateutil + # via rapidocr-onnxruntime + # via unstructured-client +sniffio==1.3.1 + # via anyio + # via httpx + # via openai +soupsieve==2.5 + # via beautifulsoup4 +sqlalchemy==2.0.30 + # via langchain + # via langchain-community +starlette==0.37.2 + # via fastapi +sympy==1.12 + # via onnxruntime + # via torch +tabulate==0.9.0 + # via unstructured +tenacity==8.3.0 + # via chromadb + # via langchain + # via langchain-community + # via langchain-core +threadpoolctl==3.5.0 + # via scikit-learn +tiktoken==0.6.0 + # via litellm +tokenizers==0.15.2 + # via chromadb + # via faster-whisper + # via litellm + # via transformers +torch==2.3.0 + # via sentence-transformers +tqdm==4.66.4 + # via chromadb + # via google-generativeai + # via huggingface-hub + # via nltk + # via openai + # via sentence-transformers + # via transformers +transformers==4.39.3 + # via sentence-transformers +typer==0.12.3 + # via chromadb + # via fastapi-cli +typing-extensions==4.11.0 + # via chromadb + # via fastapi + # via google-generativeai + # via huggingface-hub + # via openai + # via opentelemetry-sdk + # via pydantic + # via pydantic-core + # via sqlalchemy + # via torch + # via typer + # via typing-inspect + # via unstructured + # via unstructured-client +typing-inspect==0.9.0 + # via dataclasses-json + # via unstructured-client +tzdata==2024.1 + # via pandas +tzlocal==5.2 + # via apscheduler +ujson==5.10.0 + # via fastapi +unstructured==0.14.0 + # via open-webui +unstructured-client==0.22.0 + # via unstructured +uritemplate==4.1.1 + # via google-api-python-client +urllib3==2.2.1 + # via botocore + # via kubernetes + # via requests + # via unstructured-client +uvicorn==0.22.0 + # via chromadb + # via fastapi + # via litellm + # via open-webui +uvloop==0.19.0 + # via uvicorn +validators==0.28.1 + # via open-webui +watchfiles==0.21.0 + # via uvicorn +websocket-client==1.8.0 + # via kubernetes +websockets==12.0 + # via uvicorn +werkzeug==3.0.3 + # via flask +wrapt==1.16.0 + # via deprecated + # via langfuse + # via opentelemetry-instrumentation + # via unstructured +wsproto==1.2.0 + # via simple-websocket +xlrd==2.0.1 + # via open-webui +yarl==1.9.4 + # via aiohttp +youtube-transcript-api==0.6.2 + # via open-webui +zipp==3.18.1 + # via importlib-metadata diff --git a/requirements.lock b/requirements.lock new file mode 100644 index 000000000..39b1d0ef0 --- /dev/null +++ b/requirements.lock @@ -0,0 +1,688 @@ +# generated by rye +# use `rye lock` or `rye sync` to update this lockfile +# +# last locked with the following flags: +# pre: false +# features: [] +# all-features: false +# with-sources: false +# generate-hashes: false + +-e file:. +aiohttp==3.9.5 + # via langchain + # via langchain-community + # via litellm + # via open-webui +aiosignal==1.3.1 + # via aiohttp +annotated-types==0.6.0 + # via pydantic +anyio==4.3.0 + # via httpx + # via openai + # via starlette + # via watchfiles +apscheduler==3.10.4 + # via litellm + # via open-webui +argon2-cffi==23.1.0 + # via open-webui +argon2-cffi-bindings==21.2.0 + # via argon2-cffi +asgiref==3.8.1 + # via opentelemetry-instrumentation-asgi +attrs==23.2.0 + # via aiohttp +av==11.0.0 + # via faster-whisper +backoff==2.2.1 + # via langfuse + # via litellm + # via posthog + # via unstructured +bcrypt==4.1.3 + # via chromadb + # via open-webui + # via passlib +beautifulsoup4==4.12.3 + # via unstructured +bidict==0.23.1 + # via python-socketio +black==24.4.2 + # via open-webui +blinker==1.8.2 + # via flask +boto3==1.34.110 + # via open-webui +botocore==1.34.110 + # via boto3 + # via s3transfer +build==1.2.1 + # via chromadb +cachetools==5.3.3 + # via google-auth +certifi==2024.2.2 + # via httpcore + # via httpx + # via kubernetes + # via requests + # via unstructured-client +cffi==1.16.0 + # via argon2-cffi-bindings + # via cryptography +chardet==5.2.0 + # via unstructured +charset-normalizer==3.3.2 + # via requests + # via unstructured-client +chroma-hnswlib==0.7.3 + # via chromadb +chromadb==0.5.0 + # via langchain-chroma + # via open-webui +click==8.1.7 + # via black + # via flask + # via litellm + # via nltk + # via peewee-migrate + # via rq + # via typer + # via uvicorn +coloredlogs==15.0.1 + # via onnxruntime +cryptography==42.0.7 + # via litellm + # via pyjwt +ctranslate2==4.2.1 + # via faster-whisper +dataclasses-json==0.6.6 + # via langchain + # via langchain-community + # via unstructured + # via unstructured-client +deepdiff==7.0.1 + # via unstructured-client +defusedxml==0.7.1 + # via fpdf2 +deprecated==1.2.14 + # via opentelemetry-api + # via opentelemetry-exporter-otlp-proto-grpc +distro==1.9.0 + # via openai +dnspython==2.6.1 + # via email-validator +docx2txt==0.8 + # via open-webui +ecdsa==0.19.0 + # via python-jose +email-validator==2.1.1 + # via fastapi + # via pydantic +emoji==2.11.1 + # via unstructured +et-xmlfile==1.1.0 + # via openpyxl +fake-useragent==1.5.1 + # via open-webui +fastapi==0.111.0 + # via chromadb + # via fastapi-sso + # via langchain-chroma + # via litellm + # via open-webui +fastapi-cli==0.0.4 + # via fastapi +fastapi-sso==0.10.0 + # via litellm +faster-whisper==1.0.2 + # via open-webui +filelock==3.14.0 + # via huggingface-hub + # via torch + # via transformers +filetype==1.2.0 + # via unstructured +flask==3.0.3 + # via flask-cors + # via open-webui +flask-cors==4.0.1 + # via open-webui +flatbuffers==24.3.25 + # via onnxruntime +fonttools==4.51.0 + # via fpdf2 +fpdf2==2.7.9 + # via open-webui +frozenlist==1.4.1 + # via aiohttp + # via aiosignal +fsspec==2024.3.1 + # via huggingface-hub + # via torch +google-ai-generativelanguage==0.6.4 + # via google-generativeai +google-api-core==2.19.0 + # via google-ai-generativelanguage + # via google-api-python-client + # via google-generativeai +google-api-python-client==2.129.0 + # via google-generativeai +google-auth==2.29.0 + # via google-ai-generativelanguage + # via google-api-core + # via google-api-python-client + # via google-auth-httplib2 + # via google-generativeai + # via kubernetes +google-auth-httplib2==0.2.0 + # via google-api-python-client +google-generativeai==0.5.4 + # via open-webui +googleapis-common-protos==1.63.0 + # via google-api-core + # via grpcio-status + # via opentelemetry-exporter-otlp-proto-grpc +grpcio==1.63.0 + # via chromadb + # via google-api-core + # via grpcio-status + # via opentelemetry-exporter-otlp-proto-grpc +grpcio-status==1.62.2 + # via google-api-core +gunicorn==22.0.0 + # via litellm +h11==0.14.0 + # via httpcore + # via uvicorn + # via wsproto +httpcore==1.0.5 + # via httpx +httplib2==0.22.0 + # via google-api-python-client + # via google-auth-httplib2 +httptools==0.6.1 + # via uvicorn +httpx==0.27.0 + # via fastapi + # via fastapi-sso + # via langfuse + # via openai +huggingface-hub==0.23.0 + # via faster-whisper + # via sentence-transformers + # via tokenizers + # via transformers +humanfriendly==10.0 + # via coloredlogs +idna==3.7 + # via anyio + # via email-validator + # via httpx + # via langfuse + # via requests + # via unstructured-client + # via yarl +importlib-metadata==7.0.0 + # via litellm + # via opentelemetry-api +importlib-resources==6.4.0 + # via chromadb +itsdangerous==2.2.0 + # via flask +jinja2==3.1.4 + # via fastapi + # via flask + # via litellm + # via torch +jmespath==1.0.1 + # via boto3 + # via botocore +joblib==1.4.2 + # via nltk + # via scikit-learn +jsonpatch==1.33 + # via langchain-core +jsonpath-python==1.0.6 + # via unstructured-client +jsonpointer==2.4 + # via jsonpatch +kubernetes==29.0.0 + # via chromadb +langchain==0.2.0 + # via langchain-community + # via open-webui +langchain-chroma==0.1.1 + # via open-webui +langchain-community==0.2.0 + # via open-webui +langchain-core==0.2.1 + # via langchain + # via langchain-chroma + # via langchain-community + # via langchain-text-splitters +langchain-text-splitters==0.2.0 + # via langchain +langdetect==1.0.9 + # via unstructured +langfuse==2.33.0 + # via open-webui +langsmith==0.1.57 + # via langchain + # via langchain-community + # via langchain-core +litellm==1.37.20 + # via open-webui +lxml==5.2.2 + # via unstructured +markdown==3.6 + # via open-webui +markdown-it-py==3.0.0 + # via rich +markupsafe==2.1.5 + # via jinja2 + # via werkzeug +marshmallow==3.21.2 + # via dataclasses-json + # via unstructured-client +mdurl==0.1.2 + # via markdown-it-py +mmh3==4.1.0 + # via chromadb +monotonic==1.6 + # via posthog +mpmath==1.3.0 + # via sympy +multidict==6.0.5 + # via aiohttp + # via yarl +mypy-extensions==1.0.0 + # via black + # via typing-inspect + # via unstructured-client +networkx==3.3 + # via torch +nltk==3.8.1 + # via unstructured +numpy==1.26.4 + # via chroma-hnswlib + # via chromadb + # via ctranslate2 + # via langchain + # via langchain-chroma + # via langchain-community + # via onnxruntime + # via opencv-python + # via opencv-python-headless + # via pandas + # via rank-bm25 + # via rapidocr-onnxruntime + # via scikit-learn + # via scipy + # via sentence-transformers + # via shapely + # via transformers + # via unstructured +oauthlib==3.2.2 + # via fastapi-sso + # via kubernetes + # via requests-oauthlib +onnxruntime==1.17.3 + # via chromadb + # via faster-whisper + # via rapidocr-onnxruntime +openai==1.28.1 + # via litellm +opencv-python==4.9.0.80 + # via rapidocr-onnxruntime +opencv-python-headless==4.9.0.80 + # via open-webui +openpyxl==3.1.2 + # via open-webui +opentelemetry-api==1.24.0 + # via chromadb + # via opentelemetry-exporter-otlp-proto-grpc + # via opentelemetry-instrumentation + # via opentelemetry-instrumentation-asgi + # via opentelemetry-instrumentation-fastapi + # via opentelemetry-sdk +opentelemetry-exporter-otlp-proto-common==1.24.0 + # via opentelemetry-exporter-otlp-proto-grpc +opentelemetry-exporter-otlp-proto-grpc==1.24.0 + # via chromadb +opentelemetry-instrumentation==0.45b0 + # via opentelemetry-instrumentation-asgi + # via opentelemetry-instrumentation-fastapi +opentelemetry-instrumentation-asgi==0.45b0 + # via opentelemetry-instrumentation-fastapi +opentelemetry-instrumentation-fastapi==0.45b0 + # via chromadb +opentelemetry-proto==1.24.0 + # via opentelemetry-exporter-otlp-proto-common + # via opentelemetry-exporter-otlp-proto-grpc +opentelemetry-sdk==1.24.0 + # via chromadb + # via opentelemetry-exporter-otlp-proto-grpc +opentelemetry-semantic-conventions==0.45b0 + # via opentelemetry-instrumentation-asgi + # via opentelemetry-instrumentation-fastapi + # via opentelemetry-sdk +opentelemetry-util-http==0.45b0 + # via opentelemetry-instrumentation-asgi + # via opentelemetry-instrumentation-fastapi +ordered-set==4.1.0 + # via deepdiff +orjson==3.10.3 + # via chromadb + # via fastapi + # via langsmith + # via litellm +overrides==7.7.0 + # via chromadb +packaging==23.2 + # via black + # via build + # via gunicorn + # via huggingface-hub + # via langchain-core + # via langfuse + # via marshmallow + # via onnxruntime + # via transformers + # via unstructured-client +pandas==2.2.2 + # via open-webui +passlib==1.7.4 + # via open-webui +pathspec==0.12.1 + # via black +peewee==3.17.5 + # via open-webui + # via peewee-migrate +peewee-migrate==1.12.2 + # via open-webui +pillow==10.3.0 + # via fpdf2 + # via rapidocr-onnxruntime + # via sentence-transformers +platformdirs==4.2.1 + # via black +posthog==3.5.0 + # via chromadb +proto-plus==1.23.0 + # via google-ai-generativelanguage + # via google-api-core +protobuf==4.25.3 + # via google-ai-generativelanguage + # via google-api-core + # via google-generativeai + # via googleapis-common-protos + # via grpcio-status + # via onnxruntime + # via opentelemetry-proto + # via proto-plus +psycopg2-binary==2.9.9 + # via open-webui +pyasn1==0.6.0 + # via pyasn1-modules + # via python-jose + # via rsa +pyasn1-modules==0.4.0 + # via google-auth +pyclipper==1.3.0.post5 + # via rapidocr-onnxruntime +pycparser==2.22 + # via cffi +pydantic==2.7.1 + # via chromadb + # via fastapi + # via fastapi-sso + # via google-generativeai + # via langchain + # via langchain-core + # via langfuse + # via langsmith + # via open-webui + # via openai +pydantic-core==2.18.2 + # via pydantic +pygments==2.18.0 + # via rich +pyjwt==2.8.0 + # via litellm + # via open-webui +pymysql==1.1.0 + # via open-webui +pypandoc==1.13 + # via open-webui +pyparsing==3.1.2 + # via httplib2 +pypdf==4.2.0 + # via open-webui + # via unstructured-client +pypika==0.48.9 + # via chromadb +pyproject-hooks==1.1.0 + # via build +python-dateutil==2.9.0.post0 + # via botocore + # via kubernetes + # via pandas + # via posthog + # via unstructured-client +python-dotenv==1.0.1 + # via litellm + # via uvicorn +python-engineio==4.9.0 + # via python-socketio +python-iso639==2024.4.27 + # via unstructured +python-jose==3.3.0 + # via open-webui +python-magic==0.4.27 + # via unstructured +python-multipart==0.0.9 + # via fastapi + # via litellm + # via open-webui +python-socketio==5.11.2 + # via open-webui +pytube==15.0.0 + # via open-webui +pytz==2024.1 + # via apscheduler + # via pandas +pyxlsb==1.0.10 + # via open-webui +pyyaml==6.0.1 + # via chromadb + # via ctranslate2 + # via huggingface-hub + # via kubernetes + # via langchain + # via langchain-community + # via langchain-core + # via litellm + # via rapidocr-onnxruntime + # via transformers + # via uvicorn +rank-bm25==0.2.2 + # via open-webui +rapidfuzz==3.9.0 + # via unstructured +rapidocr-onnxruntime==1.3.22 + # via open-webui +redis==5.0.4 + # via rq +regex==2024.5.10 + # via nltk + # via tiktoken + # via transformers +requests==2.32.2 + # via chromadb + # via google-api-core + # via huggingface-hub + # via kubernetes + # via langchain + # via langchain-community + # via langsmith + # via litellm + # via open-webui + # via posthog + # via requests-oauthlib + # via tiktoken + # via transformers + # via unstructured + # via unstructured-client + # via youtube-transcript-api +requests-oauthlib==2.0.0 + # via kubernetes +rich==13.7.1 + # via typer +rq==1.16.2 + # via litellm +rsa==4.9 + # via google-auth + # via python-jose +s3transfer==0.10.1 + # via boto3 +safetensors==0.4.3 + # via transformers +scikit-learn==1.4.2 + # via sentence-transformers +scipy==1.13.0 + # via scikit-learn + # via sentence-transformers +sentence-transformers==2.7.0 + # via open-webui +setuptools==69.5.1 + # via ctranslate2 + # via opentelemetry-instrumentation +shapely==2.0.4 + # via rapidocr-onnxruntime +shellingham==1.5.4 + # via typer +simple-websocket==1.0.0 + # via python-engineio +six==1.16.0 + # via apscheduler + # via ecdsa + # via kubernetes + # via langdetect + # via posthog + # via python-dateutil + # via rapidocr-onnxruntime + # via unstructured-client +sniffio==1.3.1 + # via anyio + # via httpx + # via openai +soupsieve==2.5 + # via beautifulsoup4 +sqlalchemy==2.0.30 + # via langchain + # via langchain-community +starlette==0.37.2 + # via fastapi +sympy==1.12 + # via onnxruntime + # via torch +tabulate==0.9.0 + # via unstructured +tenacity==8.3.0 + # via chromadb + # via langchain + # via langchain-community + # via langchain-core +threadpoolctl==3.5.0 + # via scikit-learn +tiktoken==0.6.0 + # via litellm +tokenizers==0.15.2 + # via chromadb + # via faster-whisper + # via litellm + # via transformers +torch==2.3.0 + # via sentence-transformers +tqdm==4.66.4 + # via chromadb + # via google-generativeai + # via huggingface-hub + # via nltk + # via openai + # via sentence-transformers + # via transformers +transformers==4.39.3 + # via sentence-transformers +typer==0.12.3 + # via chromadb + # via fastapi-cli +typing-extensions==4.11.0 + # via chromadb + # via fastapi + # via google-generativeai + # via huggingface-hub + # via openai + # via opentelemetry-sdk + # via pydantic + # via pydantic-core + # via sqlalchemy + # via torch + # via typer + # via typing-inspect + # via unstructured + # via unstructured-client +typing-inspect==0.9.0 + # via dataclasses-json + # via unstructured-client +tzdata==2024.1 + # via pandas +tzlocal==5.2 + # via apscheduler +ujson==5.10.0 + # via fastapi +unstructured==0.14.0 + # via open-webui +unstructured-client==0.22.0 + # via unstructured +uritemplate==4.1.1 + # via google-api-python-client +urllib3==2.2.1 + # via botocore + # via kubernetes + # via requests + # via unstructured-client +uvicorn==0.22.0 + # via chromadb + # via fastapi + # via litellm + # via open-webui +uvloop==0.19.0 + # via uvicorn +validators==0.28.1 + # via open-webui +watchfiles==0.21.0 + # via uvicorn +websocket-client==1.8.0 + # via kubernetes +websockets==12.0 + # via uvicorn +werkzeug==3.0.3 + # via flask +wrapt==1.16.0 + # via deprecated + # via langfuse + # via opentelemetry-instrumentation + # via unstructured +wsproto==1.2.0 + # via simple-websocket +xlrd==2.0.1 + # via open-webui +yarl==1.9.4 + # via aiohttp +youtube-transcript-api==0.6.2 + # via open-webui +zipp==3.18.1 + # via importlib-metadata diff --git a/src/lib/apis/chats/index.ts b/src/lib/apis/chats/index.ts index a72b51939..834e29d29 100644 --- a/src/lib/apis/chats/index.ts +++ b/src/lib/apis/chats/index.ts @@ -654,3 +654,35 @@ export const deleteAllChats = async (token: string) => { return res; }; + +export const archiveAllChats = async (token: string) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/chats/archive/all`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; diff --git a/src/lib/apis/configs/index.ts b/src/lib/apis/configs/index.ts index 30d562ba4..4f53c53c8 100644 --- a/src/lib/apis/configs/index.ts +++ b/src/lib/apis/configs/index.ts @@ -1,4 +1,5 @@ import { WEBUI_API_BASE_URL } from '$lib/constants'; +import type { Banner } from '$lib/types'; export const setDefaultModels = async (token: string, models: string) => { let error = null; @@ -59,3 +60,60 @@ export const setDefaultPromptSuggestions = async (token: string, promptSuggestio return res; }; + +export const getBanners = async (token: string): Promise => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/configs/banners`, { + method: 'GET', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + error = err.detail; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const setBanners = async (token: string, banners: Banner[]) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/configs/banners`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + banners: banners + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + error = err.detail; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; diff --git a/src/lib/apis/index.ts b/src/lib/apis/index.ts index a610f7210..dc51abd52 100644 --- a/src/lib/apis/index.ts +++ b/src/lib/apis/index.ts @@ -1,4 +1,53 @@ -import { WEBUI_BASE_URL } from '$lib/constants'; +import { WEBUI_API_BASE_URL, WEBUI_BASE_URL } from '$lib/constants'; + +export const getModels = async (token: string = '') => { + let error = null; + + const res = await fetch(`${WEBUI_BASE_URL}/api/models`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + error = err; + return null; + }); + + if (error) { + throw error; + } + + let models = res?.data ?? []; + + models = models + .filter((models) => models) + .sort((a, b) => { + // Compare case-insensitively + const lowerA = a.name.toLowerCase(); + const lowerB = b.name.toLowerCase(); + + if (lowerA < lowerB) return -1; + if (lowerA > lowerB) return 1; + + // If same case-insensitively, sort by original strings, + // lowercase will come before uppercase due to ASCII values + if (a < b) return -1; + if (a > b) return 1; + + return 0; // They are equal + }); + + console.log(models); + return models; +}; export const getBackendConfig = async () => { let error = null; @@ -196,3 +245,131 @@ export const updateWebhookUrl = async (token: string, url: string) => { return res.url; }; + +export const getCommunitySharingEnabledStatus = async (token: string) => { + let error = null; + + const res = await fetch(`${WEBUI_BASE_URL}/api/community_sharing`, { + method: 'GET', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + error = err; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const toggleCommunitySharingEnabledStatus = async (token: string) => { + let error = null; + + const res = await fetch(`${WEBUI_BASE_URL}/api/community_sharing/toggle`, { + method: 'GET', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + error = err.detail; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const getModelConfig = async (token: string): Promise => { + let error = null; + + const res = await fetch(`${WEBUI_BASE_URL}/api/config/models`, { + method: 'GET', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + error = err; + return null; + }); + + if (error) { + throw error; + } + + return res.models; +}; + +export interface ModelConfig { + id: string; + name: string; + meta: ModelMeta; + base_model_id?: string; + params: ModelParams; +} + +export interface ModelMeta { + description?: string; + capabilities?: object; +} + +export interface ModelParams {} + +export type GlobalModelConfig = ModelConfig[]; + +export const updateModelConfig = async (token: string, config: GlobalModelConfig) => { + let error = null; + + const res = await fetch(`${WEBUI_BASE_URL}/api/config/models`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + models: config + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + error = err; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; diff --git a/src/lib/apis/litellm/index.ts b/src/lib/apis/litellm/index.ts deleted file mode 100644 index 643146b73..000000000 --- a/src/lib/apis/litellm/index.ts +++ /dev/null @@ -1,150 +0,0 @@ -import { LITELLM_API_BASE_URL } from '$lib/constants'; - -export const getLiteLLMModels = async (token: string = '') => { - let error = null; - - const res = await fetch(`${LITELLM_API_BASE_URL}/v1/models`, { - method: 'GET', - headers: { - Accept: 'application/json', - 'Content-Type': 'application/json', - ...(token && { authorization: `Bearer ${token}` }) - } - }) - .then(async (res) => { - if (!res.ok) throw await res.json(); - return res.json(); - }) - .catch((err) => { - console.log(err); - error = `LiteLLM: ${err?.error?.message ?? 'Network Problem'}`; - return []; - }); - - if (error) { - throw error; - } - - const models = Array.isArray(res) ? res : res?.data ?? null; - - return models - ? models - .map((model) => ({ - id: model.id, - name: model.name ?? model.id, - external: true, - source: 'LiteLLM' - })) - .sort((a, b) => { - return a.name.localeCompare(b.name); - }) - : models; -}; - -export const getLiteLLMModelInfo = async (token: string = '') => { - let error = null; - - const res = await fetch(`${LITELLM_API_BASE_URL}/model/info`, { - method: 'GET', - headers: { - Accept: 'application/json', - 'Content-Type': 'application/json', - ...(token && { authorization: `Bearer ${token}` }) - } - }) - .then(async (res) => { - if (!res.ok) throw await res.json(); - return res.json(); - }) - .catch((err) => { - console.log(err); - error = `LiteLLM: ${err?.error?.message ?? 'Network Problem'}`; - return []; - }); - - if (error) { - throw error; - } - - const models = Array.isArray(res) ? res : res?.data ?? null; - - return models; -}; - -type AddLiteLLMModelForm = { - name: string; - model: string; - api_base: string; - api_key: string; - rpm: string; - max_tokens: string; -}; - -export const addLiteLLMModel = async (token: string = '', payload: AddLiteLLMModelForm) => { - let error = null; - - const res = await fetch(`${LITELLM_API_BASE_URL}/model/new`, { - method: 'POST', - headers: { - Accept: 'application/json', - 'Content-Type': 'application/json', - ...(token && { authorization: `Bearer ${token}` }) - }, - body: JSON.stringify({ - model_name: payload.name, - litellm_params: { - model: payload.model, - ...(payload.api_base === '' ? {} : { api_base: payload.api_base }), - ...(payload.api_key === '' ? {} : { api_key: payload.api_key }), - ...(isNaN(parseInt(payload.rpm)) ? {} : { rpm: parseInt(payload.rpm) }), - ...(payload.max_tokens === '' ? {} : { max_tokens: payload.max_tokens }) - } - }) - }) - .then(async (res) => { - if (!res.ok) throw await res.json(); - return res.json(); - }) - .catch((err) => { - console.log(err); - error = `LiteLLM: ${err?.error?.message ?? 'Network Problem'}`; - return []; - }); - - if (error) { - throw error; - } - - return res; -}; - -export const deleteLiteLLMModel = async (token: string = '', id: string) => { - let error = null; - - const res = await fetch(`${LITELLM_API_BASE_URL}/model/delete`, { - method: 'POST', - headers: { - Accept: 'application/json', - 'Content-Type': 'application/json', - ...(token && { authorization: `Bearer ${token}` }) - }, - body: JSON.stringify({ - id: id - }) - }) - .then(async (res) => { - if (!res.ok) throw await res.json(); - return res.json(); - }) - .catch((err) => { - console.log(err); - error = `LiteLLM: ${err?.error?.message ?? 'Network Problem'}`; - return []; - }); - - if (error) { - throw error; - } - - return res; -}; diff --git a/src/lib/apis/modelfiles/index.ts b/src/lib/apis/models/index.ts similarity index 65% rename from src/lib/apis/modelfiles/index.ts rename to src/lib/apis/models/index.ts index 91af5e381..9faa358d3 100644 --- a/src/lib/apis/modelfiles/index.ts +++ b/src/lib/apis/models/index.ts @@ -1,18 +1,16 @@ import { WEBUI_API_BASE_URL } from '$lib/constants'; -export const createNewModelfile = async (token: string, modelfile: object) => { +export const addNewModel = async (token: string, model: object) => { let error = null; - const res = await fetch(`${WEBUI_API_BASE_URL}/modelfiles/create`, { + const res = await fetch(`${WEBUI_API_BASE_URL}/models/add`, { method: 'POST', headers: { Accept: 'application/json', 'Content-Type': 'application/json', authorization: `Bearer ${token}` }, - body: JSON.stringify({ - modelfile: modelfile - }) + body: JSON.stringify(model) }) .then(async (res) => { if (!res.ok) throw await res.json(); @@ -31,10 +29,10 @@ export const createNewModelfile = async (token: string, modelfile: object) => { return res; }; -export const getModelfiles = async (token: string = '') => { +export const getModelInfos = async (token: string = '') => { let error = null; - const res = await fetch(`${WEBUI_API_BASE_URL}/modelfiles/`, { + const res = await fetch(`${WEBUI_API_BASE_URL}/models`, { method: 'GET', headers: { Accept: 'application/json', @@ -59,62 +57,22 @@ export const getModelfiles = async (token: string = '') => { throw error; } - return res.map((modelfile) => modelfile.modelfile); + return res; }; -export const getModelfileByTagName = async (token: string, tagName: string) => { +export const getModelById = async (token: string, id: string) => { let error = null; - const res = await fetch(`${WEBUI_API_BASE_URL}/modelfiles/`, { - method: 'POST', + const searchParams = new URLSearchParams(); + searchParams.append('id', id); + + const res = await fetch(`${WEBUI_API_BASE_URL}/models?${searchParams.toString()}`, { + method: 'GET', headers: { Accept: 'application/json', 'Content-Type': 'application/json', authorization: `Bearer ${token}` - }, - body: JSON.stringify({ - tag_name: tagName - }) - }) - .then(async (res) => { - if (!res.ok) throw await res.json(); - return res.json(); - }) - .then((json) => { - return json; - }) - .catch((err) => { - error = err; - - console.log(err); - return null; - }); - - if (error) { - throw error; - } - - return res.modelfile; -}; - -export const updateModelfileByTagName = async ( - token: string, - tagName: string, - modelfile: object -) => { - let error = null; - - const res = await fetch(`${WEBUI_API_BASE_URL}/modelfiles/update`, { - method: 'POST', - headers: { - Accept: 'application/json', - 'Content-Type': 'application/json', - authorization: `Bearer ${token}` - }, - body: JSON.stringify({ - tag_name: tagName, - modelfile: modelfile - }) + } }) .then(async (res) => { if (!res.ok) throw await res.json(); @@ -137,19 +95,55 @@ export const updateModelfileByTagName = async ( return res; }; -export const deleteModelfileByTagName = async (token: string, tagName: string) => { +export const updateModelById = async (token: string, id: string, model: object) => { let error = null; - const res = await fetch(`${WEBUI_API_BASE_URL}/modelfiles/delete`, { + const searchParams = new URLSearchParams(); + searchParams.append('id', id); + + const res = await fetch(`${WEBUI_API_BASE_URL}/models/update?${searchParams.toString()}`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + }, + body: JSON.stringify(model) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err; + + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const deleteModelById = async (token: string, id: string) => { + let error = null; + + const searchParams = new URLSearchParams(); + searchParams.append('id', id); + + const res = await fetch(`${WEBUI_API_BASE_URL}/models/delete?${searchParams.toString()}`, { method: 'DELETE', headers: { Accept: 'application/json', 'Content-Type': 'application/json', authorization: `Bearer ${token}` - }, - body: JSON.stringify({ - tag_name: tagName - }) + } }) .then(async (res) => { if (!res.ok) throw await res.json(); diff --git a/src/lib/apis/ollama/index.ts b/src/lib/apis/ollama/index.ts index 7ecd65efe..efc3f0d0f 100644 --- a/src/lib/apis/ollama/index.ts +++ b/src/lib/apis/ollama/index.ts @@ -1,6 +1,73 @@ import { OLLAMA_API_BASE_URL } from '$lib/constants'; import { promptTemplate } from '$lib/utils'; +export const getOllamaConfig = async (token: string = '') => { + let error = null; + + const res = await fetch(`${OLLAMA_API_BASE_URL}/config`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + if ('detail' in err) { + error = err.detail; + } else { + error = 'Server connection failed'; + } + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const updateOllamaConfig = async (token: string = '', enable_ollama_api: boolean) => { + let error = null; + + const res = await fetch(`${OLLAMA_API_BASE_URL}/config/update`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + }, + body: JSON.stringify({ + enable_ollama_api: enable_ollama_api + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + if ('detail' in err) { + error = err.detail; + } else { + error = 'Server connection failed'; + } + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + export const getOllamaUrls = async (token: string = '') => { let error = null; @@ -97,7 +164,7 @@ export const getOllamaVersion = async (token: string = '') => { throw error; } - return res?.version ?? ''; + return res?.version ?? false; }; export const getOllamaModels = async (token: string = '') => { diff --git a/src/lib/apis/openai/index.ts b/src/lib/apis/openai/index.ts index 02281eff0..8afcec018 100644 --- a/src/lib/apis/openai/index.ts +++ b/src/lib/apis/openai/index.ts @@ -230,7 +230,12 @@ export const getOpenAIModels = async (token: string = '') => { return models ? models - .map((model) => ({ id: model.id, name: model.name ?? model.id, external: true })) + .map((model) => ({ + id: model.id, + name: model.name ?? model.id, + external: true, + custom_info: model.custom_info + })) .sort((a, b) => { return a.name.localeCompare(b.name); }) diff --git a/src/lib/apis/users/index.ts b/src/lib/apis/users/index.ts index 2d2bd386f..4c97b0036 100644 --- a/src/lib/apis/users/index.ts +++ b/src/lib/apis/users/index.ts @@ -115,6 +115,62 @@ export const getUsers = async (token: string) => { return res ? res : []; }; +export const getUserSettings = async (token: string) => { + let error = null; + const res = await fetch(`${WEBUI_API_BASE_URL}/users/user/settings`, { + method: 'GET', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + error = err.detail; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const updateUserSettings = async (token: string, settings: object) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/users/user/settings/update`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + ...settings + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + error = err.detail; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + export const getUserById = async (token: string, userId: string) => { let error = null; diff --git a/src/lib/components/admin/Settings/Banners.svelte b/src/lib/components/admin/Settings/Banners.svelte new file mode 100644 index 000000000..e69a8ebb1 --- /dev/null +++ b/src/lib/components/admin/Settings/Banners.svelte @@ -0,0 +1,137 @@ + + +
{ + updateBanners(); + saveHandler(); + }} +> +
+
+
+
+ {$i18n.t('Banners')} +
+ + +
+
+ {#each banners as banner, bannerIdx} +
+
+ + + + +
+ + + +
+
+ + +
+ {/each} +
+
+
+
+ +
+
diff --git a/src/lib/components/admin/Settings/Database.svelte b/src/lib/components/admin/Settings/Database.svelte index cde6bcaa4..3ae01a668 100644 --- a/src/lib/components/admin/Settings/Database.svelte +++ b/src/lib/components/admin/Settings/Database.svelte @@ -1,13 +1,24 @@ @@ -114,6 +137,47 @@ +
+
{$i18n.t('Enable Community Sharing')}
+ + +
+
diff --git a/src/lib/components/admin/Settings/Users.svelte b/src/lib/components/admin/Settings/Users.svelte index f2a8bb19a..44e38f40c 100644 --- a/src/lib/components/admin/Settings/Users.svelte +++ b/src/lib/components/admin/Settings/Users.svelte @@ -1,15 +1,19 @@ @@ -34,10 +39,13 @@ class="flex flex-col h-full justify-between space-y-3 text-sm" on:submit|preventDefault={async () => { // console.log('submit'); - await updateUserPermissions(localStorage.token, permissions); + await setDefaultModels(localStorage.token, defaultModelId); + await updateUserPermissions(localStorage.token, permissions); await updateModelFilterConfig(localStorage.token, whitelistEnabled, whitelistModels); saveHandler(); + + await config.set(await getBackendConfig()); }} >
@@ -88,26 +96,40 @@
-
+
{$i18n.t('Manage Models')}
+
+
+
+
{$i18n.t('Default Model')}
+
+
-
-
+
+ +
+
+ +
+
{$i18n.t('Model Whitelisting')}
- +
diff --git a/src/lib/components/admin/SettingsModal.svelte b/src/lib/components/admin/SettingsModal.svelte index 923ab576a..38a2602b6 100644 --- a/src/lib/components/admin/SettingsModal.svelte +++ b/src/lib/components/admin/SettingsModal.svelte @@ -6,6 +6,9 @@ import General from './Settings/General.svelte'; import Users from './Settings/Users.svelte'; + import Banners from '$lib/components/admin/Settings/Banners.svelte'; + import { toast } from 'svelte-sonner'; + const i18n = getContext('i18n'); export let show = false; @@ -117,24 +120,63 @@
{$i18n.t('Database')}
+ +
{#if selectedTab === 'general'} { show = false; + toast.success($i18n.t('Settings saved successfully!')); }} /> {:else if selectedTab === 'users'} { show = false; + toast.success($i18n.t('Settings saved successfully!')); }} /> {:else if selectedTab === 'db'} { show = false; + toast.success($i18n.t('Settings saved successfully!')); + }} + /> + {:else if selectedTab === 'banners'} + { + show = false; + toast.success($i18n.t('Settings saved successfully!')); }} /> {/if} diff --git a/src/lib/components/chat/Chat.svelte b/src/lib/components/chat/Chat.svelte new file mode 100644 index 000000000..d2050dd05 --- /dev/null +++ b/src/lib/components/chat/Chat.svelte @@ -0,0 +1,1090 @@ + + + + + {title + ? `${title.length > 30 ? `${title.slice(0, 30)}...` : title} | ${$WEBUI_NAME}` + : `${$WEBUI_NAME}`} + + + +{#if !chatIdProp || (loaded && chatIdProp)} +
+ 0} + {chat} + {initNewChat} + /> + + {#if $banners.length > 0 && !$chatId && selectedModels.length <= 1} +
+
+ {#each $banners.filter( (b) => (b.dismissible ? !JSON.parse(localStorage.getItem('dismissedBannerIds') ?? '[]').includes(b.id) : true) ) as banner} + { + const bannerId = e.detail; + + localStorage.setItem( + 'dismissedBannerIds', + JSON.stringify( + [ + bannerId, + ...JSON.parse(localStorage.getItem('dismissedBannerIds') ?? '[]') + ].filter((id) => $banners.find((b) => b.id === id)) + ) + ); + }} + /> + {/each} +
+
+ {/if} + +
+ +
+
+ + +{/if} diff --git a/src/lib/components/chat/MessageInput.svelte b/src/lib/components/chat/MessageInput.svelte index acf797cd1..afff1217d 100644 --- a/src/lib/components/chat/MessageInput.svelte +++ b/src/lib/components/chat/MessageInput.svelte @@ -1,7 +1,7 @@ -{#if code} -
-
-
{@html lang}
+
+
+
{@html lang}
-
- {#if lang.toLowerCase() === 'python' || lang.toLowerCase() === 'py' || (lang === '' && checkPythonCode(code))} - {#if executing} -
Running
- {:else} - - {/if} +
+ {#if lang.toLowerCase() === 'python' || lang.toLowerCase() === 'py' || (lang === '' && checkPythonCode(code))} + {#if executing} +
Running
+ {:else} + {/if} - -
+ {/if} +
- -
{@html highlightedCode || code}
- -
- - {#if executing} -
-
STDOUT/STDERR
-
Running...
-
- {:else if stdout || stderr || result} -
-
STDOUT/STDERR
-
{stdout || stderr || result}
-
- {/if}
-{/if} + +
{@html highlightedCode || code}
+ +
+ + {#if executing} +
+
STDOUT/STDERR
+
Running...
+
+ {:else if stdout || stderr || result} +
+
STDOUT/STDERR
+
{stdout || stderr || result}
+
+ {/if} +
diff --git a/src/lib/components/chat/Messages/CompareMessages.svelte b/src/lib/components/chat/Messages/CompareMessages.svelte index 60efdb2ab..f904a57ab 100644 --- a/src/lib/components/chat/Messages/CompareMessages.svelte +++ b/src/lib/components/chat/Messages/CompareMessages.svelte @@ -13,8 +13,6 @@ export let parentMessage; - export let selectedModelfiles; - export let updateChatMessages: Function; export let confirmEditResponseMessage: Function; export let rateMessage: Function; @@ -130,7 +128,6 @@ > m.id)} isLastMessage={true} {updateChatMessages} diff --git a/src/lib/components/chat/Messages/Placeholder.svelte b/src/lib/components/chat/Messages/Placeholder.svelte index dfb6cfb36..ed121dbe6 100644 --- a/src/lib/components/chat/Messages/Placeholder.svelte +++ b/src/lib/components/chat/Messages/Placeholder.svelte @@ -1,6 +1,6 @@ - -
-
-
{$i18n.t('Parameters')}
- - -
- -
-
-
{$i18n.t('Keep Alive')}
- - -
- - {#if keepAlive !== null} -
- -
- {/if} -
- -
-
-
{$i18n.t('Request Mode')}
- - -
-
-
- -
- -
-
diff --git a/src/lib/components/chat/Settings/Advanced/AdvancedParams.svelte b/src/lib/components/chat/Settings/Advanced/AdvancedParams.svelte index 6eaf82da8..93c482711 100644 --- a/src/lib/components/chat/Settings/Advanced/AdvancedParams.svelte +++ b/src/lib/components/chat/Settings/Advanced/AdvancedParams.svelte @@ -1,14 +1,16 @@ -
-
-
-
{$i18n.t('Seed')}
-
- -
+
+
+
+
{$i18n.t('Seed')}
+ +
+ + {#if (params?.seed ?? null) !== null} +
+
+ +
+
+ {/if}
-
-
-
{$i18n.t('Stop Sequence')}
-
- -
+
+
+
{$i18n.t('Stop Sequence')}
+ +
+ + {#if (params?.stop ?? null) !== null} +
+
+ +
+
+ {/if}
@@ -61,10 +109,10 @@ class="p-1 px-3 text-xs flex rounded transition" type="button" on:click={() => { - options.temperature = options.temperature === '' ? 0.8 : ''; + params.temperature = (params?.temperature ?? '') === '' ? 0.8 : ''; }} > - {#if options.temperature === ''} + {#if (params?.temperature ?? '') === ''} {$i18n.t('Default')} {:else} {$i18n.t('Custom')} @@ -72,7 +120,7 @@
- {#if options.temperature !== ''} + {#if (params?.temperature ?? '') !== ''}
{ - options.mirostat = options.mirostat === '' ? 0 : ''; + params.mirostat = (params?.mirostat ?? '') === '' ? 0 : ''; }} > - {#if options.mirostat === ''} + {#if (params?.mirostat ?? '') === ''} {$i18n.t('Default')} {:else} - {$i18n.t('Default')} + {$i18n.t('Custom')} {/if}
- {#if options.mirostat !== ''} + {#if (params?.mirostat ?? '') !== ''}
{ - options.mirostat_eta = options.mirostat_eta === '' ? 0.1 : ''; + params.mirostat_eta = (params?.mirostat_eta ?? '') === '' ? 0.1 : ''; }} > - {#if options.mirostat_eta === ''} + {#if (params?.mirostat_eta ?? '') === ''} {$i18n.t('Default')} {:else} - {$i18n.t('Default')} + {$i18n.t('Custom')} {/if}
- {#if options.mirostat_eta !== ''} + {#if (params?.mirostat_eta ?? '') !== ''}
{ - options.mirostat_tau = options.mirostat_tau === '' ? 5.0 : ''; + params.mirostat_tau = (params?.mirostat_tau ?? '') === '' ? 5.0 : ''; }} > - {#if options.mirostat_tau === ''} + {#if (params?.mirostat_tau ?? '') === ''} {$i18n.t('Default')} {:else} {$i18n.t('Custom')} @@ -210,7 +258,7 @@
- {#if options.mirostat_tau !== ''} + {#if (params?.mirostat_tau ?? '') !== ''}
{ - options.top_k = options.top_k === '' ? 40 : ''; + params.top_k = (params?.top_k ?? '') === '' ? 40 : ''; }} > - {#if options.top_k === ''} + {#if (params?.top_k ?? '') === ''} {$i18n.t('Default')} {:else} - {$i18n.t('Default')} + {$i18n.t('Custom')} {/if}
- {#if options.top_k !== ''} + {#if (params?.top_k ?? '') !== ''}
{ - options.top_p = options.top_p === '' ? 0.9 : ''; + params.top_p = (params?.top_p ?? '') === '' ? 0.9 : ''; }} > - {#if options.top_p === ''} + {#if (params?.top_p ?? '') === ''} {$i18n.t('Default')} {:else} - {$i18n.t('Default')} + {$i18n.t('Custom')} {/if}
- {#if options.top_p !== ''} + {#if (params?.top_p ?? '') !== ''}
-
{$i18n.t('Repeat Penalty')}
+
{$i18n.t('Frequencey Penalty')}
- {#if options.repeat_penalty !== ''} + {#if (params?.frequency_penalty ?? '') !== ''}
{ - options.repeat_last_n = options.repeat_last_n === '' ? 64 : ''; + params.repeat_last_n = (params?.repeat_last_n ?? '') === '' ? 64 : ''; }} > - {#if options.repeat_last_n === ''} + {#if (params?.repeat_last_n ?? '') === ''} {$i18n.t('Default')} {:else} - {$i18n.t('Default')} + {$i18n.t('Custom')} {/if}
- {#if options.repeat_last_n !== ''} + {#if (params?.repeat_last_n ?? '') !== ''}
{ - options.tfs_z = options.tfs_z === '' ? 1 : ''; + params.tfs_z = (params?.tfs_z ?? '') === '' ? 1 : ''; }} > - {#if options.tfs_z === ''} + {#if (params?.tfs_z ?? '') === ''} {$i18n.t('Default')} {:else} - {$i18n.t('Default')} + {$i18n.t('Custom')} {/if}
- {#if options.tfs_z !== ''} + {#if (params?.tfs_z ?? '') !== ''}
{ - options.num_ctx = options.num_ctx === '' ? 2048 : ''; + params.num_ctx = (params?.num_ctx ?? '') === '' ? 2048 : ''; }} > - {#if options.num_ctx === ''} + {#if (params?.num_ctx ?? '') === ''} {$i18n.t('Default')} {:else} - {$i18n.t('Default')} + {$i18n.t('Custom')} {/if}
- {#if options.num_ctx !== ''} + {#if (params?.num_ctx ?? '') !== ''}
-
{$i18n.t('Max Tokens')}
+
{$i18n.t('Max Tokens (num_predict)')}
- {#if options.num_predict !== ''} + {#if (params?.max_tokens ?? '') !== ''}
{/if}
+
+
+
{$i18n.t('Template')}
+ + +
+ + {#if (params?.template ?? null) !== null} +
+
+