Merge pull request #3559 from open-webui/dev

0.3.8
This commit is contained in:
Timothy Jaeryang Baek 2024-07-09 14:25:16 -07:00 committed by GitHub
commit 9bcd4ce5c0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
178 changed files with 6566 additions and 3314 deletions

View File

@ -10,7 +10,8 @@ node_modules
vite.config.js.timestamp-*
vite.config.ts.timestamp-*
__pycache__
.env
.idea
venv
_old
uploads
.ipynb_checkpoints

View File

@ -35,6 +35,10 @@ jobs:
done
echo "Service is up!"
- name: Delete Docker build cache
run: |
docker builder prune --all --force
- name: Preload Ollama model
run: |
docker exec ollama ollama pull qwen:0.5b-chat-v1.5-q2_K
@ -43,7 +47,7 @@ jobs:
uses: cypress-io/github-action@v6
with:
browser: chrome
wait-on: "http://localhost:3000"
wait-on: 'http://localhost:3000'
config: baseUrl=http://localhost:3000
- uses: actions/upload-artifact@v4
@ -67,6 +71,28 @@ jobs:
path: compose-logs.txt
if-no-files-found: ignore
# pytest:
# name: Run Backend Tests
# runs-on: ubuntu-latest
# steps:
# - uses: actions/checkout@v4
# - name: Set up Python
# uses: actions/setup-python@v4
# with:
# python-version: ${{ matrix.python-version }}
# - name: Install dependencies
# run: |
# python -m pip install --upgrade pip
# pip install -r backend/requirements.txt
# - name: pytest run
# run: |
# ls -al
# cd backend
# PYTHONPATH=. pytest . -o log_cli=true -o log_cli_level=INFO
migration_test:
name: Run Migration Tests
runs-on: ubuntu-latest
@ -126,11 +152,11 @@ jobs:
cd backend
uvicorn main:app --port "8080" --forwarded-allow-ips '*' &
UVICORN_PID=$!
# Wait up to 20 seconds for the server to start
for i in {1..20}; do
# Wait up to 40 seconds for the server to start
for i in {1..40}; do
curl -s http://localhost:8080/api/config > /dev/null && break
sleep 1
if [ $i -eq 20 ]; then
if [ $i -eq 40 ]; then
echo "Server failed to start"
kill -9 $UVICORN_PID
exit 1
@ -171,7 +197,7 @@ jobs:
fi
# Check that service will reconnect to postgres when connection will be closed
status_code=$(curl --write-out %{http_code} -s --output /dev/null http://localhost:8081/health)
status_code=$(curl --write-out %{http_code} -s --output /dev/null http://localhost:8081/health/db)
if [[ "$status_code" -ne 200 ]] ; then
echo "Server has failed before postgres reconnect check"
exit 1
@ -183,7 +209,7 @@ jobs:
cur = conn.cursor(); \
cur.execute('SELECT pg_terminate_backend(psa.pid) FROM pg_stat_activity psa WHERE datname = current_database() AND pid <> pg_backend_pid();')"
status_code=$(curl --write-out %{http_code} -s --output /dev/null http://localhost:8081/health)
status_code=$(curl --write-out %{http_code} -s --output /dev/null http://localhost:8081/health/db)
if [[ "$status_code" -ne 200 ]] ; then
echo "Server has not reconnected to postgres after connection was closed: returned status $status_code"
exit 1

1
.gitignore vendored
View File

@ -306,3 +306,4 @@ dist
# cypress artifacts
cypress/videos
cypress/screenshots
.vscode/settings.json

View File

@ -5,6 +5,33 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [0.3.8] - 2024-07-09
### Added
- **💬 Chat Controls**: Easily adjust parameters for each chat session, offering more precise control over your interactions.
- **📌 Pinned Chats**: Support for pinned chats, allowing you to keep important conversations easily accessible.
- **📄 Apache Tika Integration**: Added support for using Apache Tika as a document loader, enhancing document processing capabilities.
- **🛠️ Custom Environment for OpenID Claims**: Allows setting custom claims for OpenID, providing more flexibility in user authentication.
- **🔧 Enhanced Tools & Functions API**: Introduced 'event_emitter' and 'event_call', now you can also add citations for better documentation and tracking. Detailed documentation will be provided on our documentation website.
- **↔️ Sideways Scrolling in Settings**: Settings tabs container now supports horizontal scrolling for easier navigation.
- **🌑 Darker OLED Theme**: Includes a new, darker OLED theme and improved styling for the light theme, enhancing visual appeal.
- **🌐 Language Updates**: Updated translations for Indonesian, German, French, and Catalan languages, expanding accessibility.
### Fixed
- **⏰ OpenAI Streaming Timeout**: Resolved issues with OpenAI streaming response using the 'AIOHTTP_CLIENT_TIMEOUT' setting, ensuring reliable performance.
- **💡 User Valves**: Fixed malfunctioning user valves, ensuring proper functionality.
- **🔄 Collapsible Components**: Addressed issues with collapsible components not working, restoring expected behavior.
### Changed
- **🗃️ Database Backend**: Switched from Peewee to SQLAlchemy for improved concurrency support, enhancing database performance.
- **🔤 Primary Font Styling**: Updated primary font to Archivo for better visual consistency.
- **🔄 Font Change for Windows**: Replaced Arimo with Inter font for Windows users, improving readability.
- **🚀 Lazy Loading**: Implemented lazy loading for 'faster_whisper' and 'sentence_transformers' to reduce startup memory usage.
- **📋 Task Generation Payload**: Task generations now include only the "task" field in the body instead of "title".
## [0.3.7] - 2024-06-29
### Added

114
backend/alembic.ini Normal file
View File

@ -0,0 +1,114 @@
# A generic, single database configuration.
[alembic]
# path to migration scripts
script_location = migrations
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
# Uncomment the line below if you want the files to be prepended with date and time
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
# sys.path path, will be prepended to sys.path if present.
# defaults to the current working directory.
prepend_sys_path = .
# timezone to use when rendering the date within the migration file
# as well as the filename.
# If specified, requires the python>=3.9 or backports.zoneinfo library.
# Any required deps can installed by adding `alembic[tz]` to the pip requirements
# string value is passed to ZoneInfo()
# leave blank for localtime
# timezone =
# max length of characters to apply to the
# "slug" field
# truncate_slug_length = 40
# set to 'true' to run the environment during
# the 'revision' command, regardless of autogenerate
# revision_environment = false
# set to 'true' to allow .pyc and .pyo files without
# a source .py file to be detected as revisions in the
# versions/ directory
# sourceless = false
# version location specification; This defaults
# to migrations/versions. When using multiple version
# directories, initial revisions must be specified with --version-path.
# The path separator used here should be the separator specified by "version_path_separator" below.
# version_locations = %(here)s/bar:%(here)s/bat:migrations/versions
# version path separator; As mentioned above, this is the character used to split
# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep.
# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas.
# Valid values for version_path_separator are:
#
# version_path_separator = :
# version_path_separator = ;
# version_path_separator = space
version_path_separator = os # Use os.pathsep. Default configuration used for new projects.
# set to 'true' to search source files recursively
# in each "version_locations" directory
# new in Alembic version 1.10
# recursive_version_locations = false
# the output encoding used when revision files
# are written from script.py.mako
# output_encoding = utf-8
# sqlalchemy.url = REPLACE_WITH_DATABASE_URL
[post_write_hooks]
# post_write_hooks defines scripts or Python functions that are run
# on newly generated revision scripts. See the documentation for further
# detail and examples
# format using "black" - use the console_scripts runner, against the "black" entrypoint
# hooks = black
# black.type = console_scripts
# black.entrypoint = black
# black.options = -l 79 REVISION_SCRIPT_FILENAME
# lint with attempts to fix using "ruff" - use the exec runner, execute a binary
# hooks = ruff
# ruff.type = exec
# ruff.executable = %(here)s/.venv/bin/ruff
# ruff.options = --fix REVISION_SCRIPT_FILENAME
# Logging configuration
[loggers]
keys = root,sqlalchemy,alembic
[handlers]
keys = console
[formatters]
keys = generic
[logger_root]
level = WARN
handlers = console
qualname =
[logger_sqlalchemy]
level = WARN
handlers =
qualname = sqlalchemy.engine
[logger_alembic]
level = INFO
handlers =
qualname = alembic
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S

View File

@ -14,7 +14,6 @@ from fastapi import (
from fastapi.responses import StreamingResponse, JSONResponse, FileResponse
from fastapi.middleware.cors import CORSMiddleware
from faster_whisper import WhisperModel
from pydantic import BaseModel
import uuid
@ -277,6 +276,8 @@ def transcribe(
f.close()
if app.state.config.STT_ENGINE == "":
from faster_whisper import WhisperModel
whisper_kwargs = {
"model_size_or_path": WHISPER_MODEL,
"device": whisper_device_type,

View File

@ -12,7 +12,6 @@ from fastapi import (
Form,
)
from fastapi.middleware.cors import CORSMiddleware
from faster_whisper import WhisperModel
from constants import ERROR_MESSAGES
from utils.utils import (

View File

@ -25,6 +25,7 @@ from utils.task import prompt_template
from config import (
SRC_LOG_LEVELS,
ENABLE_OPENAI_API,
AIOHTTP_CLIENT_TIMEOUT,
OPENAI_API_BASE_URLS,
OPENAI_API_KEYS,
CACHE_DIR,
@ -463,7 +464,9 @@ async def generate_chat_completion(
streaming = False
try:
session = aiohttp.ClientSession(trust_env=True)
session = aiohttp.ClientSession(
trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
)
r = await session.request(
method="POST",
url=f"{url}/chat/completions",

View File

@ -48,8 +48,6 @@ import mimetypes
import uuid
import json
import sentence_transformers
from apps.webui.models.documents import (
Documents,
DocumentForm,
@ -93,6 +91,8 @@ from config import (
SRC_LOG_LEVELS,
UPLOAD_DIR,
DOCS_DIR,
CONTENT_EXTRACTION_ENGINE,
TIKA_SERVER_URL,
RAG_TOP_K,
RAG_RELEVANCE_THRESHOLD,
RAG_EMBEDDING_ENGINE,
@ -148,6 +148,9 @@ app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
)
app.state.config.CONTENT_EXTRACTION_ENGINE = CONTENT_EXTRACTION_ENGINE
app.state.config.TIKA_SERVER_URL = TIKA_SERVER_URL
app.state.config.CHUNK_SIZE = CHUNK_SIZE
app.state.config.CHUNK_OVERLAP = CHUNK_OVERLAP
@ -190,6 +193,8 @@ def update_embedding_model(
update_model: bool = False,
):
if embedding_model and app.state.config.RAG_EMBEDDING_ENGINE == "":
import sentence_transformers
app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer(
get_model_path(embedding_model, update_model),
device=DEVICE_TYPE,
@ -204,6 +209,8 @@ def update_reranking_model(
update_model: bool = False,
):
if reranking_model:
import sentence_transformers
app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder(
get_model_path(reranking_model, update_model),
device=DEVICE_TYPE,
@ -388,6 +395,10 @@ async def get_rag_config(user=Depends(get_admin_user)):
return {
"status": True,
"pdf_extract_images": app.state.config.PDF_EXTRACT_IMAGES,
"content_extraction": {
"engine": app.state.config.CONTENT_EXTRACTION_ENGINE,
"tika_server_url": app.state.config.TIKA_SERVER_URL,
},
"chunk": {
"chunk_size": app.state.config.CHUNK_SIZE,
"chunk_overlap": app.state.config.CHUNK_OVERLAP,
@ -417,6 +428,11 @@ async def get_rag_config(user=Depends(get_admin_user)):
}
class ContentExtractionConfig(BaseModel):
engine: str = ""
tika_server_url: Optional[str] = None
class ChunkParamUpdateForm(BaseModel):
chunk_size: int
chunk_overlap: int
@ -450,6 +466,7 @@ class WebConfig(BaseModel):
class ConfigUpdateForm(BaseModel):
pdf_extract_images: Optional[bool] = None
content_extraction: Optional[ContentExtractionConfig] = None
chunk: Optional[ChunkParamUpdateForm] = None
youtube: Optional[YoutubeLoaderConfig] = None
web: Optional[WebConfig] = None
@ -463,6 +480,11 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_
else app.state.config.PDF_EXTRACT_IMAGES
)
if form_data.content_extraction is not None:
log.info(f"Updating text settings: {form_data.content_extraction}")
app.state.config.CONTENT_EXTRACTION_ENGINE = form_data.content_extraction.engine
app.state.config.TIKA_SERVER_URL = form_data.content_extraction.tika_server_url
if form_data.chunk is not None:
app.state.config.CHUNK_SIZE = form_data.chunk.chunk_size
app.state.config.CHUNK_OVERLAP = form_data.chunk.chunk_overlap
@ -499,6 +521,10 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_
return {
"status": True,
"pdf_extract_images": app.state.config.PDF_EXTRACT_IMAGES,
"content_extraction": {
"engine": app.state.config.CONTENT_EXTRACTION_ENGINE,
"tika_server_url": app.state.config.TIKA_SERVER_URL,
},
"chunk": {
"chunk_size": app.state.config.CHUNK_SIZE,
"chunk_overlap": app.state.config.CHUNK_OVERLAP,
@ -978,13 +1004,49 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
return True
except Exception as e:
log.exception(e)
if e.__class__.__name__ == "UniqueConstraintError":
return True
log.exception(e)
return False
class TikaLoader:
def __init__(self, file_path, mime_type=None):
self.file_path = file_path
self.mime_type = mime_type
def load(self) -> List[Document]:
with open(self.file_path, "rb") as f:
data = f.read()
if self.mime_type is not None:
headers = {"Content-Type": self.mime_type}
else:
headers = {}
endpoint = app.state.config.TIKA_SERVER_URL
if not endpoint.endswith("/"):
endpoint += "/"
endpoint += "tika/text"
r = requests.put(endpoint, data=data, headers=headers)
if r.ok:
raw_metadata = r.json()
text = raw_metadata.get("X-TIKA:content", "<No text content found>")
if "Content-Type" in raw_metadata:
headers["Content-Type"] = raw_metadata["Content-Type"]
log.info("Tika extracted text: %s", text)
return [Document(page_content=text, metadata=headers)]
else:
raise Exception(f"Error calling Tika: {r.reason}")
def get_loader(filename: str, file_content_type: str, file_path: str):
file_ext = filename.split(".")[-1].lower()
known_type = True
@ -1035,6 +1097,17 @@ def get_loader(filename: str, file_content_type: str, file_path: str):
"msg",
]
if (
app.state.config.CONTENT_EXTRACTION_ENGINE == "tika"
and app.state.config.TIKA_SERVER_URL
):
if file_ext in known_source_ext or (
file_content_type and file_content_type.find("text/") >= 0
):
loader = TextLoader(file_path, autodetect_encoding=True)
else:
loader = TikaLoader(file_path, file_content_type)
else:
if file_ext == "pdf":
loader = PyPDFLoader(
file_path, extract_images=app.state.config.PDF_EXTRACT_IMAGES

View File

@ -294,15 +294,17 @@ def get_rag_context(
extracted_collections.extend(collection_names)
context_string = ""
contexts = []
citations = []
for context in relevant_contexts:
try:
if "documents" in context:
context_string += "\n\n".join(
contexts.append(
"\n\n".join(
[text for text in context["documents"][0] if text is not None]
)
)
if "metadatas" in context:
citations.append(
@ -315,9 +317,7 @@ def get_rag_context(
except Exception as e:
log.exception(e)
context_string = context_string.strip()
return context_string, citations
return contexts, citations
def get_model_path(model: str, update_model: bool = False):
@ -442,8 +442,6 @@ from langchain_core.documents import BaseDocumentCompressor, Document
from langchain_core.callbacks import Callbacks
from langchain_core.pydantic_v1 import Extra
from sentence_transformers import util
class RerankCompressor(BaseDocumentCompressor):
embedding_function: Any
@ -468,6 +466,8 @@ class RerankCompressor(BaseDocumentCompressor):
[(query, doc.page_content) for doc in documents]
)
else:
from sentence_transformers import util
query_embedding = self.embedding_function(query)
document_embedding = self.embedding_function(
[doc.page_content for doc in documents]

View File

@ -1,18 +1,39 @@
import os
import logging
import json
from contextlib import contextmanager
from peewee import *
from peewee_migrate import Router
from apps.webui.internal.wrappers import register_connection
from typing import Optional, Any
from typing_extensions import Self
from sqlalchemy import create_engine, types, Dialect
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, scoped_session
from sqlalchemy.sql.type_api import _T
from config import SRC_LOG_LEVELS, DATA_DIR, DATABASE_URL, BACKEND_DIR
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["DB"])
class JSONField(TextField):
class JSONField(types.TypeDecorator):
impl = types.Text
cache_ok = True
def process_bind_param(self, value: Optional[_T], dialect: Dialect) -> Any:
return json.dumps(value)
def process_result_value(self, value: Optional[_T], dialect: Dialect) -> Any:
if value is not None:
return json.loads(value)
def copy(self, **kw: Any) -> Self:
return JSONField(self.impl.length)
def db_value(self, value):
return json.dumps(value)
@ -30,25 +51,60 @@ else:
pass
# The `register_connection` function encapsulates the logic for setting up
# the database connection based on the connection string, while `connect`
# is a Peewee-specific method to manage the connection state and avoid errors
# when a connection is already open.
# Workaround to handle the peewee migration
# This is required to ensure the peewee migration is handled before the alembic migration
def handle_peewee_migration(DATABASE_URL):
try:
DB = register_connection(DATABASE_URL)
log.info(f"Connected to a {DB.__class__.__name__} database.")
# Replace the postgresql:// with postgres:// and %40 with @ in the DATABASE_URL
db = register_connection(
DATABASE_URL.replace("postgresql://", "postgres://").replace("%40", "@")
)
migrate_dir = BACKEND_DIR / "apps" / "webui" / "internal" / "migrations"
router = Router(db, logger=log, migrate_dir=migrate_dir)
router.run()
db.close()
# check if db connection has been closed
except Exception as e:
log.error(f"Failed to initialize the database connection: {e}")
raise
router = Router(
DB,
migrate_dir=BACKEND_DIR / "apps" / "webui" / "internal" / "migrations",
logger=log,
finally:
# Properly closing the database connection
if db and not db.is_closed():
db.close()
# Assert if db connection has been closed
assert db.is_closed(), "Database connection is still open."
handle_peewee_migration(DATABASE_URL)
SQLALCHEMY_DATABASE_URL = DATABASE_URL
if "sqlite" in SQLALCHEMY_DATABASE_URL:
engine = create_engine(
SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
)
router.run()
else:
engine = create_engine(SQLALCHEMY_DATABASE_URL, pool_pre_ping=True)
SessionLocal = sessionmaker(
autocommit=False, autoflush=False, bind=engine, expire_on_commit=False
)
Base = declarative_base()
Session = scoped_session(SessionLocal)
# Dependency
def get_session():
db = SessionLocal()
try:
DB.connect(reuse_if_open=True)
except OperationalError as e:
log.info(f"Failed to connect to database again due to: {e}")
pass
yield db
finally:
db.close()
get_db = contextmanager(get_session)

View File

@ -1,10 +1,7 @@
"""Peewee migrations -- 017_add_user_oauth_sub.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
@ -21,7 +18,6 @@ Some examples (model - class or model name)::
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress

View File

@ -1,21 +0,0 @@
# Database Migrations
This directory contains all the database migrations for the web app.
Migrations are done using the [`peewee-migrate`](https://github.com/klen/peewee_migrate) library.
Migrations are automatically ran at app startup.
## Creating a migration
Have you made a change to the schema of an existing model?
You will need to create a migration file to ensure that existing databases are updated for backwards compatibility.
1. Have a database file (`webui.db`) that has the old schema prior to any of your changes.
2. Make your changes to the models.
3. From the `backend` directory, run the following command:
```bash
pw_migrate create --auto --auto-source apps.webui.models --database sqlite:///${SQLITE_DB} --directory apps/web/internal/migrations ${MIGRATION_NAME}
```
- `$SQLITE_DB` should be the path to the database file.
- `$MIGRATION_NAME` should be a descriptive name for the migration.
4. The migration file will be created in the `apps/web/internal/migrations` directory.

View File

@ -3,7 +3,7 @@ from fastapi.routing import APIRoute
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from starlette.middleware.sessions import SessionMiddleware
from sqlalchemy.orm import Session
from apps.webui.routers import (
auths,
users,
@ -19,8 +19,13 @@ from apps.webui.routers import (
functions,
)
from apps.webui.models.functions import Functions
from apps.webui.models.models import Models
from apps.webui.utils import load_function_module_by_id
from utils.misc import stream_message_template
from utils.task import prompt_template
from config import (
WEBUI_BUILD_HASH,
@ -39,6 +44,8 @@ from config import (
WEBUI_BANNERS,
ENABLE_COMMUNITY_SHARING,
AppConfig,
OAUTH_USERNAME_CLAIM,
OAUTH_PICTURE_CLAIM,
)
import inspect
@ -74,6 +81,9 @@ app.state.config.BANNERS = WEBUI_BANNERS
app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING
app.state.config.OAUTH_USERNAME_CLAIM = OAUTH_USERNAME_CLAIM
app.state.config.OAUTH_PICTURE_CLAIM = OAUTH_PICTURE_CLAIM
app.state.MODELS = {}
app.state.TOOLS = {}
app.state.FUNCTIONS = {}
@ -129,7 +139,6 @@ async def get_pipe_models():
function_module = app.state.FUNCTIONS[pipe.id]
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
print(f"Getting valves for {pipe.id}")
valves = Functions.get_function_valves_by_id(pipe.id)
function_module.valves = function_module.Valves(
**(valves if valves else {})
@ -181,6 +190,77 @@ async def get_pipe_models():
async def generate_function_chat_completion(form_data, user):
model_id = form_data.get("model")
model_info = Models.get_model_by_id(model_id)
if model_info:
if model_info.base_model_id:
form_data["model"] = model_info.base_model_id
model_info.params = model_info.params.model_dump()
if model_info.params:
if model_info.params.get("temperature", None) is not None:
form_data["temperature"] = float(model_info.params.get("temperature"))
if model_info.params.get("top_p", None):
form_data["top_p"] = int(model_info.params.get("top_p", None))
if model_info.params.get("max_tokens", None):
form_data["max_tokens"] = int(model_info.params.get("max_tokens", None))
if model_info.params.get("frequency_penalty", None):
form_data["frequency_penalty"] = int(
model_info.params.get("frequency_penalty", None)
)
if model_info.params.get("seed", None):
form_data["seed"] = model_info.params.get("seed", None)
if model_info.params.get("stop", None):
form_data["stop"] = (
[
bytes(stop, "utf-8").decode("unicode_escape")
for stop in model_info.params["stop"]
]
if model_info.params.get("stop", None)
else None
)
system = model_info.params.get("system", None)
if system:
system = prompt_template(
system,
**(
{
"user_name": user.name,
"user_location": (
user.info.get("location") if user.info else None
),
}
if user
else {}
),
)
# Check if the payload already has a system message
# If not, add a system message to the payload
if form_data.get("messages"):
for message in form_data["messages"]:
if message.get("role") == "system":
message["content"] = system + message["content"]
break
else:
form_data["messages"].insert(
0,
{
"role": "system",
"content": system,
},
)
else:
pass
async def job():
pipe_id = form_data["model"]
if "." in pipe_id:
@ -259,6 +339,9 @@ async def generate_function_chat_completion(form_data, user):
if isinstance(line, BaseModel):
line = line.model_dump_json()
line = f"data: {line}"
if isinstance(line, dict):
line = f"data: {json.dumps(line)}"
try:
line = line.decode("utf-8")
except:

View File

@ -1,14 +1,13 @@
from pydantic import BaseModel
from typing import List, Union, Optional
import time
from typing import Optional
import uuid
import logging
from peewee import *
from sqlalchemy import String, Column, Boolean, Text
from apps.webui.models.users import UserModel, Users
from utils.utils import verify_password
from apps.webui.internal.db import DB
from apps.webui.internal.db import Base, get_db
from config import SRC_LOG_LEVELS
@ -20,14 +19,13 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
class Auth(Model):
id = CharField(unique=True)
email = CharField()
password = TextField()
active = BooleanField()
class Auth(Base):
__tablename__ = "auth"
class Meta:
database = DB
id = Column(String, primary_key=True)
email = Column(String)
password = Column(Text)
active = Column(Boolean)
class AuthModel(BaseModel):
@ -94,9 +92,6 @@ class AddUserForm(SignupForm):
class AuthsTable:
def __init__(self, db):
self.db = db
self.db.create_tables([Auth])
def insert_new_auth(
self,
@ -107,6 +102,8 @@ class AuthsTable:
role: str = "pending",
oauth_sub: Optional[str] = None,
) -> Optional[UserModel]:
with get_db() as db:
log.info("insert_new_auth")
id = str(uuid.uuid4())
@ -114,12 +111,16 @@ class AuthsTable:
auth = AuthModel(
**{"id": id, "email": email, "password": password, "active": True}
)
result = Auth.create(**auth.model_dump())
result = Auth(**auth.model_dump())
db.add(result)
user = Users.insert_new_user(
id, name, email, profile_image_url, role, oauth_sub
)
db.commit()
db.refresh(result)
if result and user:
return user
else:
@ -128,7 +129,9 @@ class AuthsTable:
def authenticate_user(self, email: str, password: str) -> Optional[UserModel]:
log.info(f"authenticate_user: {email}")
try:
auth = Auth.get(Auth.email == email, Auth.active == True)
with get_db() as db:
auth = db.query(Auth).filter_by(email=email, active=True).first()
if auth:
if verify_password(password, auth.password):
user = Users.get_user_by_id(auth.id)
@ -155,7 +158,8 @@ class AuthsTable:
def authenticate_user_by_trusted_header(self, email: str) -> Optional[UserModel]:
log.info(f"authenticate_user_by_trusted_header: {email}")
try:
auth = Auth.get(Auth.email == email, Auth.active == True)
with get_db() as db:
auth = db.query(Auth).filter(email=email, active=True).first()
if auth:
user = Users.get_user_by_id(auth.id)
return user
@ -164,31 +168,34 @@ class AuthsTable:
def update_user_password_by_id(self, id: str, new_password: str) -> bool:
try:
query = Auth.update(password=new_password).where(Auth.id == id)
result = query.execute()
with get_db() as db:
result = (
db.query(Auth).filter_by(id=id).update({"password": new_password})
)
db.commit()
return True if result == 1 else False
except:
return False
def update_email_by_id(self, id: str, email: str) -> bool:
try:
query = Auth.update(email=email).where(Auth.id == id)
result = query.execute()
with get_db() as db:
result = db.query(Auth).filter_by(id=id).update({"email": email})
db.commit()
return True if result == 1 else False
except:
return False
def delete_auth_by_id(self, id: str) -> bool:
try:
with get_db() as db:
# Delete User
result = Users.delete_user_by_id(id)
if result:
# Delete Auth
query = Auth.delete().where(Auth.id == id)
query.execute() # Remove the rows, return number of rows removed.
db.query(Auth).filter_by(id=id).delete()
db.commit()
return True
else:
@ -197,4 +204,4 @@ class AuthsTable:
return False
Auths = AuthsTable(DB)
Auths = AuthsTable()

View File

@ -1,36 +1,38 @@
from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict
from typing import List, Union, Optional
from peewee import *
from playhouse.shortcuts import model_to_dict
import json
import uuid
import time
from apps.webui.internal.db import DB
from sqlalchemy import Column, String, BigInteger, Boolean, Text
from apps.webui.internal.db import Base, get_db
####################
# Chat DB Schema
####################
class Chat(Model):
id = CharField(unique=True)
user_id = CharField()
title = TextField()
chat = TextField() # Save Chat JSON as Text
class Chat(Base):
__tablename__ = "chat"
created_at = BigIntegerField()
updated_at = BigIntegerField()
id = Column(String, primary_key=True)
user_id = Column(String)
title = Column(Text)
chat = Column(Text) # Save Chat JSON as Text
share_id = CharField(null=True, unique=True)
archived = BooleanField(default=False)
created_at = Column(BigInteger)
updated_at = Column(BigInteger)
class Meta:
database = DB
share_id = Column(Text, unique=True, nullable=True)
archived = Column(Boolean, default=False)
class ChatModel(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: str
user_id: str
title: str
@ -75,18 +77,19 @@ class ChatTitleIdResponse(BaseModel):
class ChatTable:
def __init__(self, db):
self.db = db
db.create_tables([Chat])
def insert_new_chat(self, user_id: str, form_data: ChatForm) -> Optional[ChatModel]:
with get_db() as db:
id = str(uuid.uuid4())
chat = ChatModel(
**{
"id": id,
"user_id": user_id,
"title": (
form_data.chat["title"] if "title" in form_data.chat else "New Chat"
form_data.chat["title"]
if "title" in form_data.chat
else "New Chat"
),
"chat": json.dumps(form_data.chat),
"created_at": int(time.time()),
@ -94,26 +97,32 @@ class ChatTable:
}
)
result = Chat.create(**chat.model_dump())
return chat if result else None
result = Chat(**chat.model_dump())
db.add(result)
db.commit()
db.refresh(result)
return ChatModel.model_validate(result) if result else None
def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]:
try:
query = Chat.update(
chat=json.dumps(chat),
title=chat["title"] if "title" in chat else "New Chat",
updated_at=int(time.time()),
).where(Chat.id == id)
query.execute()
with get_db() as db:
chat = Chat.get(Chat.id == id)
return ChatModel(**model_to_dict(chat))
except:
chat_obj = db.get(Chat, id)
chat_obj.chat = json.dumps(chat)
chat_obj.title = chat["title"] if "title" in chat else "New Chat"
chat_obj.updated_at = int(time.time())
db.commit()
db.refresh(chat_obj)
return ChatModel.model_validate(chat_obj)
except Exception as e:
return None
def insert_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
with get_db() as db:
# Get the existing chat to share
chat = Chat.get(Chat.id == chat_id)
chat = db.get(Chat, chat_id)
# Check if the chat is already shared
if chat.share_id:
return self.get_chat_by_id_and_user_id(chat.share_id, "shared")
@ -128,36 +137,42 @@ class ChatTable:
"updated_at": int(time.time()),
}
)
shared_result = Chat.create(**shared_chat.model_dump())
shared_result = Chat(**shared_chat.model_dump())
db.add(shared_result)
db.commit()
db.refresh(shared_result)
# Update the original chat with the share_id
result = (
Chat.update(share_id=shared_chat.id).where(Chat.id == chat_id).execute()
db.query(Chat)
.filter_by(id=chat_id)
.update({"share_id": shared_chat.id})
)
db.commit()
return shared_chat if (shared_result and result) else None
def update_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
try:
with get_db() as db:
print("update_shared_chat_by_id")
chat = Chat.get(Chat.id == chat_id)
chat = db.get(Chat, chat_id)
print(chat)
chat.title = chat.title
chat.chat = chat.chat
db.commit()
db.refresh(chat)
query = Chat.update(
title=chat.title,
chat=chat.chat,
).where(Chat.id == chat.share_id)
query.execute()
chat = Chat.get(Chat.id == chat.share_id)
return ChatModel(**model_to_dict(chat))
return self.get_chat_by_id(chat.share_id)
except:
return None
def delete_shared_chat_by_chat_id(self, chat_id: str) -> bool:
try:
query = Chat.delete().where(Chat.user_id == f"shared-{chat_id}")
query.execute() # Remove the rows, return number of rows removed.
with get_db() as db:
db.query(Chat).filter_by(user_id=f"shared-{chat_id}").delete()
db.commit()
return True
except:
@ -167,40 +182,33 @@ class ChatTable:
self, id: str, share_id: Optional[str]
) -> Optional[ChatModel]:
try:
query = Chat.update(
share_id=share_id,
).where(Chat.id == id)
query.execute()
with get_db() as db:
chat = Chat.get(Chat.id == id)
return ChatModel(**model_to_dict(chat))
chat = db.get(Chat, id)
chat.share_id = share_id
db.commit()
db.refresh(chat)
return ChatModel.model_validate(chat)
except:
return None
def toggle_chat_archive_by_id(self, id: str) -> Optional[ChatModel]:
try:
chat = self.get_chat_by_id(id)
query = Chat.update(
archived=(not chat.archived),
).where(Chat.id == id)
with get_db() as db:
query.execute()
chat = Chat.get(Chat.id == id)
return ChatModel(**model_to_dict(chat))
chat = db.get(Chat, id)
chat.archived = not chat.archived
db.commit()
db.refresh(chat)
return ChatModel.model_validate(chat)
except:
return None
def archive_all_chats_by_user_id(self, user_id: str) -> bool:
try:
chats = self.get_chats_by_user_id(user_id)
for chat in chats:
query = Chat.update(
archived=True,
).where(Chat.id == chat.id)
query.execute()
with get_db() as db:
db.query(Chat).filter_by(user_id=user_id).update({"archived": True})
db.commit()
return True
except:
return False
@ -208,15 +216,16 @@ class ChatTable:
def get_archived_chat_list_by_user_id(
self, user_id: str, skip: int = 0, limit: int = 50
) -> List[ChatModel]:
return [
ChatModel(**model_to_dict(chat))
for chat in Chat.select()
.where(Chat.archived == True)
.where(Chat.user_id == user_id)
with get_db() as db:
all_chats = (
db.query(Chat)
.filter_by(user_id=user_id, archived=True)
.order_by(Chat.updated_at.desc())
# .limit(limit)
# .offset(skip)
]
# .limit(limit).offset(skip)
.all()
)
return [ChatModel.model_validate(chat) for chat in all_chats]
def get_chat_list_by_user_id(
self,
@ -225,92 +234,97 @@ class ChatTable:
skip: int = 0,
limit: int = 50,
) -> List[ChatModel]:
if include_archived:
return [
ChatModel(**model_to_dict(chat))
for chat in Chat.select()
.where(Chat.user_id == user_id)
.order_by(Chat.updated_at.desc())
# .limit(limit)
# .offset(skip)
]
else:
return [
ChatModel(**model_to_dict(chat))
for chat in Chat.select()
.where(Chat.archived == False)
.where(Chat.user_id == user_id)
.order_by(Chat.updated_at.desc())
# .limit(limit)
# .offset(skip)
]
with get_db() as db:
query = db.query(Chat).filter_by(user_id=user_id)
if not include_archived:
query = query.filter_by(archived=False)
all_chats = (
query.order_by(Chat.updated_at.desc())
# .limit(limit).offset(skip)
.all()
)
return [ChatModel.model_validate(chat) for chat in all_chats]
def get_chat_list_by_chat_ids(
self, chat_ids: List[str], skip: int = 0, limit: int = 50
) -> List[ChatModel]:
return [
ChatModel(**model_to_dict(chat))
for chat in Chat.select()
.where(Chat.archived == False)
.where(Chat.id.in_(chat_ids))
with get_db() as db:
all_chats = (
db.query(Chat)
.filter(Chat.id.in_(chat_ids))
.filter_by(archived=False)
.order_by(Chat.updated_at.desc())
]
.all()
)
return [ChatModel.model_validate(chat) for chat in all_chats]
def get_chat_by_id(self, id: str) -> Optional[ChatModel]:
try:
chat = Chat.get(Chat.id == id)
return ChatModel(**model_to_dict(chat))
with get_db() as db:
chat = db.get(Chat, id)
return ChatModel.model_validate(chat)
except:
return None
def get_chat_by_share_id(self, id: str) -> Optional[ChatModel]:
try:
chat = Chat.get(Chat.share_id == id)
with get_db() as db:
chat = db.query(Chat).filter_by(share_id=id).first()
if chat:
chat = Chat.get(Chat.id == id)
return ChatModel(**model_to_dict(chat))
return self.get_chat_by_id(id)
else:
return None
except:
except Exception as e:
return None
def get_chat_by_id_and_user_id(self, id: str, user_id: str) -> Optional[ChatModel]:
try:
chat = Chat.get(Chat.id == id, Chat.user_id == user_id)
return ChatModel(**model_to_dict(chat))
with get_db() as db:
chat = db.query(Chat).filter_by(id=id, user_id=user_id).first()
return ChatModel.model_validate(chat)
except:
return None
def get_chats(self, skip: int = 0, limit: int = 50) -> List[ChatModel]:
return [
ChatModel(**model_to_dict(chat))
for chat in Chat.select().order_by(Chat.updated_at.desc())
with get_db() as db:
all_chats = (
db.query(Chat)
# .limit(limit).offset(skip)
]
.order_by(Chat.updated_at.desc())
)
return [ChatModel.model_validate(chat) for chat in all_chats]
def get_chats_by_user_id(self, user_id: str) -> List[ChatModel]:
return [
ChatModel(**model_to_dict(chat))
for chat in Chat.select()
.where(Chat.user_id == user_id)
with get_db() as db:
all_chats = (
db.query(Chat)
.filter_by(user_id=user_id)
.order_by(Chat.updated_at.desc())
# .limit(limit).offset(skip)
]
)
return [ChatModel.model_validate(chat) for chat in all_chats]
def get_archived_chats_by_user_id(self, user_id: str) -> List[ChatModel]:
return [
ChatModel(**model_to_dict(chat))
for chat in Chat.select()
.where(Chat.archived == True)
.where(Chat.user_id == user_id)
with get_db() as db:
all_chats = (
db.query(Chat)
.filter_by(user_id=user_id, archived=True)
.order_by(Chat.updated_at.desc())
]
)
return [ChatModel.model_validate(chat) for chat in all_chats]
def delete_chat_by_id(self, id: str) -> bool:
try:
query = Chat.delete().where((Chat.id == id))
query.execute() # Remove the rows, return number of rows removed.
with get_db() as db:
db.query(Chat).filter_by(id=id).delete()
db.commit()
return True and self.delete_shared_chat_by_chat_id(id)
except:
@ -318,8 +332,10 @@ class ChatTable:
def delete_chat_by_id_and_user_id(self, id: str, user_id: str) -> bool:
try:
query = Chat.delete().where((Chat.id == id) & (Chat.user_id == user_id))
query.execute() # Remove the rows, return number of rows removed.
with get_db() as db:
db.query(Chat).filter_by(id=id, user_id=user_id).delete()
db.commit()
return True and self.delete_shared_chat_by_chat_id(id)
except:
@ -328,10 +344,12 @@ class ChatTable:
def delete_chats_by_user_id(self, user_id: str) -> bool:
try:
with get_db() as db:
self.delete_shared_chats_by_user_id(user_id)
query = Chat.delete().where(Chat.user_id == user_id)
query.execute() # Remove the rows, return number of rows removed.
db.query(Chat).filter_by(user_id=user_id).delete()
db.commit()
return True
except:
@ -339,17 +357,18 @@ class ChatTable:
def delete_shared_chats_by_user_id(self, user_id: str) -> bool:
try:
shared_chat_ids = [
f"shared-{chat.id}"
for chat in Chat.select().where(Chat.user_id == user_id)
]
query = Chat.delete().where(Chat.user_id << shared_chat_ids)
query.execute() # Remove the rows, return number of rows removed.
with get_db() as db:
chats_by_user = db.query(Chat).filter_by(user_id=user_id).all()
shared_chat_ids = [f"shared-{chat.id}" for chat in chats_by_user]
db.query(Chat).filter(Chat.user_id.in_(shared_chat_ids)).delete()
db.commit()
return True
except:
return False
Chats = ChatTable(DB)
Chats = ChatTable()

View File

@ -1,14 +1,11 @@
from pydantic import BaseModel
from peewee import *
from playhouse.shortcuts import model_to_dict
from typing import List, Union, Optional
from pydantic import BaseModel, ConfigDict
from typing import List, Optional
import time
import logging
from utils.utils import decode_token
from utils.misc import get_gravatar_url
from sqlalchemy import String, Column, BigInteger, Text
from apps.webui.internal.db import DB
from apps.webui.internal.db import Base, get_db
import json
@ -22,20 +19,21 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
class Document(Model):
collection_name = CharField(unique=True)
name = CharField(unique=True)
title = TextField()
filename = TextField()
content = TextField(null=True)
user_id = CharField()
timestamp = BigIntegerField()
class Document(Base):
__tablename__ = "document"
class Meta:
database = DB
collection_name = Column(String, primary_key=True)
name = Column(String, unique=True)
title = Column(Text)
filename = Column(Text)
content = Column(Text, nullable=True)
user_id = Column(String)
timestamp = Column(BigInteger)
class DocumentModel(BaseModel):
model_config = ConfigDict(from_attributes=True)
collection_name: str
name: str
title: str
@ -72,13 +70,12 @@ class DocumentForm(DocumentUpdateForm):
class DocumentsTable:
def __init__(self, db):
self.db = db
self.db.create_tables([Document])
def insert_new_doc(
self, user_id: str, form_data: DocumentForm
) -> Optional[DocumentModel]:
with get_db() as db:
document = DocumentModel(
**{
**form_data.model_dump(),
@ -88,9 +85,12 @@ class DocumentsTable:
)
try:
result = Document.create(**document.model_dump())
result = Document(**document.model_dump())
db.add(result)
db.commit()
db.refresh(result)
if result:
return document
return DocumentModel.model_validate(result)
else:
return None
except:
@ -98,31 +98,35 @@ class DocumentsTable:
def get_doc_by_name(self, name: str) -> Optional[DocumentModel]:
try:
document = Document.get(Document.name == name)
return DocumentModel(**model_to_dict(document))
with get_db() as db:
document = db.query(Document).filter_by(name=name).first()
return DocumentModel.model_validate(document) if document else None
except:
return None
def get_docs(self) -> List[DocumentModel]:
with get_db() as db:
return [
DocumentModel(**model_to_dict(doc))
for doc in Document.select()
# .limit(limit).offset(skip)
DocumentModel.model_validate(doc) for doc in db.query(Document).all()
]
def update_doc_by_name(
self, name: str, form_data: DocumentUpdateForm
) -> Optional[DocumentModel]:
try:
query = Document.update(
title=form_data.title,
name=form_data.name,
timestamp=int(time.time()),
).where(Document.name == name)
query.execute()
with get_db() as db:
doc = Document.get(Document.name == form_data.name)
return DocumentModel(**model_to_dict(doc))
db.query(Document).filter_by(name=name).update(
{
"title": form_data.title,
"name": form_data.name,
"timestamp": int(time.time()),
}
)
db.commit()
return self.get_doc_by_name(form_data.name)
except Exception as e:
log.exception(e)
return None
@ -135,26 +139,29 @@ class DocumentsTable:
doc_content = json.loads(doc.content if doc.content else "{}")
doc_content = {**doc_content, **updated}
query = Document.update(
content=json.dumps(doc_content),
timestamp=int(time.time()),
).where(Document.name == name)
query.execute()
with get_db() as db:
doc = Document.get(Document.name == name)
return DocumentModel(**model_to_dict(doc))
db.query(Document).filter_by(name=name).update(
{
"content": json.dumps(doc_content),
"timestamp": int(time.time()),
}
)
db.commit()
return self.get_doc_by_name(name)
except Exception as e:
log.exception(e)
return None
def delete_doc_by_name(self, name: str) -> bool:
try:
query = Document.delete().where((Document.name == name))
query.execute() # Remove the rows, return number of rows removed.
with get_db() as db:
db.query(Document).filter_by(name=name).delete()
db.commit()
return True
except:
return False
Documents = DocumentsTable(DB)
Documents = DocumentsTable()

View File

@ -1,10 +1,11 @@
from pydantic import BaseModel
from peewee import *
from playhouse.shortcuts import model_to_dict
from pydantic import BaseModel, ConfigDict
from typing import List, Union, Optional
import time
import logging
from apps.webui.internal.db import DB, JSONField
from sqlalchemy import Column, String, BigInteger, Text
from apps.webui.internal.db import JSONField, Base, get_db
import json
@ -18,15 +19,14 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
class File(Model):
id = CharField(unique=True)
user_id = CharField()
filename = TextField()
meta = JSONField()
created_at = BigIntegerField()
class File(Base):
__tablename__ = "file"
class Meta:
database = DB
id = Column(String, primary_key=True)
user_id = Column(String)
filename = Column(Text)
meta = Column(JSONField)
created_at = Column(BigInteger)
class FileModel(BaseModel):
@ -36,6 +36,8 @@ class FileModel(BaseModel):
meta: dict
created_at: int # timestamp in epoch
model_config = ConfigDict(from_attributes=True)
####################
# Forms
@ -57,11 +59,10 @@ class FileForm(BaseModel):
class FilesTable:
def __init__(self, db):
self.db = db
self.db.create_tables([File])
def insert_new_file(self, user_id: str, form_data: FileForm) -> Optional[FileModel]:
with get_db() as db:
file = FileModel(
**{
**form_data.model_dump(),
@ -71,9 +72,12 @@ class FilesTable:
)
try:
result = File.create(**file.model_dump())
result = File(**file.model_dump())
db.add(result)
db.commit()
db.refresh(result)
if result:
return file
return FileModel.model_validate(result)
else:
return None
except Exception as e:
@ -81,32 +85,42 @@ class FilesTable:
return None
def get_file_by_id(self, id: str) -> Optional[FileModel]:
with get_db() as db:
try:
file = File.get(File.id == id)
return FileModel(**model_to_dict(file))
file = db.get(File, id)
return FileModel.model_validate(file)
except:
return None
def get_files(self) -> List[FileModel]:
return [FileModel(**model_to_dict(file)) for file in File.select()]
with get_db() as db:
return [FileModel.model_validate(file) for file in db.query(File).all()]
def delete_file_by_id(self, id: str) -> bool:
with get_db() as db:
try:
query = File.delete().where((File.id == id))
query.execute() # Remove the rows, return number of rows removed.
db.query(File).filter_by(id=id).delete()
db.commit()
return True
except:
return False
def delete_all_files(self) -> bool:
with get_db() as db:
try:
query = File.delete()
query.execute() # Remove the rows, return number of rows removed.
db.query(File).delete()
db.commit()
return True
except:
return False
Files = FilesTable(DB)
Files = FilesTable()

View File

@ -1,10 +1,11 @@
from pydantic import BaseModel
from peewee import *
from playhouse.shortcuts import model_to_dict
from pydantic import BaseModel, ConfigDict
from typing import List, Union, Optional
import time
import logging
from apps.webui.internal.db import DB, JSONField
from sqlalchemy import Column, String, Text, BigInteger, Boolean
from apps.webui.internal.db import JSONField, Base, get_db
from apps.webui.models.users import Users
import json
@ -21,21 +22,20 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
class Function(Model):
id = CharField(unique=True)
user_id = CharField()
name = TextField()
type = TextField()
content = TextField()
meta = JSONField()
valves = JSONField()
is_active = BooleanField(default=False)
is_global = BooleanField(default=False)
updated_at = BigIntegerField()
created_at = BigIntegerField()
class Function(Base):
__tablename__ = "function"
class Meta:
database = DB
id = Column(String, primary_key=True)
user_id = Column(String)
name = Column(Text)
type = Column(Text)
content = Column(Text)
meta = Column(JSONField)
valves = Column(JSONField)
is_active = Column(Boolean)
is_global = Column(Boolean)
updated_at = Column(BigInteger)
created_at = Column(BigInteger)
class FunctionMeta(BaseModel):
@ -55,6 +55,8 @@ class FunctionModel(BaseModel):
updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch
model_config = ConfigDict(from_attributes=True)
####################
# Forms
@ -85,13 +87,11 @@ class FunctionValves(BaseModel):
class FunctionsTable:
def __init__(self, db):
self.db = db
self.db.create_tables([Function])
def insert_new_function(
self, user_id: str, type: str, form_data: FunctionForm
) -> Optional[FunctionModel]:
function = FunctionModel(
**{
**form_data.model_dump(),
@ -103,9 +103,13 @@ class FunctionsTable:
)
try:
result = Function.create(**function.model_dump())
with get_db() as db:
result = Function(**function.model_dump())
db.add(result)
db.commit()
db.refresh(result)
if result:
return function
return FunctionModel.model_validate(result)
else:
return None
except Exception as e:
@ -114,52 +118,60 @@ class FunctionsTable:
def get_function_by_id(self, id: str) -> Optional[FunctionModel]:
try:
function = Function.get(Function.id == id)
return FunctionModel(**model_to_dict(function))
with get_db() as db:
function = db.get(Function, id)
return FunctionModel.model_validate(function)
except:
return None
def get_functions(self, active_only=False) -> List[FunctionModel]:
with get_db() as db:
if active_only:
return [
FunctionModel(**model_to_dict(function))
for function in Function.select().where(Function.is_active == True)
FunctionModel.model_validate(function)
for function in db.query(Function).filter_by(is_active=True).all()
]
else:
return [
FunctionModel(**model_to_dict(function))
for function in Function.select()
FunctionModel.model_validate(function)
for function in db.query(Function).all()
]
def get_functions_by_type(
self, type: str, active_only=False
) -> List[FunctionModel]:
with get_db() as db:
if active_only:
return [
FunctionModel(**model_to_dict(function))
for function in Function.select().where(
Function.type == type, Function.is_active == True
)
FunctionModel.model_validate(function)
for function in db.query(Function)
.filter_by(type=type, is_active=True)
.all()
]
else:
return [
FunctionModel(**model_to_dict(function))
for function in Function.select().where(Function.type == type)
FunctionModel.model_validate(function)
for function in db.query(Function).filter_by(type=type).all()
]
def get_global_filter_functions(self) -> List[FunctionModel]:
with get_db() as db:
return [
FunctionModel(**model_to_dict(function))
for function in Function.select().where(
Function.type == "filter",
Function.is_active == True,
Function.is_global == True,
)
FunctionModel.model_validate(function)
for function in db.query(Function)
.filter_by(type="filter", is_active=True, is_global=True)
.all()
]
def get_function_valves_by_id(self, id: str) -> Optional[dict]:
with get_db() as db:
try:
function = Function.get(Function.id == id)
function = db.get(Function, id)
return function.valves if function.valves else {}
except Exception as e:
print(f"An error occurred: {e}")
@ -168,24 +180,25 @@ class FunctionsTable:
def update_function_valves_by_id(
self, id: str, valves: dict
) -> Optional[FunctionValves]:
try:
query = Function.update(
**{"valves": valves},
updated_at=int(time.time()),
).where(Function.id == id)
query.execute()
with get_db() as db:
function = Function.get(Function.id == id)
return FunctionValves(**model_to_dict(function))
try:
function = db.get(Function, id)
function.valves = valves
function.updated_at = int(time.time())
db.commit()
db.refresh(function)
return self.get_function_by_id(id)
except:
return None
def get_user_valves_by_id_and_user_id(
self, id: str, user_id: str
) -> Optional[dict]:
try:
user = Users.get_user_by_id(user_id)
user_settings = user.settings.model_dump()
user_settings = user.settings.model_dump() if user.settings else {}
# Check if user has "functions" and "valves" settings
if "functions" not in user_settings:
@ -201,9 +214,10 @@ class FunctionsTable:
def update_user_valves_by_id_and_user_id(
self, id: str, user_id: str, valves: dict
) -> Optional[dict]:
try:
user = Users.get_user_by_id(user_id)
user_settings = user.settings.model_dump()
user_settings = user.settings.model_dump() if user.settings else {}
# Check if user has "functions" and "valves" settings
if "functions" not in user_settings:
@ -214,8 +228,7 @@ class FunctionsTable:
user_settings["functions"]["valves"][id] = valves
# Update the user settings in the database
query = Users.update_user_by_id(user_id, {"settings": user_settings})
query.execute()
Users.update_user_by_id(user_id, {"settings": user_settings})
return user_settings["functions"]["valves"][id]
except Exception as e:
@ -223,39 +236,44 @@ class FunctionsTable:
return None
def update_function_by_id(self, id: str, updated: dict) -> Optional[FunctionModel]:
try:
query = Function.update(
**updated,
updated_at=int(time.time()),
).where(Function.id == id)
query.execute()
with get_db() as db:
function = Function.get(Function.id == id)
return FunctionModel(**model_to_dict(function))
try:
db.query(Function).filter_by(id=id).update(
{
**updated,
"updated_at": int(time.time()),
}
)
db.commit()
return self.get_function_by_id(id)
except:
return None
def deactivate_all_functions(self) -> Optional[bool]:
with get_db() as db:
try:
query = Function.update(
**{"is_active": False},
updated_at=int(time.time()),
db.query(Function).update(
{
"is_active": False,
"updated_at": int(time.time()),
}
)
query.execute()
db.commit()
return True
except:
return None
def delete_function_by_id(self, id: str) -> bool:
with get_db() as db:
try:
query = Function.delete().where((Function.id == id))
query.execute() # Remove the rows, return number of rows removed.
db.query(Function).filter_by(id=id).delete()
db.commit()
return True
except:
return False
Functions = FunctionsTable(DB)
Functions = FunctionsTable()

View File

@ -1,10 +1,9 @@
from pydantic import BaseModel
from peewee import *
from playhouse.shortcuts import model_to_dict
from pydantic import BaseModel, ConfigDict
from typing import List, Union, Optional
from apps.webui.internal.db import DB
from apps.webui.models.chats import Chats
from sqlalchemy import Column, String, BigInteger, Text
from apps.webui.internal.db import Base, get_db
import time
import uuid
@ -14,15 +13,14 @@ import uuid
####################
class Memory(Model):
id = CharField(unique=True)
user_id = CharField()
content = TextField()
updated_at = BigIntegerField()
created_at = BigIntegerField()
class Memory(Base):
__tablename__ = "memory"
class Meta:
database = DB
id = Column(String, primary_key=True)
user_id = Column(String)
content = Column(Text)
updated_at = Column(BigInteger)
created_at = Column(BigInteger)
class MemoryModel(BaseModel):
@ -32,6 +30,8 @@ class MemoryModel(BaseModel):
updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch
model_config = ConfigDict(from_attributes=True)
####################
# Forms
@ -39,15 +39,14 @@ class MemoryModel(BaseModel):
class MemoriesTable:
def __init__(self, db):
self.db = db
self.db.create_tables([Memory])
def insert_new_memory(
self,
user_id: str,
content: str,
) -> Optional[MemoryModel]:
with get_db() as db:
id = str(uuid.uuid4())
memory = MemoryModel(
@ -59,9 +58,12 @@ class MemoriesTable:
"updated_at": int(time.time()),
}
)
result = Memory.create(**memory.model_dump())
result = Memory(**memory.model_dump())
db.add(result)
db.commit()
db.refresh(result)
if result:
return memory
return MemoryModel.model_validate(result)
else:
return None
@ -70,40 +72,50 @@ class MemoriesTable:
id: str,
content: str,
) -> Optional[MemoryModel]:
with get_db() as db:
try:
memory = Memory.get(Memory.id == id)
memory.content = content
memory.updated_at = int(time.time())
memory.save()
return MemoryModel(**model_to_dict(memory))
db.query(Memory).filter_by(id=id).update(
{"content": content, "updated_at": int(time.time())}
)
db.commit()
return self.get_memory_by_id(id)
except:
return None
def get_memories(self) -> List[MemoryModel]:
with get_db() as db:
try:
memories = Memory.select()
return [MemoryModel(**model_to_dict(memory)) for memory in memories]
memories = db.query(Memory).all()
return [MemoryModel.model_validate(memory) for memory in memories]
except:
return None
def get_memories_by_user_id(self, user_id: str) -> List[MemoryModel]:
with get_db() as db:
try:
memories = Memory.select().where(Memory.user_id == user_id)
return [MemoryModel(**model_to_dict(memory)) for memory in memories]
memories = db.query(Memory).filter_by(user_id=user_id).all()
return [MemoryModel.model_validate(memory) for memory in memories]
except:
return None
def get_memory_by_id(self, id) -> Optional[MemoryModel]:
def get_memory_by_id(self, id: str) -> Optional[MemoryModel]:
with get_db() as db:
try:
memory = Memory.get(Memory.id == id)
return MemoryModel(**model_to_dict(memory))
memory = db.get(Memory, id)
return MemoryModel.model_validate(memory)
except:
return None
def delete_memory_by_id(self, id: str) -> bool:
with get_db() as db:
try:
query = Memory.delete().where(Memory.id == id)
query.execute() # Remove the rows, return number of rows removed.
db.query(Memory).filter_by(id=id).delete()
db.commit()
return True
@ -111,22 +123,26 @@ class MemoriesTable:
return False
def delete_memories_by_user_id(self, user_id: str) -> bool:
with get_db() as db:
try:
query = Memory.delete().where(Memory.user_id == user_id)
query.execute()
db.query(Memory).filter_by(user_id=user_id).delete()
db.commit()
return True
except:
return False
def delete_memory_by_id_and_user_id(self, id: str, user_id: str) -> bool:
with get_db() as db:
try:
query = Memory.delete().where(Memory.id == id, Memory.user_id == user_id)
query.execute()
db.query(Memory).filter_by(id=id, user_id=user_id).delete()
db.commit()
return True
except:
return False
Memories = MemoriesTable(DB)
Memories = MemoriesTable()

View File

@ -2,13 +2,10 @@ import json
import logging
from typing import Optional
import peewee as pw
from peewee import *
from playhouse.shortcuts import model_to_dict
from pydantic import BaseModel, ConfigDict
from sqlalchemy import String, Column, BigInteger, Text
from apps.webui.internal.db import DB, JSONField
from apps.webui.internal.db import Base, JSONField, get_db
from typing import List, Union, Optional
from config import SRC_LOG_LEVELS
@ -32,7 +29,7 @@ class ModelParams(BaseModel):
# ModelMeta is a model for the data stored in the meta field of the Model table
class ModelMeta(BaseModel):
profile_image_url: Optional[str] = "/favicon.png"
profile_image_url: Optional[str] = "/static/favicon.png"
description: Optional[str] = None
"""
@ -46,38 +43,37 @@ class ModelMeta(BaseModel):
pass
class Model(pw.Model):
id = pw.TextField(unique=True)
class Model(Base):
__tablename__ = "model"
id = Column(Text, primary_key=True)
"""
The model's id as used in the API. If set to an existing model, it will override the model.
"""
user_id = pw.TextField()
user_id = Column(Text)
base_model_id = pw.TextField(null=True)
base_model_id = Column(Text, nullable=True)
"""
An optional pointer to the actual model that should be used when proxying requests.
"""
name = pw.TextField()
name = Column(Text)
"""
The human-readable display name of the model.
"""
params = JSONField()
params = Column(JSONField)
"""
Holds a JSON encoded blob of parameters, see `ModelParams`.
"""
meta = JSONField()
meta = Column(JSONField)
"""
Holds a JSON encoded blob of metadata, see `ModelMeta`.
"""
updated_at = BigIntegerField()
created_at = BigIntegerField()
class Meta:
database = DB
updated_at = Column(BigInteger)
created_at = Column(BigInteger)
class ModelModel(BaseModel):
@ -92,6 +88,8 @@ class ModelModel(BaseModel):
updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch
model_config = ConfigDict(from_attributes=True)
####################
# Forms
@ -115,12 +113,6 @@ class ModelForm(BaseModel):
class ModelsTable:
def __init__(
self,
db: pw.SqliteDatabase | pw.PostgresqlDatabase,
):
self.db = db
self.db.create_tables([Model])
def insert_new_model(
self, form_data: ModelForm, user_id: str
@ -134,10 +126,16 @@ class ModelsTable:
}
)
try:
result = Model.create(**model.model_dump())
with get_db() as db:
result = Model(**model.model_dump())
db.add(result)
db.commit()
db.refresh(result)
if result:
return model
return ModelModel.model_validate(result)
else:
return None
except Exception as e:
@ -145,23 +143,33 @@ class ModelsTable:
return None
def get_all_models(self) -> List[ModelModel]:
return [ModelModel(**model_to_dict(model)) for model in Model.select()]
with get_db() as db:
return [ModelModel.model_validate(model) for model in db.query(Model).all()]
def get_model_by_id(self, id: str) -> Optional[ModelModel]:
try:
model = Model.get(Model.id == id)
return ModelModel(**model_to_dict(model))
with get_db() as db:
model = db.get(Model, id)
return ModelModel.model_validate(model)
except:
return None
def update_model_by_id(self, id: str, model: ModelForm) -> Optional[ModelModel]:
try:
with get_db() as db:
# update only the fields that are present in the model
query = Model.update(**model.model_dump()).where(Model.id == id)
query.execute()
result = (
db.query(Model)
.filter_by(id=id)
.update(model.model_dump(exclude={"id"}, exclude_none=True))
)
db.commit()
model = Model.get(Model.id == id)
return ModelModel(**model_to_dict(model))
model = db.get(Model, id)
db.refresh(model)
return ModelModel.model_validate(model)
except Exception as e:
print(e)
@ -169,11 +177,14 @@ class ModelsTable:
def delete_model_by_id(self, id: str) -> bool:
try:
query = Model.delete().where(Model.id == id)
query.execute()
with get_db() as db:
db.query(Model).filter_by(id=id).delete()
db.commit()
return True
except:
return False
Models = ModelsTable(DB)
Models = ModelsTable()

View File

@ -1,13 +1,10 @@
from pydantic import BaseModel
from peewee import *
from playhouse.shortcuts import model_to_dict
from typing import List, Union, Optional
from pydantic import BaseModel, ConfigDict
from typing import List, Optional
import time
from utils.utils import decode_token
from utils.misc import get_gravatar_url
from sqlalchemy import String, Column, BigInteger, Text
from apps.webui.internal.db import DB
from apps.webui.internal.db import Base, get_db
import json
@ -16,15 +13,14 @@ import json
####################
class Prompt(Model):
command = CharField(unique=True)
user_id = CharField()
title = TextField()
content = TextField()
timestamp = BigIntegerField()
class Prompt(Base):
__tablename__ = "prompt"
class Meta:
database = DB
command = Column(String, primary_key=True)
user_id = Column(String)
title = Column(Text)
content = Column(Text)
timestamp = Column(BigInteger)
class PromptModel(BaseModel):
@ -34,6 +30,8 @@ class PromptModel(BaseModel):
content: str
timestamp: int # timestamp in epoch
model_config = ConfigDict(from_attributes=True)
####################
# Forms
@ -48,10 +46,6 @@ class PromptForm(BaseModel):
class PromptsTable:
def __init__(self, db):
self.db = db
self.db.create_tables([Prompt])
def insert_new_prompt(
self, user_id: str, form_data: PromptForm
) -> Optional[PromptModel]:
@ -66,53 +60,60 @@ class PromptsTable:
)
try:
result = Prompt.create(**prompt.model_dump())
with get_db() as db:
result = Prompt(**prompt.dict())
db.add(result)
db.commit()
db.refresh(result)
if result:
return prompt
return PromptModel.model_validate(result)
else:
return None
except:
except Exception as e:
return None
def get_prompt_by_command(self, command: str) -> Optional[PromptModel]:
try:
prompt = Prompt.get(Prompt.command == command)
return PromptModel(**model_to_dict(prompt))
with get_db() as db:
prompt = db.query(Prompt).filter_by(command=command).first()
return PromptModel.model_validate(prompt)
except:
return None
def get_prompts(self) -> List[PromptModel]:
with get_db() as db:
return [
PromptModel(**model_to_dict(prompt))
for prompt in Prompt.select()
# .limit(limit).offset(skip)
PromptModel.model_validate(prompt) for prompt in db.query(Prompt).all()
]
def update_prompt_by_command(
self, command: str, form_data: PromptForm
) -> Optional[PromptModel]:
try:
query = Prompt.update(
title=form_data.title,
content=form_data.content,
timestamp=int(time.time()),
).where(Prompt.command == command)
with get_db() as db:
query.execute()
prompt = Prompt.get(Prompt.command == command)
return PromptModel(**model_to_dict(prompt))
prompt = db.query(Prompt).filter_by(command=command).first()
prompt.title = form_data.title
prompt.content = form_data.content
prompt.timestamp = int(time.time())
db.commit()
return PromptModel.model_validate(prompt)
except:
return None
def delete_prompt_by_command(self, command: str) -> bool:
try:
query = Prompt.delete().where((Prompt.command == command))
query.execute() # Remove the rows, return number of rows removed.
with get_db() as db:
db.query(Prompt).filter_by(command=command).delete()
db.commit()
return True
except:
return False
Prompts = PromptsTable(DB)
Prompts = PromptsTable()

View File

@ -1,14 +1,14 @@
from pydantic import BaseModel
from typing import List, Union, Optional
from peewee import *
from playhouse.shortcuts import model_to_dict
from pydantic import BaseModel, ConfigDict
from typing import List, Optional
import json
import uuid
import time
import logging
from apps.webui.internal.db import DB
from sqlalchemy import String, Column, BigInteger, Text
from apps.webui.internal.db import Base, get_db
from config import SRC_LOG_LEVELS
@ -20,25 +20,23 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
class Tag(Model):
id = CharField(unique=True)
name = CharField()
user_id = CharField()
data = TextField(null=True)
class Tag(Base):
__tablename__ = "tag"
class Meta:
database = DB
id = Column(String, primary_key=True)
name = Column(String)
user_id = Column(String)
data = Column(Text, nullable=True)
class ChatIdTag(Model):
id = CharField(unique=True)
tag_name = CharField()
chat_id = CharField()
user_id = CharField()
timestamp = BigIntegerField()
class ChatIdTag(Base):
__tablename__ = "chatidtag"
class Meta:
database = DB
id = Column(String, primary_key=True)
tag_name = Column(String)
chat_id = Column(String)
user_id = Column(String)
timestamp = Column(BigInteger)
class TagModel(BaseModel):
@ -47,6 +45,8 @@ class TagModel(BaseModel):
user_id: str
data: Optional[str] = None
model_config = ConfigDict(from_attributes=True)
class ChatIdTagModel(BaseModel):
id: str
@ -55,6 +55,8 @@ class ChatIdTagModel(BaseModel):
user_id: str
timestamp: int
model_config = ConfigDict(from_attributes=True)
####################
# Forms
@ -75,17 +77,19 @@ class ChatTagsResponse(BaseModel):
class TagTable:
def __init__(self, db):
self.db = db
db.create_tables([Tag, ChatIdTag])
def insert_new_tag(self, name: str, user_id: str) -> Optional[TagModel]:
with get_db() as db:
id = str(uuid.uuid4())
tag = TagModel(**{"id": id, "user_id": user_id, "name": name})
try:
result = Tag.create(**tag.model_dump())
result = Tag(**tag.model_dump())
db.add(result)
db.commit()
db.refresh(result)
if result:
return tag
return TagModel.model_validate(result)
else:
return None
except Exception as e:
@ -95,8 +99,9 @@ class TagTable:
self, name: str, user_id: str
) -> Optional[TagModel]:
try:
tag = Tag.get(Tag.name == name, Tag.user_id == user_id)
return TagModel(**model_to_dict(tag))
with get_db() as db:
tag = db.query(Tag).filter(name=name, user_id=user_id).first()
return TagModel.model_validate(tag)
except Exception as e:
return None
@ -118,81 +123,109 @@ class TagTable:
}
)
try:
result = ChatIdTag.create(**chatIdTag.model_dump())
with get_db() as db:
result = ChatIdTag(**chatIdTag.model_dump())
db.add(result)
db.commit()
db.refresh(result)
if result:
return chatIdTag
return ChatIdTagModel.model_validate(result)
else:
return None
except:
return None
def get_tags_by_user_id(self, user_id: str) -> List[TagModel]:
with get_db() as db:
tag_names = [
ChatIdTagModel(**model_to_dict(chat_id_tag)).tag_name
for chat_id_tag in ChatIdTag.select()
.where(ChatIdTag.user_id == user_id)
chat_id_tag.tag_name
for chat_id_tag in (
db.query(ChatIdTag)
.filter_by(user_id=user_id)
.order_by(ChatIdTag.timestamp.desc())
.all()
)
]
return [
TagModel(**model_to_dict(tag))
for tag in Tag.select()
.where(Tag.user_id == user_id)
.where(Tag.name.in_(tag_names))
TagModel.model_validate(tag)
for tag in (
db.query(Tag)
.filter_by(user_id=user_id)
.filter(Tag.name.in_(tag_names))
.all()
)
]
def get_tags_by_chat_id_and_user_id(
self, chat_id: str, user_id: str
) -> List[TagModel]:
with get_db() as db:
tag_names = [
ChatIdTagModel(**model_to_dict(chat_id_tag)).tag_name
for chat_id_tag in ChatIdTag.select()
.where((ChatIdTag.user_id == user_id) & (ChatIdTag.chat_id == chat_id))
chat_id_tag.tag_name
for chat_id_tag in (
db.query(ChatIdTag)
.filter_by(user_id=user_id, chat_id=chat_id)
.order_by(ChatIdTag.timestamp.desc())
.all()
)
]
return [
TagModel(**model_to_dict(tag))
for tag in Tag.select()
.where(Tag.user_id == user_id)
.where(Tag.name.in_(tag_names))
TagModel.model_validate(tag)
for tag in (
db.query(Tag)
.filter_by(user_id=user_id)
.filter(Tag.name.in_(tag_names))
.all()
)
]
def get_chat_ids_by_tag_name_and_user_id(
self, tag_name: str, user_id: str
) -> Optional[ChatIdTagModel]:
) -> List[ChatIdTagModel]:
with get_db() as db:
return [
ChatIdTagModel(**model_to_dict(chat_id_tag))
for chat_id_tag in ChatIdTag.select()
.where((ChatIdTag.user_id == user_id) & (ChatIdTag.tag_name == tag_name))
ChatIdTagModel.model_validate(chat_id_tag)
for chat_id_tag in (
db.query(ChatIdTag)
.filter_by(user_id=user_id, tag_name=tag_name)
.order_by(ChatIdTag.timestamp.desc())
.all()
)
]
def count_chat_ids_by_tag_name_and_user_id(
self, tag_name: str, user_id: str
) -> int:
with get_db() as db:
return (
ChatIdTag.select()
.where((ChatIdTag.tag_name == tag_name) & (ChatIdTag.user_id == user_id))
db.query(ChatIdTag)
.filter_by(tag_name=tag_name, user_id=user_id)
.count()
)
def delete_tag_by_tag_name_and_user_id(self, tag_name: str, user_id: str) -> bool:
try:
query = ChatIdTag.delete().where(
(ChatIdTag.tag_name == tag_name) & (ChatIdTag.user_id == user_id)
with get_db() as db:
res = (
db.query(ChatIdTag)
.filter_by(tag_name=tag_name, user_id=user_id)
.delete()
)
res = query.execute() # Remove the rows, return number of rows removed.
log.debug(f"res: {res}")
db.commit()
tag_count = self.count_chat_ids_by_tag_name_and_user_id(tag_name, user_id)
tag_count = self.count_chat_ids_by_tag_name_and_user_id(
tag_name, user_id
)
if tag_count == 0:
# Remove tag item from Tag col as well
query = Tag.delete().where(
(Tag.name == tag_name) & (Tag.user_id == user_id)
)
query.execute() # Remove the rows, return number of rows removed.
db.query(Tag).filter_by(name=tag_name, user_id=user_id).delete()
db.commit()
return True
except Exception as e:
log.error(f"delete_tag: {e}")
@ -202,21 +235,23 @@ class TagTable:
self, tag_name: str, chat_id: str, user_id: str
) -> bool:
try:
query = ChatIdTag.delete().where(
(ChatIdTag.tag_name == tag_name)
& (ChatIdTag.chat_id == chat_id)
& (ChatIdTag.user_id == user_id)
)
res = query.execute() # Remove the rows, return number of rows removed.
log.debug(f"res: {res}")
with get_db() as db:
tag_count = self.count_chat_ids_by_tag_name_and_user_id(tag_name, user_id)
res = (
db.query(ChatIdTag)
.filter_by(tag_name=tag_name, chat_id=chat_id, user_id=user_id)
.delete()
)
log.debug(f"res: {res}")
db.commit()
tag_count = self.count_chat_ids_by_tag_name_and_user_id(
tag_name, user_id
)
if tag_count == 0:
# Remove tag item from Tag col as well
query = Tag.delete().where(
(Tag.name == tag_name) & (Tag.user_id == user_id)
)
query.execute() # Remove the rows, return number of rows removed.
db.query(Tag).filter_by(name=tag_name, user_id=user_id).delete()
db.commit()
return True
except Exception as e:
@ -234,4 +269,4 @@ class TagTable:
return True
Tags = TagTable(DB)
Tags = TagTable()

View File

@ -1,10 +1,10 @@
from pydantic import BaseModel
from peewee import *
from playhouse.shortcuts import model_to_dict
from typing import List, Union, Optional
from pydantic import BaseModel, ConfigDict
from typing import List, Optional
import time
import logging
from apps.webui.internal.db import DB, JSONField
from sqlalchemy import String, Column, BigInteger, Text
from apps.webui.internal.db import Base, JSONField, get_db
from apps.webui.models.users import Users
import json
@ -21,19 +21,18 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
class Tool(Model):
id = CharField(unique=True)
user_id = CharField()
name = TextField()
content = TextField()
specs = JSONField()
meta = JSONField()
valves = JSONField()
updated_at = BigIntegerField()
created_at = BigIntegerField()
class Tool(Base):
__tablename__ = "tool"
class Meta:
database = DB
id = Column(String, primary_key=True)
user_id = Column(String)
name = Column(Text)
content = Column(Text)
specs = Column(JSONField)
meta = Column(JSONField)
valves = Column(JSONField)
updated_at = Column(BigInteger)
created_at = Column(BigInteger)
class ToolMeta(BaseModel):
@ -51,6 +50,8 @@ class ToolModel(BaseModel):
updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch
model_config = ConfigDict(from_attributes=True)
####################
# Forms
@ -78,13 +79,13 @@ class ToolValves(BaseModel):
class ToolsTable:
def __init__(self, db):
self.db = db
self.db.create_tables([Tool])
def insert_new_tool(
self, user_id: str, form_data: ToolForm, specs: List[dict]
) -> Optional[ToolModel]:
with get_db() as db:
tool = ToolModel(
**{
**form_data.model_dump(),
@ -96,9 +97,12 @@ class ToolsTable:
)
try:
result = Tool.create(**tool.model_dump())
result = Tool(**tool.model_dump())
db.add(result)
db.commit()
db.refresh(result)
if result:
return tool
return ToolModel.model_validate(result)
else:
return None
except Exception as e:
@ -107,17 +111,22 @@ class ToolsTable:
def get_tool_by_id(self, id: str) -> Optional[ToolModel]:
try:
tool = Tool.get(Tool.id == id)
return ToolModel(**model_to_dict(tool))
with get_db() as db:
tool = db.get(Tool, id)
return ToolModel.model_validate(tool)
except:
return None
def get_tools(self) -> List[ToolModel]:
return [ToolModel(**model_to_dict(tool)) for tool in Tool.select()]
with get_db() as db:
return [ToolModel.model_validate(tool) for tool in db.query(Tool).all()]
def get_tool_valves_by_id(self, id: str) -> Optional[dict]:
try:
tool = Tool.get(Tool.id == id)
with get_db() as db:
tool = db.get(Tool, id)
return tool.valves if tool.valves else {}
except Exception as e:
print(f"An error occurred: {e}")
@ -125,14 +134,13 @@ class ToolsTable:
def update_tool_valves_by_id(self, id: str, valves: dict) -> Optional[ToolValves]:
try:
query = Tool.update(
**{"valves": valves},
updated_at=int(time.time()),
).where(Tool.id == id)
query.execute()
with get_db() as db:
tool = Tool.get(Tool.id == id)
return ToolValves(**model_to_dict(tool))
db.query(Tool).filter_by(id=id).update(
{"valves": valves, "updated_at": int(time.time())}
)
db.commit()
return self.get_tool_by_id(id)
except:
return None
@ -141,7 +149,7 @@ class ToolsTable:
) -> Optional[dict]:
try:
user = Users.get_user_by_id(user_id)
user_settings = user.settings.model_dump()
user_settings = user.settings.model_dump() if user.settings else {}
# Check if user has "tools" and "valves" settings
if "tools" not in user_settings:
@ -159,7 +167,7 @@ class ToolsTable:
) -> Optional[dict]:
try:
user = Users.get_user_by_id(user_id)
user_settings = user.settings.model_dump()
user_settings = user.settings.model_dump() if user.settings else {}
# Check if user has "tools" and "valves" settings
if "tools" not in user_settings:
@ -170,8 +178,7 @@ class ToolsTable:
user_settings["tools"]["valves"][id] = valves
# Update the user settings in the database
query = Users.update_user_by_id(user_id, {"settings": user_settings})
query.execute()
Users.update_user_by_id(user_id, {"settings": user_settings})
return user_settings["tools"]["valves"][id]
except Exception as e:
@ -180,25 +187,27 @@ class ToolsTable:
def update_tool_by_id(self, id: str, updated: dict) -> Optional[ToolModel]:
try:
query = Tool.update(
**updated,
updated_at=int(time.time()),
).where(Tool.id == id)
query.execute()
with get_db() as db:
db.query(Tool).filter_by(id=id).update(
{**updated, "updated_at": int(time.time())}
)
db.commit()
tool = Tool.get(Tool.id == id)
return ToolModel(**model_to_dict(tool))
tool = db.query(Tool).get(id)
db.refresh(tool)
return ToolModel.model_validate(tool)
except:
return None
def delete_tool_by_id(self, id: str) -> bool:
try:
query = Tool.delete().where((Tool.id == id))
query.execute() # Remove the rows, return number of rows removed.
with get_db() as db:
db.query(Tool).filter_by(id=id).delete()
db.commit()
return True
except:
return False
Tools = ToolsTable(DB)
Tools = ToolsTable()

View File

@ -1,11 +1,12 @@
from pydantic import BaseModel, ConfigDict
from peewee import *
from playhouse.shortcuts import model_to_dict
from pydantic import BaseModel, ConfigDict, parse_obj_as
from typing import List, Union, Optional
import time
from sqlalchemy import String, Column, BigInteger, Text
from utils.misc import get_gravatar_url
from apps.webui.internal.db import DB, JSONField
from apps.webui.internal.db import Base, JSONField, Session, get_db
from apps.webui.models.chats import Chats
####################
@ -13,25 +14,24 @@ from apps.webui.models.chats import Chats
####################
class User(Model):
id = CharField(unique=True)
name = CharField()
email = CharField()
role = CharField()
profile_image_url = TextField()
class User(Base):
__tablename__ = "user"
last_active_at = BigIntegerField()
updated_at = BigIntegerField()
created_at = BigIntegerField()
id = Column(String, primary_key=True)
name = Column(String)
email = Column(String)
role = Column(String)
profile_image_url = Column(Text)
api_key = CharField(null=True, unique=True)
settings = JSONField(null=True)
info = JSONField(null=True)
last_active_at = Column(BigInteger)
updated_at = Column(BigInteger)
created_at = Column(BigInteger)
oauth_sub = TextField(null=True, unique=True)
api_key = Column(String, nullable=True, unique=True)
settings = Column(JSONField, nullable=True)
info = Column(JSONField, nullable=True)
class Meta:
database = DB
oauth_sub = Column(Text, unique=True)
class UserSettings(BaseModel):
@ -57,6 +57,8 @@ class UserModel(BaseModel):
oauth_sub: Optional[str] = None
model_config = ConfigDict(from_attributes=True)
####################
# Forms
@ -76,9 +78,6 @@ class UserUpdateForm(BaseModel):
class UsersTable:
def __init__(self, db):
self.db = db
self.db.create_tables([User])
def insert_new_user(
self,
@ -89,6 +88,7 @@ class UsersTable:
role: str = "pending",
oauth_sub: Optional[str] = None,
) -> Optional[UserModel]:
with get_db() as db:
user = UserModel(
**{
"id": id,
@ -102,7 +102,10 @@ class UsersTable:
"oauth_sub": oauth_sub,
}
)
result = User.create(**user.model_dump())
result = User(**user.model_dump())
db.add(result)
db.commit()
db.refresh(result)
if result:
return user
else:
@ -110,56 +113,67 @@ class UsersTable:
def get_user_by_id(self, id: str) -> Optional[UserModel]:
try:
user = User.get(User.id == id)
return UserModel(**model_to_dict(user))
except:
with get_db() as db:
user = db.query(User).filter_by(id=id).first()
return UserModel.model_validate(user)
except Exception as e:
return None
def get_user_by_api_key(self, api_key: str) -> Optional[UserModel]:
try:
user = User.get(User.api_key == api_key)
return UserModel(**model_to_dict(user))
with get_db() as db:
user = db.query(User).filter_by(api_key=api_key).first()
return UserModel.model_validate(user)
except:
return None
def get_user_by_email(self, email: str) -> Optional[UserModel]:
try:
user = User.get(User.email == email)
return UserModel(**model_to_dict(user))
with get_db() as db:
user = db.query(User).filter_by(email=email).first()
return UserModel.model_validate(user)
except:
return None
def get_user_by_oauth_sub(self, sub: str) -> Optional[UserModel]:
try:
user = User.get(User.oauth_sub == sub)
return UserModel(**model_to_dict(user))
with get_db() as db:
user = db.query(User).filter_by(oauth_sub=sub).first()
return UserModel.model_validate(user)
except:
return None
def get_users(self, skip: int = 0, limit: int = 50) -> List[UserModel]:
return [
UserModel(**model_to_dict(user))
for user in User.select()
# .limit(limit).offset(skip)
]
with get_db() as db:
users = (
db.query(User)
# .offset(skip).limit(limit)
.all()
)
return [UserModel.model_validate(user) for user in users]
def get_num_users(self) -> Optional[int]:
return User.select().count()
with get_db() as db:
return db.query(User).count()
def get_first_user(self) -> UserModel:
try:
user = User.select().order_by(User.created_at).first()
return UserModel(**model_to_dict(user))
with get_db() as db:
user = db.query(User).order_by(User.created_at).first()
return UserModel.model_validate(user)
except:
return None
def update_user_role_by_id(self, id: str, role: str) -> Optional[UserModel]:
try:
query = User.update(role=role).where(User.id == id)
query.execute()
user = User.get(User.id == id)
return UserModel(**model_to_dict(user))
with get_db() as db:
db.query(User).filter_by(id=id).update({"role": role})
db.commit()
user = db.query(User).filter_by(id=id).first()
return UserModel.model_validate(user)
except:
return None
@ -167,23 +181,28 @@ class UsersTable:
self, id: str, profile_image_url: str
) -> Optional[UserModel]:
try:
query = User.update(profile_image_url=profile_image_url).where(
User.id == id
with get_db() as db:
db.query(User).filter_by(id=id).update(
{"profile_image_url": profile_image_url}
)
query.execute()
db.commit()
user = User.get(User.id == id)
return UserModel(**model_to_dict(user))
user = db.query(User).filter_by(id=id).first()
return UserModel.model_validate(user)
except:
return None
def update_user_last_active_by_id(self, id: str) -> Optional[UserModel]:
try:
query = User.update(last_active_at=int(time.time())).where(User.id == id)
query.execute()
with get_db() as db:
user = User.get(User.id == id)
return UserModel(**model_to_dict(user))
db.query(User).filter_by(id=id).update(
{"last_active_at": int(time.time())}
)
db.commit()
user = db.query(User).filter_by(id=id).first()
return UserModel.model_validate(user)
except:
return None
@ -191,22 +210,25 @@ class UsersTable:
self, id: str, oauth_sub: str
) -> Optional[UserModel]:
try:
query = User.update(oauth_sub=oauth_sub).where(User.id == id)
query.execute()
with get_db() as db:
db.query(User).filter_by(id=id).update({"oauth_sub": oauth_sub})
db.commit()
user = User.get(User.id == id)
return UserModel(**model_to_dict(user))
user = db.query(User).filter_by(id=id).first()
return UserModel.model_validate(user)
except:
return None
def update_user_by_id(self, id: str, updated: dict) -> Optional[UserModel]:
try:
query = User.update(**updated).where(User.id == id)
query.execute()
with get_db() as db:
db.query(User).filter_by(id=id).update(updated)
db.commit()
user = User.get(User.id == id)
return UserModel(**model_to_dict(user))
except:
user = db.query(User).filter_by(id=id).first()
return UserModel.model_validate(user)
# return UserModel(**user.dict())
except Exception as e:
return None
def delete_user_by_id(self, id: str) -> bool:
@ -215,9 +237,10 @@ class UsersTable:
result = Chats.delete_chats_by_user_id(id)
if result:
with get_db() as db:
# Delete User
query = User.delete().where(User.id == id)
query.execute() # Remove the rows, return number of rows removed.
db.query(User).filter_by(id=id).delete()
db.commit()
return True
else:
@ -227,19 +250,20 @@ class UsersTable:
def update_user_api_key_by_id(self, id: str, api_key: str) -> str:
try:
query = User.update(api_key=api_key).where(User.id == id)
result = query.execute()
with get_db() as db:
result = db.query(User).filter_by(id=id).update({"api_key": api_key})
db.commit()
return True if result == 1 else False
except:
return False
def get_user_api_key_by_id(self, id: str) -> Optional[str]:
try:
user = User.get(User.id == id)
with get_db() as db:
user = db.query(User).filter_by(id=id).first()
return user.api_key
except:
except Exception as e:
return None
Users = UsersTable(DB)
Users = UsersTable()

View File

@ -76,7 +76,10 @@ async def delete_all_user_chats(request: Request, user=Depends(get_verified_user
@router.get("/list/user/{user_id}", response_model=List[ChatTitleIdResponse])
async def get_user_chat_list_by_user_id(
user_id: str, user=Depends(get_admin_user), skip: int = 0, limit: int = 50
user_id: str,
user=Depends(get_admin_user),
skip: int = 0,
limit: int = 50,
):
return Chats.get_chat_list_by_user_id(
user_id, include_archived=True, skip=skip, limit=limit
@ -119,7 +122,7 @@ async def get_user_chats(user=Depends(get_verified_user)):
@router.get("/all/archived", response_model=List[ChatResponse])
async def get_user_chats(user=Depends(get_verified_user)):
async def get_user_archived_chats(user=Depends(get_verified_user)):
return [
ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
for chat in Chats.get_archived_chats_by_user_id(user.id)
@ -207,7 +210,6 @@ async def get_user_chat_list_by_tag_name(
form_data: TagNameForm, user=Depends(get_verified_user)
):
print(form_data)
chat_ids = [
chat_id_tag.chat_id
for chat_id_tag in Tags.get_chat_ids_by_tag_name_and_user_id(

View File

@ -130,7 +130,9 @@ async def tag_doc_by_name(form_data: TagDocumentForm, user=Depends(get_verified_
@router.post("/doc/update", response_model=Optional[DocumentResponse])
async def update_doc_by_name(
name: str, form_data: DocumentUpdateForm, user=Depends(get_admin_user)
name: str,
form_data: DocumentUpdateForm,
user=Depends(get_admin_user),
):
doc = Documents.update_doc_by_name(name, form_data)
if doc:

View File

@ -50,10 +50,7 @@ router = APIRouter()
@router.post("/")
def upload_file(
file: UploadFile = File(...),
user=Depends(get_verified_user),
):
def upload_file(file: UploadFile = File(...), user=Depends(get_verified_user)):
log.info(f"file.content_type: {file.content_type}")
try:
unsanitized_filename = file.filename

View File

@ -233,7 +233,10 @@ async def delete_function_by_id(
# delete the function file
function_path = os.path.join(FUNCTIONS_DIR, f"{id}.py")
try:
os.remove(function_path)
except:
pass
return result

View File

@ -50,7 +50,9 @@ class MemoryUpdateModel(BaseModel):
@router.post("/add", response_model=Optional[MemoryModel])
async def add_memory(
request: Request, form_data: AddMemoryForm, user=Depends(get_verified_user)
request: Request,
form_data: AddMemoryForm,
user=Depends(get_verified_user),
):
memory = Memories.insert_new_memory(user.id, form_data.content)
memory_embedding = request.app.state.EMBEDDING_FUNCTION(memory.content)

View File

@ -5,6 +5,7 @@ from typing import List, Union, Optional
from fastapi import APIRouter
from pydantic import BaseModel
import json
from apps.webui.models.models import Models, ModelModel, ModelForm, ModelResponse
from utils.utils import get_verified_user, get_admin_user
@ -29,7 +30,9 @@ async def get_models(user=Depends(get_verified_user)):
@router.post("/add", response_model=Optional[ModelModel])
async def add_new_model(
request: Request, form_data: ModelForm, user=Depends(get_admin_user)
request: Request,
form_data: ModelForm,
user=Depends(get_admin_user),
):
if form_data.id in request.app.state.MODELS:
raise HTTPException(
@ -73,7 +76,10 @@ async def get_model_by_id(id: str, user=Depends(get_verified_user)):
@router.post("/update", response_model=Optional[ModelModel])
async def update_model_by_id(
request: Request, id: str, form_data: ModelForm, user=Depends(get_admin_user)
request: Request,
id: str,
form_data: ModelForm,
user=Depends(get_admin_user),
):
model = Models.get_model_by_id(id)
if model:

View File

@ -71,7 +71,9 @@ async def get_prompt_by_command(command: str, user=Depends(get_verified_user)):
@router.post("/command/{command}/update", response_model=Optional[PromptModel])
async def update_prompt_by_command(
command: str, form_data: PromptForm, user=Depends(get_admin_user)
command: str,
form_data: PromptForm,
user=Depends(get_admin_user),
):
prompt = Prompts.update_prompt_by_command(f"/{command}", form_data)
if prompt:

View File

@ -6,7 +6,6 @@ from fastapi import APIRouter
from pydantic import BaseModel
import json
from apps.webui.models.users import Users
from apps.webui.models.tools import Tools, ToolForm, ToolModel, ToolResponse
from apps.webui.utils import load_toolkit_module_by_id
@ -57,7 +56,9 @@ async def get_toolkits(user=Depends(get_admin_user)):
@router.post("/create", response_model=Optional[ToolResponse])
async def create_new_toolkit(
request: Request, form_data: ToolForm, user=Depends(get_admin_user)
request: Request,
form_data: ToolForm,
user=Depends(get_admin_user),
):
if not form_data.id.isidentifier():
raise HTTPException(
@ -131,7 +132,10 @@ async def get_toolkit_by_id(id: str, user=Depends(get_admin_user)):
@router.post("/id/{id}/update", response_model=Optional[ToolModel])
async def update_toolkit_by_id(
request: Request, id: str, form_data: ToolForm, user=Depends(get_admin_user)
request: Request,
id: str,
form_data: ToolForm,
user=Depends(get_admin_user),
):
toolkit_path = os.path.join(TOOLS_DIR, f"{id}.py")

View File

@ -138,7 +138,7 @@ async def get_user_info_by_session_user(user=Depends(get_verified_user)):
@router.post("/user/info/update", response_model=Optional[dict])
async def update_user_settings_by_session_user(
async def update_user_info_by_session_user(
form_data: dict, user=Depends(get_verified_user)
):
user = Users.get_user_by_id(user.id)
@ -205,7 +205,9 @@ async def get_user_by_id(user_id: str, user=Depends(get_verified_user)):
@router.post("/{user_id}/update", response_model=Optional[UserModel])
async def update_user_by_id(
user_id: str, form_data: UserUpdateForm, session_user=Depends(get_admin_user)
user_id: str,
form_data: UserUpdateForm,
session_user=Depends(get_admin_user),
):
user = Users.get_user_by_id(user_id)

View File

@ -1,6 +1,5 @@
from fastapi import APIRouter, UploadFile, File, Response
from fastapi import Depends, HTTPException, status
from peewee import SqliteDatabase
from starlette.responses import StreamingResponse, FileResponse
from pydantic import BaseModel
@ -10,7 +9,6 @@ import markdown
import black
from apps.webui.internal.db import DB
from utils.utils import get_admin_user
from utils.misc import calculate_sha256, get_gravatar_url
@ -114,13 +112,15 @@ async def download_db(user=Depends(get_admin_user)):
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
if not isinstance(DB, SqliteDatabase):
from apps.webui.internal.db import engine
if engine.name != "sqlite":
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DB_NOT_SQLITE,
)
return FileResponse(
DB.database,
engine.url.database,
media_type="application/octet-stream",
filename="webui.db",
)

View File

@ -5,9 +5,8 @@ import importlib.metadata
import pkgutil
import chromadb
from chromadb import Settings
from base64 import b64encode
from bs4 import BeautifulSoup
from typing import TypeVar, Generic, Union
from typing import TypeVar, Generic
from pydantic import BaseModel
from typing import Optional
@ -19,7 +18,6 @@ import markdown
import requests
import shutil
from secrets import token_bytes
from constants import ERROR_MESSAGES
####################################
@ -395,6 +393,18 @@ OAUTH_PROVIDER_NAME = PersistentConfig(
os.environ.get("OAUTH_PROVIDER_NAME", "SSO"),
)
OAUTH_USERNAME_CLAIM = PersistentConfig(
"OAUTH_USERNAME_CLAIM",
"oauth.oidc.username_claim",
os.environ.get("OAUTH_USERNAME_CLAIM", "name"),
)
OAUTH_PICTURE_CLAIM = PersistentConfig(
"OAUTH_USERNAME_CLAIM",
"oauth.oidc.avatar_claim",
os.environ.get("OAUTH_PICTURE_CLAIM", "picture"),
)
def load_oauth_providers():
OAUTH_PROVIDERS.clear()
@ -440,16 +450,27 @@ load_oauth_providers()
STATIC_DIR = Path(os.getenv("STATIC_DIR", BACKEND_DIR / "static")).resolve()
frontend_favicon = FRONTEND_BUILD_DIR / "favicon.png"
frontend_favicon = FRONTEND_BUILD_DIR / "static" / "favicon.png"
if frontend_favicon.exists():
try:
shutil.copyfile(frontend_favicon, STATIC_DIR / "favicon.png")
except Exception as e:
logging.error(f"An error occurred: {e}")
else:
logging.warning(f"Frontend favicon not found at {frontend_favicon}")
frontend_splash = FRONTEND_BUILD_DIR / "static" / "splash.png"
if frontend_splash.exists():
try:
shutil.copyfile(frontend_splash, STATIC_DIR / "splash.png")
except Exception as e:
logging.error(f"An error occurred: {e}")
else:
logging.warning(f"Frontend splash not found at {frontend_splash}")
####################################
# CUSTOM_NAME
####################################
@ -474,6 +495,19 @@ if CUSTOM_NAME:
r.raw.decode_content = True
shutil.copyfileobj(r.raw, f)
if "splash" in data:
url = (
f"https://api.openwebui.com{data['splash']}"
if data["splash"][0] == "/"
else data["splash"]
)
r = requests.get(url, stream=True)
if r.status_code == 200:
with open(f"{STATIC_DIR}/splash.png", "wb") as f:
r.raw.decode_content = True
shutil.copyfileobj(r.raw, f)
WEBUI_NAME = data["name"]
except Exception as e:
log.exception(e)
@ -769,11 +803,14 @@ class BannerModel(BaseModel):
timestamp: int
WEBUI_BANNERS = PersistentConfig(
"WEBUI_BANNERS",
"ui.banners",
[BannerModel(**banner) for banner in json.loads("[]")],
)
try:
banners = json.loads(os.environ.get("WEBUI_BANNERS", "[]"))
banners = [BannerModel(**banner) for banner in banners]
except Exception as e:
print(f"Error loading WEBUI_BANNERS: {e}")
banners = []
WEBUI_BANNERS = PersistentConfig("WEBUI_BANNERS", "ui.banners", banners)
SHOW_ADMIN_DETAILS = PersistentConfig(
@ -885,6 +922,22 @@ WEBUI_SESSION_COOKIE_SECURE = os.environ.get(
if WEBUI_AUTH and WEBUI_SECRET_KEY == "":
raise ValueError(ERROR_MESSAGES.ENV_VAR_NOT_FOUND)
####################################
# RAG document content extraction
####################################
CONTENT_EXTRACTION_ENGINE = PersistentConfig(
"CONTENT_EXTRACTION_ENGINE",
"rag.CONTENT_EXTRACTION_ENGINE",
os.environ.get("CONTENT_EXTRACTION_ENGINE", "").lower(),
)
TIKA_SERVER_URL = PersistentConfig(
"TIKA_SERVER_URL",
"rag.tika_server_url",
os.getenv("TIKA_SERVER_URL", "http://tika:9998"), # Default for sidecar deployment
)
####################################
# RAG
####################################
@ -1302,3 +1355,7 @@ AUDIO_TTS_VOICE = PersistentConfig(
####################################
DATABASE_URL = os.environ.get("DATABASE_URL", f"sqlite:///{DATA_DIR}/webui.db")
# Replace the postgres:// with postgresql://
if "postgres://" in DATABASE_URL:
DATABASE_URL = DATABASE_URL.replace("postgres://", "postgresql://")

View File

@ -89,3 +89,14 @@ class ERROR_MESSAGES(str, Enum):
OLLAMA_API_DISABLED = (
"The Ollama API is disabled. Please enable it to use this feature."
)
class TASKS(str, Enum):
def __str__(self) -> str:
return super().__str__()
DEFAULT = lambda task="": f"{task if task else 'default'}"
TITLE_GENERATION = "Title Generation"
EMOJI_GENERATION = "Emoji Generation"
QUERY_GENERATION = "Query Generation"
FUNCTION_CALLING = "Function Calling"

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,4 @@
Generic single-database configuration.
Create new migrations with
DATABASE_URL=<replace with actual url> alembic revision --autogenerate -m "a description"

96
backend/migrations/env.py Normal file
View File

@ -0,0 +1,96 @@
import os
from logging.config import fileConfig
from sqlalchemy import engine_from_config
from sqlalchemy import pool
from alembic import context
from apps.webui.models.auths import Auth
from apps.webui.models.chats import Chat
from apps.webui.models.documents import Document
from apps.webui.models.memories import Memory
from apps.webui.models.models import Model
from apps.webui.models.prompts import Prompt
from apps.webui.models.tags import Tag, ChatIdTag
from apps.webui.models.tools import Tool
from apps.webui.models.users import User
from apps.webui.models.files import File
from apps.webui.models.functions import Function
from config import DATABASE_URL
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config = context.config
# Interpret the config file for Python logging.
# This line sets up loggers basically.
if config.config_file_name is not None:
fileConfig(config.config_file_name)
# add your model's MetaData object here
# for 'autogenerate' support
# from myapp import mymodel
# target_metadata = mymodel.Base.metadata
target_metadata = Auth.metadata
# other values from the config, defined by the needs of env.py,
# can be acquired:
# my_important_option = config.get_main_option("my_important_option")
# ... etc.
DB_URL = DATABASE_URL
if DB_URL:
config.set_main_option("sqlalchemy.url", DB_URL.replace("%", "%%"))
def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode.
This configures the context with just a URL
and not an Engine, though an Engine is acceptable
here as well. By skipping the Engine creation
we don't even need a DBAPI to be available.
Calls to context.execute() here emit the given string to the
script output.
"""
url = config.get_main_option("sqlalchemy.url")
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
context.run_migrations()
def run_migrations_online() -> None:
"""Run migrations in 'online' mode.
In this scenario we need to create an Engine
and associate a connection with the context.
"""
connectable = engine_from_config(
config.get_section(config.config_ini_section, {}),
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
with connectable.connect() as connection:
context.configure(connection=connection, target_metadata=target_metadata)
with context.begin_transaction():
context.run_migrations()
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()

View File

@ -0,0 +1,27 @@
"""${message}
Revision ID: ${up_revision}
Revises: ${down_revision | comma,n}
Create Date: ${create_date}
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
import apps.webui.internal.db
${imports if imports else ""}
# revision identifiers, used by Alembic.
revision: str = ${repr(up_revision)}
down_revision: Union[str, None] = ${repr(down_revision)}
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
def upgrade() -> None:
${upgrades if upgrades else "pass"}
def downgrade() -> None:
${downgrades if downgrades else "pass"}

View File

@ -0,0 +1,9 @@
from alembic import op
from sqlalchemy import Inspector
def get_existing_tables():
con = op.get_bind()
inspector = Inspector.from_engine(con)
tables = set(inspector.get_table_names())
return tables

View File

@ -0,0 +1,202 @@
"""init
Revision ID: 7e5b5dc7342b
Revises:
Create Date: 2024-06-24 13:15:33.808998
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
import apps.webui.internal.db
from migrations.util import get_existing_tables
# revision identifiers, used by Alembic.
revision: str = "7e5b5dc7342b"
down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
existing_tables = set(get_existing_tables())
# ### commands auto generated by Alembic - please adjust! ###
if "auth" not in existing_tables:
op.create_table(
"auth",
sa.Column("id", sa.String(), nullable=False),
sa.Column("email", sa.String(), nullable=True),
sa.Column("password", sa.Text(), nullable=True),
sa.Column("active", sa.Boolean(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
if "chat" not in existing_tables:
op.create_table(
"chat",
sa.Column("id", sa.String(), nullable=False),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("title", sa.Text(), nullable=True),
sa.Column("chat", sa.Text(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.Column("updated_at", sa.BigInteger(), nullable=True),
sa.Column("share_id", sa.Text(), nullable=True),
sa.Column("archived", sa.Boolean(), nullable=True),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("share_id"),
)
if "chatidtag" not in existing_tables:
op.create_table(
"chatidtag",
sa.Column("id", sa.String(), nullable=False),
sa.Column("tag_name", sa.String(), nullable=True),
sa.Column("chat_id", sa.String(), nullable=True),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("timestamp", sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
if "document" not in existing_tables:
op.create_table(
"document",
sa.Column("collection_name", sa.String(), nullable=False),
sa.Column("name", sa.String(), nullable=True),
sa.Column("title", sa.Text(), nullable=True),
sa.Column("filename", sa.Text(), nullable=True),
sa.Column("content", sa.Text(), nullable=True),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("timestamp", sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint("collection_name"),
sa.UniqueConstraint("name"),
)
if "file" not in existing_tables:
op.create_table(
"file",
sa.Column("id", sa.String(), nullable=False),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("filename", sa.Text(), nullable=True),
sa.Column("meta", apps.webui.internal.db.JSONField(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
if "function" not in existing_tables:
op.create_table(
"function",
sa.Column("id", sa.String(), nullable=False),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("name", sa.Text(), nullable=True),
sa.Column("type", sa.Text(), nullable=True),
sa.Column("content", sa.Text(), nullable=True),
sa.Column("meta", apps.webui.internal.db.JSONField(), nullable=True),
sa.Column("valves", apps.webui.internal.db.JSONField(), nullable=True),
sa.Column("is_active", sa.Boolean(), nullable=True),
sa.Column("is_global", sa.Boolean(), nullable=True),
sa.Column("updated_at", sa.BigInteger(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
if "memory" not in existing_tables:
op.create_table(
"memory",
sa.Column("id", sa.String(), nullable=False),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("content", sa.Text(), nullable=True),
sa.Column("updated_at", sa.BigInteger(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
if "model" not in existing_tables:
op.create_table(
"model",
sa.Column("id", sa.Text(), nullable=False),
sa.Column("user_id", sa.Text(), nullable=True),
sa.Column("base_model_id", sa.Text(), nullable=True),
sa.Column("name", sa.Text(), nullable=True),
sa.Column("params", apps.webui.internal.db.JSONField(), nullable=True),
sa.Column("meta", apps.webui.internal.db.JSONField(), nullable=True),
sa.Column("updated_at", sa.BigInteger(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
if "prompt" not in existing_tables:
op.create_table(
"prompt",
sa.Column("command", sa.String(), nullable=False),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("title", sa.Text(), nullable=True),
sa.Column("content", sa.Text(), nullable=True),
sa.Column("timestamp", sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint("command"),
)
if "tag" not in existing_tables:
op.create_table(
"tag",
sa.Column("id", sa.String(), nullable=False),
sa.Column("name", sa.String(), nullable=True),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("data", sa.Text(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
if "tool" not in existing_tables:
op.create_table(
"tool",
sa.Column("id", sa.String(), nullable=False),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("name", sa.Text(), nullable=True),
sa.Column("content", sa.Text(), nullable=True),
sa.Column("specs", apps.webui.internal.db.JSONField(), nullable=True),
sa.Column("meta", apps.webui.internal.db.JSONField(), nullable=True),
sa.Column("valves", apps.webui.internal.db.JSONField(), nullable=True),
sa.Column("updated_at", sa.BigInteger(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
if "user" not in existing_tables:
op.create_table(
"user",
sa.Column("id", sa.String(), nullable=False),
sa.Column("name", sa.String(), nullable=True),
sa.Column("email", sa.String(), nullable=True),
sa.Column("role", sa.String(), nullable=True),
sa.Column("profile_image_url", sa.Text(), nullable=True),
sa.Column("last_active_at", sa.BigInteger(), nullable=True),
sa.Column("updated_at", sa.BigInteger(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.Column("api_key", sa.String(), nullable=True),
sa.Column("settings", apps.webui.internal.db.JSONField(), nullable=True),
sa.Column("info", apps.webui.internal.db.JSONField(), nullable=True),
sa.Column("oauth_sub", sa.Text(), nullable=True),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("api_key"),
sa.UniqueConstraint("oauth_sub"),
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table("user")
op.drop_table("tool")
op.drop_table("tag")
op.drop_table("prompt")
op.drop_table("model")
op.drop_table("memory")
op.drop_table("function")
op.drop_table("file")
op.drop_table("document")
op.drop_table("chatidtag")
op.drop_table("chat")
op.drop_table("auth")
# ### end Alembic commands ###

View File

@ -10,9 +10,11 @@ python-socketio==5.11.3
python-jose==3.3.0
passlib[bcrypt]==1.7.4
requests==2.32.2
requests==2.32.3
aiohttp==3.9.5
peewee==3.17.5
sqlalchemy==2.0.30
alembic==1.13.2
peewee==3.17.6
peewee-migrate==1.12.2
psycopg2-binary==2.9.9
PyMySQL==1.1.1
@ -30,26 +32,26 @@ openai
anthropic
google-generativeai==0.5.4
langchain==0.2.0
langchain-community==0.2.0
langchain==0.2.6
langchain-community==0.2.6
langchain-chroma==0.1.2
fake-useragent==1.5.1
chromadb==0.5.3
sentence-transformers==2.7.0
sentence-transformers==3.0.1
pypdf==4.2.0
docx2txt==0.8
python-pptx==0.6.23
unstructured==0.14.0
unstructured==0.14.9
Markdown==3.6
pypandoc==1.13
pandas==2.2.2
openpyxl==3.1.2
openpyxl==3.1.5
pyxlsb==1.0.10
xlrd==2.0.1
validators==0.28.1
opencv-python-headless==4.9.0.80
opencv-python-headless==4.10.0.84
rapidocr-onnxruntime==1.3.22
fpdf2==2.7.9
@ -61,10 +63,15 @@ PyJWT[crypto]==2.8.0
authlib==1.3.1
black==24.4.2
langfuse==2.33.0
langfuse==2.38.0
youtube-transcript-api==0.6.2
pytube==15.0.0
extract_msg
pydub
duckduckgo-search~=6.1.7
## Tests
docker~=7.1.0
pytest~=8.2.2
pytest-docker~=3.1.1

BIN
backend/static/splash.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.1 KiB

0
backend/test/__init__.py Normal file
View File

View File

@ -0,0 +1,202 @@
import pytest
from test.util.abstract_integration_test import AbstractPostgresTest
from test.util.mock_user import mock_webui_user
class TestAuths(AbstractPostgresTest):
BASE_PATH = "/api/v1/auths"
def setup_class(cls):
super().setup_class()
from apps.webui.models.users import Users
from apps.webui.models.auths import Auths
cls.users = Users
cls.auths = Auths
def test_get_session_user(self):
with mock_webui_user():
response = self.fast_api_client.get(self.create_url(""))
assert response.status_code == 200
assert response.json() == {
"id": "1",
"name": "John Doe",
"email": "john.doe@openwebui.com",
"role": "user",
"profile_image_url": "/user.png",
}
def test_update_profile(self):
from utils.utils import get_password_hash
user = self.auths.insert_new_auth(
email="john.doe@openwebui.com",
password=get_password_hash("old_password"),
name="John Doe",
profile_image_url="/user.png",
role="user",
)
with mock_webui_user(id=user.id):
response = self.fast_api_client.post(
self.create_url("/update/profile"),
json={"name": "John Doe 2", "profile_image_url": "/user2.png"},
)
assert response.status_code == 200
db_user = self.users.get_user_by_id(user.id)
assert db_user.name == "John Doe 2"
assert db_user.profile_image_url == "/user2.png"
def test_update_password(self):
from utils.utils import get_password_hash
user = self.auths.insert_new_auth(
email="john.doe@openwebui.com",
password=get_password_hash("old_password"),
name="John Doe",
profile_image_url="/user.png",
role="user",
)
with mock_webui_user(id=user.id):
response = self.fast_api_client.post(
self.create_url("/update/password"),
json={"password": "old_password", "new_password": "new_password"},
)
assert response.status_code == 200
old_auth = self.auths.authenticate_user(
"john.doe@openwebui.com", "old_password"
)
assert old_auth is None
new_auth = self.auths.authenticate_user(
"john.doe@openwebui.com", "new_password"
)
assert new_auth is not None
def test_signin(self):
from utils.utils import get_password_hash
user = self.auths.insert_new_auth(
email="john.doe@openwebui.com",
password=get_password_hash("password"),
name="John Doe",
profile_image_url="/user.png",
role="user",
)
response = self.fast_api_client.post(
self.create_url("/signin"),
json={"email": "john.doe@openwebui.com", "password": "password"},
)
assert response.status_code == 200
data = response.json()
assert data["id"] == user.id
assert data["name"] == "John Doe"
assert data["email"] == "john.doe@openwebui.com"
assert data["role"] == "user"
assert data["profile_image_url"] == "/user.png"
assert data["token"] is not None and len(data["token"]) > 0
assert data["token_type"] == "Bearer"
def test_signup(self):
response = self.fast_api_client.post(
self.create_url("/signup"),
json={
"name": "John Doe",
"email": "john.doe@openwebui.com",
"password": "password",
},
)
assert response.status_code == 200
data = response.json()
assert data["id"] is not None and len(data["id"]) > 0
assert data["name"] == "John Doe"
assert data["email"] == "john.doe@openwebui.com"
assert data["role"] in ["admin", "user", "pending"]
assert data["profile_image_url"] == "/user.png"
assert data["token"] is not None and len(data["token"]) > 0
assert data["token_type"] == "Bearer"
def test_add_user(self):
with mock_webui_user():
response = self.fast_api_client.post(
self.create_url("/add"),
json={
"name": "John Doe 2",
"email": "john.doe2@openwebui.com",
"password": "password2",
"role": "admin",
},
)
assert response.status_code == 200
data = response.json()
assert data["id"] is not None and len(data["id"]) > 0
assert data["name"] == "John Doe 2"
assert data["email"] == "john.doe2@openwebui.com"
assert data["role"] == "admin"
assert data["profile_image_url"] == "/user.png"
assert data["token"] is not None and len(data["token"]) > 0
assert data["token_type"] == "Bearer"
def test_get_admin_details(self):
self.auths.insert_new_auth(
email="john.doe@openwebui.com",
password="password",
name="John Doe",
profile_image_url="/user.png",
role="admin",
)
with mock_webui_user():
response = self.fast_api_client.get(self.create_url("/admin/details"))
assert response.status_code == 200
assert response.json() == {
"name": "John Doe",
"email": "john.doe@openwebui.com",
}
def test_create_api_key_(self):
user = self.auths.insert_new_auth(
email="john.doe@openwebui.com",
password="password",
name="John Doe",
profile_image_url="/user.png",
role="admin",
)
with mock_webui_user(id=user.id):
response = self.fast_api_client.post(self.create_url("/api_key"))
assert response.status_code == 200
data = response.json()
assert data["api_key"] is not None
assert len(data["api_key"]) > 0
def test_delete_api_key(self):
user = self.auths.insert_new_auth(
email="john.doe@openwebui.com",
password="password",
name="John Doe",
profile_image_url="/user.png",
role="admin",
)
self.users.update_user_api_key_by_id(user.id, "abc")
with mock_webui_user(id=user.id):
response = self.fast_api_client.delete(self.create_url("/api_key"))
assert response.status_code == 200
assert response.json() == True
db_user = self.users.get_user_by_id(user.id)
assert db_user.api_key is None
def test_get_api_key(self):
user = self.auths.insert_new_auth(
email="john.doe@openwebui.com",
password="password",
name="John Doe",
profile_image_url="/user.png",
role="admin",
)
self.users.update_user_api_key_by_id(user.id, "abc")
with mock_webui_user(id=user.id):
response = self.fast_api_client.get(self.create_url("/api_key"))
assert response.status_code == 200
assert response.json() == {"api_key": "abc"}

View File

@ -0,0 +1,238 @@
import uuid
from test.util.abstract_integration_test import AbstractPostgresTest
from test.util.mock_user import mock_webui_user
class TestChats(AbstractPostgresTest):
BASE_PATH = "/api/v1/chats"
def setup_class(cls):
super().setup_class()
def setup_method(self):
super().setup_method()
from apps.webui.models.chats import ChatForm
from apps.webui.models.chats import Chats
self.chats = Chats
self.chats.insert_new_chat(
"2",
ChatForm(
**{
"chat": {
"name": "chat1",
"description": "chat1 description",
"tags": ["tag1", "tag2"],
"history": {"currentId": "1", "messages": []},
}
}
),
)
def test_get_session_user_chat_list(self):
with mock_webui_user(id="2"):
response = self.fast_api_client.get(self.create_url("/"))
assert response.status_code == 200
first_chat = response.json()[0]
assert first_chat["id"] is not None
assert first_chat["title"] == "New Chat"
assert first_chat["created_at"] is not None
assert first_chat["updated_at"] is not None
def test_delete_all_user_chats(self):
with mock_webui_user(id="2"):
response = self.fast_api_client.delete(self.create_url("/"))
assert response.status_code == 200
assert len(self.chats.get_chats()) == 0
def test_get_user_chat_list_by_user_id(self):
with mock_webui_user(id="3"):
response = self.fast_api_client.get(self.create_url("/list/user/2"))
assert response.status_code == 200
first_chat = response.json()[0]
assert first_chat["id"] is not None
assert first_chat["title"] == "New Chat"
assert first_chat["created_at"] is not None
assert first_chat["updated_at"] is not None
def test_create_new_chat(self):
with mock_webui_user(id="2"):
response = self.fast_api_client.post(
self.create_url("/new"),
json={
"chat": {
"name": "chat2",
"description": "chat2 description",
"tags": ["tag1", "tag2"],
}
},
)
assert response.status_code == 200
data = response.json()
assert data["archived"] is False
assert data["chat"] == {
"name": "chat2",
"description": "chat2 description",
"tags": ["tag1", "tag2"],
}
assert data["user_id"] == "2"
assert data["id"] is not None
assert data["share_id"] is None
assert data["title"] == "New Chat"
assert data["updated_at"] is not None
assert data["created_at"] is not None
assert len(self.chats.get_chats()) == 2
def test_get_user_chats(self):
self.test_get_session_user_chat_list()
def test_get_user_archived_chats(self):
self.chats.archive_all_chats_by_user_id("2")
from apps.webui.internal.db import Session
Session.commit()
with mock_webui_user(id="2"):
response = self.fast_api_client.get(self.create_url("/all/archived"))
assert response.status_code == 200
first_chat = response.json()[0]
assert first_chat["id"] is not None
assert first_chat["title"] == "New Chat"
assert first_chat["created_at"] is not None
assert first_chat["updated_at"] is not None
def test_get_all_user_chats_in_db(self):
with mock_webui_user(id="4"):
response = self.fast_api_client.get(self.create_url("/all/db"))
assert response.status_code == 200
assert len(response.json()) == 1
def test_get_archived_session_user_chat_list(self):
self.test_get_user_archived_chats()
def test_archive_all_chats(self):
with mock_webui_user(id="2"):
response = self.fast_api_client.post(self.create_url("/archive/all"))
assert response.status_code == 200
assert len(self.chats.get_archived_chats_by_user_id("2")) == 1
def test_get_shared_chat_by_id(self):
chat_id = self.chats.get_chats()[0].id
self.chats.update_chat_share_id_by_id(chat_id, chat_id)
with mock_webui_user(id="2"):
response = self.fast_api_client.get(self.create_url(f"/share/{chat_id}"))
assert response.status_code == 200
data = response.json()
assert data["id"] == chat_id
assert data["chat"] == {
"name": "chat1",
"description": "chat1 description",
"tags": ["tag1", "tag2"],
"history": {"currentId": "1", "messages": []},
}
assert data["id"] == chat_id
assert data["share_id"] == chat_id
assert data["title"] == "New Chat"
def test_get_chat_by_id(self):
chat_id = self.chats.get_chats()[0].id
with mock_webui_user(id="2"):
response = self.fast_api_client.get(self.create_url(f"/{chat_id}"))
assert response.status_code == 200
data = response.json()
assert data["id"] == chat_id
assert data["chat"] == {
"name": "chat1",
"description": "chat1 description",
"tags": ["tag1", "tag2"],
"history": {"currentId": "1", "messages": []},
}
assert data["share_id"] is None
assert data["title"] == "New Chat"
assert data["user_id"] == "2"
def test_update_chat_by_id(self):
chat_id = self.chats.get_chats()[0].id
with mock_webui_user(id="2"):
response = self.fast_api_client.post(
self.create_url(f"/{chat_id}"),
json={
"chat": {
"name": "chat2",
"description": "chat2 description",
"tags": ["tag2", "tag4"],
"title": "Just another title",
}
},
)
assert response.status_code == 200
data = response.json()
assert data["id"] == chat_id
assert data["chat"] == {
"name": "chat2",
"title": "Just another title",
"description": "chat2 description",
"tags": ["tag2", "tag4"],
"history": {"currentId": "1", "messages": []},
}
assert data["share_id"] is None
assert data["title"] == "Just another title"
assert data["user_id"] == "2"
def test_delete_chat_by_id(self):
chat_id = self.chats.get_chats()[0].id
with mock_webui_user(id="2"):
response = self.fast_api_client.delete(self.create_url(f"/{chat_id}"))
assert response.status_code == 200
assert response.json() is True
def test_clone_chat_by_id(self):
chat_id = self.chats.get_chats()[0].id
with mock_webui_user(id="2"):
response = self.fast_api_client.get(self.create_url(f"/{chat_id}/clone"))
assert response.status_code == 200
data = response.json()
assert data["id"] != chat_id
assert data["chat"] == {
"branchPointMessageId": "1",
"description": "chat1 description",
"history": {"currentId": "1", "messages": []},
"name": "chat1",
"originalChatId": chat_id,
"tags": ["tag1", "tag2"],
"title": "Clone of New Chat",
}
assert data["share_id"] is None
assert data["title"] == "Clone of New Chat"
assert data["user_id"] == "2"
def test_archive_chat_by_id(self):
chat_id = self.chats.get_chats()[0].id
with mock_webui_user(id="2"):
response = self.fast_api_client.get(self.create_url(f"/{chat_id}/archive"))
assert response.status_code == 200
chat = self.chats.get_chat_by_id(chat_id)
assert chat.archived is True
def test_share_chat_by_id(self):
chat_id = self.chats.get_chats()[0].id
with mock_webui_user(id="2"):
response = self.fast_api_client.post(self.create_url(f"/{chat_id}/share"))
assert response.status_code == 200
chat = self.chats.get_chat_by_id(chat_id)
assert chat.share_id is not None
def test_delete_shared_chat_by_id(self):
chat_id = self.chats.get_chats()[0].id
share_id = str(uuid.uuid4())
self.chats.update_chat_share_id_by_id(chat_id, share_id)
with mock_webui_user(id="2"):
response = self.fast_api_client.delete(self.create_url(f"/{chat_id}/share"))
assert response.status_code
chat = self.chats.get_chat_by_id(chat_id)
assert chat.share_id is None

View File

@ -0,0 +1,106 @@
from test.util.abstract_integration_test import AbstractPostgresTest
from test.util.mock_user import mock_webui_user
class TestDocuments(AbstractPostgresTest):
BASE_PATH = "/api/v1/documents"
def setup_class(cls):
super().setup_class()
from apps.webui.models.documents import Documents
cls.documents = Documents
def test_documents(self):
# Empty database
assert len(self.documents.get_docs()) == 0
with mock_webui_user(id="2"):
response = self.fast_api_client.get(self.create_url("/"))
assert response.status_code == 200
assert len(response.json()) == 0
# Create a new document
with mock_webui_user(id="2"):
response = self.fast_api_client.post(
self.create_url("/create"),
json={
"name": "doc_name",
"title": "doc title",
"collection_name": "custom collection",
"filename": "doc_name.pdf",
"content": "",
},
)
assert response.status_code == 200
assert response.json()["name"] == "doc_name"
assert len(self.documents.get_docs()) == 1
# Get the document
with mock_webui_user(id="2"):
response = self.fast_api_client.get(self.create_url("/doc?name=doc_name"))
assert response.status_code == 200
data = response.json()
assert data["collection_name"] == "custom collection"
assert data["name"] == "doc_name"
assert data["title"] == "doc title"
assert data["filename"] == "doc_name.pdf"
assert data["content"] == {}
# Create another document
with mock_webui_user(id="2"):
response = self.fast_api_client.post(
self.create_url("/create"),
json={
"name": "doc_name 2",
"title": "doc title 2",
"collection_name": "custom collection 2",
"filename": "doc_name2.pdf",
"content": "",
},
)
assert response.status_code == 200
assert response.json()["name"] == "doc_name 2"
assert len(self.documents.get_docs()) == 2
# Get all documents
with mock_webui_user(id="2"):
response = self.fast_api_client.get(self.create_url("/"))
assert response.status_code == 200
assert len(response.json()) == 2
# Update the first document
with mock_webui_user(id="2"):
response = self.fast_api_client.post(
self.create_url("/doc/update?name=doc_name"),
json={"name": "doc_name rework", "title": "updated title"},
)
assert response.status_code == 200
data = response.json()
assert data["name"] == "doc_name rework"
assert data["title"] == "updated title"
# Tag the first document
with mock_webui_user(id="2"):
response = self.fast_api_client.post(
self.create_url("/doc/tags"),
json={
"name": "doc_name rework",
"tags": [{"name": "testing-tag"}, {"name": "another-tag"}],
},
)
assert response.status_code == 200
data = response.json()
assert data["name"] == "doc_name rework"
assert data["content"] == {
"tags": [{"name": "testing-tag"}, {"name": "another-tag"}]
}
assert len(self.documents.get_docs()) == 2
# Delete the first document
with mock_webui_user(id="2"):
response = self.fast_api_client.delete(
self.create_url("/doc/delete?name=doc_name rework")
)
assert response.status_code == 200
assert len(self.documents.get_docs()) == 1

View File

@ -0,0 +1,62 @@
from test.util.abstract_integration_test import AbstractPostgresTest
from test.util.mock_user import mock_webui_user
class TestModels(AbstractPostgresTest):
BASE_PATH = "/api/v1/models"
def setup_class(cls):
super().setup_class()
from apps.webui.models.models import Model
cls.models = Model
def test_models(self):
with mock_webui_user(id="2"):
response = self.fast_api_client.get(self.create_url("/"))
assert response.status_code == 200
assert len(response.json()) == 0
with mock_webui_user(id="2"):
response = self.fast_api_client.post(
self.create_url("/add"),
json={
"id": "my-model",
"base_model_id": "base-model-id",
"name": "Hello World",
"meta": {
"profile_image_url": "/static/favicon.png",
"description": "description",
"capabilities": None,
"model_config": {},
},
"params": {},
},
)
assert response.status_code == 200
with mock_webui_user(id="2"):
response = self.fast_api_client.get(self.create_url("/"))
assert response.status_code == 200
assert len(response.json()) == 1
with mock_webui_user(id="2"):
response = self.fast_api_client.get(
self.create_url(query_params={"id": "my-model"})
)
assert response.status_code == 200
data = response.json()[0]
assert data["id"] == "my-model"
assert data["name"] == "Hello World"
with mock_webui_user(id="2"):
response = self.fast_api_client.delete(
self.create_url("/delete?id=my-model")
)
assert response.status_code == 200
with mock_webui_user(id="2"):
response = self.fast_api_client.get(self.create_url("/"))
assert response.status_code == 200
assert len(response.json()) == 0

View File

@ -0,0 +1,92 @@
from test.util.abstract_integration_test import AbstractPostgresTest
from test.util.mock_user import mock_webui_user
class TestPrompts(AbstractPostgresTest):
BASE_PATH = "/api/v1/prompts"
def test_prompts(self):
# Get all prompts
with mock_webui_user(id="2"):
response = self.fast_api_client.get(self.create_url("/"))
assert response.status_code == 200
assert len(response.json()) == 0
# Create a two new prompts
with mock_webui_user(id="2"):
response = self.fast_api_client.post(
self.create_url("/create"),
json={
"command": "/my-command",
"title": "Hello World",
"content": "description",
},
)
assert response.status_code == 200
with mock_webui_user(id="3"):
response = self.fast_api_client.post(
self.create_url("/create"),
json={
"command": "/my-command2",
"title": "Hello World 2",
"content": "description 2",
},
)
assert response.status_code == 200
# Get all prompts
with mock_webui_user(id="2"):
response = self.fast_api_client.get(self.create_url("/"))
assert response.status_code == 200
assert len(response.json()) == 2
# Get prompt by command
with mock_webui_user(id="2"):
response = self.fast_api_client.get(self.create_url("/command/my-command"))
assert response.status_code == 200
data = response.json()
assert data["command"] == "/my-command"
assert data["title"] == "Hello World"
assert data["content"] == "description"
assert data["user_id"] == "2"
# Update prompt
with mock_webui_user(id="2"):
response = self.fast_api_client.post(
self.create_url("/command/my-command2/update"),
json={
"command": "irrelevant for request",
"title": "Hello World Updated",
"content": "description Updated",
},
)
assert response.status_code == 200
data = response.json()
assert data["command"] == "/my-command2"
assert data["title"] == "Hello World Updated"
assert data["content"] == "description Updated"
assert data["user_id"] == "3"
# Get prompt by command
with mock_webui_user(id="2"):
response = self.fast_api_client.get(self.create_url("/command/my-command2"))
assert response.status_code == 200
data = response.json()
assert data["command"] == "/my-command2"
assert data["title"] == "Hello World Updated"
assert data["content"] == "description Updated"
assert data["user_id"] == "3"
# Delete prompt
with mock_webui_user(id="2"):
response = self.fast_api_client.delete(
self.create_url("/command/my-command/delete")
)
assert response.status_code == 200
# Get all prompts
with mock_webui_user(id="2"):
response = self.fast_api_client.get(self.create_url("/"))
assert response.status_code == 200
assert len(response.json()) == 1

View File

@ -0,0 +1,168 @@
from test.util.abstract_integration_test import AbstractPostgresTest
from test.util.mock_user import mock_webui_user
def _get_user_by_id(data, param):
return next((item for item in data if item["id"] == param), None)
def _assert_user(data, id, **kwargs):
user = _get_user_by_id(data, id)
assert user is not None
comparison_data = {
"name": f"user {id}",
"email": f"user{id}@openwebui.com",
"profile_image_url": f"/user{id}.png",
"role": "user",
**kwargs,
}
for key, value in comparison_data.items():
assert user[key] == value
class TestUsers(AbstractPostgresTest):
BASE_PATH = "/api/v1/users"
def setup_class(cls):
super().setup_class()
from apps.webui.models.users import Users
cls.users = Users
def setup_method(self):
super().setup_method()
self.users.insert_new_user(
id="1",
name="user 1",
email="user1@openwebui.com",
profile_image_url="/user1.png",
role="user",
)
self.users.insert_new_user(
id="2",
name="user 2",
email="user2@openwebui.com",
profile_image_url="/user2.png",
role="user",
)
def test_users(self):
# Get all users
with mock_webui_user(id="3"):
response = self.fast_api_client.get(self.create_url(""))
assert response.status_code == 200
assert len(response.json()) == 2
data = response.json()
_assert_user(data, "1")
_assert_user(data, "2")
# update role
with mock_webui_user(id="3"):
response = self.fast_api_client.post(
self.create_url("/update/role"), json={"id": "2", "role": "admin"}
)
assert response.status_code == 200
_assert_user([response.json()], "2", role="admin")
# Get all users
with mock_webui_user(id="3"):
response = self.fast_api_client.get(self.create_url(""))
assert response.status_code == 200
assert len(response.json()) == 2
data = response.json()
_assert_user(data, "1")
_assert_user(data, "2", role="admin")
# Get (empty) user settings
with mock_webui_user(id="2"):
response = self.fast_api_client.get(self.create_url("/user/settings"))
assert response.status_code == 200
assert response.json() is None
# Update user settings
with mock_webui_user(id="2"):
response = self.fast_api_client.post(
self.create_url("/user/settings/update"),
json={
"ui": {"attr1": "value1", "attr2": "value2"},
"model_config": {"attr3": "value3", "attr4": "value4"},
},
)
assert response.status_code == 200
# Get user settings
with mock_webui_user(id="2"):
response = self.fast_api_client.get(self.create_url("/user/settings"))
assert response.status_code == 200
assert response.json() == {
"ui": {"attr1": "value1", "attr2": "value2"},
"model_config": {"attr3": "value3", "attr4": "value4"},
}
# Get (empty) user info
with mock_webui_user(id="1"):
response = self.fast_api_client.get(self.create_url("/user/info"))
assert response.status_code == 200
assert response.json() is None
# Update user info
with mock_webui_user(id="1"):
response = self.fast_api_client.post(
self.create_url("/user/info/update"),
json={"attr1": "value1", "attr2": "value2"},
)
assert response.status_code == 200
# Get user info
with mock_webui_user(id="1"):
response = self.fast_api_client.get(self.create_url("/user/info"))
assert response.status_code == 200
assert response.json() == {"attr1": "value1", "attr2": "value2"}
# Get user by id
with mock_webui_user(id="1"):
response = self.fast_api_client.get(self.create_url("/2"))
assert response.status_code == 200
assert response.json() == {"name": "user 2", "profile_image_url": "/user2.png"}
# Update user by id
with mock_webui_user(id="1"):
response = self.fast_api_client.post(
self.create_url("/2/update"),
json={
"name": "user 2 updated",
"email": "user2-updated@openwebui.com",
"profile_image_url": "/user2-updated.png",
},
)
assert response.status_code == 200
# Get all users
with mock_webui_user(id="3"):
response = self.fast_api_client.get(self.create_url(""))
assert response.status_code == 200
assert len(response.json()) == 2
data = response.json()
_assert_user(data, "1")
_assert_user(
data,
"2",
role="admin",
name="user 2 updated",
email="user2-updated@openwebui.com",
profile_image_url="/user2-updated.png",
)
# Delete user by id
with mock_webui_user(id="1"):
response = self.fast_api_client.delete(self.create_url("/2"))
assert response.status_code == 200
# Get all users
with mock_webui_user(id="3"):
response = self.fast_api_client.get(self.create_url(""))
assert response.status_code == 200
assert len(response.json()) == 1
data = response.json()
_assert_user(data, "1")

View File

@ -0,0 +1,161 @@
import logging
import os
import time
import docker
import pytest
from docker import DockerClient
from pytest_docker.plugin import get_docker_ip
from fastapi.testclient import TestClient
from sqlalchemy import text, create_engine
log = logging.getLogger(__name__)
def get_fast_api_client():
from main import app
with TestClient(app) as c:
return c
class AbstractIntegrationTest:
BASE_PATH = None
def create_url(self, path="", query_params=None):
if self.BASE_PATH is None:
raise Exception("BASE_PATH is not set")
parts = self.BASE_PATH.split("/")
parts = [part.strip() for part in parts if part.strip() != ""]
path_parts = path.split("/")
path_parts = [part.strip() for part in path_parts if part.strip() != ""]
query_parts = ""
if query_params:
query_parts = "&".join(
[f"{key}={value}" for key, value in query_params.items()]
)
query_parts = f"?{query_parts}"
return "/".join(parts + path_parts) + query_parts
@classmethod
def setup_class(cls):
pass
def setup_method(self):
pass
@classmethod
def teardown_class(cls):
pass
def teardown_method(self):
pass
class AbstractPostgresTest(AbstractIntegrationTest):
DOCKER_CONTAINER_NAME = "postgres-test-container-will-get-deleted"
docker_client: DockerClient
@classmethod
def _create_db_url(cls, env_vars_postgres: dict) -> str:
host = get_docker_ip()
user = env_vars_postgres["POSTGRES_USER"]
pw = env_vars_postgres["POSTGRES_PASSWORD"]
port = 8081
db = env_vars_postgres["POSTGRES_DB"]
return f"postgresql://{user}:{pw}@{host}:{port}/{db}"
@classmethod
def setup_class(cls):
super().setup_class()
try:
env_vars_postgres = {
"POSTGRES_USER": "user",
"POSTGRES_PASSWORD": "example",
"POSTGRES_DB": "openwebui",
}
cls.docker_client = docker.from_env()
cls.docker_client.containers.run(
"postgres:16.2",
detach=True,
environment=env_vars_postgres,
name=cls.DOCKER_CONTAINER_NAME,
ports={5432: ("0.0.0.0", 8081)},
command="postgres -c log_statement=all",
)
time.sleep(0.5)
database_url = cls._create_db_url(env_vars_postgres)
os.environ["DATABASE_URL"] = database_url
retries = 10
db = None
while retries > 0:
try:
from config import BACKEND_DIR
db = create_engine(database_url, pool_pre_ping=True)
db = db.connect()
log.info("postgres is ready!")
break
except Exception as e:
log.warning(e)
time.sleep(3)
retries -= 1
if db:
# import must be after setting env!
cls.fast_api_client = get_fast_api_client()
db.close()
else:
raise Exception("Could not connect to Postgres")
except Exception as ex:
log.error(ex)
cls.teardown_class()
pytest.fail(f"Could not setup test environment: {ex}")
def _check_db_connection(self):
from apps.webui.internal.db import Session
retries = 10
while retries > 0:
try:
Session.execute(text("SELECT 1"))
Session.commit()
break
except Exception as e:
Session.rollback()
log.warning(e)
time.sleep(3)
retries -= 1
def setup_method(self):
super().setup_method()
self._check_db_connection()
@classmethod
def teardown_class(cls) -> None:
super().teardown_class()
cls.docker_client.containers.get(cls.DOCKER_CONTAINER_NAME).remove(force=True)
def teardown_method(self):
from apps.webui.internal.db import Session
# rollback everything not yet committed
Session.commit()
# truncate all tables
tables = [
"auth",
"chat",
"chatidtag",
"document",
"memory",
"model",
"prompt",
"tag",
'"user"',
]
for table in tables:
Session.execute(text(f"TRUNCATE TABLE {table}"))
Session.commit()

View File

@ -0,0 +1,45 @@
from contextlib import contextmanager
from fastapi import FastAPI
@contextmanager
def mock_webui_user(**kwargs):
from apps.webui.main import app
with mock_user(app, **kwargs):
yield
@contextmanager
def mock_user(app: FastAPI, **kwargs):
from utils.utils import (
get_current_user,
get_verified_user,
get_admin_user,
get_current_user_by_api_key,
)
from apps.webui.models.users import User
def create_user():
user_parameters = {
"id": "1",
"name": "John Doe",
"email": "john.doe@openwebui.com",
"role": "user",
"profile_image_url": "/user.png",
"last_active_at": 1627351200,
"updated_at": 1627351200,
"created_at": 162735120,
**kwargs,
}
return User(**user_parameters)
app.dependency_overrides = {
get_current_user: create_user,
get_verified_user: create_user,
get_admin_user: create_user,
get_current_user_by_api_key: create_user,
}
yield
app.dependency_overrides = {}

View File

@ -8,9 +8,17 @@ import uuid
import time
def get_last_user_message(messages: List[dict]) -> str:
def get_last_user_message_item(messages: List[dict]) -> str:
for message in reversed(messages):
if message["role"] == "user":
return message
return None
def get_last_user_message(messages: List[dict]) -> str:
message = get_last_user_message_item(messages)
if message is not None:
if isinstance(message["content"], list):
for item in message["content"]:
if item["type"] == "text":

View File

@ -59,7 +59,10 @@ def get_tools_specs(tools) -> List[dict]:
for param_name, param_annotation in get_type_hints(
function
).items()
if param_name != "return" and param_name != "__user__"
if param_name != "return"
and not (
param_name.startswith("__") and param_name.endswith("__")
)
},
"required": [
name

View File

@ -1,5 +1,6 @@
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi import HTTPException, status, Depends, Request
from sqlalchemy.orm import Session
from apps.webui.models.users import Users

4
package-lock.json generated
View File

@ -1,12 +1,12 @@
{
"name": "open-webui",
"version": "0.3.7",
"version": "0.3.8",
"lockfileVersion": 3,
"requires": true,
"packages": {
"": {
"name": "open-webui",
"version": "0.3.7",
"version": "0.3.8",
"dependencies": {
"@codemirror/lang-javascript": "^6.2.2",
"@codemirror/lang-python": "^6.1.6",

View File

@ -1,6 +1,6 @@
{
"name": "open-webui",
"version": "0.3.7",
"version": "0.3.8",
"private": true,
"scripts": {
"dev": "npm run pyodide:fetch && vite dev --host",

View File

@ -7,6 +7,7 @@
# all-features: false
# with-sources: false
# generate-hashes: false
# universal: false
-e file:.
aiohttp==3.9.5
@ -31,7 +32,7 @@ asgiref==3.8.1
# via opentelemetry-instrumentation-asgi
attrs==23.2.0
# via aiohttp
authlib==1.3.0
authlib==1.3.1
# via open-webui
av==11.0.0
# via faster-whisper
@ -398,7 +399,6 @@ pandas==2.2.2
# via open-webui
passlib==1.7.4
# via open-webui
# via passlib
pathspec==0.12.1
# via black
pcodedmp==1.2.6
@ -457,7 +457,6 @@ pygments==2.18.0
# via rich
pyjwt==2.8.0
# via open-webui
# via pyjwt
pymysql==1.1.0
# via open-webui
pypandoc==1.13
@ -559,6 +558,9 @@ scipy==1.13.0
# via sentence-transformers
sentence-transformers==2.7.0
# via open-webui
setuptools==69.5.1
# via ctranslate2
# via opentelemetry-instrumentation
shapely==2.0.4
# via rapidocr-onnxruntime
shellingham==1.5.4
@ -653,7 +655,6 @@ uvicorn==0.22.0
# via chromadb
# via fastapi
# via open-webui
# via uvicorn
uvloop==0.19.0
# via uvicorn
validators==0.28.1
@ -681,6 +682,3 @@ youtube-transcript-api==0.6.2
# via open-webui
zipp==3.18.1
# via importlib-metadata
setuptools==69.5.1
# via ctranslate2
# via opentelemetry-instrumentation

View File

@ -7,6 +7,7 @@
# all-features: false
# with-sources: false
# generate-hashes: false
# universal: false
-e file:.
aiohttp==3.9.5
@ -31,7 +32,7 @@ asgiref==3.8.1
# via opentelemetry-instrumentation-asgi
attrs==23.2.0
# via aiohttp
authlib==1.3.0
authlib==1.3.1
# via open-webui
av==11.0.0
# via faster-whisper
@ -398,7 +399,6 @@ pandas==2.2.2
# via open-webui
passlib==1.7.4
# via open-webui
# via passlib
pathspec==0.12.1
# via black
pcodedmp==1.2.6
@ -457,7 +457,6 @@ pygments==2.18.0
# via rich
pyjwt==2.8.0
# via open-webui
# via pyjwt
pymysql==1.1.0
# via open-webui
pypandoc==1.13
@ -559,6 +558,9 @@ scipy==1.13.0
# via sentence-transformers
sentence-transformers==2.7.0
# via open-webui
setuptools==69.5.1
# via ctranslate2
# via opentelemetry-instrumentation
shapely==2.0.4
# via rapidocr-onnxruntime
shellingham==1.5.4
@ -653,7 +655,6 @@ uvicorn==0.22.0
# via chromadb
# via fastapi
# via open-webui
# via uvicorn
uvloop==0.19.0
# via uvicorn
validators==0.28.1
@ -681,6 +682,3 @@ youtube-transcript-api==0.6.2
# via open-webui
zipp==3.18.1
# via importlib-metadata
setuptools==69.5.1
# via ctranslate2
# via opentelemetry-instrumentation

View File

@ -1,6 +1,12 @@
@font-face {
font-family: 'Arimo';
src: url('/assets/fonts/Arimo-Variable.ttf');
font-family: 'Inter';
src: url('/assets/fonts/Inter-Variable.ttf');
font-display: swap;
}
@font-face {
font-family: 'Archivo';
src: url('/assets/fonts/Archivo-Variable.ttf');
font-display: swap;
}
@ -32,6 +38,10 @@ math {
@apply underline;
}
.font-primary {
font-family: 'Archivo', sans-serif;
}
iframe {
@apply rounded-lg;
}
@ -140,3 +150,7 @@ input[type='number'] {
.cm-editor.cm-focused {
outline: none;
}
.tippy-box[data-theme~='dark'] {
@apply rounded-lg bg-gray-950 text-xs border border-gray-900 shadow-xl;
}

View File

@ -23,6 +23,8 @@
// On page load or when changing themes, best to add inline in `head` to avoid FOUC
(() => {
if (localStorage?.theme && localStorage?.theme.includes('oled')) {
document.documentElement.style.setProperty('--color-gray-800', '#101010');
document.documentElement.style.setProperty('--color-gray-850', '#050505');
document.documentElement.style.setProperty('--color-gray-900', '#000000');
document.documentElement.style.setProperty('--color-gray-950', '#000000');
document.documentElement.classList.add('dark');
@ -80,13 +82,13 @@
id="logo"
style="
position: absolute;
width: 6rem;
width: auto;
height: 6rem;
top: 41%;
top: 44%;
left: 50%;
margin-left: -3rem;
"
src="/logo.svg"
src="/static/splash.png"
/>
<div
@ -105,8 +107,8 @@
>
<img
id="logo-her"
style="width: 13rem; height: 13rem"
src="/logo.svg"
style="width: auto; height: 13rem"
src="/static/splash.png"
class="animate-pulse-fast"
/>

View File

@ -32,6 +32,11 @@ type ChunkConfigForm = {
chunk_overlap: number;
};
type ContentExtractConfigForm = {
engine: string;
tika_server_url: string | null;
};
type YoutubeConfigForm = {
language: string[];
translation?: string | null;
@ -40,6 +45,7 @@ type YoutubeConfigForm = {
type RAGConfigForm = {
pdf_extract_images?: boolean;
chunk?: ChunkConfigForm;
content_extraction?: ContentExtractConfigForm;
web_loader_ssl_verification?: boolean;
youtube?: YoutubeConfigForm;
};

View File

@ -24,7 +24,7 @@
<Modal bind:show>
<div class="px-5 pt-4 dark:text-gray-300 text-gray-700">
<div class="flex justify-between items-start">
<div class="text-xl font-bold">
<div class="text-xl font-semibold">
{$i18n.t('Whats New in')}
{$WEBUI_NAME}
<Confetti x={[-1, -0.25]} y={[0, 0.5]} />
@ -63,7 +63,7 @@
{#if changelog}
{#each Object.keys(changelog) as version}
<div class=" mb-3 pr-2">
<div class="font-bold text-xl mb-1 dark:text-white">
<div class="font-semibold text-xl mb-1 dark:text-white">
v{version} - {changelog[version].date}
</div>
@ -72,7 +72,7 @@
{#each Object.keys(changelog[version]).filter((section) => section !== 'date') as section}
<div class="">
<div
class="font-bold uppercase text-xs {section === 'added'
class="font-semibold uppercase text-xs {section === 'added'
? 'text-white bg-blue-600'
: section === 'fixed'
? 'text-white bg-green-600'

View File

@ -1,5 +1,5 @@
<script>
import { getContext, tick } from 'svelte';
import { getContext, tick, onMount } from 'svelte';
import { toast } from 'svelte-sonner';
import Database from './Settings/Database.svelte';
@ -21,17 +21,31 @@
const i18n = getContext('i18n');
let selectedTab = 'general';
onMount(() => {
const containerElement = document.getElementById('admin-settings-tabs-container');
if (containerElement) {
containerElement.addEventListener('wheel', function (event) {
if (event.deltaY !== 0) {
// Adjust horizontal scroll position based on vertical scroll
containerElement.scrollLeft += event.deltaY;
}
});
}
});
</script>
<div class="flex flex-col lg:flex-row w-full h-full py-2 lg:space-x-4">
<div
id="admin-settings-tabs-container"
class="tabs flex flex-row overflow-x-auto space-x-1 max-w-full lg:space-x-0 lg:space-y-1 lg:flex-col lg:flex-none lg:w-44 dark:text-gray-200 text-xs text-left scrollbar-none"
>
<button
class="px-2.5 py-2.5 min-w-fit rounded-lg flex-1 lg:flex-none flex text-right transition {selectedTab ===
'general'
? 'bg-gray-200 dark:bg-gray-800'
: ' hover:bg-gray-300 dark:hover:bg-gray-850'}"
? 'bg-gray-100 dark:bg-gray-800'
: ' hover:bg-gray-50 dark:hover:bg-gray-850'}"
on:click={() => {
selectedTab = 'general';
}}
@ -56,8 +70,8 @@
<button
class="px-2.5 py-2.5 min-w-fit rounded-lg flex-1 md:flex-none flex text-right transition {selectedTab ===
'users'
? 'bg-gray-200 dark:bg-gray-800'
: ' hover:bg-gray-300 dark:hover:bg-gray-850'}"
? 'bg-gray-100 dark:bg-gray-800'
: ' hover:bg-gray-50 dark:hover:bg-gray-850'}"
on:click={() => {
selectedTab = 'users';
}}
@ -80,8 +94,8 @@
<button
class="px-2.5 py-2.5 min-w-fit rounded-lg flex-1 md:flex-none flex text-right transition {selectedTab ===
'connections'
? 'bg-gray-200 dark:bg-gray-800'
: ' hover:bg-gray-300 dark:hover:bg-gray-850'}"
? 'bg-gray-100 dark:bg-gray-800'
: ' hover:bg-gray-50 dark:hover:bg-gray-850'}"
on:click={() => {
selectedTab = 'connections';
}}
@ -104,8 +118,8 @@
<button
class="px-2.5 py-2.5 min-w-fit rounded-lg flex-1 md:flex-none flex text-right transition {selectedTab ===
'models'
? 'bg-gray-200 dark:bg-gray-800'
: ' hover:bg-gray-300 dark:hover:bg-gray-850'}"
? 'bg-gray-100 dark:bg-gray-800'
: ' hover:bg-gray-50 dark:hover:bg-gray-850'}"
on:click={() => {
selectedTab = 'models';
}}
@ -130,8 +144,8 @@
<button
class="px-2.5 py-2.5 min-w-fit rounded-lg flex-1 md:flex-none flex text-right transition {selectedTab ===
'documents'
? 'bg-gray-200 dark:bg-gray-800'
: ' hover:bg-gray-300 dark:hover:bg-gray-850'}"
? 'bg-gray-100 dark:bg-gray-800'
: ' hover:bg-gray-50 dark:hover:bg-gray-850'}"
on:click={() => {
selectedTab = 'documents';
}}
@ -160,8 +174,8 @@
<button
class="px-2.5 py-2.5 min-w-fit rounded-lg flex-1 md:flex-none flex text-right transition {selectedTab ===
'web'
? 'bg-gray-200 dark:bg-gray-800'
: ' hover:bg-gray-300 dark:hover:bg-gray-850'}"
? 'bg-gray-100 dark:bg-gray-800'
: ' hover:bg-gray-50 dark:hover:bg-gray-850'}"
on:click={() => {
selectedTab = 'web';
}}
@ -184,8 +198,8 @@
<button
class="px-2.5 py-2.5 min-w-fit rounded-lg flex-1 md:flex-none flex text-right transition {selectedTab ===
'interface'
? 'bg-gray-200 dark:bg-gray-800'
: ' hover:bg-gray-300 dark:hover:bg-gray-850'}"
? 'bg-gray-100 dark:bg-gray-800'
: ' hover:bg-gray-50 dark:hover:bg-gray-850'}"
on:click={() => {
selectedTab = 'interface';
}}
@ -210,8 +224,8 @@
<button
class="px-2.5 py-2.5 min-w-fit rounded-lg flex-1 md:flex-none flex text-right transition {selectedTab ===
'audio'
? 'bg-gray-200 dark:bg-gray-800'
: ' hover:bg-gray-300 dark:hover:bg-gray-850'}"
? 'bg-gray-100 dark:bg-gray-800'
: ' hover:bg-gray-50 dark:hover:bg-gray-850'}"
on:click={() => {
selectedTab = 'audio';
}}
@ -237,8 +251,8 @@
<button
class="px-2.5 py-2.5 min-w-fit rounded-lg flex-1 md:flex-none flex text-right transition {selectedTab ===
'images'
? 'bg-gray-200 dark:bg-gray-800'
: ' hover:bg-gray-300 dark:hover:bg-gray-850'}"
? 'bg-gray-100 dark:bg-gray-800'
: ' hover:bg-gray-50 dark:hover:bg-gray-850'}"
on:click={() => {
selectedTab = 'images';
}}
@ -263,8 +277,8 @@
<button
class="px-2.5 py-2.5 min-w-fit rounded-lg flex-1 md:flex-none flex text-right transition {selectedTab ===
'pipelines'
? 'bg-gray-200 dark:bg-gray-800'
: ' hover:bg-gray-300 dark:hover:bg-gray-850'}"
? 'bg-gray-100 dark:bg-gray-800'
: ' hover:bg-gray-50 dark:hover:bg-gray-850'}"
on:click={() => {
selectedTab = 'pipelines';
}}
@ -293,8 +307,8 @@
<button
class="px-2.5 py-2.5 min-w-fit rounded-lg flex-1 md:flex-none flex text-right transition {selectedTab ===
'db'
? 'bg-gray-200 dark:bg-gray-800'
: ' hover:bg-gray-300 dark:hover:bg-gray-850'}"
? 'bg-gray-100 dark:bg-gray-800'
: ' hover:bg-gray-50 dark:hover:bg-gray-850'}"
on:click={() => {
selectedTab = 'db';
}}

View File

@ -138,7 +138,7 @@
<div>
<div class="mt-1 flex gap-2 mb-1">
<input
class="flex-1 w-full rounded-l-lg py-2 pl-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
class="flex-1 w-full rounded-l-lg py-2 pl-4 text-sm bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-none"
placeholder={$i18n.t('API Base URL')}
bind:value={STT_OPENAI_API_BASE_URL}
required
@ -156,7 +156,7 @@
<div class="flex-1">
<input
list="model-list"
class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
class="w-full rounded-lg py-2 px-4 text-sm bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-none"
bind:value={STT_MODEL}
placeholder="Select a model"
/>
@ -203,7 +203,7 @@
<div>
<div class="mt-1 flex gap-2 mb-1">
<input
class="flex-1 w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
class="flex-1 w-full rounded-lg py-2 px-4 text-sm bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-none"
placeholder={$i18n.t('API Base URL')}
bind:value={TTS_OPENAI_API_BASE_URL}
required
@ -222,7 +222,7 @@
<div class="flex w-full">
<div class="flex-1">
<select
class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
class="w-full rounded-lg py-2 px-4 text-sm bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-none"
bind:value={TTS_VOICE}
>
<option value="" selected={TTS_VOICE !== ''}>{$i18n.t('Default')}</option>
@ -245,7 +245,7 @@
<div class="flex-1">
<input
list="voice-list"
class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
class="w-full rounded-lg py-2 px-4 text-sm bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-none"
bind:value={TTS_VOICE}
placeholder="Select a voice"
/>
@ -264,7 +264,7 @@
<div class="flex-1">
<input
list="model-list"
class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
class="w-full rounded-lg py-2 px-4 text-sm bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-none"
bind:value={TTS_MODEL}
placeholder="Select a model"
/>

View File

@ -200,7 +200,7 @@
<input
class="w-full rounded-lg py-2 px-4 {pipelineUrls[url]
? 'pr-8'
: ''} text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
: ''} text-sm bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-none"
placeholder={$i18n.t('API Base URL')}
bind:value={url}
autocomplete="off"
@ -338,7 +338,7 @@
{#each OLLAMA_BASE_URLS as url, idx}
<div class="flex gap-1.5">
<input
class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
class="w-full rounded-lg py-2 px-4 text-sm bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-none"
placeholder={$i18n.t('Enter URL (e.g. http://localhost:11434)')}
bind:value={url}
/>

View File

@ -37,6 +37,10 @@
let embeddingModel = '';
let rerankingModel = '';
let contentExtractionEngine = 'default';
let tikaServerUrl = '';
let showTikaServerUrl = false;
let chunkSize = 0;
let chunkOverlap = 0;
let pdfExtractImages = true;
@ -163,11 +167,20 @@
rerankingModelUpdateHandler();
}
if (contentExtractionEngine === 'tika' && tikaServerUrl === '') {
toast.error($i18n.t('Tika Server URL required.'));
return;
}
const res = await updateRAGConfig(localStorage.token, {
pdf_extract_images: pdfExtractImages,
chunk: {
chunk_overlap: chunkOverlap,
chunk_size: chunkSize
},
content_extraction: {
engine: contentExtractionEngine,
tika_server_url: tikaServerUrl
}
});
@ -213,6 +226,10 @@
chunkSize = res.chunk.chunk_size;
chunkOverlap = res.chunk.chunk_overlap;
contentExtractionEngine = res.content_extraction.engine;
tikaServerUrl = res.content_extraction.tika_server_url;
showTikaServerUrl = contentExtractionEngine === 'tika';
}
});
</script>
@ -262,7 +279,7 @@
</div>
<button
class=" self-center text-xs p-1 px-3 bg-gray-100 dark:bg-gray-800 dark:hover:bg-gray-700 rounded-lg flex flex-row space-x-1 items-center {scanDirLoading
class=" self-center text-xs p-1 px-3 bg-gray-50 dark:bg-gray-800 dark:hover:bg-gray-700 rounded-lg flex flex-row space-x-1 items-center {scanDirLoading
? ' cursor-not-allowed'
: ''}"
on:click={() => {
@ -335,7 +352,7 @@
{#if embeddingEngine === 'openai'}
<div class="my-0.5 flex gap-2">
<input
class="flex-1 w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
class="flex-1 w-full rounded-lg py-2 px-4 text-sm bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-none"
placeholder={$i18n.t('API Base URL')}
bind:value={OpenAIUrl}
required
@ -388,7 +405,7 @@
</div>
</div>
<hr class=" dark:border-gray-850 my-1" />
<hr class="dark:border-gray-850" />
<div class="space-y-2" />
<div>
@ -398,7 +415,7 @@
<div class="flex w-full">
<div class="flex-1 mr-2">
<select
class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
class="w-full rounded-lg py-2 px-4 text-sm bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-none"
bind:value={embeddingModel}
placeholder={$i18n.t('Select a model')}
required
@ -407,7 +424,7 @@
<option value="" disabled selected>{$i18n.t('Select a model')}</option>
{/if}
{#each $models.filter((m) => m.id && m.ollama && !(m?.preset ?? false)) as model}
<option value={model.id} class="bg-gray-100 dark:bg-gray-700">{model.name}</option>
<option value={model.id} class="bg-gray-50 dark:bg-gray-700">{model.name}</option>
{/each}
</select>
</div>
@ -416,7 +433,7 @@
<div class="flex w-full">
<div class="flex-1 mr-2">
<input
class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
class="w-full rounded-lg py-2 px-4 text-sm bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-none"
placeholder={$i18n.t('Set embedding model (e.g. {{model}})', {
model: embeddingModel.slice(-40)
})}
@ -426,7 +443,7 @@
{#if embeddingEngine === ''}
<button
class="px-2.5 bg-gray-100 hover:bg-gray-200 text-gray-800 dark:bg-gray-850 dark:hover:bg-gray-800 dark:text-gray-100 rounded-lg transition"
class="px-2.5 bg-gray-50 hover:bg-gray-200 text-gray-800 dark:bg-gray-850 dark:hover:bg-gray-800 dark:text-gray-100 rounded-lg transition"
on:click={() => {
embeddingModelUpdateHandler();
}}
@ -495,7 +512,7 @@
<div class="flex w-full">
<div class="flex-1 mr-2">
<input
class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
class="w-full rounded-lg py-2 px-4 text-sm bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-none"
placeholder={$i18n.t('Set reranking model (e.g. {{model}})', {
model: 'BAAI/bge-reranker-v2-m3'
})}
@ -503,7 +520,7 @@
/>
</div>
<button
class="px-2.5 bg-gray-100 hover:bg-gray-200 text-gray-800 dark:bg-gray-850 dark:hover:bg-gray-800 dark:text-gray-100 rounded-lg transition"
class="px-2.5 bg-gray-50 hover:bg-gray-200 text-gray-800 dark:bg-gray-850 dark:hover:bg-gray-800 dark:text-gray-100 rounded-lg transition"
on:click={() => {
rerankingModelUpdateHandler();
}}
@ -562,6 +579,39 @@
<hr class=" dark:border-gray-850" />
<div class="">
<div class="text-sm font-medium">{$i18n.t('Content Extraction')}</div>
<div class="flex w-full justify-between mt-2">
<div class="self-center text-xs font-medium">{$i18n.t('Engine')}</div>
<div class="flex items-center relative">
<select
class="dark:bg-gray-900 w-fit pr-8 rounded px-2 p-1 text-xs bg-transparent outline-none text-right"
bind:value={contentExtractionEngine}
on:change={(e) => {
showTikaServerUrl = e.target.value === 'tika';
}}
>
<option value="">{$i18n.t('Default')} </option>
<option value="tika">{$i18n.t('Tika')}</option>
</select>
</div>
</div>
{#if showTikaServerUrl}
<div class="flex w-full mt-2">
<div class="flex-1 mr-2">
<input
class="w-full rounded-lg py-2 px-4 text-sm bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-none"
placeholder={$i18n.t('Enter Tika Server URL')}
bind:value={tikaServerUrl}
/>
</div>
</div>
{/if}
</div>
<hr class=" dark:border-gray-850" />
<div class=" ">
<div class=" text-sm font-medium">{$i18n.t('Query Params')}</div>
@ -571,7 +621,7 @@
<div class="self-center p-3">
<input
class=" w-full rounded-lg py-1.5 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
class=" w-full rounded-lg py-1.5 px-4 text-sm bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-none"
type="number"
placeholder={$i18n.t('Enter Top K')}
bind:value={querySettings.k}
@ -589,7 +639,7 @@
<div class="self-center p-3">
<input
class=" w-full rounded-lg py-1.5 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
class=" w-full rounded-lg py-1.5 px-4 text-sm bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-none"
type="number"
step="0.01"
placeholder={$i18n.t('Enter Score')}
@ -617,7 +667,7 @@
<div class=" mb-2.5 text-sm font-medium">{$i18n.t('RAG Template')}</div>
<textarea
bind:value={querySettings.template}
class="w-full rounded-lg px-4 py-3 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none resize-none"
class="w-full rounded-lg px-4 py-3 text-sm bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-none resize-none"
rows="4"
/>
</div>
@ -633,7 +683,7 @@
<div class="self-center text-xs font-medium min-w-fit mb-1">{$i18n.t('Chunk Size')}</div>
<div class="self-center">
<input
class=" w-full rounded-lg py-1.5 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
class=" w-full rounded-lg py-1.5 px-4 text-sm bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-none"
type="number"
placeholder={$i18n.t('Enter Chunk Size')}
bind:value={chunkSize}
@ -650,7 +700,7 @@
<div class="self-center">
<input
class="w-full rounded-lg py-1.5 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
class="w-full rounded-lg py-1.5 px-4 text-sm bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-none"
type="number"
placeholder={$i18n.t('Enter Chunk Overlap')}
bind:value={chunkOverlap}

View File

@ -107,7 +107,7 @@
<div class="flex mt-2 space-x-2">
<input
class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
class="w-full rounded-lg py-2 px-4 text-sm bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-none"
type="text"
placeholder={`e.g.) "30m","1h", "10d". `}
bind:value={adminConfig.JWT_EXPIRES_IN}
@ -131,7 +131,7 @@
<div class="flex mt-2 space-x-2">
<input
class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
class="w-full rounded-lg py-2 px-4 text-sm bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-none"
type="text"
placeholder={`https://example.com/webhook`}
bind:value={webhookUrl}

View File

@ -240,13 +240,13 @@
<div class="flex w-full">
<div class="flex-1 mr-2">
<input
class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
class="w-full rounded-lg py-2 px-4 text-sm bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-none"
placeholder={$i18n.t('Enter URL (e.g. http://127.0.0.1:7860/)')}
bind:value={AUTOMATIC1111_BASE_URL}
/>
</div>
<button
class="px-2.5 bg-gray-100 hover:bg-gray-200 text-gray-800 dark:bg-gray-850 dark:hover:bg-gray-800 dark:text-gray-100 rounded-lg transition"
class="px-2.5 bg-gray-50 hover:bg-gray-100 text-gray-800 dark:bg-gray-850 dark:hover:bg-gray-800 dark:text-gray-100 rounded-lg transition"
type="button"
on:click={() => {
updateUrlHandler();
@ -299,13 +299,13 @@
<div class="flex w-full">
<div class="flex-1 mr-2">
<input
class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
class="w-full rounded-lg py-2 px-4 text-sm bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-none"
placeholder={$i18n.t('Enter URL (e.g. http://127.0.0.1:7860/)')}
bind:value={COMFYUI_BASE_URL}
/>
</div>
<button
class="px-2.5 bg-gray-100 hover:bg-gray-200 text-gray-800 dark:bg-gray-850 dark:hover:bg-gray-800 dark:text-gray-100 rounded-lg transition"
class="px-2.5 bg-gray-50 hover:bg-gray-100 text-gray-800 dark:bg-gray-850 dark:hover:bg-gray-800 dark:text-gray-100 rounded-lg transition"
type="button"
on:click={() => {
updateUrlHandler();
@ -331,7 +331,7 @@
<div class="flex gap-2 mb-1">
<input
class="flex-1 w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
class="flex-1 w-full rounded-lg py-2 px-4 text-sm bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-none"
placeholder={$i18n.t('API Base URL')}
bind:value={OPENAI_API_BASE_URL}
required
@ -354,7 +354,7 @@
<div class="flex-1">
<input
list="model-list"
class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
class="w-full rounded-lg py-2 px-4 text-sm bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-none"
bind:value={selectedModel}
placeholder="Select a model"
/>
@ -368,7 +368,7 @@
</div>
{:else}
<select
class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
class="w-full rounded-lg py-2 px-4 text-sm bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-none"
bind:value={selectedModel}
placeholder={$i18n.t('Select a model')}
required
@ -391,7 +391,7 @@
<div class="flex w-full">
<div class="flex-1 mr-2">
<input
class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
class="w-full rounded-lg py-2 px-4 text-sm bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-none"
placeholder={$i18n.t('Enter Image Size (e.g. 512x512)')}
bind:value={imageSize}
/>
@ -404,7 +404,7 @@
<div class="flex w-full">
<div class="flex-1 mr-2">
<input
class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
class="w-full rounded-lg py-2 px-4 text-sm bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-none"
placeholder={$i18n.t('Enter Number of Steps (e.g. 50)')}
bind:value={steps}
/>

View File

@ -88,7 +88,7 @@
<div class="flex-1">
<div class=" text-xs mb-1">{$i18n.t('Local Models')}</div>
<select
class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
class="w-full rounded-lg py-2 px-4 text-sm bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-none"
bind:value={taskConfig.TASK_MODEL}
placeholder={$i18n.t('Select a model')}
>
@ -104,7 +104,7 @@
<div class="flex-1">
<div class=" text-xs mb-1">{$i18n.t('External Models')}</div>
<select
class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
class="w-full rounded-lg py-2 px-4 text-sm bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-none"
bind:value={taskConfig.TASK_MODEL_EXTERNAL}
placeholder={$i18n.t('Select a model')}
>
@ -122,7 +122,7 @@
<div class=" mb-2.5 text-sm font-medium">{$i18n.t('Title Generation Prompt')}</div>
<textarea
bind:value={taskConfig.TITLE_GENERATION_PROMPT_TEMPLATE}
class="w-full rounded-lg py-3 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none resize-none"
class="w-full rounded-lg py-3 px-4 text-sm bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-none resize-none"
rows="6"
/>
</div>
@ -131,7 +131,7 @@
<div class=" mb-2.5 text-sm font-medium">{$i18n.t('Search Query Generation Prompt')}</div>
<textarea
bind:value={taskConfig.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE}
class="w-full rounded-lg py-3 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none resize-none"
class="w-full rounded-lg py-3 px-4 text-sm bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-none resize-none"
rows="6"
/>
</div>
@ -142,7 +142,7 @@
</div>
<input
bind:value={taskConfig.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD}
class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none resize-none"
class="w-full rounded-lg py-2 px-4 text-sm bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-none resize-none"
type="number"
/>
</div>
@ -273,24 +273,26 @@
</div>
<div class="grid lg:grid-cols-2 flex-col gap-1.5">
{#each promptSuggestions as prompt, promptIdx}
<div class=" flex dark:bg-gray-850 rounded-xl py-1.5">
<div
class=" flex border border-gray-100 dark:border-none dark:bg-gray-850 rounded-xl py-1.5"
>
<div class="flex flex-col flex-1 pl-1">
<div class="flex border-b dark:border-gray-800 w-full">
<div class="flex border-b border-gray-100 dark:border-gray-800 w-full">
<input
class="px-3 py-1.5 text-xs w-full bg-transparent outline-none border-r dark:border-gray-800"
class="px-3 py-1.5 text-xs w-full bg-transparent outline-none border-r border-gray-100 dark:border-gray-800"
placeholder={$i18n.t('Title (e.g. Tell me a fun fact)')}
bind:value={prompt.title[0]}
/>
<input
class="px-3 py-1.5 text-xs w-full bg-transparent outline-none border-r dark:border-gray-800"
class="px-3 py-1.5 text-xs w-full bg-transparent outline-none border-r border-gray-100 dark:border-gray-800"
placeholder={$i18n.t('Subtitle (e.g. about the Roman Empire)')}
bind:value={prompt.title[1]}
/>
</div>
<input
class="px-3 py-1.5 text-xs w-full bg-transparent outline-none border-r dark:border-gray-800"
class="px-3 py-1.5 text-xs w-full bg-transparent outline-none border-r border-gray-100 dark:border-gray-800"
placeholder={$i18n.t('Prompt (e.g. Tell me a fun fact about the Roman Empire)')}
bind:value={prompt.content}
/>

View File

@ -158,12 +158,14 @@
return;
}
const [res, controller] = await pullModel(localStorage.token, sanitizedModelTag, '0').catch(
(error) => {
const [res, controller] = await pullModel(
localStorage.token,
sanitizedModelTag,
selectedOllamaUrlIdx
).catch((error) => {
toast.error(error);
return null;
}
);
});
if (res) {
const reader = res.body
@ -570,12 +572,12 @@
<div class="flex gap-2">
<div class="flex-1 pb-1">
<select
class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
class="w-full rounded-lg py-2 px-4 text-sm bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-none"
bind:value={selectedOllamaUrlIdx}
placeholder={$i18n.t('Select an Ollama instance')}
>
{#each OLLAMA_URLS as url, idx}
<option value={idx} class="bg-gray-100 dark:bg-gray-700">{url}</option>
<option value={idx} class="bg-gray-50 dark:bg-gray-700">{url}</option>
{/each}
</select>
</div>
@ -584,7 +586,7 @@
<div class="flex w-full justify-end">
<Tooltip content="Update All Models" placement="top">
<button
class="p-2.5 flex gap-2 items-center bg-gray-100 hover:bg-gray-200 text-gray-800 dark:bg-gray-850 dark:hover:bg-gray-800 dark:text-gray-100 rounded-lg transition"
class="p-2.5 flex gap-2 items-center bg-gray-50 hover:bg-gray-200 text-gray-800 dark:bg-gray-850 dark:hover:bg-gray-800 dark:text-gray-100 rounded-lg transition"
on:click={() => {
updateModelsHandler();
}}
@ -619,7 +621,7 @@
<div class="flex w-full">
<div class="flex-1 mr-2">
<input
class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
class="w-full rounded-lg py-2 px-4 text-sm bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-none"
placeholder={$i18n.t('Enter model tag (e.g. {{modelTag}})', {
modelTag: 'mistral:7b'
})}
@ -627,7 +629,7 @@
/>
</div>
<button
class="px-2.5 bg-gray-100 hover:bg-gray-200 text-gray-800 dark:bg-gray-850 dark:hover:bg-gray-800 dark:text-gray-100 rounded-lg transition"
class="px-2.5 bg-gray-50 hover:bg-gray-200 text-gray-800 dark:bg-gray-850 dark:hover:bg-gray-800 dark:text-gray-100 rounded-lg transition"
on:click={() => {
pullModelHandler();
}}
@ -753,7 +755,7 @@
<div class="flex w-full">
<div class="flex-1 mr-2">
<select
class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
class="w-full rounded-lg py-2 px-4 text-sm bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-none"
bind:value={deleteModelTag}
placeholder={$i18n.t('Select a model')}
>
@ -761,7 +763,7 @@
<option value="" disabled selected>{$i18n.t('Select a model')}</option>
{/if}
{#each $models.filter((m) => !(m?.preset ?? false) && m.owned_by === 'ollama' && (selectedOllamaUrlIdx === null ? true : (m?.ollama?.urls ?? []).includes(selectedOllamaUrlIdx))) as model}
<option value={model.name} class="bg-gray-100 dark:bg-gray-700"
<option value={model.name} class="bg-gray-50 dark:bg-gray-700"
>{model.name +
' (' +
(model.ollama.size / 1024 ** 3).toFixed(1) +
@ -771,7 +773,7 @@
</select>
</div>
<button
class="px-2.5 bg-gray-100 hover:bg-gray-200 text-gray-800 dark:bg-gray-850 dark:hover:bg-gray-800 dark:text-gray-100 rounded-lg transition"
class="px-2.5 bg-gray-50 hover:bg-gray-200 text-gray-800 dark:bg-gray-850 dark:hover:bg-gray-800 dark:text-gray-100 rounded-lg transition"
on:click={() => {
showModelDeleteConfirm = true;
}}
@ -797,7 +799,7 @@
<div class="flex w-full">
<div class="flex-1 mr-2 flex flex-col gap-2">
<input
class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
class="w-full rounded-lg py-2 px-4 text-sm bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-none"
placeholder={$i18n.t('Enter model tag (e.g. {{modelTag}})', {
modelTag: 'my-modelfile'
})}
@ -807,7 +809,7 @@
<textarea
bind:value={createModelContent}
class="w-full rounded-lg py-2 px-4 text-sm bg-gray-100 dark:text-gray-100 dark:bg-gray-850 outline-none resize-none scrollbar-hidden"
class="w-full rounded-lg py-2 px-4 text-sm bg-gray-50 dark:text-gray-100 dark:bg-gray-850 outline-none resize-none scrollbar-hidden"
rows="6"
placeholder={`TEMPLATE """{{ .System }}\nUSER: {{ .Prompt }}\nASSISTANT: """\nPARAMETER num_ctx 4096\nPARAMETER stop "</s>"\nPARAMETER stop "USER:"\nPARAMETER stop "ASSISTANT:"`}
disabled={createModelLoading}
@ -816,7 +818,7 @@
<div class="flex self-start">
<button
class="px-2.5 py-2.5 bg-gray-100 hover:bg-gray-200 text-gray-800 dark:bg-gray-850 dark:hover:bg-gray-800 dark:text-gray-100 rounded-lg transition disabled:cursor-not-allowed"
class="px-2.5 py-2.5 bg-gray-50 hover:bg-gray-200 text-gray-800 dark:bg-gray-850 dark:hover:bg-gray-800 dark:text-gray-100 rounded-lg transition disabled:cursor-not-allowed"
on:click={() => {
createModelHandler();
}}
@ -925,7 +927,7 @@
<button
type="button"
class="w-full rounded-lg text-left py-2 px-4 bg-white dark:text-gray-300 dark:bg-gray-850"
class="w-full rounded-lg text-left py-2 px-4 bg-gray-50 dark:text-gray-300 dark:bg-gray-850"
on:click={() => {
modelUploadInputElement.click();
}}
@ -940,7 +942,7 @@
{:else}
<div class="flex-1 {modelFileUrl !== '' ? 'mr-2' : ''}">
<input
class="w-full rounded-lg text-left py-2 px-4 bg-white dark:text-gray-300 dark:bg-gray-850 outline-none {modelFileUrl !==
class="w-full rounded-lg text-left py-2 px-4 bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-none {modelFileUrl !==
''
? 'mr-2'
: ''}"
@ -955,7 +957,7 @@
{#if (modelUploadMode === 'file' && modelInputFile && modelInputFile.length > 0) || (modelUploadMode === 'url' && modelFileUrl !== '')}
<button
class="px-2.5 bg-gray-100 hover:bg-gray-200 text-gray-800 dark:bg-gray-850 dark:hover:bg-gray-800 dark:text-gray-100 rounded-lg disabled:cursor-not-allowed transition"
class="px-2.5 bg-gray-50 hover:bg-gray-200 text-gray-800 dark:bg-gray-850 dark:hover:bg-gray-800 dark:text-gray-100 rounded-lg disabled:cursor-not-allowed transition"
type="submit"
disabled={modelTransferring}
>
@ -1014,7 +1016,7 @@
<div class=" my-2.5 text-sm font-medium">{$i18n.t('Modelfile Content')}</div>
<textarea
bind:value={modelFileContent}
class="w-full rounded-lg py-2 px-4 text-sm bg-gray-100 dark:text-gray-100 dark:bg-gray-850 outline-none resize-none"
class="w-full rounded-lg py-2 px-4 text-sm bg-gray-50 dark:text-gray-100 dark:bg-gray-850 outline-none resize-none"
rows="6"
/>
</div>

View File

@ -19,6 +19,7 @@
} from '$lib/apis';
import Spinner from '$lib/components/common/Spinner.svelte';
import Switch from '$lib/components/common/Switch.svelte';
const i18n: Writable<i18nType> = getContext('i18n');
@ -213,7 +214,7 @@
<div class="flex gap-2">
<div class="flex-1">
<select
class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
class="w-full rounded-lg py-2 px-4 text-sm bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-none"
bind:value={selectedPipelinesUrlIdx}
placeholder={$i18n.t('Select a pipeline url')}
on:change={async () => {
@ -327,7 +328,7 @@
<div class="flex w-full">
<div class="flex-1 mr-2">
<input
class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
class="w-full rounded-lg py-2 px-4 text-sm bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-none"
placeholder={$i18n.t('Enter Github Raw URL')}
bind:value={pipelineDownloadUrl}
/>
@ -411,7 +412,7 @@
<div class="flex gap-2">
<div class="flex-1">
<select
class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
class="w-full rounded-lg py-2 px-4 text-sm bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-none"
bind:value={selectedPipelineIdx}
placeholder={$i18n.t('Select a pipeline')}
on:change={async () => {
@ -476,15 +477,40 @@
</div>
{#if (valves[property] ?? null) !== null}
<div class="flex mt-0.5 space-x-2">
<!-- {valves[property]} -->
<div class="flex mt-0.5 mb-1.5 space-x-2">
<div class=" flex-1">
{#if valves_spec.properties[property]?.enum ?? null}
<select
class="w-full rounded-lg py-2 px-4 text-sm bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-none"
bind:value={valves[property]}
>
{#each valves_spec.properties[property].enum as option}
<option value={option} selected={option === valves[property]}>
{option}
</option>
{/each}
</select>
{:else if (valves_spec.properties[property]?.type ?? null) === 'boolean'}
<div class="flex justify-between items-center">
<div class="text-xs text-gray-500">
{valves[property] ? 'Enabled' : 'Disabled'}
</div>
<div class=" pr-2">
<Switch bind:state={valves[property]} />
</div>
</div>
{:else}
<input
class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
class="w-full rounded-lg py-2 px-4 text-sm bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-none"
type="text"
placeholder={valves_spec.properties[property].title}
bind:value={valves[property]}
autocomplete="off"
required
/>
{/if}
</div>
</div>
{/if}

View File

@ -112,7 +112,7 @@
<div class="flex-1 mr-2">
<select
class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
class="w-full rounded-lg py-2 px-4 text-sm bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-none"
bind:value={defaultModelId}
placeholder="Select a model"
>
@ -140,7 +140,7 @@
<div class="flex w-full">
<div class="flex-1 mr-2">
<select
class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
class="w-full rounded-lg py-2 px-4 text-sm bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-none"
bind:value={modelId}
placeholder="Select a model"
>

View File

@ -101,7 +101,7 @@
<div class="flex w-full">
<div class="flex-1">
<input
class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
class="w-full rounded-lg py-2 px-4 text-sm bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-none"
type="text"
placeholder={$i18n.t('Enter Searxng Query URL')}
bind:value={webConfig.search.searxng_query_url}
@ -129,7 +129,7 @@
<div class="flex w-full">
<div class="flex-1">
<input
class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
class="w-full rounded-lg py-2 px-4 text-sm bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-none"
type="text"
placeholder={$i18n.t('Enter Google PSE Engine Id')}
bind:value={webConfig.search.google_pse_engine_id}
@ -205,7 +205,7 @@
</div>
<input
class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
class="w-full rounded-lg py-2 px-4 text-sm bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-none"
placeholder={$i18n.t('Search Result Count')}
bind:value={webConfig.search.result_count}
required
@ -218,7 +218,7 @@
</div>
<input
class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
class="w-full rounded-lg py-2 px-4 text-sm bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-none"
placeholder={$i18n.t('Concurrent Requests')}
bind:value={webConfig.search.concurrent_requests}
required
@ -267,7 +267,7 @@
<div class=" w-20 text-xs font-medium self-center">{$i18n.t('Language')}</div>
<div class=" flex-1 self-center">
<input
class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
class="w-full rounded-lg py-2 px-4 text-sm bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-none"
type="text"
placeholder={$i18n.t('Enter language codes')}
bind:value={youtubeLanguage}

View File

@ -60,19 +60,26 @@
import Navbar from '$lib/components/layout/Navbar.svelte';
import CallOverlay from './MessageInput/CallOverlay.svelte';
import { error } from '@sveltejs/kit';
import ChatControls from './ChatControls.svelte';
import EventConfirmDialog from '../common/ConfirmDialog.svelte';
const i18n: Writable<i18nType> = getContext('i18n');
export let chatIdProp = '';
let loaded = false;
const eventTarget = new EventTarget();
let showControls = false;
let stopResponseFlag = false;
let autoScroll = true;
let processing = '';
let messagesContainerElement: HTMLDivElement;
let showEventConfirmation = false;
let eventConfirmationTitle = '';
let eventConfirmationMessage = '';
let eventCallback = null;
let showModelSelector = true;
let selectedModels = [''];
@ -96,6 +103,8 @@
currentId: null
};
let params = {};
$: if (history.currentId !== null) {
let _messages = [];
@ -126,6 +135,41 @@
})();
}
const chatEventHandler = async (event, cb) => {
if (event.chat_id === $chatId) {
await tick();
console.log(event);
let message = history.messages[event.message_id];
const type = event?.data?.type ?? null;
const data = event?.data?.data ?? null;
if (type === 'status') {
if (message?.statusHistory) {
message.statusHistory.push(data);
} else {
message.statusHistory = [data];
}
} else if (type === 'citation') {
if (message?.citations) {
message.citations.push(data);
} else {
message.citations = [data];
}
} else if (type === 'confirmation') {
eventCallback = cb;
showEventConfirmation = true;
eventConfirmationTitle = data.title;
eventConfirmationMessage = data.message;
} else {
console.log('Unknown message type', data);
}
messages = messages;
}
};
onMount(async () => {
const onMessageHandler = async (event) => {
if (event.origin === window.origin) {
@ -163,6 +207,8 @@
};
window.addEventListener('message', onMessageHandler);
$socket.on('chat-events', chatEventHandler);
if (!$chatId) {
chatId.subscribe(async (value) => {
if (!value) {
@ -177,6 +223,8 @@
return () => {
window.removeEventListener('message', onMessageHandler);
$socket.off('chat-events');
};
});
@ -196,6 +244,7 @@
messages: {},
currentId: null
};
params = {};
if ($page.url.searchParams.get('models')) {
selectedModels = $page.url.searchParams.get('models')?.split(',');
@ -265,11 +314,7 @@
await settings.set(JSON.parse(localStorage.getItem('settings') ?? '{}'));
}
await settings.set({
...$settings,
system: chatContent.system ?? $settings.system,
params: chatContent.options ?? $settings.params
});
params = chatContent?.params ?? {};
autoScroll = true;
await tick();
@ -302,7 +347,7 @@
}
};
const chatCompletedHandler = async (modelId, messages) => {
const chatCompletedHandler = async (modelId, responseMessageId, messages) => {
await mermaid.run({
querySelector: '.mermaid'
});
@ -316,7 +361,9 @@
info: m.info ? m.info : undefined,
timestamp: m.timestamp
})),
chat_id: $chatId
chat_id: $chatId,
session_id: $socket?.id,
id: responseMessageId
}).catch((error) => {
toast.error(error);
messages.at(-1).error = { content: error };
@ -480,9 +527,7 @@
title: $i18n.t('New Chat'),
models: selectedModels,
system: $settings.system ?? undefined,
options: {
...($settings.params ?? {})
},
params: params,
messages: messages,
history: history,
tags: [],
@ -580,11 +625,11 @@
scrollToBottom();
const messagesBody = [
$settings.system || (responseMessage?.userContext ?? null)
params?.system || $settings.system || (responseMessage?.userContext ?? null)
? {
role: 'system',
content: `${promptTemplate(
$settings?.system ?? '',
params?.system ?? $settings?.system ?? '',
$user.name,
$settings?.userLocation
? await getAndUpdateUserLocation(localStorage.token)
@ -665,25 +710,28 @@
await tick();
const [res, controller] = await generateChatCompletion(localStorage.token, {
stream: true,
model: model.id,
messages: messagesBody,
options: {
...($settings.params ?? {}),
...(params ?? $settings.params ?? {}),
stop:
$settings?.params?.stop ?? undefined
? $settings.params.stop.map((str) =>
params?.stop ?? $settings?.params?.stop ?? undefined
? (params?.stop ?? $settings.params.stop).map((str) =>
decodeURIComponent(JSON.parse('"' + str.replace(/\"/g, '\\"') + '"'))
)
: undefined,
num_predict: $settings?.params?.max_tokens ?? undefined,
repeat_penalty: $settings?.params?.frequency_penalty ?? undefined
num_predict: params?.max_tokens ?? $settings?.params?.max_tokens ?? undefined,
repeat_penalty:
params?.frequency_penalty ?? $settings?.params?.frequency_penalty ?? undefined
},
format: $settings.requestFormat ?? undefined,
keep_alive: $settings.keepAlive ?? undefined,
tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined,
files: files.length > 0 ? files : undefined,
citations: files.length > 0 ? true : undefined,
chat_id: $chatId
session_id: $socket?.id,
chat_id: $chatId,
id: responseMessageId
});
if (res && res.ok) {
@ -704,7 +752,7 @@
controller.abort('User: Stop Response');
} else {
const messages = createMessagesList(responseMessageId);
await chatCompletedHandler(model.id, messages);
await chatCompletedHandler(model.id, responseMessageId, messages);
}
_response = responseMessage.content;
@ -811,7 +859,8 @@
chat = await updateChatById(localStorage.token, _chatId, {
messages: messages,
history: history,
models: selectedModels
models: selectedModels,
params: params
});
await chats.set(await getChatList(localStorage.token));
}
@ -912,8 +961,8 @@
const [res, controller] = await generateOpenAIChatCompletion(
localStorage.token,
{
model: model.id,
stream: true,
model: model.id,
stream_options:
model.info?.meta?.capabilities?.usage ?? false
? {
@ -921,11 +970,11 @@
}
: undefined,
messages: [
$settings.system || (responseMessage?.userContext ?? null)
params?.system || $settings.system || (responseMessage?.userContext ?? null)
? {
role: 'system',
content: `${promptTemplate(
$settings?.system ?? '',
params?.system ?? $settings?.system ?? '',
$user.name,
$settings?.userLocation
? await getAndUpdateUserLocation(localStorage.token)
@ -970,22 +1019,23 @@
: message?.raContent ?? message.content
})
})),
seed: $settings?.params?.seed ?? undefined,
seed: params?.seed ?? $settings?.params?.seed ?? undefined,
stop:
$settings?.params?.stop ?? undefined
? $settings.params.stop.map((str) =>
params?.stop ?? $settings?.params?.stop ?? undefined
? (params?.stop ?? $settings.params.stop).map((str) =>
decodeURIComponent(JSON.parse('"' + str.replace(/\"/g, '\\"') + '"'))
)
: undefined,
temperature: $settings?.params?.temperature ?? undefined,
top_p: $settings?.params?.top_p ?? undefined,
frequency_penalty: $settings?.params?.frequency_penalty ?? undefined,
max_tokens: $settings?.params?.max_tokens ?? undefined,
temperature: params?.temperature ?? $settings?.params?.temperature ?? undefined,
top_p: params?.top_p ?? $settings?.params?.top_p ?? undefined,
frequency_penalty:
params?.frequency_penalty ?? $settings?.params?.frequency_penalty ?? undefined,
max_tokens: params?.max_tokens ?? $settings?.params?.max_tokens ?? undefined,
tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined,
files: files.length > 0 ? files : undefined,
citations: files.length > 0 ? true : undefined,
chat_id: $chatId
session_id: $socket?.id,
chat_id: $chatId,
id: responseMessageId
},
`${WEBUI_BASE_URL}/api`
);
@ -1014,7 +1064,7 @@
} else {
const messages = createMessagesList(responseMessageId);
await chatCompletedHandler(model.id, messages);
await chatCompletedHandler(model.id, responseMessageId, messages);
}
_response = responseMessage.content;
@ -1086,7 +1136,8 @@
chat = await updateChatById(localStorage.token, _chatId, {
models: selectedModels,
messages: messages,
history: history
history: history,
params: params
});
await chats.set(await getChatList(localStorage.token));
}
@ -1353,6 +1404,18 @@
<audio id="audioElement" src="" style="display: none;" />
<EventConfirmDialog
bind:show={showEventConfirmation}
title={eventConfirmationTitle}
message={eventConfirmationMessage}
on:confirm={(e) => {
eventCallback(true);
}}
on:cancel={() => {
eventCallback(false);
}}
/>
{#if $showCallOverlay}
<CallOverlay
{submitPrompt}
@ -1387,6 +1450,7 @@
{title}
bind:selectedModels
bind:showModelSelector
bind:showControls
shareEnabled={messages.length > 0}
{chat}
{initNewChat}
@ -1396,7 +1460,7 @@
<div
class="absolute top-[4.25rem] w-full {$showSidebar
? 'md:max-w-[calc(100%-260px)]'
: ''} z-20"
: ''} {showControls ? 'lg:pr-[24rem]' : ''} z-20"
>
<div class=" flex flex-col gap-1 w-full">
{#each $banners.filter( (b) => (b.dismissible ? !JSON.parse(localStorage.getItem('dismissedBannerIds') ?? '[]').includes(b.id) : true) ) as banner}
@ -1423,7 +1487,9 @@
<div class="flex flex-col flex-auto z-10">
<div
class=" pb-2.5 flex flex-col justify-between w-full flex-auto overflow-auto h-0 max-w-full z-10"
class=" pb-2.5 flex flex-col justify-between w-full flex-auto overflow-auto h-0 max-w-full z-10 scrollbar-hidden {showControls
? 'lg:pr-[24rem]'
: ''}"
id="messages-container"
bind:this={messagesContainerElement}
on:scroll={(e) => {
@ -1448,6 +1514,8 @@
/>
</div>
</div>
<div class={showControls ? 'lg:pr-[24rem]' : ''}>
<MessageInput
bind:files
bind:prompt
@ -1470,4 +1538,7 @@
/>
</div>
</div>
<ChatControls bind:show={showControls} bind:params />
</div>
{/if}

View File

@ -0,0 +1,63 @@
<script lang="ts">
import { slide } from 'svelte/transition';
import Modal from '../common/Modal.svelte';
import Controls from './Controls/Controls.svelte';
import { onMount } from 'svelte';
export let show = false;
export let chatId = null;
export let params = {};
let largeScreen = false;
onMount(() => {
// listen to resize 1024px
const mediaQuery = window.matchMedia('(min-width: 1024px)');
const handleMediaQuery = (e) => {
if (e.matches) {
largeScreen = true;
} else {
largeScreen = false;
}
};
mediaQuery.addEventListener('change', handleMediaQuery);
handleMediaQuery(mediaQuery);
return () => {
mediaQuery.removeEventListener('change', handleMediaQuery);
};
});
</script>
{#if largeScreen}
{#if show}
<div class=" absolute bottom-0 right-0 z-20 h-full pointer-events-none">
<div class="pr-4 pt-14 pb-8 w-[24rem] h-full" in:slide={{ duration: 200, axis: 'x' }}>
<div
class="w-full h-full px-5 py-4 bg-white dark:shadow-lg dark:bg-gray-850 border border-gray-50 dark:border-gray-800 rounded-xl z-50 pointer-events-auto overflow-y-auto scrollbar-hidden"
>
<Controls
on:close={() => {
show = false;
}}
bind:params
/>
</div>
</div>
</div>
{/if}
{:else}
<Modal bind:show>
<div class=" px-6 py-4 h-full">
<Controls
on:close={() => {
show = false;
}}
bind:params
/>
</div>
</Modal>
{/if}

View File

@ -0,0 +1,49 @@
<script>
import { createEventDispatcher, getContext } from 'svelte';
const dispatch = createEventDispatcher();
const i18n = getContext('i18n');
import XMark from '$lib/components/icons/XMark.svelte';
import AdvancedParams from '../Settings/Advanced/AdvancedParams.svelte';
export let params = {};
</script>
<div class=" dark:text-white">
<div class=" flex justify-between dark:text-gray-100 mb-2">
<div class=" text-lg font-medium self-center font-primary">{$i18n.t('Chat Controls')}</div>
<button
class="self-center"
on:click={() => {
dispatch('close');
}}
>
<XMark className="size-4" />
</button>
</div>
<div class=" dark:text-gray-200 text-sm font-primary">
<div>
<div class="mb-1.5 font-medium">System Prompt</div>
<div>
<textarea
bind:value={params.system}
class="w-full rounded-lg px-4 py-3 text-sm dark:text-gray-300 dark:bg-gray-850 border border-gray-100 dark:border-gray-800 outline-none resize-none"
rows="3"
placeholder="Enter system prompt"
/>
</div>
</div>
<hr class="my-2 border-gray-100 dark:border-gray-800" />
<div>
<div class="mb-1.5 font-medium">Advanced Params</div>
<div>
<AdvancedParams bind:params />
</div>
</div>
</div>
</div>

View File

@ -316,7 +316,7 @@
</div>
{/if}
<div class="w-full">
<div class="w-full font-primary">
<div class=" -mb-0.5 mx-auto inset-x-0 bg-transparent flex justify-center">
<div class="flex flex-col max-w-6xl px-2.5 md:px-6 w-full">
<div class="relative">

View File

@ -662,10 +662,11 @@
: rmsLevel * 100 > 1
? 'size-14'
: 'size-12'} transition-all rounded-full {(model?.info?.meta
?.profile_image_url ?? '/favicon.png') !== '/favicon.png'
?.profile_image_url ?? '/static/favicon.png') !== '/static/favicon.png'
? ' bg-cover bg-center bg-no-repeat'
: 'bg-black dark:bg-white'} bg-black dark:bg-white"
style={(model?.info?.meta?.profile_image_url ?? '/favicon.png') !== '/favicon.png'
style={(model?.info?.meta?.profile_image_url ?? '/static/favicon.png') !==
'/static/favicon.png'
? `background-image: url('${model?.info?.meta?.profile_image_url}');`
: ''}
/>
@ -743,10 +744,11 @@
: rmsLevel * 100 > 1
? 'size-[11.5rem]'
: 'size-44'} transition-all rounded-full {(model?.info?.meta
?.profile_image_url ?? '/favicon.png') !== '/favicon.png'
?.profile_image_url ?? '/static/favicon.png') !== '/static/favicon.png'
? ' bg-cover bg-center bg-no-repeat'
: 'bg-black dark:bg-white'} "
style={(model?.info?.meta?.profile_image_url ?? '/favicon.png') !== '/favicon.png'
style={(model?.info?.meta?.profile_image_url ?? '/static/favicon.png') !==
'/static/favicon.png'
? `background-image: url('${model?.info?.meta?.profile_image_url}');`
: ''}
/>

View File

@ -162,7 +162,7 @@
.substring(1)
.startsWith('https://youtu.be'))}
<button
class="px-3 py-1.5 rounded-xl w-full text-left bg-gray-100 selected-command-option-button"
class="px-3 py-1.5 rounded-xl w-full text-left bg-gray-50 dark:bg-gray-850 dark:text-gray-100 selected-command-option-button"
type="button"
on:click={() => {
const url = prompt.split(' ')?.at(0)?.substring(1);
@ -177,7 +177,7 @@
}
}}
>
<div class=" font-medium text-black line-clamp-1">
<div class=" font-medium text-black dark:text-gray-100 line-clamp-1">
{prompt.split(' ')?.at(0)?.substring(1)}
</div>
@ -185,7 +185,7 @@
</button>
{:else if prompt.split(' ')?.at(0)?.substring(1).startsWith('http')}
<button
class="px-3 py-1.5 rounded-xl w-full text-left bg-gray-100 selected-command-option-button"
class="px-3 py-1.5 rounded-xl w-full text-left bg-gray-50 dark:bg-gray-850 dark:text-gray-100 selected-command-option-button"
type="button"
on:click={() => {
const url = prompt.split(' ')?.at(0)?.substring(1);
@ -200,7 +200,7 @@
}
}}
>
<div class=" font-medium text-black line-clamp-1">
<div class=" font-medium text-black dark:text-gray-100 line-clamp-1">
{prompt.split(' ')?.at(0)?.substring(1)}
</div>

View File

@ -250,7 +250,8 @@ __builtins__.input = input`);
stderr ||
result) &&
'border-bottom-left-radius: 0px; border-bottom-right-radius: 0px;'}"><code
class="language-{lang} rounded-t-none whitespace-pre">{@html highlightedCode || code}</code
class="language-{lang} rounded-t-none whitespace-pre"
>{#if highlightedCode}{@html highlightedCode}{:else}{code}{/if}</code
></pre>
<div

View File

@ -100,25 +100,27 @@
class="flex snap-x snap-mandatory overflow-x-auto scrollbar-hidden"
id="responses-container-{parentMessage.id}"
>
{#key currentMessageId}
{#each Object.keys(groupedMessages) as model}
{#if groupedMessagesIdx[model] !== undefined && groupedMessages[model].messages.length > 0}
<!-- svelte-ignore a11y-no-static-element-interactions -->
<!-- svelte-ignore a11y-click-events-have-key-events -->
{@const message = groupedMessages[model].messages[groupedMessagesIdx[model]]}
<div
class=" snap-center min-w-80 w-full max-w-full m-1 border {history.messages[
currentMessageId
].model === model
? 'border-gray-100 dark:border-gray-850 border-[1.5px]'
? 'border-gray-100 dark:border-gray-800 border-[1.5px]'
: 'border-gray-50 dark:border-gray-850 '} transition p-5 rounded-3xl"
on:click={() => {
currentMessageId = groupedMessages[model].messages[groupedMessagesIdx[model]].id;
let messageId = groupedMessages[model].messages[groupedMessagesIdx[model]].id;
if (currentMessageId != message.id) {
currentMessageId = message.id;
let messageId = message.id;
console.log(messageId);
let messageChildrenIds = history.messages[messageId].childrenIds;
//
let messageChildrenIds = history.messages[messageId].childrenIds;
while (messageChildrenIds.length !== 0) {
messageId = messageChildrenIds.at(-1);
messageChildrenIds = history.messages[messageId].childrenIds;
@ -126,6 +128,7 @@
history.currentId = messageId;
dispatch('change');
}
}}
>
<ResponseMessage
@ -159,5 +162,6 @@
</div>
{/if}
{/each}
{/key}
</div>
</div>

View File

@ -1,3 +1,3 @@
<div class=" self-center font-bold mb-0.5 line-clamp-1 contents">
<div class=" self-center font-semibold mb-0.5 line-clamp-1 contents">
<slot />
</div>

View File

@ -65,7 +65,7 @@
</div>
<div
class=" mt-2 mb-4 text-3xl text-gray-800 dark:text-gray-100 font-semibold text-left flex items-center gap-4"
class=" mt-2 mb-4 text-3xl text-gray-800 dark:text-gray-100 font-semibold text-left flex items-center gap-4 font-primary"
>
<div>
<div class=" capitalize line-clamp-1" in:fade={{ duration: 200 }}>
@ -102,7 +102,7 @@
</div>
{/if}
{:else}
<div class=" font-medium text-gray-400 dark:text-gray-500 line-clamp-1">
<div class=" font-medium text-gray-400 dark:text-gray-500 line-clamp-1 font-p">
{$i18n.t('How can I help you today?')}
</div>
{/if}
@ -110,7 +110,7 @@
</div>
</div>
<div class=" w-full" in:fade={{ duration: 200, delay: 300 }}>
<div class=" w-full font-primary" in:fade={{ duration: 200, delay: 300 }}>
<Suggestions
suggestionPrompts={models[selectedModelIdx]?.info?.meta?.suggestion_prompts ??
$config.default_prompt_suggestions}

View File

@ -152,7 +152,10 @@
}
tooltipInstance = tippy(`#info-${message.id}`, {
content: `<span class="text-xs" id="tooltip-${message.id}">${tooltipContent}</span>`,
allowHTML: true
allowHTML: true,
theme: 'dark',
arrow: false,
offset: [0, 4]
});
}
};

View File

@ -2,15 +2,13 @@
import ChevronDown from '$lib/components/icons/ChevronDown.svelte';
import ChevronUp from '$lib/components/icons/ChevronUp.svelte';
import MagnifyingGlass from '$lib/components/icons/MagnifyingGlass.svelte';
import { Collapsible } from 'bits-ui';
import { slide } from 'svelte/transition';
import Collapsible from '$lib/components/common/Collapsible.svelte';
export let status = { urls: [], query: '' };
let state = false;
</script>
<Collapsible.Root class="w-full space-y-1" bind:open={state}>
<Collapsible.Trigger>
<Collapsible bind:open={state} className="w-full space-y-1">
<div
class="flex items-center gap-2 text-gray-500 hover:text-gray-700 dark:hover:text-gray-300 transition"
>
@ -22,12 +20,7 @@
<ChevronDown strokeWidth="3.5" className="size-3.5 " />
{/if}
</div>
</Collapsible.Trigger>
<Collapsible.Content
class=" text-sm border border-gray-300/30 dark:border-gray-700/50 rounded-xl"
transition={slide}
>
<div class="text-sm border border-gray-300/30 dark:border-gray-700/50 rounded-xl" slot="content">
{#if status?.query}
<a
href="https://www.google.com/search?q={status.query}"
@ -93,5 +86,5 @@
</div>
</a>
{/each}
</Collapsible.Content>
</Collapsible.Root>
</div>
</Collapsible>

View File

@ -34,7 +34,7 @@
}
</script>
<div class="flex flex-col w-full items-center md:items-start">
<div class="flex flex-col w-full items-start">
{#each selectedModels as selectedModel, selectedModelIdx}
<div class="flex w-full max-w-fit">
<div class="overflow-hidden w-full">
@ -103,7 +103,7 @@
</div>
{#if showSetDefault && !$mobile}
<div class="text-left mt-0.5 ml-1 text-[0.7rem] text-gray-500">
<div class="text-left mt-0.5 ml-1 text-[0.7rem] text-gray-500 font-primary">
<button on:click={saveDefaultModel}> {$i18n.t('Set as default')}</button>
</div>
{/if}

View File

@ -206,7 +206,7 @@
}}
closeFocus={false}
>
<DropdownMenu.Trigger class="relative w-full" aria-label={placeholder}>
<DropdownMenu.Trigger class="relative w-full font-primary" aria-label={placeholder}>
<div
class="flex w-full text-left px-0.5 outline-none bg-transparent truncate text-lg font-semibold placeholder-gray-400 focus:outline-none"
>
@ -222,7 +222,7 @@
<DropdownMenu.Content
class=" z-40 {$mobile
? `w-full`
: `${className}`} max-w-[calc(100vw-1rem)] justify-start rounded-xl bg-white dark:bg-gray-850 dark:text-white shadow-lg border border-gray-300/30 dark:border-gray-850/50 outline-none "
: `${className}`} max-w-[calc(100vw-1rem)] justify-start rounded-xl bg-white dark:bg-gray-850 dark:text-white shadow-lg border border-gray-300/30 dark:border-gray-700/40 outline-none"
transition={flyAndScale}
side={$mobile ? 'bottom' : 'bottom-start'}
sideOffset={4}
@ -260,7 +260,7 @@
<div class="flex gap-0.5 self-start h-full mb-0.5 -translate-x-1">
{#each item.model?.info?.meta.tags as tag}
<div
class=" text-xs font-black px-1 rounded uppercase line-clamp-1 bg-gray-500/20 text-gray-700 dark:text-gray-200"
class=" text-xs font-bold px-1 rounded uppercase line-clamp-1 bg-gray-500/20 text-gray-700 dark:text-gray-200"
>
{tag.name}
</div>
@ -299,7 +299,7 @@
<div class="flex gap-0.5 self-center items-center h-full translate-y-[0.5px]">
{#each item.model?.info?.meta.tags as tag}
<div
class=" text-xs font-black px-1 rounded uppercase line-clamp-1 bg-gray-500/20 text-gray-700 dark:text-gray-200"
class=" text-xs font-bold px-1 rounded uppercase line-clamp-1 bg-gray-500/20 text-gray-700 dark:text-gray-200"
>
{tag.name}
</div>

View File

@ -44,7 +44,7 @@
<div class=" self-center text-xs font-medium">{$i18n.t('Seed')}</div>
<button
class="p-1 px-3 text-xs flex rounded transition"
class="p-1 px-3 text-xs flex rounded transition flex-shrink-0 outline-none"
type="button"
on:click={() => {
params.seed = (params?.seed ?? null) === null ? 0 : null;
@ -79,7 +79,7 @@
<div class=" self-center text-xs font-medium">{$i18n.t('Stop Sequence')}</div>
<button
class="p-1 px-3 text-xs flex rounded transition"
class="p-1 px-3 text-xs flex rounded transition flex-shrink-0 outline-none"
type="button"
on:click={() => {
params.stop = (params?.stop ?? null) === null ? '' : null;
@ -113,7 +113,7 @@
<div class=" self-center text-xs font-medium">{$i18n.t('Temperature')}</div>
<button
class="p-1 px-3 text-xs flex rounded transition"
class="p-1 px-3 text-xs flex rounded transition flex-shrink-0 outline-none"
type="button"
on:click={() => {
params.temperature = (params?.temperature ?? null) === null ? 0.8 : null;
@ -159,7 +159,7 @@
<div class=" self-center text-xs font-medium">{$i18n.t('Mirostat')}</div>
<button
class="p-1 px-3 text-xs flex rounded transition"
class="p-1 px-3 text-xs flex rounded transition flex-shrink-0 outline-none"
type="button"
on:click={() => {
params.mirostat = (params?.mirostat ?? null) === null ? 0 : null;
@ -205,7 +205,7 @@
<div class=" self-center text-xs font-medium">{$i18n.t('Mirostat Eta')}</div>
<button
class="p-1 px-3 text-xs flex rounded transition"
class="p-1 px-3 text-xs flex rounded transition flex-shrink-0 outline-none"
type="button"
on:click={() => {
params.mirostat_eta = (params?.mirostat_eta ?? null) === null ? 0.1 : null;
@ -251,7 +251,7 @@
<div class=" self-center text-xs font-medium">{$i18n.t('Mirostat Tau')}</div>
<button
class="p-1 px-3 text-xs flex rounded transition"
class="p-1 px-3 text-xs flex rounded transition flex-shrink-0 outline-none"
type="button"
on:click={() => {
params.mirostat_tau = (params?.mirostat_tau ?? null) === null ? 5.0 : null;
@ -297,7 +297,7 @@
<div class=" self-center text-xs font-medium">{$i18n.t('Top K')}</div>
<button
class="p-1 px-3 text-xs flex rounded transition"
class="p-1 px-3 text-xs flex rounded transition flex-shrink-0 outline-none"
type="button"
on:click={() => {
params.top_k = (params?.top_k ?? null) === null ? 40 : null;
@ -343,7 +343,7 @@
<div class=" self-center text-xs font-medium">{$i18n.t('Top P')}</div>
<button
class="p-1 px-3 text-xs flex rounded transition"
class="p-1 px-3 text-xs flex rounded transition flex-shrink-0 outline-none"
type="button"
on:click={() => {
params.top_p = (params?.top_p ?? null) === null ? 0.9 : null;
@ -389,7 +389,7 @@
<div class=" self-center text-xs font-medium">{$i18n.t('Frequency Penalty')}</div>
<button
class="p-1 px-3 text-xs flex rounded transition"
class="p-1 px-3 text-xs flex rounded transition flex-shrink-0 outline-none"
type="button"
on:click={() => {
params.frequency_penalty = (params?.frequency_penalty ?? null) === null ? 1.1 : null;
@ -435,7 +435,7 @@
<div class=" self-center text-xs font-medium">{$i18n.t('Repeat Last N')}</div>
<button
class="p-1 px-3 text-xs flex rounded transition"
class="p-1 px-3 text-xs flex rounded transition flex-shrink-0 outline-none"
type="button"
on:click={() => {
params.repeat_last_n = (params?.repeat_last_n ?? null) === null ? 64 : null;
@ -481,7 +481,7 @@
<div class=" self-center text-xs font-medium">{$i18n.t('Tfs Z')}</div>
<button
class="p-1 px-3 text-xs flex rounded transition"
class="p-1 px-3 text-xs flex rounded transition flex-shrink-0 outline-none"
type="button"
on:click={() => {
params.tfs_z = (params?.tfs_z ?? null) === null ? 1 : null;
@ -527,7 +527,7 @@
<div class=" self-center text-xs font-medium">{$i18n.t('Context Length')}</div>
<button
class="p-1 px-3 text-xs flex rounded transition"
class="p-1 px-3 text-xs flex rounded transition flex-shrink-0 outline-none"
type="button"
on:click={() => {
params.num_ctx = (params?.num_ctx ?? null) === null ? 2048 : null;
@ -572,7 +572,7 @@
<div class=" self-center text-xs font-medium">{$i18n.t('Batch Size (num_batch)')}</div>
<button
class="p-1 px-3 text-xs flex rounded transition"
class="p-1 px-3 text-xs flex rounded transition flex-shrink-0 outline-none"
type="button"
on:click={() => {
params.num_batch = (params?.num_batch ?? null) === null ? 512 : null;
@ -619,7 +619,7 @@
</div>
<button
class="p-1 px-3 text-xs flex rounded transition"
class="p-1 px-3 text-xs flex rounded transition flex-shrink-0 outline-none"
type="button"
on:click={() => {
params.num_keep = (params?.num_keep ?? null) === null ? 24 : null;
@ -664,7 +664,7 @@
<div class=" self-center text-xs font-medium">{$i18n.t('Max Tokens (num_predict)')}</div>
<button
class="p-1 px-3 text-xs flex rounded transition"
class="p-1 px-3 text-xs flex rounded transition flex-shrink-0 outline-none"
type="button"
on:click={() => {
params.max_tokens = (params?.max_tokens ?? null) === null ? 128 : null;
@ -711,7 +711,7 @@
<div class=" self-center text-xs font-medium">{$i18n.t('use_mmap (Ollama)')}</div>
<button
class="p-1 px-3 text-xs flex rounded transition"
class="p-1 px-3 text-xs flex rounded transition flex-shrink-0 outline-none"
type="button"
on:click={() => {
params.use_mmap = (params?.use_mmap ?? null) === null ? true : null;
@ -731,7 +731,7 @@
<div class=" self-center text-xs font-medium">{$i18n.t('use_mlock (Ollama)')}</div>
<button
class="p-1 px-3 text-xs flex rounded transition"
class="p-1 px-3 text-xs flex rounded transition flex-shrink-0 outline-none"
type="button"
on:click={() => {
params.use_mlock = (params?.use_mlock ?? null) === null ? true : null;
@ -751,7 +751,7 @@
<div class=" self-center text-xs font-medium">{$i18n.t('num_thread (Ollama)')}</div>
<button
class="p-1 px-3 text-xs flex rounded transition"
class="p-1 px-3 text-xs flex rounded transition flex-shrink-0 outline-none"
type="button"
on:click={() => {
params.num_thread = (params?.num_thread ?? null) === null ? 2 : null;
@ -797,7 +797,7 @@
<div class=" self-center text-xs font-medium">{$i18n.t('Template')}</div>
<button
class="p-1 px-3 text-xs flex rounded transition"
class="p-1 px-3 text-xs flex rounded transition flex-shrink-0 outline-none"
type="button"
on:click={() => {
params.template = (params?.template ?? null) === null ? '' : null;

View File

@ -95,6 +95,8 @@
}
if (themeToApply === 'dark' && !_theme.includes('oled')) {
document.documentElement.style.setProperty('--color-gray-800', '#333');
document.documentElement.style.setProperty('--color-gray-850', '#262626');
document.documentElement.style.setProperty('--color-gray-900', '#171717');
document.documentElement.style.setProperty('--color-gray-950', '#0d0d0d');
}
@ -118,6 +120,8 @@
theme.set(_theme);
localStorage.setItem('theme', _theme);
if (_theme.includes('oled')) {
document.documentElement.style.setProperty('--color-gray-800', '#101010');
document.documentElement.style.setProperty('--color-gray-850', '#050505');
document.documentElement.style.setProperty('--color-gray-900', '#000000');
document.documentElement.style.setProperty('--color-gray-950', '#000000');
document.documentElement.classList.add('dark');

View File

@ -18,6 +18,7 @@
import ManageModal from './Personalization/ManageModal.svelte';
import Tooltip from '$lib/components/common/Tooltip.svelte';
import Spinner from '$lib/components/common/Spinner.svelte';
import Switch from '$lib/components/common/Switch.svelte';
const dispatch = createEventDispatcher();
@ -185,7 +186,10 @@
class="p-1 px-3 text-xs flex rounded transition"
type="button"
on:click={() => {
valves[property] = (valves[property] ?? null) === null ? '' : null;
valves[property] =
(valves[property] ?? null) === null
? valvesSpec.properties[property]?.default ?? ''
: null;
}}
>
{#if (valves[property] ?? null) === null}
@ -203,8 +207,31 @@
</div>
{#if (valves[property] ?? null) !== null}
<!-- {valves[property]} -->
<div class="flex mt-0.5 mb-1.5 space-x-2">
<div class=" flex-1">
{#if valvesSpec.properties[property]?.enum ?? null}
<select
class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
bind:value={valves[property]}
>
{#each valvesSpec.properties[property].enum as option}
<option value={option} selected={option === valves[property]}>
{option}
</option>
{/each}
</select>
{:else if (valvesSpec.properties[property]?.type ?? null) === 'boolean'}
<div class="flex justify-between items-center">
<div class="text-xs text-gray-500">
{valves[property] ? 'Enabled' : 'Disabled'}
</div>
<div class=" pr-2">
<Switch bind:state={valves[property]} />
</div>
</div>
{:else}
<input
class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
type="text"
@ -213,6 +240,7 @@
autocomplete="off"
required
/>
{/if}
</div>
</div>
{/if}

View File

@ -1,9 +1,10 @@
<script lang="ts">
import { getContext } from 'svelte';
import { getContext, tick } from 'svelte';
import { toast } from 'svelte-sonner';
import { models, settings, user } from '$lib/stores';
import { updateUserSettings } from '$lib/apis/users';
import { getModels as _getModels } from '$lib/apis';
import { goto } from '$app/navigation';
import Modal from '../common/Modal.svelte';
import Account from './Settings/Account.svelte';
@ -14,8 +15,6 @@
import Chats from './Settings/Chats.svelte';
import User from '../icons/User.svelte';
import Personalization from './Settings/Personalization.svelte';
import { updateUserSettings } from '$lib/apis/users';
import { goto } from '$app/navigation';
import Valves from './Settings/Valves.svelte';
const i18n = getContext('i18n');
@ -34,6 +33,37 @@
};
let selectedTab = 'general';
// Function to handle sideways scrolling
const scrollHandler = (event) => {
const settingsTabsContainer = document.getElementById('settings-tabs-container');
if (settingsTabsContainer) {
event.preventDefault(); // Prevent default vertical scrolling
settingsTabsContainer.scrollLeft += event.deltaY; // Scroll sideways
}
};
const addScrollListener = async () => {
await tick();
const settingsTabsContainer = document.getElementById('settings-tabs-container');
if (settingsTabsContainer) {
settingsTabsContainer.addEventListener('wheel', scrollHandler);
}
};
const removeScrollListener = async () => {
await tick();
const settingsTabsContainer = document.getElementById('settings-tabs-container');
if (settingsTabsContainer) {
settingsTabsContainer.removeEventListener('wheel', scrollHandler);
}
};
$: if (show) {
addScrollListener();
} else {
removeScrollListener();
}
</script>
<Modal bind:show>
@ -61,6 +91,7 @@
<div class="flex flex-col md:flex-row w-full p-4 md:space-x-4">
<div
id="settings-tabs-container"
class="tabs flex flex-row overflow-x-auto space-x-1 md:space-x-0 md:space-y-1 md:flex-col flex-1 md:flex-none md:w-40 dark:text-gray-200 text-xs text-left mb-3 md:mb-0"
>
<button

View File

@ -8,7 +8,7 @@
getTagsById,
updateChatById
} from '$lib/apis/chats';
import { tags as _tags, chats } from '$lib/stores';
import { tags as _tags, chats, pinnedChats } from '$lib/stores';
import { createEventDispatcher, onMount } from 'svelte';
const dispatch = createEventDispatcher();
@ -19,9 +19,11 @@
let tags = [];
const getTags = async () => {
return await getTagsById(localStorage.token, chatId).catch(async (error) => {
return (
await getTagsById(localStorage.token, chatId).catch(async (error) => {
return [];
});
})
).filter((tag) => tag.name !== 'pinned');
};
const addTag = async (tagName) => {
@ -33,6 +35,7 @@
});
_tags.set(await getAllChatTags(localStorage.token));
await pinnedChats.set(await getChatListByTagName(localStorage.token, 'pinned'));
};
const deleteTag = async (tagName) => {
@ -44,19 +47,23 @@
});
console.log($_tags);
await _tags.set(await getAllChatTags(localStorage.token));
console.log($_tags);
if ($_tags.map((t) => t.name).includes(tagName)) {
if (tagName === 'pinned') {
await pinnedChats.set(await getChatListByTagName(localStorage.token, 'pinned'));
} else {
await chats.set(await getChatListByTagName(localStorage.token, tagName));
}
if ($chats.find((chat) => chat.id === chatId)) {
dispatch('close');
}
} else {
await chats.set(await getChatList(localStorage.token));
await pinnedChats.set(await getChatListByTagName(localStorage.token, 'pinned'));
}
};

View File

@ -45,7 +45,7 @@
<div class=" flex flex-col md:flex-row md:items-center flex-1 text-sm w-fit gap-1.5">
<div class="flex justify-between self-start">
<div
class=" text-xs font-black {classNames[banner.type] ??
class=" text-xs font-bold {classNames[banner.type] ??
classNames['info']} w-fit px-2 rounded uppercase line-clamp-1 mr-0.5"
>
{banner.type}
@ -54,7 +54,7 @@
{#if banner.url}
<div class="flex md:hidden group w-fit md:items-center">
<a
class="text-gray-700 dark:text-white text-xs font-bold underline"
class="text-gray-700 dark:text-white text-xs font-semibold underline"
href="/assets/files/whitepaper.pdf"
target="_blank">Learn More</a
>
@ -88,7 +88,7 @@
{#if banner.url}
<div class="hidden md:flex group w-fit md:items-center">
<a
class="text-gray-700 dark:text-white text-xs font-bold underline"
class="text-gray-700 dark:text-white text-xs font-semibold underline"
href="/"
target="_blank">Learn More</a
>
@ -116,7 +116,8 @@
on:click={() => {
dismiss(banner.id);
}}
class=" -mt-[3px] ml-1.5 mr-1 text-gray-400 dark:hover:text-white h-1">&times;</button
class=" -mt-1 -mb-2 -translate-y-[1px] ml-1.5 mr-1 text-gray-400 dark:hover:text-white"
>&times;</button
>
{/if}
</div>

View File

@ -0,0 +1,18 @@
<script lang="ts">
import { slide } from 'svelte/transition';
import { quintOut } from 'svelte/easing';
export let open = false;
export let className = '';
</script>
<div class={className}>
<button on:click={() => (open = !open)}>
<slot />
</button>
{#if open}
<div transition:slide={{ duration: 300, easing: quintOut, axis: 'y' }}>
<slot name="content" />
</div>
{/if}
</div>

View File

@ -7,8 +7,8 @@
const dispatch = createEventDispatcher();
export let title = $i18n.t('Confirm your action');
export let message = $i18n.t('This action cannot be undone. Do you wish to continue?');
export let title = '';
export let message = '';
export let cancelLabel = $i18n.t('Cancel');
export let confirmLabel = $i18n.t('Confirm');
@ -58,11 +58,21 @@
}}
>
<div class="px-[1.75rem] py-6">
<div class=" text-lg font-semibold dark:text-gray-200 mb-2.5">{title}</div>
<div class=" text-lg font-semibold dark:text-gray-200 mb-2.5">
{#if title !== ''}
{title}
{:else}
{$i18n.t('Confirm your action')}
{/if}
</div>
<slot>
<div class=" text-sm text-gray-500">
{#if message !== ''}
{message}
{:else}
{$i18n.t('This action cannot be undone. Do you wish to continue?')}
{/if}
</div>
</slot>
@ -71,6 +81,7 @@
class="bg-gray-100 hover:bg-gray-200 text-gray-800 dark:bg-gray-850 dark:hover:bg-gray-800 dark:text-white font-medium w-full py-2.5 rounded-lg transition"
on:click={() => {
show = false;
dispatch('cancel');
}}
type="button"
>

View File

@ -63,7 +63,7 @@
<div
class=" m-auto rounded-2xl max-w-full {sizeToWidth(
size
)} mx-2 bg-gray-50 dark:bg-gray-900 shadow-3xl"
)} mx-2 bg-gray-50 dark:bg-gray-900 shadow-3xl max-h-[100dvh] overflow-y-auto scrollbar-hidden"
in:flyAndScale
on:mousedown={(e) => {
e.stopPropagation();

View File

@ -48,7 +48,7 @@
<ChevronDown className="absolute end-2 top-1/2 -translate-y-[45%] size-3.5" strokeWidth="2.5" />
</Select.Trigger>
<Select.Content
class="w-full rounded-lg bg-white dark:bg-gray-900 dark:text-white shadow-lg border border-gray-300/30 dark:border-gray-850/50 outline-none"
class="w-full rounded-lg bg-white dark:bg-gray-900 dark:text-white shadow-lg border border-gray-300/30 dark:border-gray-700/40 outline-none"
transition={flyAndScale}
sideOffset={4}
>

Some files were not shown because too many files have changed in this diff Show More